Skip to content

Commit 932ec74

Browse files
fix(backend/kernel): comparator parity — async retention + intervals_as_string + precision/scale + named params + utils.exc fix
Five small comparator-parity fixes surfaced by the same audit run. All are independently small but ship together because they share the kernel-backend client surface and the audit baseline. ### 1. Async Statement retention `KernelDatabricksClient.execute_command` closed the parent `Statement` in `finally` regardless of `async_op`. The kernel's `Statement.close()` invalidates child handles (see databricks-sql-kernel `src/statement/validity.rs`), so the async handle died before the user could poll it — breaking `execute_async` → `is_query_pending` → `get_async_execution_result`. Fix: when `async_op=True`, retain the parent Statement in a new `_async_statements` dict alongside `_async_handles`, and close it from `close_command`, `close_session`, and `get_execution_result` after the executed handle is done. Comparator outcome: STATEMENT_ASYNC suite 3/3 match (was 0/3). ### 2. `intervals_as_string` wire-through pyarrow's Python bindings cannot decode Arrow's `month_interval` type (id 21 — `KeyError` from `.as_py`, `to_pylist`, `cast(string)`, `to_pandas`). Every kernel-backend `SELECT *` over a table with an `INTERVAL YEAR TO MONTH` column raised `ArrowNotImplementedError` — 32 / 88 audit diffs. Fix: pass `intervals_as_string=True` to the kernel `Session(...)` constructor unconditionally. The kernel post-processor stringifies `Interval` / `Duration` columns server-side to `Utf8` (kernel PR #64). Comparator outcome: bucket A (ArrowNotImplementedError) 32 → 0. ### 3. Decimal precision/scale extraction `description_from_arrow_schema` hard-coded `None` for PEP 249 description-tuple slots 4 and 5. For DECIMAL columns the Arrow schema carries `precision` / `scale` on `Decimal128Type`, but they were silently dropped — 88 cell diffs (44 precision + 44 scale). Fix: factor out `_precision_scale_for_arrow_type(arrow_type)` and call it from the description builder. Today it only handles decimals; future extensions slot in here. Comparator outcome: 88 precision+scale diffs → 0. ### 4. Named-parameter binding `bind_tspark_params` raised `NotSupportedError` for any `TSparkParameter` with `name` set. The canonical SEA proto marks `StatementParameter.name` as `openapi_required=true` (named is the spec-required public form; `ordinal` is `PUBLIC_UNDOCUMENTED`). Kernel PR #65 added a `Statement.bind_named_param` PyO3 API. Fix: route named bindings via the new API. Positional ordinals self-increment so a named entry in the middle of the list doesn't claim a positional slot. Comparator outcome: PREPARED_STATEMENT_NAMED 1/1 match (was 0/1). Full thrift-vs-kernel run: 97/132 match (was 96/132). ### 5. `utils.ParamEscaper` `exc.ProgrammingError` import fix `ParamEscaper.escape_args` and `escape_item` both raised `exc.ProgrammingError(...)`, but `exc` was never imported. On any unsupported parameter shape the user saw `NameError: name 'exc' is not defined` instead of a clean PEP 249 `ProgrammingError`. Surfaced via the same audit harness running INLINE_PARAMS: both backends raised `NameError`, which the comparator's class+message match treated as parity — a false-positive that hid both the driver bug and the underlying caller-side type mismatch. Fix: import `ProgrammingError` directly from `databricks.sql.exc` (matching the pattern used in `session.py`, `client.py`, `result_set.py`, etc.) and replace the two `exc.ProgrammingError(...)` sites with bare `ProgrammingError(...)`. ## Dependencies Items #2, #3, and #4 require the matching databricks-sql-kernel changes: PR #64 (`intervals_as_string` + empty-result schema fix) and PR #65 (named-param binding). For local testing the comparator harness uses `KERNEL_FREEZE=1` against a kernel checkout of the feature branches. ## Headline comparator results (full thrift-vs-kernel run) | | Before | After | |------------------------------------|--------|-------| | match / diff / skipped | 60 / 88 / 0 | **97 / 34 / 1** | | Bucket A (ArrowNotImplementedError)| 32 | 0 | | `decimal_column` precision/scale | 88 | 0 | | Bucket B1 (named params) | 1 | 0 | | Suites fully clean | 12 / 30 | **17 / 30** | Remaining 34 diffs cluster into documented causes (complex types in `fetchall_arrow`, timestamp tz on Arrow path, VOID surface, METADATA pattern semantics) — tracked in `~/docs/python-kernel/comparator-diff-tasklist.md`. ## Test plan - [x] Manual repro of async path: `cur.execute_async("SELECT 1")` → `is_query_pending` → `get_async_execution_result` → `fetchall` succeeds on `use_kernel=True` - [x] Manual repro of interval path: `cur.execute("SELECT ym_interval_column FROM ...")` returns string-shaped rows matching the Thrift surface - [x] Manual repro of decimal path: `cur.description` for a `DECIMAL(10,2)` column now reports `precision=10, scale=2` - [x] Live e2e: `test_parameterized_query_named_params` / `test_parameterized_query_named_param_with_null` (added) - [x] Unit tests: `bind_tspark_params` named/positional/mixed routing (`tests/unit/test_kernel_type_mapping.py`) - [x] Manual repro of utils fix: `ParamEscaper().escape_args(object())` raises `ProgrammingError`, not `NameError` - [x] Full thrift-vs-kernel comparator: 97 / 34 / 1 of 132 - [x] Existing connector unit suite: 752 passed Co-authored-by: Isaac
1 parent fb55001 commit 932ec74

