Skip to content

Commit 81f065f

Browse files
fix(backend/kernel): comparator parity — async statement retention + intervals_as_string + precision/scale
Three parity-blocking issues surfaced by the python-comparator audit (tests/comparator-tests/python/), bundled into one PR since they're all small kernel-backend client changes. ## 1. Async Statement retention (original scope of this PR) `KernelDatabricksClient.execute_command` closed the parent `Statement` in `finally` regardless of `async_op`. The kernel's `Statement.close()` invalidates child handles (see `databricks-sql-kernel/src/statement/ validity.rs`), so the async handle was being killed before the user could poll it, breaking the entire async surface (`execute_async` → `is_query_pending` → `get_async_execution_result`). Fix: when `async_op=True`, retain the parent Statement in a new `_async_statements` dict alongside `_async_handles`, and close it from `close_command`, `close_session`, and `get_execution_result` after the executed handle is done. Comparator outcome: STATEMENT_ASYNC suite 3/3 match (was 0/3). ## 2. intervals_as_string wire-through (companion to kernel PR #64) pyarrow's Python bindings cannot decode Arrow's `month_interval` type at all (id 21 — raises `KeyError` from `.as_py`, `to_pylist`, `cast(string)`, `to_pandas`). Every kernel-backend `SELECT *` over any table with an INTERVAL YEAR TO MONTH column was throwing `ArrowNotImplementedError` — 32/88 audit diffs. Fix: pass `intervals_as_string=True` to the kernel `Session(...)` constructor unconditionally. The kernel post-processor then stringifies Interval / Duration columns server-side to Utf8 (see kernel PR #64), so pyarrow never sees the unreadable type. Comparator outcome: bucket A (ArrowNotImplementedError) drops from 32 → 0 diffs. ## 3. Decimal precision/scale extraction (new) `description_from_arrow_schema` hard-coded `None` for slots 4 and 5 of the PEP 249 description tuple. For DECIMAL columns the Arrow schema carries precision/scale on `Decimal128Type.precision` / `.scale`, but we were silently dropping them — so kernel-backend `cursor.description[i]` returned `('decimal_column', 'decimal', None, None, None, None, None)` while Thrift returned `('decimal_column', 'decimal', None, None, 10, 2, None)`. That diff propagates to any consumer that reads PEP 249 slots 4/5 (SQLAlchemy, pandas-read-sql, etc.) so they can't tell DECIMAL(10,2) from DECIMAL(38,18) on the kernel backend. 88 comparator diffs (44 precision + 44 scale). Fix: factor out `_precision_scale_for_arrow_type(arrow_type)` and call it from the description builder. Today it only handles decimals; future extensions (e.g. fractional-second precision from `Time64Type`) slot in here without touching the description builder. Comparator outcome: 88 precision+scale diffs → 0. ## Dependencies Both intervals_as_string and the empty-result schema fix in kernel PR #64 are required for the parity gains to land. The driver-side fixes here work standalone but the comparator outcome numbers assume PR #64 is also live. ## Headline comparator results - Before this PR (and PR #64): 60 match / 88 diff (out of 148) - After this PR + PR #64: 103 match / 45 diff - 17 of 30 (connection_config × suite) pairs now fully clean. - Remaining 45 diffs cluster into 4 known causes documented in ~/docs/python-kernel/comparator-diff-tasklist.md (complex types in fetchall_arrow path, timestamp tz, VOID, METADATA pattern matching). Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent fb55001 commit 81f065f

2 files changed

Lines changed: 110 additions & 13 deletions

File tree

