Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c79bf8e
refactor: improve type safety across multiple adapters and unify conn…
cofin Jun 28, 2026
ca04356
fix: restore psqlpy runtime aliases
cofin Jun 28, 2026
56de0d9
chore: bump deps
cofin Jun 28, 2026
5e6c75d
perf(splitter): optimize tokenization hot paths
cofin Jun 28, 2026
a43fbda
perf(parameters): trim conversion hot paths
cofin Jun 28, 2026
c558517
perf(core): streamline compiler detection paths
cofin Jun 28, 2026
578ce6d
perf(core): tighten filter and hashing helpers
cofin Jun 28, 2026
af9d824
perf(cache): reduce lru lock overhead
cofin Jun 28, 2026
a50e017
style: format c2 hot-path changes
cofin Jun 28, 2026
fa94d7e
perf(core): apply c3 mypyc idioms
cofin Jun 28, 2026
0e1ae44
perf(core): replace c3 result transformer idioms
cofin Jun 28, 2026
b6fd92b
perf(parameters): tighten c3 transformer idioms
cofin Jun 28, 2026
958757e
fix(core): use runtime final annotations
cofin Jun 28, 2026
f8b594c
refactor(core): consolidate statement where helpers
cofin Jun 28, 2026
6462003
refactor(core): share statement clone setup
cofin Jun 28, 2026
c9fe088
refactor(core): consolidate filter helpers
cofin Jun 28, 2026
3acea33
refactor(core): consolidate internal helpers
cofin Jun 28, 2026
c32b985
chore(core): mark external helper APIs
cofin Jun 28, 2026
d9896a1
refactor(splitter): consolidate dialect configs
cofin Jun 28, 2026
58a5c26
test(core): remove internal chapter labels
cofin Jun 28, 2026
47ed98d
chore: enhance type definitions and error handling across adapters
cofin Jun 29, 2026
491a3b2
chore(splitter): update special terminators type annotation for MySQL…
cofin Jun 29, 2026
e23600f
fix(core): avoid mypyc stack result crash
cofin Jun 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
"ParameterMapping": "sqlspec.core.parameters.ParameterMapping",
"ParameterSequence": "sqlspec.core.parameters.ParameterSequence",
"ParameterPayload": "sqlspec.core.parameters.ParameterPayload",
"DialectType": "sqlglot.dialects.dialect.DialectType",
"Union": "typing.Union",
"Callable": "typing.Callable",
"Any": "typing.Any",
Expand Down
4 changes: 4 additions & 0 deletions docs/reference/core/filters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ Date Filters

.. autoclass:: BeforeAfterFilter
:members:
:inherited-members:
:show-inheritance:

.. autoclass:: OnBeforeAfterFilter
:members:
:inherited-members:
:show-inheritance:

Collection Filters
Expand Down Expand Up @@ -84,10 +86,12 @@ Search

.. autoclass:: SearchFilter
:members:
:inherited-members:
:show-inheritance:

.. autoclass:: NotInSearchFilter
:members:
:inherited-members:
:show-inheritance:

Type Aliases
Expand Down
6 changes: 3 additions & 3 deletions sqlspec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
from typing import TYPE_CHECKING, Any

from sqlspec import adapters, base, builder, core, driver, exceptions, extensions, loader, migrations, typing, utils