5 files changed

Lines changed: 195 additions & 53 deletions

File tree

src/databricks/sql/backend/kernel/client.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ def __init__(
129129
# Guarded by ``_async_handles_lock`` so concurrent cursors on the
130130
# same connection don't race on submit / close / close-session.
131131
self._async_handles: Dict[str, Any] = {}
132+
# Parent ``Statement`` objects kept alive alongside async handles.
133+
# On the kernel, ``Statement.close()`` flips the validity flag on
134+
# the produced executed handle (see kernel
135+
# ``statement::mutable::close``), so we cannot close the
136+
# Statement immediately after ``submit()`` as we do for sync
137+
# ``execute()``. Instead retain it here and close it in
138+
# ``close_command`` / ``close_session`` after the async handle
139+
# has finished its work.
140+
self._async_statements: Dict[str, Any] = {}
132141
# CommandId.guids of async commands that have already been
133142
# closed (via ``close_command`` or ``close_session``). Lets
134143
# ``get_query_state`` report ``CLOSED`` for them rather than
@@ -167,6 +176,16 @@ def open_session(
167176
schema=schema or self._schema,
168177
session_conf=session_conf,
169178
complex_types_as_json=not self._use_arrow_native_complex_types,
179+
# Pyarrow's Python bindings cannot decode Arrow's
180+
# ``month_interval`` type at all (id 21 — raises
181+
# ``KeyError`` from ``.as_py``, ``to_pylist``,
182+
# ``cast(string)``, and ``to_pandas``). Ask the kernel
183+
# to stringify INTERVAL / DURATION columns server-side
184+
# so result sets containing interval columns are
185+
# decodable on the Python side. Matches the Thrift
186+
# backend's surface (interval columns arrive as
187+
# strings).
188+
intervals_as_string=True,
170189
**auth_kwargs,
171190
)
172191
except Exception as exc:
@@ -197,7 +216,9 @@ def close_session(self, session_id: SessionId) -> None:
197216
# server-side CloseStatement before the session goes away.
198217
with self._async_handles_lock:
199218
tracked = list(self._async_handles.items())
219+
tracked_stmts = list(self._async_statements.items())
200220
self._async_handles.clear()
221+
self._async_statements.clear()
201222
for guid, _ in tracked:
202223
self._closed_commands.add(guid)
203224
for _, handle in tracked:
@@ -211,6 +232,16 @@ def close_session(self, session_id: SessionId) -> None:
211232
logger.warning(
212233
"Error closing async handle during session close: %s", exc
213234
)
235+
# Now drop the parent Statements that were keeping those handles
236+
# alive. Same non-fatal close semantics — close errors are not
237+
# actionable at session-close time.
238+
for _, stmt in tracked_stmts:
239+
try:
240+
stmt.close()
241+
except Exception as exc:
242+
logger.warning(
243+
"Error closing async statement during session close: %s", exc
244+
)
214245
try:
215246
self._kernel_session.close()
216247
except Exception as exc:
@@ -249,6 +280,11 @@ def execute_command(
249280
stmt = self._kernel_session.statement()
250281
except Exception as exc:
251282
raise _wrap_kernel_exception("execute_command", exc) from exc
283+
# ``async_op`` keeps ``stmt`` alive (tracked in
284+
# ``_async_statements`` and closed by ``close_command``); the sync
285+
# path drops it in finally. ``close_stmt`` is the post-success
286+
# decision flag — it stays True on sync, flips to False on async.
287+
close_stmt = True
252288
try:
253289
try:
254290
stmt.set_sql(operation)
@@ -262,21 +298,26 @@ def execute_command(
262298
cursor.active_command_id = command_id
263299
with self._async_handles_lock:
264300
self._async_handles[command_id.guid] = async_exec
301+
# Closing the kernel ``Statement`` invalidates the
302+
# async handle (see kernel validity flag). Retain
303+
# the Statement here and close it on
304+
# ``close_command`` / ``close_session``.
305+
self._async_statements[command_id.guid] = stmt
306+
close_stmt = False
265307
return None
266308
executed = stmt.execute()
267309
except Exception as exc:
268310
raise _wrap_kernel_exception("execute_command", exc) from exc
269311
finally:
270-
# ``Statement`` is a lifecycle owner separate from the
271-
# executed handle it produces. Drop it here so the
272-
# parent doesn't keep the handle alive longer than the
273-
# caller expects. Swallow all close errors (including
274-
# PyO3 native exceptions) — a failed stmt.close() is
275-
# not actionable for the caller.
276-
try:
277-
stmt.close()
278-
except Exception:
279-
pass
312+
if close_stmt:
313+
# Sync path: ``Statement`` is a lifecycle owner separate
314+
# from the executed handle. Drop it here so the parent
315+
# doesn't outlive its caller. Swallow close errors —
316+
# they're not actionable.
317+
try:
318+
stmt.close()
319+
except Exception:
320+
pass
280321

281322
command_id = CommandId.from_sea_statement_id(executed.statement_id)
282323
cursor.active_command_id = command_id
@@ -307,17 +348,34 @@ def cancel_command(self, command_id: CommandId) -> None:
307348
def close_command(self, command_id: CommandId) -> None:
308349
with self._async_handles_lock:
309350
handle = self._async_handles.pop(command_id.guid, None)
351+
stmt = self._async_statements.pop(command_id.guid, None)
310352
if handle is not None:
311353
# Record the close so ``get_query_state`` can report
312354
# ``CLOSED`` (not ``SUCCEEDED``) for this command.
313355
self._closed_commands.add(command_id.guid)
314356
if handle is None:
315357
logger.debug("close_command: no tracked handle for %s", command_id)
358+
# Still drop the parent Statement if somehow tracked without
359+
# the handle — keeps the invariant clean even on bookkeeping
360+
# races.
361+
if stmt is not None:
362+
try:
363+
stmt.close()
364+
except Exception:
365+
pass
316366
return
317367
try:
318368
handle.close()
319369
except Exception as exc:
320370
raise _wrap_kernel_exception("close_command", exc) from exc
371+
finally:
372+
# Now safe to close the parent Statement — the executed
373+
# handle has finished its lifecycle.
374+
if stmt is not None:
375+
try:
376+
stmt.close()
377+
except Exception:
378+
pass
321379

322380
def get_query_state(self, command_id: CommandId) -> CommandState:
323381
with self._async_handles_lock:
@@ -378,6 +436,7 @@ def get_execution_result(
378436
# it wraps. Drop tracking and fire-and-forget the close.
379437
with self._async_handles_lock:
380438
self._async_handles.pop(command_id.guid, None)
439+
stmt = self._async_statements.pop(command_id.guid, None)
381440
self._closed_commands.add(command_id.guid)
382441
try:
383442
async_exec.close()
@@ -387,6 +446,18 @@ def get_execution_result(
387446
command_id,
388447
exc,
389448
)
449+
# The parent Statement is no longer needed once the async handle
450+
# has produced its ResultStream. Close to release server-side
451+
# tracking; matches the sync path's eager Statement close.
452+
if stmt is not None:
453+
try:
454+
stmt.close()
455+
except Exception as exc:
456+
logger.warning(
457+
"Error closing async statement after await_result for %s: %s",
458+
command_id,
459+
exc,
460+
)
390461
# ``KernelResultSet.__init__`` calls ``arrow_schema()`` which
391462
# can raise — map that to PEP 249 too.
392463
try:

src/databricks/sql/backend/kernel/type_mapping.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import annotations
2323

24-
from typing import Any, List, Tuple
24+
from typing import Any, List, Optional, Tuple
2525

2626
import pyarrow
2727

@@ -102,21 +102,47 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]:
102102
backend's behaviour; other precise types (``INTERVAL_*``,
103103
``GEOMETRY``, ``GEOGRAPHY``) collapse to their Arrow shape on
104104
both backends and don't need a remap.
105+
106+
``precision`` / ``scale`` are extracted from ``Decimal128Type`` /
107+
``Decimal256Type`` so DECIMAL columns expose the same
108+
``(precision, scale)`` pair the Thrift backend reports. The Arrow
109+
schema carries these on the type itself; without this extraction
110+
the kernel-backend description would silently drop them, breaking
111+
parity for any consumer (SQLAlchemy, pandas-read-sql, etc.) that
112+
reads slots 4/5 to know how to display or round decimal values.
105113
"""
106114
return [
107115
(
108116
field.name,
109117
_databricks_type_for_field(field),
110118
None,
111119
None,
112-
None,
113-
None,
120+
*_precision_scale_for_arrow_type(field.type),
114121
None,
115122
)
116123
for field in schema
117124
]
118125

119126

127+
def _precision_scale_for_arrow_type(
128+
arrow_type: pyarrow.DataType,
129+
) -> Tuple[Optional[int], Optional[int]]:
130+
"""Extract PEP 249 ``(precision, scale)`` from an Arrow type.
131+
132+
Only Arrow's decimal types carry both; every other type collapses
133+
to ``(None, None)`` to match the Thrift backend's behaviour. Future
134+
extensions (e.g. fractional-second precision from
135+
``Time64Type`` / ``Timestamp``) can land here without touching the
136+
description builder above.
137+
"""
138+
if pyarrow.types.is_decimal(arrow_type):
139+
# Decimal128Type / Decimal256Type both expose `.precision` and
140+
# `.scale`. The cast is for the type checker — pyarrow's
141+
# `DataType` base type doesn't declare them.
142+
return arrow_type.precision, arrow_type.scale # type: ignore[attr-defined]
143+
return None, None
144+
145+
120146
def _databricks_type_for_field(field: pyarrow.Field) -> str:
121147
"""Pick the PEP 249 type code for a single field.
122148
@@ -173,32 +199,19 @@ def _tspark_param_value_str(param: ttypes.TSparkParameter) -> Any:
173199
def bind_tspark_params(kernel_stmt, parameters: List[ttypes.TSparkParameter]) -> None:
174200
"""Bind a list of ``TSparkParameter`` onto a kernel ``Statement``.
175201
176-
The kernel expects positional bindings only (SEA v0 doesn't
177-
accept named bindings on the wire). The connector's
202+
Both positional and named bindings are supported. The connector's
178203
``TSparkParameter`` has an ``ordinal: bool`` flag; ``True`` means
179-
"treat as positional in source-list order". Named-binding
180-
parameters surface as ``NotSupportedError`` so the user gets a
181-
clear message instead of a server-side rejection.
204+
"treat as positional in source-list order", otherwise the
205+
parameter is bound by name via ``Statement.bind_named_param``.
182206
183207
Compound types (``ARRAY`` / ``MAP`` / ``STRUCT``) build a
184208
``TSparkParameter`` with the payload on ``arguments`` and
185209
``value=None`` — forwarding that would silently bind a typed
186210
NULL. Reject up front with ``NotSupportedError`` so callers get
187211
a clear message instead of silent data loss.
188212
"""
189-
for i, param in enumerate(parameters, start=1):
190-
# ``ordinal`` on connector-native params is a bool (True for
191-
# positional, False for named). Thrift defaults to ``None``;
192-
# treat any non-True value with a name as a named binding so
193-
# a future caller that forgets to set ordinal=True still gets
194-
# rejected instead of silently dropping the name.
195-
name = getattr(param, "name", None)
196-
if name and getattr(param, "ordinal", None) is not True:
197-
raise NotSupportedError(
198-
f"Named parameter binding (got name={name!r}) is not yet "
199-
"supported on the kernel backend; pass parameters positionally."
200-
)
201-
213+
positional_index = 0
214+
for param in parameters:
202215
sql_type = param.type or "STRING"
203216
# Compound types put their payload on ``arguments``, not
204217
# ``value``. The kernel parser doesn't accept them yet, and
@@ -214,7 +227,12 @@ def bind_tspark_params(kernel_stmt, parameters: List[ttypes.TSparkParameter]) ->
214227
)
215228

216229
value_str = _tspark_param_value_str(param)
217-
# The kernel takes 1-based ordinals; `i` is already that.
218-
# Errors from the kernel side (bad literal, unsupported type,
219-
# etc.) come up as KernelError and bubble through normally.
220-
kernel_stmt.bind_param(i, value_str, sql_type)
230+
# ``ordinal`` on connector-native params is a bool. ``True``
231+
# → positional (assign the next 1-based ordinal). Anything
232+
# else with a name → named binding.
233+
name = getattr(param, "name", None)
234+
if name and getattr(param, "ordinal", None) is not True:
235+
kernel_stmt.bind_named_param(name, value_str, sql_type)
236+
else:
237+
positional_index += 1
238+
kernel_stmt.bind_param(positional_index, value_str, sql_type)

src/databricks/sql/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
pyarrow = None
2020

2121
from databricks.sql import OperationalError
22+
from databricks.sql.exc import ProgrammingError
2223
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
2324
from databricks.sql.thrift_api.TCLIService.ttypes import (
2425
TRowSet,
@@ -548,7 +549,7 @@ def escape_args(self, parameters):
548549
elif isinstance(parameters, (list, tuple)):
549550
return tuple(self.escape_item(x) for x in parameters)
550551
else:
551-
raise exc.ProgrammingError(
552+
raise ProgrammingError(
552553
"Unsupported param format: {}".format(parameters)
553554
)
554555

@@ -606,7 +607,7 @@ def escape_item(self, item):
606607
elif isinstance(item, Mapping):
607608
return self.escape_mapping(item)
608609
else:
609-
raise exc.ProgrammingError("Unsupported object {}".format(item))
610+
raise ProgrammingError("Unsupported object {}".format(item))
610611

611612

612613
def inject_parameters(operation: str, parameters: Dict[str, str]):

tests/e2e/test_kernel_backend.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,34 @@ def test_parameterized_query_with_null(conn):
241241
assert rows[0][0] is True
242242

243243

244+
def test_parameterized_query_named_params(conn):
245+
"""Named parameter binding via the kernel backend. The
246+
connector passes ``parameters={name: value}`` dicts (DB-API
247+
style); the kernel forwards them through ``bind_named_param``
248+
so the SEA wire payload sets ``StatementParameter.name`` (the
249+
spec-required public form per canonical proto).
250+
"""
251+
with conn.cursor() as cur:
252+
cur.execute(
253+
"SELECT :n AS n, :s AS s, :b AS b",
254+
{"n": 42, "s": "alice", "b": True},
255+
)
256+
rows = cur.fetchall()
257+
assert len(rows) == 1
258+
assert rows[0][0] == 42
259+
assert rows[0][1] == "alice"
260+
assert rows[0][2] is True
261+
262+
263+
def test_parameterized_query_named_param_with_null(conn):
264+
"""``None`` value in a named binding flows through as
265+
VoidParameter → kernel ``TypedValue::Null``."""
266+
with conn.cursor() as cur:
267+
cur.execute("SELECT :x IS NULL AS is_null", {"x": None})
268+
rows = cur.fetchall()
269+
assert rows[0][0] is True
270+
271+
244272
def test_parameterized_query_decimal(conn):
245273
"""DECIMAL parameters carry precision/scale in the SQL type
246274
string ('DECIMAL(p,s)') — the kernel parser extracts them so

0 commit comments

Comments
 (0)