src/databricks/sql/backend/kernel/client.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ def __init__(
129129
# Guarded by ``_async_handles_lock`` so concurrent cursors on the
130130
# same connection don't race on submit / close / close-session.
131131
self._async_handles: Dict[str, Any] = {}
132+
# Parent ``Statement`` objects kept alive alongside async handles.
133+
# On the kernel, ``Statement.close()`` flips the validity flag on
134+
# the produced executed handle (see kernel
135+
# ``statement::mutable::close``), so we cannot close the
136+
# Statement immediately after ``submit()`` as we do for sync
137+
# ``execute()``. Instead retain it here and close it in
138+
# ``close_command`` / ``close_session`` after the async handle
139+
# has finished its work.
140+
self._async_statements: Dict[str, Any] = {}
132141
# CommandId.guids of async commands that have already been
133142
# closed (via ``close_command`` or ``close_session``). Lets
134143
# ``get_query_state`` report ``CLOSED`` for them rather than
@@ -167,6 +176,16 @@ def open_session(
167176
schema=schema or self._schema,
168177
session_conf=session_conf,
169178
complex_types_as_json=not self._use_arrow_native_complex_types,
179+
# Pyarrow's Python bindings cannot decode Arrow's
180+
# ``month_interval`` type at all (id 21 — raises
181+
# ``KeyError`` from ``.as_py``, ``to_pylist``,
182+
# ``cast(string)``, and ``to_pandas``). Ask the kernel
183+
# to stringify INTERVAL / DURATION columns server-side
184+
# so result sets containing interval columns are
185+
# decodable on the Python side. Matches the Thrift
186+
# backend's surface (interval columns arrive as
187+
# strings).
188+
intervals_as_string=True,
170189
**auth_kwargs,
171190
)
172191
except Exception as exc:
@@ -197,7 +216,9 @@ def close_session(self, session_id: SessionId) -> None:
197216
# server-side CloseStatement before the session goes away.
198217
with self._async_handles_lock:
199218
tracked = list(self._async_handles.items())
219+
tracked_stmts = list(self._async_statements.items())
200220
self._async_handles.clear()
221+
self._async_statements.clear()
201222
for guid, _ in tracked:
202223
self._closed_commands.add(guid)
203224
for _, handle in tracked:
@@ -211,6 +232,16 @@ def close_session(self, session_id: SessionId) -> None:
211232
logger.warning(
212233
"Error closing async handle during session close: %s", exc
213234
)
235+
# Now drop the parent Statements that were keeping those handles
236+
# alive. Same non-fatal close semantics — close errors are not
237+
# actionable at session-close time.
238+
for _, stmt in tracked_stmts:
239+
try:
240+
stmt.close()
241+
except Exception as exc:
242+
logger.warning(
243+
"Error closing async statement during session close: %s", exc
244+
)
214245
try:
215246
self._kernel_session.close()
216247
except Exception as exc:
@@ -249,6 +280,11 @@ def execute_command(
249280
stmt = self._kernel_session.statement()
250281
except Exception as exc:
251282
raise _wrap_kernel_exception("execute_command", exc) from exc
283+
# ``async_op`` keeps ``stmt`` alive (tracked in
284+
# ``_async_statements`` and closed by ``close_command``); the sync
285+
# path drops it in finally. ``close_stmt`` is the post-success
286+
# decision flag — it stays True on sync, flips to False on async.
287+
close_stmt = True
252288
try:
253289
try:
254290
stmt.set_sql(operation)
@@ -262,21 +298,26 @@ def execute_command(
262298
cursor.active_command_id = command_id
263299
with self._async_handles_lock:
264300
self._async_handles[command_id.guid] = async_exec
301+
# Closing the kernel ``Statement`` invalidates the
302+
# async handle (see kernel validity flag). Retain
303+
# the Statement here and close it on
304+
# ``close_command`` / ``close_session``.
305+
self._async_statements[command_id.guid] = stmt
306+
close_stmt = False
265307
return None
266308
executed = stmt.execute()
267309
except Exception as exc:
268310
raise _wrap_kernel_exception("execute_command", exc) from exc
269311
finally:
270-
# ``Statement`` is a lifecycle owner separate from the
271-
# executed handle it produces. Drop it here so the
272-
# parent doesn't keep the handle alive longer than the
273-
# caller expects. Swallow all close errors (including
274-
# PyO3 native exceptions) — a failed stmt.close() is
275-
# not actionable for the caller.
276-
try:
277-
stmt.close()
278-
except Exception:
279-
pass
312+
if close_stmt:
313+
# Sync path: ``Statement`` is a lifecycle owner separate
314+
# from the executed handle. Drop it here so the parent
315+
# doesn't outlive its caller. Swallow close errors —
316+
# they're not actionable.
317+
try:
318+
stmt.close()
319+
except Exception:
320+
pass
280321

281322
command_id = CommandId.from_sea_statement_id(executed.statement_id)
282323
cursor.active_command_id = command_id
@@ -307,17 +348,34 @@ def cancel_command(self, command_id: CommandId) -> None:
307348
def close_command(self, command_id: CommandId) -> None:
308349
with self._async_handles_lock:
309350
handle = self._async_handles.pop(command_id.guid, None)
351+
stmt = self._async_statements.pop(command_id.guid, None)
310352
if handle is not None:
311353
# Record the close so ``get_query_state`` can report
312354
# ``CLOSED`` (not ``SUCCEEDED``) for this command.
313355
self._closed_commands.add(command_id.guid)
314356
if handle is None:
315357
logger.debug("close_command: no tracked handle for %s", command_id)
358+
# Still drop the parent Statement if somehow tracked without
359+
# the handle — keeps the invariant clean even on bookkeeping
360+
# races.
361+
if stmt is not None:
362+
try:
363+
stmt.close()
364+
except Exception:
365+
pass
316366
return
317367
try:
318368
handle.close()
319369
except Exception as exc:
320370
raise _wrap_kernel_exception("close_command", exc) from exc
371+
finally:
372+
# Now safe to close the parent Statement — the executed
373+
# handle has finished its lifecycle.
374+
if stmt is not None:
375+
try:
376+
stmt.close()
377+
except Exception:
378+
pass
321379

322380
def get_query_state(self, command_id: CommandId) -> CommandState:
323381
with self._async_handles_lock:
@@ -378,6 +436,7 @@ def get_execution_result(
378436
# it wraps. Drop tracking and fire-and-forget the close.
379437
with self._async_handles_lock:
380438
self._async_handles.pop(command_id.guid, None)
439+
stmt = self._async_statements.pop(command_id.guid, None)
381440
self._closed_commands.add(command_id.guid)
382441
try:
383442
async_exec.close()
@@ -387,6 +446,18 @@ def get_execution_result(
387446
command_id,
388447
exc,
389448
)
449+
# The parent Statement is no longer needed once the async handle
450+
# has produced its ResultStream. Close to release server-side
451+
# tracking; matches the sync path's eager Statement close.
452+
if stmt is not None:
453+
try:
454+
stmt.close()
455+
except Exception as exc:
456+
logger.warning(
457+
"Error closing async statement after await_result for %s: %s",
458+
command_id,
459+
exc,
460+
)
390461
# ``KernelResultSet.__init__`` calls ``arrow_schema()`` which
391462
# can raise — map that to PEP 249 too.
392463
try:

src/databricks/sql/backend/kernel/type_mapping.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import annotations
2323

24-
from typing import Any, List, Tuple
24+
from typing import Any, List, Optional, Tuple
2525

2626
import pyarrow
2727

@@ -102,21 +102,47 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]:
102102
backend's behaviour; other precise types (``INTERVAL_*``,
103103
``GEOMETRY``, ``GEOGRAPHY``) collapse to their Arrow shape on
104104
both backends and don't need a remap.
105+
106+
``precision`` / ``scale`` are extracted from ``Decimal128Type`` /
107+
``Decimal256Type`` so DECIMAL columns expose the same
108+
``(precision, scale)`` pair the Thrift backend reports. The Arrow
109+
schema carries these on the type itself; without this extraction
110+
the kernel-backend description would silently drop them, breaking
111+
parity for any consumer (SQLAlchemy, pandas-read-sql, etc.) that
112+
reads slots 4/5 to know how to display or round decimal values.
105113
"""
106114
return [
107115
(
108116
field.name,
109117
_databricks_type_for_field(field),
110118
None,
111119
None,
112-
None,
113-
None,
120+
*_precision_scale_for_arrow_type(field.type),
114121
None,
115122
)
116123
for field in schema
117124
]
118125

119126

127+
def _precision_scale_for_arrow_type(
128+
arrow_type: pyarrow.DataType,
129+
) -> Tuple[Optional[int], Optional[int]]:
130+
"""Extract PEP 249 ``(precision, scale)`` from an Arrow type.
131+
132+
Only Arrow's decimal types carry both; every other type collapses
133+
to ``(None, None)`` to match the Thrift backend's behaviour. Future
134+
extensions (e.g. fractional-second precision from
135+
``Time64Type`` / ``Timestamp``) can land here without touching the
136+
description builder above.
137+
"""
138+
if pyarrow.types.is_decimal(arrow_type):
139+
# Decimal128Type / Decimal256Type both expose `.precision` and
140+
# `.scale`. The cast is for the type checker — pyarrow's
141+
# `DataType` base type doesn't declare them.
142+
return arrow_type.precision, arrow_type.scale # type: ignore[attr-defined]
143+
return None, None
144+
145+
120146
def _databricks_type_for_field(field: pyarrow.Field) -> str:
121147
"""Pick the PEP 249 type code for a single field.
122148

0 commit comments

Comments
 (0)