diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs index b1b9f99dc..a6ea6671c 100644 --- a/crates/core/src/codec.rs +++ b/crates/core/src/codec.rs @@ -19,11 +19,11 @@ //! //! Datafusion-python plans can carry references to Python-defined //! objects that the upstream protobuf codecs do not know how to -//! serialize: pure-Python scalar UDFs, Python query-planning -//! extensions, and so on. Their state lives inside `Py` -//! callables and closures rather than being recoverable from a name -//! in the receiver's function registry. To ship a plan across a -//! process boundary (pickle, `multiprocessing`, Ray actor, +//! serialize: pure-Python scalar / aggregate / window UDFs, Python +//! query-planning extensions, and so on. Their state lives inside +//! `Py` callables and closures rather than being recoverable +//! from a name in the receiver's function registry. To ship a plan +//! across a process boundary (pickle, `multiprocessing`, Ray actor, //! `datafusion-distributed`, etc.) those payloads have to be encoded //! into the proto wire format itself. //! @@ -256,7 +256,12 @@ impl PythonLogicalCodec { /// `cloudpickle.loads` on the inline `DFPY*` payload. It does /// **not** make `pickle.loads(untrusted_bytes)` safe; treat every /// `pickle.loads` on untrusted input as unsafe regardless of this - /// setting. + /// setting. See `docs/source/user-guide/io/distributing_work.rst` + /// (Security section) for the full threat model, and Python's + /// [pickle module security warning][1] for why `pickle.loads` is + /// unsafe in general. + /// + /// [1]: https://docs.python.org/3/library/pickle.html#module-pickle pub fn with_python_udf_inlining(mut self, enabled: bool) -> Self { self.python_udf_inlining = enabled; self @@ -433,7 +438,7 @@ fn refuse_inline_payload(kind: &str, name: &str) -> datafusion::error::DataFusio /// encoding on this layer too — otherwise a plan with a Python UDF /// would round-trip at the logical level but break at the physical /// level. Both layers reuse the shared payload framing -/// ([`PY_SCALAR_UDF_FAMILY`]) so the wire format is identical. +/// ([`PY_SCALAR_UDF_FAMILY`] et al.) so the wire format is identical. #[derive(Debug)] pub struct PythonPhysicalCodec { inner: Arc, diff --git a/docs/source/user-guide/io/distributing_work.rst b/docs/source/user-guide/io/distributing_work.rst new file mode 100644 index 000000000..104409690 --- /dev/null +++ b/docs/source/user-guide/io/distributing_work.rst @@ -0,0 +1,391 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Distributing DataFusion work +============================ + +Splitting a DataFusion workload across multiple processes — for +throughput, isolation, or to use a worker pool — comes in a few +different shapes depending on what is being split. + +* **Expression-level distribution** ✅ *supported today*. The driver + builds a DataFusion :py:class:`~datafusion.Expr`, sends it to + worker processes, and each worker evaluates the expression against + its own slice of data. Suits embarrassingly-parallel workloads + where the driver decides up front how to partition. +* **Query-level distribution via datafusion-distributed** 🚧 *work in + progress upstream*. A single logical / physical plan is split into + stages and run across worker nodes. The driver writes one SQL or + DataFrame query; the runtime decides partitioning. +* **Query-level distribution via Apache Ballista** 🚧 *work in + progress upstream*. Similar query-level model, with a more + cluster-management-oriented runtime. + +Only the first option is ready for use from datafusion-python today. +The other two are documented below so the surrounding story is in +one place; integration details will land here as those projects +become usable from datafusion-python. + +Expression-level distribution +----------------------------- + +DataFusion expressions support distribution directly: pass one to a +worker process and Python's standard +`pickle `_ machinery +serializes it transparently — the same machinery +:py:meth:`multiprocessing.pool.Pool.map`, Ray's ``@ray.remote``, and +similar libraries already use to ship function arguments. Python UDFs +— scalar, aggregate, and window — travel inside the serialized +expression; the receiver does not need to pre-register them. + +Basic worker-pool example +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Define a worker function that takes the expression plus a batch and +returns the evaluated result: + +.. code-block:: python + + import pyarrow as pa + from datafusion import SessionContext + + + def evaluate(expr, batch): + # `expr` arrived here via the pool's automatic pickling — + # no manual serialization needed in user code. + ctx = SessionContext() + df = ctx.from_pydict({"a": batch}) + return df.with_column("result", expr).select("result").to_pydict()["result"] + +Then build the expression in the driver and fan it out: + +.. code-block:: python + + import multiprocessing as mp + from datafusion import col, udf + + double = udf( + lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + [pa.int64()], pa.int64(), volatility="immutable", name="double", + ) + expr = double(col("a")) + + mp_ctx = mp.get_context("forkserver") + with mp_ctx.Pool(processes=4) as pool: + results = pool.starmap( + evaluate, + [(expr, [1, 2, 3]), (expr, [10, 20, 30])], + ) + print(results) # [[2, 4, 6], [20, 40, 60]] + +When saved to a ``.py`` file and executed with the ``spawn`` or +``forkserver`` start method, wrap the driver block in +``if __name__ == "__main__":`` so worker processes can re-import the +module without re-running it. + + +What travels with the expression +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* **Built-in functions** (``abs``, ``length``, arithmetic, comparisons, + etc.) — fully portable. Worker needs nothing pre-registered. +* **Python UDFs** — travel inline (subject to the two portability + requirements below). The callable, its signature, and any state + captured in closures travel inside the serialized expression and are + reconstructed on the worker automatically. Applies equally to: + + * **scalar UDFs** (:py:func:`datafusion.udf`) + * **aggregate UDFs** (:py:func:`datafusion.udaf`) + * **window UDFs** (:py:func:`datafusion.udwf`) +* **UDFs imported via the FFI capsule protocol** — travel **by name + only**. The worker must already have a matching registration on its + :py:class:`SessionContext`. Without that registration, evaluation + raises an error. + +Portability requirements for inline Python UDFs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Inline Python UDFs ride on `cloudpickle +`_, which imposes two +requirements on the worker environment: + +* **Matching Python minor version.** A cloudpickle payload serializes + Python bytecode, which is not stable across Python minor versions. A + UDF pickled on Python 3.12 cannot be reconstructed on a 3.11 or 3.13 + worker. The wire format stamps the sender's ``(major, minor)``; a + mismatch raises a clear error naming both versions rather than + failing obscurely deep inside ``cloudpickle.loads``. Align the Python + version on driver and workers. +* **Imported modules must be importable on the worker.** cloudpickle + captures the UDF callable *by value* — bytecode and closure cells are + inlined, so locally-defined functions and lambdas travel whole. But + any name the callable resolves through ``import`` is captured *by + reference* (module path only). If a UDF body does + ``from mylib import transform`` and calls ``transform(...)``, the + worker reconstructs the reference by importing ``mylib`` — which must + therefore be installed on the worker. The same applies to bound + methods of imported classes. Self-contained UDFs (no imports beyond + what the worker already has, e.g. ``pyarrow``) avoid this entirely. + +Session contexts at a glance +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There is only one type — :py:class:`SessionContext`. It can occupy +up to four *slots* in a running program: + +.. list-table:: + :header-rows: 1 + :widths: 12 18 40 30 + + * - Slot + - Lifetime + - Purpose + - Set how + * - User-held + - Local variable / attribute + - Build and run queries + - ``ctx = SessionContext(...)`` + * - Global + - Process singleton (lazy-init) + - Backs module-level + :py:func:`~datafusion.io.read_parquet`, + :py:func:`~datafusion.io.read_csv`, + :py:func:`~datafusion.io.read_json`, + :py:func:`~datafusion.io.read_avro`; final fallback for + :py:meth:`Expr.from_bytes` + - Implicit; access via + :py:meth:`SessionContext.global_ctx` + * - Sender + - Thread-local on the driver + - Codec settings for outbound :py:func:`pickle.dumps` / + :py:meth:`Expr.to_bytes` without ``ctx`` + - :py:func:`~datafusion.ipc.set_sender_ctx` + * - Worker + - Thread-local on the worker + - Function registry for inbound :py:func:`pickle.loads` / + :py:meth:`Expr.from_bytes` without ``ctx`` + - :py:func:`~datafusion.ipc.set_worker_ctx` + +The same :py:class:`SessionContext` object may occupy more than one +slot simultaneously — installing it into a slot is a reference, not +a copy. + +**Non-distributed user.** One user-held context. The global slot is +invisible unless you call top-level ``read_*`` helpers. Sender and +worker slots are unused. + +**Distributed user.** Two questions to answer: + +1. *Driver side — what wire format do I want?* The default (Python UDF + inlining on) is self-contained; you do not need a sender context. + To opt into the strict format, + :py:func:`~datafusion.ipc.set_sender_ctx` + with a session built via + :py:meth:`SessionContext.with_python_udf_inlining(enabled=False) + `. + +2. *Worker side — what registrations does decode need?* For built-ins + and inline Python UDFs, nothing. For FFI-capsule UDFs (or + strict-mode round-trips that travel by name), call + :py:func:`~datafusion.ipc.set_worker_ctx` once per worker with a + context that has the relevant registrations. + +Resolution order on the worker side is *explicit argument → +worker context → global context.* Explicit ``ctx=`` on +:py:meth:`Expr.from_bytes` always wins; the sender slot is ignored +on decode and the worker slot is ignored on encode. + +Sharp edges: + +* Sender and worker slots are **thread-local**. Background threads + on either side see ``None`` until they install their own. +* Under the ``fork`` start method, the parent's ``threading.local()`` + values are copied into the child by copy-on-write — a forked + worker initially observes whatever sender / worker slot the parent + had set, until the worker writes its own value (or calls the + matching ``clear_*_ctx``). ``spawn`` and ``forkserver`` workers + start with empty thread-local slots. Treat the slot as + uninitialized on worker entry and install (or clear) it explicitly + in the worker initializer; do not rely on inherited state. +* The global slot persists across ``fork`` workers (copy-on-write + memory inherit) but not across ``spawn`` / ``forkserver`` workers + (fresh process — register or install a worker context on + start-up). +* The inlining toggle is per-context state, not a global switch. + Two contexts with different toggles can coexist in one process. + +Registering shared UDFs on workers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When an expression references an FFI capsule UDF (or any UDF the +worker must resolve from its registered functions), set up the +worker's :py:class:`SessionContext` once per process and install it +as the *worker context*: + +.. code-block:: python + + from datafusion import SessionContext + from datafusion.ipc import set_worker_ctx + + + def init_worker(): + ctx = SessionContext() + ctx.register_udaf(my_ffi_aggregate) + set_worker_ctx(ctx) + + + with mp.get_context("forkserver").Pool( + processes=4, initializer=init_worker + ) as pool: + ... + +Inside a worker, expressions arriving from the driver resolve their +by-name references against the installed worker context. If no worker +context is installed, the global :py:class:`SessionContext` is used — +fine for expressions that only reference built-ins and Python UDFs, +but FFI-capsule-backed registrations must be installed on the global +context to resolve. + +Python 3.14 default change +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Python 3.14 changed the Linux default start method for +:py:mod:`multiprocessing` from ``fork`` to ``forkserver`` (macOS has +defaulted to ``spawn`` since Python 3.8; Windows has always used +``spawn``). With ``fork``, any state set in the parent was visible in +workers via copy-on-write; with ``forkserver`` and ``spawn`` it is +not. The :py:func:`~datafusion.ipc.set_worker_ctx` pattern works on +every start method — prefer it over relying on inherited state. + +Practical considerations +~~~~~~~~~~~~~~~~~~~~~~~~ + +* **Serialized size scales with what travels inline.** A serialized + expression of just built-ins is small (tens of bytes). An + expression carrying a Python UDF is hundreds of bytes (the callable + and its signature). When the same UDF is shipped many times, + registering an equivalent FFI-capsule UDF on each worker via + :py:func:`~datafusion.ipc.set_worker_ctx` and referring to it by + name cuts the per-trip overhead. +* **Closure capture.** When a Python UDF closes over surrounding + state — local variables, module-level objects, file paths — that + state is captured at serialization time. Surprises are possible if + the captured state is large, mutable, or not portable to the + worker's environment. See `Portability requirements for inline + Python UDFs`_ for the Python-version and imported-module rules. + +Disabling Python UDF inlining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For a stricter wire format, call +:py:meth:`SessionContext.with_python_udf_inlining(enabled=False) +` on the session +producing or consuming the bytes. With inlining disabled, Python +UDFs travel by name only — the same way FFI-capsule UDFs do — and +the receiver must have a matching registration. + +Two use cases: + +* **Cross-language portability.** A non-Python decoder cannot + reconstruct a cloudpickled payload. Senders aimed at Java, C++, + or another Rust binary disable inlining and rely on the receiver + having compatible UDF registrations. +* **Untrusted-source decode.** With inlining disabled, + :py:meth:`Expr.from_bytes` never calls ``cloudpickle.loads`` on + the incoming bytes — an inline payload from a misbehaving sender + raises a clear error instead of executing arbitrary Python code. + +Mismatched configurations raise a descriptive error: an inline blob +fed to a strict receiver fails fast rather than silently dropping +into ``cloudpickle.loads``. + +To make the toggle apply through :py:func:`pickle.dumps` (which +calls :py:meth:`Expr.to_bytes` with no context), install the strict +session as the driver's *sender context*: + +.. code-block:: python + + from datafusion import SessionContext + from datafusion.ipc import set_sender_ctx + + set_sender_ctx(SessionContext().with_python_udf_inlining(enabled=False)) + # Every subsequent pickle.dumps(expr) on this thread encodes + # without inlining the Python callable. + +Pair with a matching strict worker context +(:py:func:`~datafusion.ipc.set_worker_ctx`) so the ``pickle.loads`` +side also refuses inline payloads. Explicit +:py:meth:`Expr.to_bytes(ctx) ` and +:py:meth:`Expr.from_bytes(blob, ctx=ctx) ` calls +honor the supplied ``ctx`` directly and ignore the sender / worker +contexts. + +The toggle only narrows the :py:meth:`Expr.from_bytes` surface; +:py:func:`pickle.loads` on untrusted bytes remains unsafe regardless +of this setting. See the `Security`_ section below for the full +threat model. + +Security +~~~~~~~~ + +.. warning:: + + Reconstructing an expression containing a Python UDF executes + arbitrary Python code on the receiver — pickle is doing the work + under the hood and pickle is unsafe on untrusted input (see the + `pickle module security warning + `_ + in the Python standard library docs). Only accept expressions + from trusted sources. For untrusted-source workflows, disable + Python UDF inlining (see above), restrict senders to built-in + functions and pre-registered Rust-side UDFs, and avoid + :py:func:`pickle.loads` on externally supplied bytes entirely. + +Query-level distribution via datafusion-distributed +--------------------------------------------------- + +🚧 *Work in progress upstream — not yet usable from datafusion-python.* + +`datafusion-distributed `_ +splits a single physical plan into stages and runs each stage on a +different worker node. The driver writes a SQL or DataFrame query +once; the runtime handles partitioning, shuffles, and reassembly. + +A datafusion-python integration is in development. This section will +document the integration once it lands. In the meantime, the +expression-level approach above covers most use cases that do not +require automatic plan partitioning. + +Query-level distribution via Apache Ballista +-------------------------------------------- + +🚧 *Work in progress upstream — not yet usable from datafusion-python.* + +`Apache Ballista `_ +provides distributed query execution on top of DataFusion with a +scheduler / executor model better suited to long-lived cluster +deployments. A datafusion-python integration is on the roadmap; this +section will fill in once the integration is usable. + +See also +-------- + +* :py:mod:`datafusion.ipc` — worker context API. +* ``examples/multiprocessing_pickle_expr.py`` — runnable + ``multiprocessing.Pool`` example that ships a different parametric + expression to each worker and collects results back. +* ``examples/ray_pickle_expr.py`` — runnable Ray actor example. diff --git a/docs/source/user-guide/io/index.rst b/docs/source/user-guide/io/index.rst index b885cfeda..73f9babf8 100644 --- a/docs/source/user-guide/io/index.rst +++ b/docs/source/user-guide/io/index.rst @@ -24,6 +24,7 @@ IO arrow avro csv + distributing_work json parquet table_provider diff --git a/examples/README.md b/examples/README.md index 3024c782f..e0e3056d9 100644 --- a/examples/README.md +++ b/examples/README.md @@ -44,6 +44,11 @@ Here is a direct link to the file used in the examples: - [Register a Python UDF with DataFusion](./python-udf.py) - [Register a Python UDAF with DataFusion](./python-udaf.py) +### Distributing DataFusion expressions + +- [Fan out distinct expressions to a multiprocessing pool](./multiprocessing_pickle_expr.py) +- [Distribute expression evaluation across Ray actors](./ray_pickle_expr.py) + ### Substrait Support - [Serialize query plans using Substrait](./substrait.py) diff --git a/examples/datafusion-ffi-example/python/tests/_test_pickle_strict_ffi.py b/examples/datafusion-ffi-example/python/tests/_test_pickle_strict_ffi.py new file mode 100644 index 000000000..67c0b245a --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_pickle_strict_ffi.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Strict-mode Expr round-trip with an FFI-capsule scalar UDF. + +Verifies the by-name path: an FFI-imported UDF (no +``PythonFunctionScalarUDF`` downcast on the codec) serializes by name +and resolves from the receiver's function registry on decode. Covers +both the explicit ``Expr.to_bytes(ctx)`` / ``Expr.from_bytes(ctx=...)`` +API and the ``pickle.dumps`` / ``pickle.loads`` route through the +sender / worker context slots. +""" + +from __future__ import annotations + +import pickle + +import pyarrow as pa +import pytest +from datafusion import Expr, SessionContext, col, udf +from datafusion.ipc import ( + clear_sender_ctx, + clear_worker_ctx, + set_sender_ctx, + set_worker_ctx, +) +from datafusion_ffi_example import IsNullUDF + + +@pytest.fixture(autouse=True) +def _reset_thread_locals(): + """Ensure no sender / worker context leaks across tests.""" + clear_worker_ctx() + clear_sender_ctx() + yield + clear_worker_ctx() + clear_sender_ctx() + + +def _strict_session_with_ffi_udf(): + """Build a strict-mode session with the FFI ``IsNullUDF`` registered.""" + ctx = SessionContext().with_python_udf_inlining(enabled=False) + my_udf = udf(IsNullUDF()) + ctx.register_udf(my_udf) + return ctx, my_udf + + +def test_strict_ffi_udf_expr_roundtrip_via_to_bytes(): + """Strict-mode encode emits a by-name payload; receiver resolves + ``my_custom_is_null`` from its registered functions and the decoded + expression evaluates to the same result as the original.""" + sender, my_udf = _strict_session_with_ffi_udf() + receiver, _ = _strict_session_with_ffi_udf() + + expr = my_udf(col("a")) + blob = expr.to_bytes(sender) + restored = Expr.from_bytes(blob, ctx=receiver) + + assert "my_custom_is_null" in restored.canonical_name() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, None, 4], type=pa.int64())], names=["a"] + ) + receiver.register_record_batches("t", [[batch]]) + out = receiver.table("t").select(restored.alias("r")).collect() + expected = pa.array([False, False, True, False], type=pa.bool_()) + assert out[0].column(0) == expected + + +def test_strict_ffi_udf_pickle_roundtrip_via_thread_locals(): + """Driver installs a strict sender context; worker installs a + matching strict receiver. ``pickle.dumps`` / ``pickle.loads`` route + through them and the FFI UDF resolves by name on decode.""" + sender, my_udf = _strict_session_with_ffi_udf() + receiver, _ = _strict_session_with_ffi_udf() + + expr = my_udf(col("a")) + + set_sender_ctx(sender) + try: + blob = pickle.dumps(expr) + finally: + clear_sender_ctx() + + set_worker_ctx(receiver) + try: + restored = pickle.loads(blob) # noqa: S301 + finally: + clear_worker_ctx() + + assert "my_custom_is_null" in restored.canonical_name() + + +def test_strict_ffi_udf_smaller_than_inline_python_udf(): + """Sanity-check the wire size claim: strict-mode FFI UDF bytes are + a small by-name payload, dramatically smaller than the inline form + of a Python UDF with the same arity. Confirms the encode path + actually took the by-name branch instead of falling through to an + inline path.""" + sender, my_udf = _strict_session_with_ffi_udf() + ffi_blob = my_udf(col("a")).to_bytes(sender) + + inline_ctx = SessionContext() + py_udf = udf( + lambda arr: pa.array([v.as_py() is None for v in arr]), + [pa.int64()], + pa.bool_(), + volatility="immutable", + name="py_is_null", + ) + py_blob = py_udf(col("a")).to_bytes(inline_ctx) + + assert len(ffi_blob) < len(py_blob) // 4 diff --git a/examples/multiprocessing_pickle_expr.py b/examples/multiprocessing_pickle_expr.py new file mode 100644 index 000000000..4f78203b7 --- /dev/null +++ b/examples/multiprocessing_pickle_expr.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Distribute different DataFusion expressions to worker processes. + +Builds a list of parametric expressions in the driver — each closing +over a different threshold value — ships one per worker via +``multiprocessing.Pool``, and collects the results back. The closure +state forces the cloudpickle path (a by-name registration would lose +the captured threshold), so this is a real test of the expression- +pickling story rather than a same-expression fan-out. + +Worker layout: + +* Each worker receives a different ``(label, expr)`` task. +* Each worker materializes the shared dataset locally and runs its + own expression against it. +* The result and the worker's PID travel back to the driver, so the + output makes it visible that the work was spread across processes. + +Run: + python examples/multiprocessing_pickle_expr.py +""" + +from __future__ import annotations + +import multiprocessing as mp +import os + +import pyarrow as pa +from datafusion import Expr, SessionContext, col, udaf, udf +from datafusion import functions as F +from datafusion.user_defined import Accumulator, AggregateUDF, ScalarUDF + +# A shared input dataset. In a production pipeline this would live on +# object storage; here we hand-roll a small batch so the example runs +# without any I/O setup. +DATASET = { + "value": [3, 17, 42, 5, 88, 21, 9, 56, 4, 73, 12, 31], +} + + +def make_above_threshold_udf(threshold: int) -> ScalarUDF: + """Build a scalar UDF that returns 1 where ``value > threshold`` else 0. + + The threshold is captured in the closure, so cloudpickle has to + walk into the function body to ship the value across processes — + a by-name registration on the worker would collapse every + threshold into the same callable and lose the per-task state. + """ + + def above(arr: pa.Array) -> pa.Array: + # `v.as_py() or 0` coerces nulls to 0 — the demo dataset has no + # nulls, but real-world code should decide explicitly how nulls + # compare against the threshold. + return pa.array([1 if (v.as_py() or 0) > threshold else 0 for v in arr]) + + return udf( + above, + [pa.int64()], + pa.int64(), + volatility="immutable", + name=f"above_{threshold}", + ) + + +class _SumAccumulator(Accumulator): + """Tiny aggregate UDF state used to demonstrate UDAFs travel too.""" + + def __init__(self) -> None: + self._total = 0 + + def state(self) -> list[pa.Scalar]: + return [pa.scalar(self._total, type=pa.int64())] + + def update(self, values: pa.Array) -> None: + for v in values: + self._total += v.as_py() or 0 + + def merge(self, states: list[pa.Array]) -> None: + for s in states: + self._total += s[0].as_py() + + def evaluate(self) -> pa.Scalar: + return pa.scalar(self._total, type=pa.int64()) + + +def _build_sum_udaf() -> AggregateUDF: + return udaf( + _SumAccumulator, + [pa.int64()], + pa.int64(), + [pa.int64()], + "immutable", + name="my_sum", + ) + + +def evaluate_in_worker(task: tuple[str, Expr]) -> tuple[str, int, int]: + """Run one expression against the shared dataset. + + ``task`` arrived here via the pool's automatic pickling. The Python + callable inside the expression (including its captured threshold) + was reconstructed by the codec — the worker did not have to + register anything before this call. + """ + label, expr = task + ctx = SessionContext() + df = ctx.from_pydict(DATASET) + # ``expr`` is an aggregate over the whole batch; ``aggregate`` keeps + # a single row of output, which we read as a Python int. + result_df = df.aggregate([], [expr.alias("result")]) + result = result_df.to_pydict()["result"][0] + return label, result, os.getpid() + + +def build_tasks() -> list[tuple[str, Expr]]: + """Return ``(label, expr)`` pairs — one task per worker invocation. + + Mixes scalar-UDF-in-aggregate and pure-aggregate work to show both + UDF kinds round-tripping through pickle. + """ + sum_udaf = _build_sum_udaf() + tasks: list[tuple[str, Expr]] = [] + + # Three "count values strictly above threshold T" tasks built from + # closure-capturing scalar UDFs. + for threshold in (10, 30, 60): + above_udf = make_above_threshold_udf(threshold) + tasks.append((f"count_above_{threshold}", F.sum(above_udf(col("value"))))) + + # One pure aggregate UDF task. + tasks.append(("custom_sum", sum_udaf(col("value")))) + + return tasks + + +def main() -> None: + tasks = build_tasks() + + # ``forkserver`` works on every POSIX platform and is the Python 3.14 + # default for POSIX. ``spawn`` would also work; ``fork`` is unsafe + # with pyarrow/tokio on macOS. + mp_ctx = mp.get_context("forkserver") + with mp_ctx.Pool(processes=min(4, len(tasks))) as pool: + results = pool.map(evaluate_in_worker, tasks) + + print(f"driver pid: {os.getpid()}") + for label, value, worker_pid in results: + print(f" [{label:>16}] = {value:>6} (worker pid: {worker_pid})") + + +if __name__ == "__main__": + main() diff --git a/examples/ray_pickle_expr.py b/examples/ray_pickle_expr.py new file mode 100644 index 000000000..8ef6140b2 --- /dev/null +++ b/examples/ray_pickle_expr.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Distribute DataFusion expressions to Ray actors. + +Build an expression in the driver, ship it to a pool of Ray actors, and +have each actor evaluate it against its own slice of data. Python UDFs +travel with the shipped expression — no actor-side registration needed. + +Prerequisites: + pip install ray + +Run: + python examples/ray_pickle_expr.py +""" + +import pyarrow as pa +import ray +from datafusion import Expr, SessionContext, col, lit, udf + + +def _build_double_udf(): + """Return the demo UDF used by the driver.""" + return udf( + lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + [pa.int64()], + pa.int64(), + volatility="immutable", + name="double", + ) + + +@ray.remote +class DataFusionWorker: + """A Ray actor with a private :class:`SessionContext`.""" + + def __init__(self) -> None: + self._ctx = SessionContext() + + def evaluate(self, expr: Expr, batch_pylist: list[int]) -> list[int]: + """Run the expression against an in-memory batch.""" + # `expr` arrived here via Ray's automatic argument serialization; + # the Python UDF inside it was reconstructed from the bytes — no + # pre-registration on this actor required. + df = self._ctx.from_pydict({"a": batch_pylist}) + out = df.with_column("result", expr).select("result") + return out.to_pydict()["result"] + + +def main() -> None: + ray.init(ignore_reinit_error=True) + + expr = _build_double_udf()(col("a")) + lit(1) + + workers = [DataFusionWorker.remote() for _ in range(2)] + batches = [[1, 2, 3], [10, 20, 30], [100, 200, 300]] + futures = [ + workers[i % len(workers)].evaluate.remote(expr, batch) + for i, batch in enumerate(batches) + ] + for batch, result in zip(batches, ray.get(futures), strict=True): + print(f"input {batch} -> {result}") + + ray.shutdown() + + +if __name__ == "__main__": + main()