if TYPE_CHECKING:
from sqlspec import dialects
from sqlspec.__metadata__ import __version__
from sqlspec.base import SQLSpec
from sqlspec.builder import (
Expand Down Expand Up @@ -86,6 +83,9 @@
from sqlspec.typing import ConnectionT, PoolT, SchemaT, StatementParameters, SupportedSchemaModel
from sqlspec.utils.logging import suppress_erroneous_sqlglot_log_messages

if TYPE_CHECKING:
from sqlspec import dialects

__all__ = (
"SQL",
"ArrowResult",
Expand Down
19 changes: 11 additions & 8 deletions sqlspec/adapters/adbc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_statement_config,
is_postgres_dialect,
resolve_dialect_from_config,
resolve_dialect_name,
resolve_driver_connect_func,
resolve_postgres_extension_state,
resolve_runtime_statement_config,
Expand Down Expand Up @@ -194,7 +195,7 @@ def __init__(
self,
*,
connection_config: "AdbcConnectionParams | dict[str, Any] | None" = None,
connection_instance: "Any" = None,
connection_instance: "AdbcConnection | None" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: StatementConfig | None = None,
driver_features: "AdbcDriverFeatures | dict[str, Any] | None" = None,
Expand Down Expand Up @@ -256,13 +257,15 @@ def create_connection(self) -> AdbcConnection:
"""

try:
connection = resolve_driver_connect_func(
self.connection_config.get("driver_name"), self.connection_config.get("uri")
)(**build_connection_config(self.connection_config))
driver_name = cast("str | None", self.connection_config.get("driver_name"))
uri = cast("str | None", self.connection_config.get("uri"))
connection = resolve_driver_connect_func(driver_name, uri)(
**build_connection_config(self.connection_config)
)
return cast("AdbcConnection", connection)
except Exception as e:
driver_name = self.connection_config.get("driver_name", "Unknown")
msg = f"Could not configure connection using driver '{driver_name}'. Error: {e}"
err_driver_name = self.connection_config.get("driver_name", "Unknown")
msg = f"Could not configure connection using driver '{err_driver_name}'. Error: {e}"
raise ImproperConfigurationError(msg) from e

def _update_dialect_for_extensions(self) -> None:
Expand All @@ -271,7 +274,7 @@ def _update_dialect_for_extensions(self) -> None:
Priority: paradedb > pgvector > postgres (default).
Only switches when current dialect is ``postgres``.
"""
current_dialect = getattr(self.statement_config, "dialect", "postgres")
current_dialect = self.statement_config.dialect or "postgres"
if current_dialect != "postgres":
return

Expand All @@ -289,7 +292,7 @@ def _detect_extensions_if_needed(self) -> None:
if self._pgvector_available is not None:
return

dialect = getattr(self.statement_config, "dialect", "")
dialect = resolve_dialect_name(self.statement_config.dialect)
if not is_postgres_dialect(dialect):
self._pgvector_available = False
self._paradedb_available = False
Expand Down
13 changes: 6 additions & 7 deletions sqlspec/adapters/aiomysql/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from collections.abc import Callable, Sized
from typing import TYPE_CHECKING, Any

from aiomysql import SSCursor

from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile
from sqlspec.driver import rows_to_dicts
from sqlspec.exceptions import (
Expand All @@ -24,6 +22,7 @@
TransactionError,
UniqueViolationError,
)
from sqlspec.protocols import HasSqlStateProtocol, HasTypeCodeProtocol
from sqlspec.utils.serializers import from_json, to_json
from sqlspec.utils.text import quote_backtick_identifier, split_qualified_identifier
from sqlspec.utils.type_converters import build_uuid_coercions
Expand Down Expand Up @@ -189,6 +188,8 @@ def __init__(self, driver: Any, sql: str, parameters: Any, chunk_size: int) -> N
self._column_names: list[str] | None = None

async def start(self) -> None:
from aiomysql import SSCursor

handler = self._driver.handle_database_exceptions()
async with handler:
cursor = await self._driver.connection.cursor(SSCursor)
Expand Down Expand Up @@ -324,13 +325,11 @@ def create_mapped_exception(error: Any, *, logger: Any | None = None) -> "SQLSpe
Returns:
True to suppress expected migration errors, or a SQLSpec exception
"""
error_args = getattr(error, "args", ())
error_args = error.args
error_code = error_args[0] if error_args and isinstance(error_args[0], int) else None
sqlstate_attr = getattr(error, "sqlstate", None)
sqlstate = sqlstate_attr if isinstance(sqlstate_attr, str) else None
sqlstate = error.sqlstate if isinstance(error, HasSqlStateProtocol) else None
sqlstate_prefix = sqlstate[:2] if isinstance(sqlstate, str) and sqlstate else None

# Migration-specific errors to suppress
if error_code in _MYSQL_MIGRATION_ERROR_CODES:
if logger is not None:
logger.warning("aiomysql MySQL expected migration error (ignoring): %s", error)
Expand Down Expand Up @@ -402,7 +401,7 @@ def detect_json_columns_from_description(
if isinstance(column, (tuple, list)):
type_code = column[1] if len(column) > 1 else None
else:
type_code = getattr(column, "type_code", None)
type_code = column.type_code if isinstance(column, HasTypeCodeProtocol) else None
if type_code in json_type_codes:
append(index)
return json_indexes
Expand Down
8 changes: 4 additions & 4 deletions sqlspec/adapters/arrow_odbc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlspec.driver._sync import SyncPoolConnectionContext, SyncPoolSessionFactory
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.protocols import SupportsCloseProtocol
from sqlspec.utils.config_tools import normalize_connection_config

if TYPE_CHECKING:
Expand Down Expand Up @@ -124,7 +125,7 @@ def __init__(
self,
*,
connection_config: "ArrowOdbcConnectionParams | dict[str, Any] | None" = None,
connection_instance: "Any" = None,
connection_instance: "ArrowOdbcConnection | None" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "ArrowOdbcDriverFeatures | dict[str, Any] | None" = None,
Expand Down Expand Up @@ -204,6 +205,5 @@ def get_event_runtime_hints(self) -> "EventRuntimeHints":

def _close_arrow_odbc_connection(connection: "ArrowOdbcConnection") -> None:
"""Close connection objects from compatible wrappers when they expose close()."""
close = getattr(connection, "close", None)
if close is not None:
close()
if isinstance(connection, SupportsCloseProtocol):
connection.close()
13 changes: 6 additions & 7 deletions sqlspec/adapters/asyncmy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from collections.abc import Callable, Sized
from typing import TYPE_CHECKING, Any

from asyncmy.cursors import SSCursor

from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile
from sqlspec.driver import rows_to_dicts
from sqlspec.exceptions import (
Expand All @@ -24,6 +22,7 @@
TransactionError,
UniqueViolationError,
)
from sqlspec.protocols import HasSqlStateProtocol, HasTypeCodeProtocol
from sqlspec.utils.serializers import from_json, to_json
from sqlspec.utils.text import quote_backtick_identifier, split_qualified_identifier
from sqlspec.utils.type_converters import build_uuid_coercions
Expand Down Expand Up @@ -161,6 +160,8 @@ def __init__(self, driver: Any, sql: str, parameters: Any, chunk_size: int) -> N
self._column_names: list[str] | None = None

async def start(self) -> None:
from asyncmy.cursors import SSCursor

handler = self._driver.handle_database_exceptions()
async with handler:
cursor = self._driver.connection.cursor(SSCursor)
Expand Down Expand Up @@ -296,13 +297,11 @@ def create_mapped_exception(error: Any, *, logger: Any | None = None) -> "SQLSpe
Returns:
True to suppress expected migration errors, or a SQLSpec exception
"""
error_args = getattr(error, "args", ())
error_args = error.args
error_code = error_args[0] if error_args and isinstance(error_args[0], int) else None
sqlstate_attr = getattr(error, "sqlstate", None)
sqlstate = sqlstate_attr if isinstance(sqlstate_attr, str) else None
sqlstate = error.sqlstate if isinstance(error, HasSqlStateProtocol) else None
sqlstate_prefix = sqlstate[:2] if isinstance(sqlstate, str) and sqlstate else None

# Migration-specific errors to suppress
if error_code in _MYSQL_MIGRATION_ERROR_CODES:
if logger is not None:
logger.warning("AsyncMy MySQL expected migration error (ignoring): %s", error)
Expand Down Expand Up @@ -374,7 +373,7 @@ def detect_json_columns_from_description(
if isinstance(column, (tuple, list)):
type_code = column[1] if len(column) > 1 else None
else:
type_code = getattr(column, "type_code", None)
type_code = column.type_code if isinstance(column, HasTypeCodeProtocol) else None
if type_code in json_type_codes:
append(index)
return json_indexes
Expand Down
13 changes: 11 additions & 2 deletions sqlspec/adapters/asyncpg/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import TYPE_CHECKING, Any

from asyncpg import Pool
from asyncpg import Pool, PostgresError
from asyncpg.pool import PoolConnectionProxy
from asyncpg.prepared_stmt import PreparedStatement

Expand All @@ -22,14 +22,23 @@

AsyncpgConnection: TypeAlias = Connection[Record] | PoolConnectionProxy[Record]
AsyncpgPool: TypeAlias = Pool[Record]
AsyncpgPostgresError: TypeAlias = PostgresError
AsyncpgPreparedStatement: TypeAlias = PreparedStatement[Record]

if not TYPE_CHECKING:
AsyncpgConnection = PoolConnectionProxy
AsyncpgPool = Pool
AsyncpgPostgresError = PostgresError
AsyncpgPreparedStatement = PreparedStatement

__all__ = ("AsyncpgConnection", "AsyncpgCursor", "AsyncpgPool", "AsyncpgPreparedStatement", "AsyncpgSessionContext")
__all__ = (
"AsyncpgConnection",
"AsyncpgCursor",
"AsyncpgPool",
"AsyncpgPostgresError",
"AsyncpgPreparedStatement",
"AsyncpgSessionContext",
)


class AsyncpgCursor:
Expand Down
17 changes: 7 additions & 10 deletions sqlspec/adapters/asyncpg/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from io import BytesIO
from typing import TYPE_CHECKING, Any, Final, cast

import asyncpg

from sqlspec.adapters.asyncpg._typing import AsyncpgCursor, AsyncpgSessionContext
from sqlspec.adapters.asyncpg._typing import AsyncpgCursor, AsyncpgPostgresError, AsyncpgSessionContext
from sqlspec.adapters.asyncpg.core import (
PREPARED_STATEMENT_CACHE_SIZE,
AsyncpgStreamSource,
Expand Down Expand Up @@ -77,7 +75,7 @@ class AsyncpgExceptionHandler(BaseAsyncExceptionHandler):

def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
_ = exc_type
if isinstance(exc_val, asyncpg.PostgresError) or has_sqlstate(exc_val):
if isinstance(exc_val, AsyncpgPostgresError) or has_sqlstate(exc_val):
self.pending_exception = create_mapped_exception(exc_val)
return True
return False
Expand Down Expand Up @@ -217,23 +215,23 @@ async def begin(self) -> None:
"""Begin a database transaction."""
try:
await self.connection.execute("BEGIN")
except asyncpg.PostgresError as e:
except AsyncpgPostgresError as e:
msg = f"Failed to begin async transaction: {e}"
raise SQLSpecError(msg) from e

async def commit(self) -> None:
"""Commit the current transaction."""
try:
await self.connection.execute("COMMIT")
except asyncpg.PostgresError as e:
except AsyncpgPostgresError as e:
msg = f"Failed to commit async transaction: {e}"
raise SQLSpecError(msg) from e

async def rollback(self) -> None:
"""Rollback the current transaction."""
try:
await self.connection.execute("ROLLBACK")
except asyncpg.PostgresError as e:
except AsyncpgPostgresError as e:
msg = f"Failed to rollback async transaction: {e}"
raise SQLSpecError(msg) from e

Expand Down Expand Up @@ -414,13 +412,12 @@ async def load_from_arrow(
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
"""Load Arrow data into a PostgreSQL table via COPY."""

self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
if overwrite:
try:
await self.connection.execute(f"TRUNCATE TABLE {table}")
except asyncpg.PostgresError as exc:
except AsyncpgPostgresError as exc:
msg = f"Failed to truncate table '{table}': {exc}"
raise SQLSpecError(msg) from exc
columns, records = self._arrow_table_to_rows(arrow_table)
Expand Down Expand Up @@ -504,7 +501,7 @@ async def _handle_copy_operation(self, cursor: "AsyncpgConnection", statement: "

execution_args = statement.statement_config.execution_args
metadata: dict[str, Any] = dict(execution_args) if execution_args else {}
if getattr(statement, "is_processed", False):
if statement.is_processed:
sql_text = statement.get_processed_state().compiled_sql
else:
sql_text, _ = self._get_compiled_sql(statement, statement.statement_config)
Expand Down
19 changes: 17 additions & 2 deletions sqlspec/adapters/bigquery/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from typing import TYPE_CHECKING, Any

from google.cloud.bigquery import ArrayQueryParameter, Client, QueryJob, ScalarQueryParameter
from google.cloud.exceptions import GoogleCloudError

from sqlspec.typing import import_optional

if TYPE_CHECKING:
from collections.abc import Callable
Expand All @@ -19,12 +22,24 @@

BigQueryConnection: TypeAlias = Client
BigQueryParam: TypeAlias = ArrayQueryParameter | ScalarQueryParameter
BigQueryStorageWriteModule: Any
BigQueryStorageWriteTypes: Any

if not TYPE_CHECKING:
BigQueryConnection = Client
BigQueryParam = ArrayQueryParameter | ScalarQueryParameter

__all__ = ("BigQueryConnection", "BigQueryCursor", "BigQueryParam", "BigQuerySessionContext")
BigQueryStorageWriteModule = import_optional("google.cloud.bigquery_storage_v1")
BigQueryStorageWriteTypes = import_optional("google.cloud.bigquery_storage_v1.types")

__all__ = (
"BigQueryConnection",
"BigQueryCursor",
"BigQueryParam",
"BigQuerySessionContext",
"BigQueryStorageWriteModule",
"BigQueryStorageWriteTypes",
"GoogleCloudError",
)


class BigQueryCursor:
Expand Down
Loading
Loading