From 061f3ab74ad14edd9554280cb04109faeab2c5a2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 15 May 2026 14:17:59 -0400 Subject: [PATCH 1/4] feat: pickle support for Expr via inline scalar UDF encoding Adds Python-aware encoding to PythonLogicalCodec/PythonPhysicalCodec so a ScalarUDF defined in Python travels inside the serialized expression (cloudpickled into fun_definition) instead of needing a matching registration on the receiver. With that in place, Expr gains __reduce__ + classmethod from_bytes(buf, ctx=None) so pickle.dumps / pickle.loads work end-to-end on expressions built from col, lit, built-in functions, and Python scalar UDFs. Wire format is framed as ; the version byte lets a too-new/too-old payload surface a clean Execution error instead of an opaque cloudpickle unpack failure. Schema serde is via arrow-rs's native IPC (no pyarrow round-trip). Cloudpickle module handle is cached per-interpreter through PyOnceLock. Worker-side context resolution lives in a new datafusion.ipc module: set_worker_ctx / get_worker_ctx / clear_worker_ctx plus a private _resolve_ctx helper consulted by Expr.from_bytes. Priority is explicit ctx > worker ctx > global SessionContext. FFI UDFs still travel by name and require the matching registration on the receiver's context. Aggregate and window UDF inline encoding, the per-session with_python_udf_inlining toggle, sender-side context, and the user-guide docs land in follow-on PRs. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/codec.rs | 428 +++++++++++++++++++++++++++--- crates/core/src/udf.rs | 82 +++++- pyproject.toml | 7 + python/datafusion/__init__.py | 3 +- python/datafusion/expr.py | 58 +++- python/datafusion/ipc.py | 113 ++++++++ python/datafusion/user_defined.py | 10 + python/tests/test_expr.py | 4 +- python/tests/test_pickle_expr.py | 127 +++++++++ uv.lock | 13 +- 10 files changed, 788 insertions(+), 57 deletions(-) create mode 100644 python/datafusion/ipc.py create mode 100644 python/tests/test_pickle_expr.py diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs index 088532df2..c95d8cb19 100644 --- a/crates/core/src/codec.rs +++ b/crates/core/src/codec.rs @@ -19,11 +19,11 @@ //! //! Datafusion-python plans can carry references to Python-defined //! objects that the upstream protobuf codecs do not know how to -//! serialize: pure-Python scalar / aggregate / window UDFs, Python -//! query-planning extensions, and so on. Their state lives inside -//! `Py` callables and closures rather than being recoverable -//! from a name in the receiver's function registry. To ship a plan -//! across a process boundary (pickle, `multiprocessing`, Ray actor, +//! serialize: pure-Python scalar UDFs, Python query-planning +//! extensions, and so on. Their state lives inside `Py` +//! callables and closures rather than being recoverable from a name +//! in the receiver's function registry. To ship a plan across a +//! process boundary (pickle, `multiprocessing`, Ray actor, //! `datafusion-distributed`, etc.) those payloads have to be encoded //! into the proto wire format itself. //! @@ -48,52 +48,121 @@ //! plans to survive a serialization round-trip. Both codecs share //! the same payload framing for that reason. //! -//! Payloads emitted by these codecs are tagged with an 8-byte magic -//! prefix so the decoder can distinguish them from arbitrary bytes -//! (empty `fun_definition` from the default codec, user FFI payloads -//! that picked a non-colliding prefix). Dispatch precedence on -//! decode: **Python-inline payload (magic prefix match) → `inner` -//! codec → caller's `FunctionRegistry` fallback.** +//! Payloads emitted by these codecs are framed as +//! ` `. The +//! family magic identifies the UDF flavor; the version byte lets the +//! decoder reject too-new or too-old payloads with a clean error +//! instead of falling into an opaque `cloudpickle` tuple-unpack +//! failure when the tuple shape changes. Dispatch precedence on +//! decode: **family match + supported version → `inner` codec → +//! caller's `FunctionRegistry` fallback.** //! -//! ## Wire-format magic prefix registry +//! ## Wire-format family registry //! -//! | Layer + kind | Magic prefix | -//! | ----------------------------- | ------------ | -//! | `PythonLogicalCodec` scalar | `DFPYUDF1` | -//! | `PythonLogicalCodec` agg | `DFPYUDA1` | -//! | `PythonLogicalCodec` window | `DFPYUDW1` | -//! | `PythonPhysicalCodec` scalar | `DFPYUDF1` | -//! | `PythonPhysicalCodec` agg | `DFPYUDA1` | -//! | `PythonPhysicalCodec` window | `DFPYUDW1` | -//! | `PythonPhysicalCodec` expr | `DFPYPE1` | -//! | User FFI extension codec | user-chosen | -//! | Default codec | (none) | +//! | Layer + kind | Family prefix | +//! | ----------------------------- | ------------- | +//! | `PythonLogicalCodec` scalar | `DFPYUDF` | +//! | `PythonPhysicalCodec` scalar | `DFPYUDF` | +//! | User FFI extension codec | user-chosen | +//! | Default codec | (none) | //! -//! Downstream FFI codecs should pick non-colliding prefixes (use a -//! `DF` namespace plus a crate-specific suffix). The codec +//! Aggregate and window UDF families are reserved for follow-on work. +//! +//! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported +//! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. +//! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape +//! changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when dropping support +//! for an older shape. +//! +//! Downstream FFI codecs should pick non-colliding family prefixes +//! (use a `DF` namespace plus a crate-specific suffix). The codec //! implementations in this module currently delegate every method to //! `inner`; the encoder/decoder hooks for each kind are added as the //! corresponding Python-side type becomes serializable. use std::sync::Arc; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::ipc::reader::StreamReader; +use arrow::ipc::writer::StreamWriter; use datafusion::common::{Result, TableReference}; use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::FileFormatFactory; use datafusion::execution::TaskContext; -use datafusion::logical_expr::{AggregateUDF, Extension, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion::logical_expr::{ + AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, + Volatility, WindowUDF, +}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; +use pyo3::prelude::*; +use pyo3::sync::PyOnceLock; +use pyo3::types::{PyBytes, PyTuple}; -/// Wire-format prefix that tags a `fun_definition` payload as an -/// inlined Python scalar UDF (cloudpickled tuple of name, callable, -/// input schema, return field, volatility). Defined once here so -/// the encoder and decoder cannot drift. -#[allow(dead_code)] -pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1"; +use crate::udf::PythonFunctionScalarUDF; + +// Wire-format framing for inlined Python UDF payloads. +// +// Layout: ` `. +// The family magic identifies the UDF flavor; the version byte lets +// the decoder reject too-new or too-old payloads with a clean error +// instead of falling into an opaque `cloudpickle` tuple-unpack failure +// when the tuple shape changes. Bump [`WIRE_VERSION_CURRENT`] whenever +// the tuple shape changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when +// dropping support for an older shape. + +/// Family prefix for an inlined Python scalar UDF +/// (cloudpickled tuple of name, callable, input schema, return field, +/// volatility). +pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF"; + +/// Wire-format version this build emits. +pub(crate) const WIRE_VERSION_CURRENT: u8 = 1; + +/// Oldest wire-format version this build still decodes. Bump when +/// retiring support for an older payload shape. +pub(crate) const WIRE_VERSION_MIN_SUPPORTED: u8 = 1; + +/// Tag `buf` with the framing header for `family` at the current +/// wire-format version. Append-only — the caller writes the +/// cloudpickle payload after. +fn write_wire_header(buf: &mut Vec, family: &[u8]) { + buf.extend_from_slice(family); + buf.push(WIRE_VERSION_CURRENT); +} + +/// Inspect the framing on `buf`. +/// +/// * `Ok(None)` — `buf` does not carry `family`. The caller should +/// delegate to its `inner` codec. +/// * `Ok(Some(payload))` — `buf` carries `family` at a version this +/// build accepts; `payload` is the cloudpickle blob. +/// * `Err(_)` — `buf` carries `family` but at a version outside +/// `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. The error +/// names the version and the supported range so an operator can +/// diagnose sender/receiver version drift instead of seeing an +/// opaque cloudpickle tuple-unpack failure. +fn strip_wire_header<'a>(buf: &'a [u8], family: &[u8], kind: &str) -> Result> { + if !buf.starts_with(family) { + return Ok(None); + } + let version_idx = family.len(); + let Some(&version) = buf.get(version_idx) else { + return Err(datafusion::error::DataFusionError::Execution(format!( + "Truncated inline Python {kind} payload: missing wire-format version byte" + ))); + }; + if !(WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT).contains(&version) { + return Err(datafusion::error::DataFusionError::Execution(format!( + "Inline Python {kind} payload wire-format version v{version}; \ + this build supports v{WIRE_VERSION_MIN_SUPPORTED}..=v{WIRE_VERSION_CURRENT}. \ + Align datafusion-python versions on sender and receiver." + ))); + } + Ok(Some(&buf[version_idx + 1..])) +} /// `LogicalExtensionCodec` parked on every `SessionContext`. Holds /// the Python-aware encoding hooks for logical-layer types @@ -177,10 +246,16 @@ impl LogicalExtensionCodec for PythonLogicalCodec { } fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_scalar_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udf(node, buf) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udf) = try_decode_python_scalar_udf(buf)? { + return Ok(udf); + } self.inner.try_decode_udf(name, buf) } @@ -212,7 +287,7 @@ impl LogicalExtensionCodec for PythonLogicalCodec { /// encoding on this layer too — otherwise a plan with a Python UDF /// would round-trip at the logical level but break at the physical /// level. Both layers reuse the shared payload framing -/// ([`PY_SCALAR_UDF_MAGIC`] et al.) so the wire format is identical. +/// ([`PY_SCALAR_UDF_FAMILY`]) so the wire format is identical. #[derive(Debug)] pub struct PythonPhysicalCodec { inner: Arc, @@ -249,10 +324,16 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { } fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_scalar_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udf(node, buf) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udf) = try_decode_python_scalar_udf(buf)? { + return Ok(udf); + } self.inner.try_decode_udf(name, buf) } @@ -284,3 +365,282 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { self.inner.try_decode_udwf(name, buf) } } + +// ============================================================================= +// Shared Python scalar UDF encode / decode helpers +// +// Both `PythonLogicalCodec` and `PythonPhysicalCodec` consult these on +// every `try_encode_udf` / `try_decode_udf` call. Same wire format on +// both layers — a Python `ScalarUDF` referenced inside a `LogicalPlan` +// or an `ExecutionPlan` round-trips identically. +// ============================================================================= + +/// Encode a Python scalar UDF inline if `node` is one. Returns +/// `Ok(true)` when the payload (`DFPYUDF` family prefix, version byte, +/// cloudpickled tuple) was written and the caller should skip its +/// inner codec. Returns `Ok(false)` for any non-Python UDF, signalling +/// the caller to delegate to its `inner`. +pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node + .inner() + .as_any() + .downcast_ref::() + else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_scalar_udf(py, py_udf) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + write_wire_header(buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(&bytes); + Ok(true) + }) +} + +/// Decode an inline Python scalar UDF payload. Returns `Ok(None)` +/// when `buf` does not carry the `DFPYUDF` family prefix, signalling +/// the caller to delegate to its `inner` codec (and eventually the +/// `FunctionRegistry`). +pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result>> { + let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF")? else { + return Ok(None); + }; + + Python::attach(|py| -> Result>> { + let udf = decode_python_scalar_udf(py, payload) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf)))) + }) +} + +/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`. +/// +/// Layout: `cloudpickle.dumps((name, func, input_schema_bytes, +/// return_schema_bytes, volatility_str))`. Schema blobs are produced +/// by arrow-rs's native IPC stream writer (no pyarrow round-trip) and +/// decoded with the matching stream reader on the receiver. See +/// [`build_input_schema_bytes`] for what the input blob carries. +fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> PyResult> { + let signature = udf.signature(); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionScalarUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_schema_bytes = build_single_field_schema_bytes(udf.return_field().as_ref())?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + udf.name().into_pyobject(py)?.into_any(), + udf.func().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +/// Inverse of [`encode_python_scalar_udf`]. +fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let func: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let volatility_str: String = tuple.get_item(4)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_field = read_single_return_field(&return_schema_bytes, "PythonFunctionScalarUDF")?; + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionScalarUDF::from_parts( + name, + func, + input_types, + return_field, + volatility, + )) +} + +/// Serialize a `Schema` to a self-contained IPC stream containing +/// only the schema message (no record batches). Inverse: +/// [`schema_from_ipc_bytes`]. +fn schema_to_ipc_bytes(schema: &Schema) -> arrow::error::Result> { + let mut buf: Vec = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buf, schema)?; + writer.finish()?; + } + Ok(buf) +} + +/// Decode an IPC stream containing only a schema message back into a +/// `Schema`. Inverse: [`schema_to_ipc_bytes`]. +fn schema_from_ipc_bytes(bytes: &[u8]) -> arrow::error::Result { + let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?; + Ok(reader.schema().as_ref().clone()) +} + +/// Extract the per-arg `DataType`s from a `Signature` known to be +/// `TypeSignature::Exact` (all Python-defined UDFs are constructed +/// with `Signature::exact`). Any other variant indicates the impl was +/// not built by this crate's UDF/UDAF/UDWF constructors. +fn signature_input_dtypes(signature: &Signature, kind: &str) -> PyResult> { + match &signature.type_signature { + TypeSignature::Exact(types) => Ok(types.clone()), + other => Err(pyo3::exceptions::PyValueError::new_err(format!( + "{kind} expected Signature::Exact, got {other:?}" + ))), + } +} + +/// Wrap per-arg `DataType`s in synthetic `arg_{i}` fields and emit +/// the IPC schema blob the encoder writes into the cloudpickle tuple. +/// +/// The names and `nullable: true` are arbitrary: the underlying +/// `TypeSignature::Exact` carries no per-input nullability or +/// metadata, and the receiver collapses these fields back to +/// `Vec` via [`read_input_dtypes`], so anything set here +/// beyond the data type is discarded on decode. +fn build_input_schema_bytes(dtypes: &[DataType]) -> PyResult> { + let fields: Vec = dtypes + .iter() + .enumerate() + .map(|(i, dt)| Field::new(format!("arg_{i}"), dt.clone(), true)) + .collect(); + schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err) +} + +/// Emit a single-field IPC schema blob. Used for return-type and +/// state-field payloads where the receiver needs to recover field +/// metadata (names, nullability, key/value attributes) verbatim. +fn build_single_field_schema_bytes(field: &Field) -> PyResult> { + schema_to_ipc_bytes(&Schema::new(vec![field.clone()])).map_err(arrow_to_py_err) +} + +/// Decode the per-arg `DataType`s the encoder wrote via +/// [`build_input_schema_bytes`]. +fn read_input_dtypes(bytes: &[u8]) -> PyResult> { + let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; + Ok(schema + .fields() + .iter() + .map(|f| f.data_type().clone()) + .collect()) +} + +/// Decode a single-field IPC schema blob and return that field by +/// value. `kind` names the UDF flavor in the error message produced +/// when the blob is empty (should be unreachable for sender-side +/// payloads built via [`build_single_field_schema_bytes`]). +fn read_single_return_field(bytes: &[u8], kind: &str) -> PyResult { + let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; + let field = schema.fields().first().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err(format!( + "{kind} return schema must contain exactly one field" + )) + })?; + Ok(field.as_ref().clone()) +} + +fn arrow_to_py_err(e: arrow::error::ArrowError) -> PyErr { + pyo3::exceptions::PyValueError::new_err(format!("{e}")) +} + +fn parse_volatility_str(s: &str) -> PyResult { + datafusion_python_util::parse_volatility(s) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}"))) +} + +/// Stable wire-format string for a `Volatility`. Pinned to the three +/// tokens [`datafusion_python_util::parse_volatility`] accepts, so an +/// upstream change to `Volatility`'s `Debug` repr cannot silently +/// produce bytes the decoder rejects. +fn volatility_wire_str(v: Volatility) -> &'static str { + match v { + Volatility::Immutable => "immutable", + Volatility::Stable => "stable", + Volatility::Volatile => "volatile", + } +} + +/// Cached handle to the `cloudpickle` module. +/// +/// The encode/decode helpers above would otherwise re-resolve the +/// module on every call. `py.import` is backed by `sys.modules` and +/// therefore cheap, but each call still walks a dict and re-binds the +/// result; a plan with many Python UDFs pays that cost per UDF. +/// +/// `PyOnceLock` scopes the cached `Py` to the current +/// interpreter, so the slot drops cleanly on interpreter teardown +/// (relevant under CPython subinterpreters, PEP 684) instead of +/// resurrecting a `Py` rooted in a dead interpreter on the next call. +fn cloudpickle<'py>(py: Python<'py>) -> PyResult> { + static CLOUDPICKLE: PyOnceLock> = PyOnceLock::new(); + CLOUDPICKLE + .get_or_try_init(py, || Ok(py.import("cloudpickle")?.unbind().into_any())) + .map(|cached| cached.bind(py).clone()) +} + +#[cfg(test)] +mod wire_header_tests { + use super::*; + + #[test] + fn strip_returns_none_when_family_absent() { + let buf = b"OTHER_PAYLOAD"; + assert!(matches!( + strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF"), + Ok(None) + )); + } + + #[test] + fn strip_errors_on_truncated_version_byte() { + let buf = PY_SCALAR_UDF_FAMILY; + let err = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + assert!(format!("{err}").contains("missing wire-format version byte")); + } + + #[test] + fn strip_errors_on_too_new_version() { + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_CURRENT.saturating_add(1)); + buf.extend_from_slice(b"payload"); + let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("wire-format version v")); + assert!(msg.contains("supports")); + assert!(msg.contains("Align datafusion-python versions")); + } + + #[test] + fn strip_errors_on_too_old_version() { + if WIRE_VERSION_MIN_SUPPORTED == 0 { + return; + } + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_MIN_SUPPORTED - 1); + buf.extend_from_slice(b"payload"); + assert!(strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").is_err()); + } + + #[test] + fn write_then_strip_round_trips_payload() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(b"scalar-payload"); + + let payload = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF") + .unwrap() + .unwrap(); + assert_eq!(payload, b"scalar-payload"); + } +} diff --git a/crates/core/src/udf.rs b/crates/core/src/udf.rs index c0a39cb47..72cdddba1 100644 --- a/crates/core/src/udf.rs +++ b/crates/core/src/udf.rs @@ -43,7 +43,7 @@ use crate::expr::PyExpr; /// This struct holds the Python written function that is a /// ScalarUDF. #[derive(Debug)] -struct PythonFunctionScalarUDF { +pub(crate) struct PythonFunctionScalarUDF { name: String, func: Py, signature: Signature, @@ -67,6 +67,37 @@ impl PythonFunctionScalarUDF { return_field: Arc::new(return_field), } } + + /// Stored Python callable. Consumed by the codec to cloudpickle + /// the function body across process boundaries. + pub(crate) fn func(&self) -> &Py { + &self.func + } + + pub(crate) fn return_field(&self) -> &FieldRef { + &self.return_field + } + + /// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted + /// by the codec. Inputs collapse to `Vec` because + /// `Signature::exact` cannot carry per-input nullability or + /// metadata — the encoder is free to discard that side of the + /// schema. `return_field` is kept as a `Field` so the post-decode + /// nullability and metadata match the sender's instance. + pub(crate) fn from_parts( + name: String, + func: Py, + input_types: Vec, + return_field: Field, + volatility: Volatility, + ) -> Self { + Self { + name, + func, + signature: Signature::exact(input_types, volatility), + return_field: Arc::new(return_field), + } + } } impl Eq for PythonFunctionScalarUDF {} @@ -75,21 +106,51 @@ impl PartialEq for PythonFunctionScalarUDF { self.name == other.name && self.signature == other.signature && self.return_field == other.return_field - && Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false)) + // Identical pointers ⇒ same Python object. Most equality + // checks compare `Arc`-shared clones of the same UDF + // (e.g. expression rewriting), so the pointer match short- + // circuits before touching the GIL. + && (self.func.as_ptr() == other.func.as_ptr() + || Python::attach(|py| { + // Rust's `PartialEq` cannot return `Result`, so we + // have to pick a side when Python `__eq__` raises. + // `false` is the conservative choice — better to + // report two UDFs as distinct than to wrongly + // merge them — but the silent miss can still + // surface as expression-dedup or cache-lookup + // anomalies. Log at `debug` so the failure is + // observable without flooding production logs. + // FIXME: revisit if upstream `ScalarUDFImpl` + // exposes a fallible `PartialEq`. + self.func + .bind(py) + .eq(other.func.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udf", + "PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) } } impl Hash for PythonFunctionScalarUDF { fn hash(&self, state: &mut H) { + // Hash only the identifying header (name + signature + return + // field). Skipping `func` is intentional: the Rust `Hash` + // contract requires `a == b ⇒ hash(a) == hash(b)`, not the + // converse, so a coarser hash is sound — `PartialEq` still + // disambiguates two UDFs with the same header but distinct + // callables. Falling back to a sentinel on `py_hash` failure + // (as a prior revision did) silently mapped every unhashable + // closure to the same bucket; that is the worst case for a + // hashmap and is what this rewrite avoids. self.name.hash(state); self.signature.hash(state); self.return_field.hash(state); - - Python::attach(|py| { - let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects - - state.write_isize(py_hash); - }); } } @@ -220,4 +281,9 @@ impl PyScalarUDF { fn __repr__(&self) -> PyResult { Ok(format!("ScalarUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } diff --git a/pyproject.toml b/pyproject.toml index 951f7adc3..a02f4608a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,13 @@ classifiers = [ "Programming Language :: Rust", ] dependencies = [ + # cloudpickle is invoked by the Rust-side PythonLogicalCodec / + # PythonPhysicalCodec via pyo3 to serialize Python UDF callables — + # scalar, aggregate, and window — into the proto wire format. + # Lazy-imported on the encode / decode hot paths (and cached after + # the first import), so users who never serialize a plan or + # expression incur no runtime cost beyond the install footprint. + "cloudpickle>=2.0", "pyarrow>=16.0.0;python_version<'3.14'", "pyarrow>=22.0.0;python_version>='3.14'", "typing-extensions;python_version<'3.13'", diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index f08b464bb..dfdeef07e 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -65,7 +65,7 @@ import importlib_metadata # type: ignore[import] # Public submodules -from . import functions, object_store, substrait, unparser +from . import functions, ipc, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config @@ -142,6 +142,7 @@ "configure_formatter", "expr", "functions", + "ipc", "lit", "literal", "object_store", diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index e0135e3ed..10b011ffb 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -434,23 +434,59 @@ def variant_name(self) -> str: return self.expr.variant_name() def to_bytes(self, ctx: SessionContext | None = None) -> bytes: - """Serialize this expression to protobuf bytes. + """Serialize this expression to bytes for shipping to another process. - When ``ctx`` is supplied, encoding routes through the session's - installed :class:`LogicalExtensionCodec`. Without ``ctx`` a - default codec is used. + Use this — or :func:`pickle.dumps` — to send an expression to a + worker process for distributed evaluation. + + When ``ctx`` is supplied, encoding routes through that session's + installed :class:`LogicalExtensionCodec`. When ``ctx`` is + ``None``, the default codec is used. + + Built-in functions and Python scalar UDFs travel inside the + returned bytes; the worker does not need to pre-register them. + UDFs imported via the FFI capsule protocol travel by name only + and must be registered on the worker. """ ctx_arg = ctx.ctx if ctx is not None else None return self.expr.to_bytes(ctx_arg) - @staticmethod - def from_bytes(ctx: SessionContext, data: bytes) -> Expr: - """Decode an expression from serialized protobuf bytes. - - ``ctx`` provides the function registry for resolving UDF - references and the logical codec for in-band Python payloads. + @classmethod + def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr: + """Reconstruct an expression from serialized bytes. + + Accepts output of :meth:`to_bytes` or :func:`pickle.dumps`. + ``ctx`` is the :class:`SessionContext` used to resolve any + function references that travel by name (e.g. FFI UDFs). When + ``ctx`` is ``None`` the worker context installed via + :func:`datafusion.ipc.set_worker_ctx` is consulted; if no worker + context is installed, the global :class:`SessionContext` is used + (sufficient for built-ins and Python scalar UDFs, plus any UDFs + registered on the global context). + """ + from datafusion.ipc import _resolve_ctx + + resolved = _resolve_ctx(ctx) + return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf)) + + def __reduce__(self) -> tuple: + """Pickle protocol hook. + + Lets expressions be shipped to worker processes via + :func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions + and Python scalar UDFs travel inside the pickle bytes; only + FFI-capsule UDFs require pre-registration on the worker. The + worker's :class:`SessionContext` for resolving those references + is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling + back to the global :class:`SessionContext` if none has been + installed on the worker. """ - return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data)) + return (Expr._reconstruct, (self.to_bytes(),)) + + @classmethod + def _reconstruct(cls, proto_bytes: bytes) -> Expr: + """Internal entry point used by :meth:`__reduce__` on unpickle.""" + return cls.from_bytes(proto_bytes) def __richcmp__(self, other: Expr, op: int) -> Expr: """Comparison operator.""" diff --git a/python/datafusion/ipc.py b/python/datafusion/ipc.py new file mode 100644 index 000000000..d1867a917 --- /dev/null +++ b/python/datafusion/ipc.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Worker-side setup for distributing DataFusion expressions. + +When a :class:`Expr` is shipped to a worker process (e.g. through +:func:`multiprocessing.Pool` or a Ray actor), the worker reconstructs the +expression against a :class:`SessionContext`. If the expression references +UDFs imported via the FFI capsule protocol — or any UDF the worker would +otherwise resolve from its registered functions rather than from inside +the shipped expression — install a configured :class:`SessionContext` +once per worker: + +.. code-block:: python + + from datafusion import SessionContext + from datafusion.ipc import set_worker_ctx + + def init_worker(): + ctx = SessionContext() + ctx.register_udaf(my_ffi_aggregate) + set_worker_ctx(ctx) + +Built-in functions and Python scalar UDFs travel inside the shipped +expression itself and do not need pre-registration on the worker. +""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datafusion.context import SessionContext + + +__all__ = [ + "clear_worker_ctx", + "get_worker_ctx", + "set_worker_ctx", +] + + +_local = threading.local() + + +def set_worker_ctx(ctx: SessionContext) -> None: + """Install this worker's :class:`SessionContext` for shipped expressions. + + Call once per worker — typically from a ``multiprocessing.Pool`` + initializer or a Ray actor ``__init__``. Idempotent: overwrites any + previous value. Stored in a thread-local slot, so each thread within a + worker may install its own context independently. + """ + _local.ctx = ctx + + +def clear_worker_ctx() -> None: + """Remove this worker's installed :class:`SessionContext`. + + After clearing, expressions reconstructed in this worker fall back to + the global :class:`SessionContext` — adequate for built-ins and Python + scalar UDFs, but anything imported via the FFI capsule protocol must + be registered on the global context to resolve. + """ + if hasattr(_local, "ctx"): + del _local.ctx + + +def get_worker_ctx() -> SessionContext | None: + """Return this worker's installed :class:`SessionContext`, or ``None``.""" + return getattr(_local, "ctx", None) + + +def _resolve_ctx( + explicit_ctx: SessionContext | None = None, +) -> SessionContext: + """Resolve a context for Expr reconstruction. + + Priority: explicit argument > worker context > global context. + Falling back to the global :class:`SessionContext` (instead of a + freshly constructed one) preserves any registrations the user has + installed on it. + """ + if explicit_ctx is not None: + return explicit_ctx + worker = get_worker_ctx() + if worker is not None: + return worker + # Lazy import: `datafusion/__init__.py` imports `datafusion.ipc` + # before `datafusion.context`, so a module-top import would force + # `datafusion.context` to load mid-init of `datafusion.ipc`. The + # cycle is benign today (context.py only pulls expr.py at module + # scope, neither pulls ipc.py back), but a single new import in + # context.py's transitive deps could turn it into a real cycle. + # Deferring keeps `datafusion.ipc` import-order-independent. + from datafusion.context import SessionContext # noqa: PLC0415 + + return SessionContext.global_ctx() diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 848ab4cee..f80b613a2 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -141,6 +141,16 @@ def __init__( name, func, input_fields, return_field, str(volatility) ) + @property + def name(self) -> str: + """Return the registered name of this UDF. + + For UDFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + """ + return self._udf.name + def __repr__(self) -> str: """Print a string representation of the Scalar UDF.""" return self._udf.__repr__() diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 6a466f6f2..e1fdeab44 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -1186,7 +1186,7 @@ def test_expr_to_bytes_roundtrip(ctx: SessionContext) -> None: original = col("a") + lit(1) blob = original.to_bytes(ctx) - restored = Expr.from_bytes(ctx, blob) + restored = Expr.from_bytes(blob, ctx=ctx) # Canonical name preserves the structure of the expression even # though the underlying PyExpr instances are different. @@ -1201,6 +1201,6 @@ def test_expr_to_bytes_no_ctx_default_codec() -> None: fresh = SessionContext() original = col("a") * lit(2) blob = original.to_bytes() # encode side: default codec - restored = Expr.from_bytes(fresh, blob) + restored = Expr.from_bytes(blob, ctx=fresh) assert restored.canonical_name() == original.canonical_name() diff --git a/python/tests/test_pickle_expr.py b/python/tests/test_pickle_expr.py new file mode 100644 index 000000000..c0d749271 --- /dev/null +++ b/python/tests/test_pickle_expr.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""In-process pickle round-trip tests for :class:`Expr`. + +Built-in functions and Python scalar UDFs travel with the pickled +expression and do not need worker-side pre-registration. The worker +context (:mod:`datafusion.ipc`) is only consulted for UDFs imported +via the FFI capsule protocol. +""" + +from __future__ import annotations + +import pickle + +import pyarrow as pa +import pytest +from datafusion import Expr, SessionContext, col, lit, udf +from datafusion.ipc import ( + clear_worker_ctx, + set_worker_ctx, +) + + +@pytest.fixture(autouse=True) +def _reset_worker_ctx(): + """Ensure every test starts with no worker context installed.""" + clear_worker_ctx() + yield + clear_worker_ctx() + + +def _double_udf(): + return udf( + lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + [pa.int64()], + pa.int64(), + volatility="immutable", + name="double", + ) + + +class TestProtoRoundTrip: + def test_builtin_round_trip(self): + e = col("a") + lit(1) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() + + def test_to_bytes_from_bytes(self): + e = col("x") * lit(7) + blob = e.to_bytes() + assert isinstance(blob, bytes) + decoded = Expr.from_bytes(blob) + assert decoded.canonical_name() == e.canonical_name() + + def test_explicit_ctx_used(self, ctx): + e = col("a") + lit(1) + decoded = Expr.from_bytes(e.to_bytes(), ctx=ctx) + assert decoded.canonical_name() == e.canonical_name() + + +class TestUDFCodec: + """Python scalar UDFs ride inside the proto blob via the Rust codec. + + No worker context needed on the receiver — the cloudpickled callable is + embedded in ``fun_definition`` and reconstructed automatically. + """ + + def test_udf_self_contained_blob(self): + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + # The codec inlines the callable, so the blob is much bigger than a + # pure built-in blob but doesn't depend on receiver-side registration. + assert len(blob) > 200 + + def test_udf_decodes_into_fresh_ctx(self): + e = _double_udf()(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "double" in decoded.canonical_name() + + def test_udf_decodes_via_pickle_with_no_worker_ctx(self): + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "double" in decoded.canonical_name() + + def test_udf_decodes_via_pickle_with_worker_ctx(self): + set_worker_ctx(SessionContext()) + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "double" in decoded.canonical_name() + + def test_closure_capturing_udf_names_match(self): + captured_multiplier = 7 + + def fn(arr): + return pa.array([(v.as_py() or 0) * captured_multiplier for v in arr]) + + u = udf( + fn, + [pa.int64()], + pa.int64(), + volatility="immutable", + name="times_seven", + ) + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() diff --git a/uv.lock b/uv.lock index 3b7135e32..3fd3eec4b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -257,6 +257,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767, upload-time = "2024-12-24T18:12:32.852Z" }, ] +[[package]] +name = "cloudpickle" +version = "3.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, +] + [[package]] name = "codespell" version = "2.4.1" @@ -316,6 +325,7 @@ wheels = [ name = "datafusion" source = { editable = "." } dependencies = [ + { name = "cloudpickle" }, { name = "pyarrow" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] @@ -351,6 +361,7 @@ docs = [ [package.metadata] requires-dist = [ + { name = "cloudpickle", specifier = ">=2.0" }, { name = "pyarrow", marker = "python_full_version < '3.14'", specifier = ">=16.0.0" }, { name = "pyarrow", marker = "python_full_version >= '3.14'", specifier = ">=22.0.0" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, From ca6849e98009ad85ae9b1132f33afb29d57c0518 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 15 May 2026 15:35:34 -0400 Subject: [PATCH 2/4] docs(pickle): add cloudpickle security warnings, docstring examples, edge-case tests Inline `.. warning::` blocks on `Expr.to_bytes`, `Expr.from_bytes`, and `Expr.__reduce__` so the cloudpickle / arbitrary-code-execution caveat is visible at the public API surface in advance of the user-guide page that lands in PR 4. Add doctest-style `Examples:` blocks to `datafusion.ipc` functions (`set_worker_ctx`, `clear_worker_ctx`, `get_worker_ctx`, `_resolve_ctx`), `ScalarUDF.name`, and the new `Expr` pickle methods, per CLAUDE.md. Tighten `Expr.__reduce__` return annotation to `tuple[Callable[[bytes], Expr], tuple[bytes]]`. Tests: multi-arg UDF round-trip (covers synthetic `arg_{i}` schema-field loop in the codec) plus malformed-bytes paths through `Expr.from_bytes`. Co-Authored-By: Claude Opus 4.7 (1M context) --- python/datafusion/expr.py | 51 +++++++++++++++++++++++++++++-- python/datafusion/ipc.py | 35 ++++++++++++++++++++- python/datafusion/user_defined.py | 13 ++++++++ python/tests/test_pickle_expr.py | 30 ++++++++++++++++++ 4 files changed, 125 insertions(+), 4 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 10b011ffb..cdb8377de 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -46,7 +46,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from typing import TYPE_CHECKING, Any, ClassVar import pyarrow as pa @@ -447,6 +447,19 @@ def to_bytes(self, ctx: SessionContext | None = None) -> bytes: returned bytes; the worker does not need to pre-register them. UDFs imported via the FFI capsule protocol travel by name only and must be registered on the worker. + + .. warning:: + Bytes returned here may embed a cloudpickled Python + callable (when the expression carries a Python scalar UDF). + Reconstructing them via :meth:`from_bytes` or + :func:`pickle.loads` executes arbitrary Python on the + receiver. Only accept payloads from trusted sources. + + Examples: + >>> from datafusion import col, lit + >>> blob = (col("a") + lit(1)).to_bytes() + >>> isinstance(blob, bytes) + True """ ctx_arg = ctx.ctx if ctx is not None else None return self.expr.to_bytes(ctx_arg) @@ -463,13 +476,25 @@ def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr: context is installed, the global :class:`SessionContext` is used (sufficient for built-ins and Python scalar UDFs, plus any UDFs registered on the global context). + + .. warning:: + Decoding may invoke ``cloudpickle.loads`` on bytes embedded + in the payload, which executes arbitrary Python code. Treat + ``buf`` as code, not data — only decode bytes you produced + yourself or received from a trusted sender. + + Examples: + >>> from datafusion import Expr, col, lit + >>> blob = (col("a") + lit(1)).to_bytes() + >>> Expr.from_bytes(blob).canonical_name() + 'a + Int64(1)' """ from datafusion.ipc import _resolve_ctx resolved = _resolve_ctx(ctx) return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf)) - def __reduce__(self) -> tuple: + def __reduce__(self) -> tuple[Callable[[bytes], Expr], tuple[bytes]]: """Pickle protocol hook. Lets expressions be shipped to worker processes via @@ -480,12 +505,32 @@ def __reduce__(self) -> tuple: is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling back to the global :class:`SessionContext` if none has been installed on the worker. + + .. warning:: + :func:`pickle.loads` on the returned tuple executes + arbitrary Python on the receiver, including any + cloudpickled UDF callable embedded in the payload. Only + unpickle expressions from trusted sources. + + Examples: + >>> import pickle + >>> from datafusion import col, lit + >>> e = col("a") * lit(2) + >>> pickle.loads(pickle.dumps(e)).canonical_name() + 'a * Int64(2)' """ return (Expr._reconstruct, (self.to_bytes(),)) @classmethod def _reconstruct(cls, proto_bytes: bytes) -> Expr: - """Internal entry point used by :meth:`__reduce__` on unpickle.""" + """Internal entry point used by :meth:`__reduce__` on unpickle. + + Examples: + >>> from datafusion import Expr, col, lit + >>> blob = (col("a") + lit(1)).to_bytes() + >>> Expr._reconstruct(blob).canonical_name() + 'a + Int64(1)' + """ return cls.from_bytes(proto_bytes) def __richcmp__(self, other: Expr, op: int) -> Expr: diff --git a/python/datafusion/ipc.py b/python/datafusion/ipc.py index d1867a917..16e68c4d0 100644 --- a/python/datafusion/ipc.py +++ b/python/datafusion/ipc.py @@ -65,6 +65,14 @@ def set_worker_ctx(ctx: SessionContext) -> None: initializer or a Ray actor ``__init__``. Idempotent: overwrites any previous value. Stored in a thread-local slot, so each thread within a worker may install its own context independently. + + Examples: + >>> from datafusion import SessionContext + >>> from datafusion.ipc import set_worker_ctx, get_worker_ctx, clear_worker_ctx + >>> set_worker_ctx(SessionContext()) + >>> get_worker_ctx() is not None + True + >>> clear_worker_ctx() """ _local.ctx = ctx @@ -76,13 +84,28 @@ def clear_worker_ctx() -> None: the global :class:`SessionContext` — adequate for built-ins and Python scalar UDFs, but anything imported via the FFI capsule protocol must be registered on the global context to resolve. + + Examples: + >>> from datafusion import SessionContext + >>> from datafusion.ipc import set_worker_ctx, clear_worker_ctx, get_worker_ctx + >>> set_worker_ctx(SessionContext()) + >>> clear_worker_ctx() + >>> get_worker_ctx() is None + True """ if hasattr(_local, "ctx"): del _local.ctx def get_worker_ctx() -> SessionContext | None: - """Return this worker's installed :class:`SessionContext`, or ``None``.""" + """Return this worker's installed :class:`SessionContext`, or ``None``. + + Examples: + >>> from datafusion.ipc import get_worker_ctx, clear_worker_ctx + >>> clear_worker_ctx() + >>> get_worker_ctx() is None + True + """ return getattr(_local, "ctx", None) @@ -95,6 +118,16 @@ def _resolve_ctx( Falling back to the global :class:`SessionContext` (instead of a freshly constructed one) preserves any registrations the user has installed on it. + + Examples: + >>> from datafusion import SessionContext + >>> from datafusion.ipc import _resolve_ctx, clear_worker_ctx + >>> clear_worker_ctx() + >>> isinstance(_resolve_ctx(), SessionContext) + True + >>> ctx = SessionContext() + >>> _resolve_ctx(ctx) is ctx + True """ if explicit_ctx is not None: return explicit_ctx diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index f80b613a2..d79cf22e8 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -148,6 +148,19 @@ def name(self) -> str: For UDFs imported via the FFI capsule protocol, this is the name the capsule itself reports — not the ``name`` argument passed to the constructor (which is ignored on the FFI path). + + Examples: + >>> import pyarrow as pa + >>> from datafusion import udf + >>> double = udf( + ... lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + ... [pa.int64()], + ... pa.int64(), + ... volatility="immutable", + ... name="double", + ... ) + >>> double.name + 'double' """ return self._udf.name diff --git a/python/tests/test_pickle_expr.py b/python/tests/test_pickle_expr.py index c0d749271..3e30ca14a 100644 --- a/python/tests/test_pickle_expr.py +++ b/python/tests/test_pickle_expr.py @@ -125,3 +125,33 @@ def fn(arr): blob = pickle.dumps(e) decoded = pickle.loads(blob) # noqa: S301 assert decoded.canonical_name() == e.canonical_name() + + def test_multi_arg_udf_round_trip(self): + """Wire format builds synthetic `arg_{i}` fields per input — exercise + with a 2-arg UDF spanning two distinct DataTypes.""" + add_scaled = udf( + lambda a, b: pa.array( + [ + (x.as_py() or 0) + (y.as_py() or 0.0) + for x, y in zip(a, b, strict=False) + ] + ), + [pa.int64(), pa.float64()], + pa.float64(), + volatility="immutable", + name="add_scaled", + ) + e = add_scaled(col("a"), col("b")) + decoded = pickle.loads(pickle.dumps(e)) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() + assert "add_scaled" in decoded.canonical_name() + + +class TestErrorPaths: + def test_from_bytes_rejects_garbage(self): + with pytest.raises(Exception): # noqa: B017 + Expr.from_bytes(b"not a valid protobuf payload") + + def test_from_bytes_rejects_empty(self): + with pytest.raises(Exception): # noqa: B017 + Expr.from_bytes(b"") From df34e0aa43f250c825d123f2eb97b15f77078245 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 18 May 2026 12:00:50 -0400 Subject: [PATCH 3/4] as_any no longer in api --- crates/core/src/codec.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs index c95d8cb19..ff052506e 100644 --- a/crates/core/src/codec.rs +++ b/crates/core/src/codec.rs @@ -381,11 +381,7 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { /// inner codec. Returns `Ok(false)` for any non-Python UDF, signalling /// the caller to delegate to its `inner`. pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec) -> Result { - let Some(py_udf) = node - .inner() - .as_any() - .downcast_ref::() - else { + let Some(py_udf) = node.inner().downcast_ref::() else { return Ok(false); }; From 5c9991505569297dc7cda90ea67e35eeb27d7895 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 18 May 2026 12:33:01 -0400 Subject: [PATCH 4/4] feat(pickle): stamp Python (major, minor) in UDF wire header MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cloudpickle bytecode is not portable across Python minor versions — a payload produced on 3.11 fails to load on 3.12 with an opaque marshal/unpickle error. Embed the sender's (major, minor) in the DFPYUDF wire header and reject mismatches at decode time with an actionable error that names both versions, instead of letting the failure surface from inside cloudpickle.loads. Header layout becomes: DFPYUDF (7) | version (1) | py_major (1) | py_minor (1) | cloudpickle Extend the Security warnings on Expr.to_bytes / from_bytes / __reduce__ with a Portability section covering the cross-version constraint and cloudpickle's by-value/by-reference behavior (the callable inlines bytecode and closure cells, but imported names travel by reference and must be importable on the receiver). Add a matching Serialization model note to the datafusion.ipc module docstring. New tests: - codec::wire_header_tests: py-major/minor mismatch, truncated py-version bytes, round-trip with py-version - test_pickle_expr::test_cross_version_error_message: patches the py_minor byte inside an emitted payload and asserts the error message identifies the version mismatch Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/codec.rs | 162 +++++++++++++++++++++++++------ python/datafusion/expr.py | 67 ++++++++++++- python/datafusion/ipc.py | 15 +++ python/tests/test_pickle_expr.py | 27 ++++++ 4 files changed, 237 insertions(+), 34 deletions(-) diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs index ff052506e..cc038edc9 100644 --- a/crates/core/src/codec.rs +++ b/crates/core/src/codec.rs @@ -49,12 +49,16 @@ //! the same payload framing for that reason. //! //! Payloads emitted by these codecs are framed as -//! ` `. The -//! family magic identifies the UDF flavor; the version byte lets the -//! decoder reject too-new or too-old payloads with a clean error +//! ` `. +//! The family magic identifies the UDF flavor; the version byte lets +//! the decoder reject too-new or too-old payloads with a clean error //! instead of falling into an opaque `cloudpickle` tuple-unpack -//! failure when the tuple shape changes. Dispatch precedence on -//! decode: **family match + supported version → `inner` codec → +//! failure when the tuple shape changes; the Python `(major, minor)` +//! bytes catch the cloudpickle-cross-minor-version case and raise an +//! actionable error instead of an opaque `marshal` failure on load +//! (cloudpickle payloads are not portable across Python minor +//! versions). Dispatch precedence on decode: **family match + +//! supported version + matching Python version → `inner` codec → //! caller's `FunctionRegistry` fallback.** //! //! ## Wire-format family registry @@ -105,13 +109,17 @@ use crate::udf::PythonFunctionScalarUDF; // Wire-format framing for inlined Python UDF payloads. // -// Layout: ` `. +// Layout: ` `. // The family magic identifies the UDF flavor; the version byte lets // the decoder reject too-new or too-old payloads with a clean error // instead of falling into an opaque `cloudpickle` tuple-unpack failure -// when the tuple shape changes. Bump [`WIRE_VERSION_CURRENT`] whenever -// the tuple shape changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when -// dropping support for an older shape. +// when the tuple shape changes; the Python `(major, minor)` bytes +// catch the cloudpickle-cross-minor-version case (cloudpickle is not +// portable across Python minor versions) and raise an actionable +// error instead of an opaque `marshal` failure on load. Bump +// [`WIRE_VERSION_CURRENT`] whenever the tuple shape changes; raise +// [`WIRE_VERSION_MIN_SUPPORTED`] when dropping support for an older +// shape. /// Family prefix for an inlined Python scalar UDF /// (cloudpickled tuple of name, callable, input schema, return field, @@ -126,11 +134,14 @@ pub(crate) const WIRE_VERSION_CURRENT: u8 = 1; pub(crate) const WIRE_VERSION_MIN_SUPPORTED: u8 = 1; /// Tag `buf` with the framing header for `family` at the current -/// wire-format version. Append-only — the caller writes the -/// cloudpickle payload after. -fn write_wire_header(buf: &mut Vec, family: &[u8]) { +/// wire-format version, stamping `py_version` as `(major, minor)` +/// bytes. Append-only — the caller writes the cloudpickle payload +/// after. +fn write_wire_header(buf: &mut Vec, family: &[u8], py_version: (u8, u8)) { buf.extend_from_slice(family); buf.push(WIRE_VERSION_CURRENT); + buf.push(py_version.0); + buf.push(py_version.1); } /// Inspect the framing on `buf`. @@ -138,13 +149,20 @@ fn write_wire_header(buf: &mut Vec, family: &[u8]) { /// * `Ok(None)` — `buf` does not carry `family`. The caller should /// delegate to its `inner` codec. /// * `Ok(Some(payload))` — `buf` carries `family` at a version this -/// build accepts; `payload` is the cloudpickle blob. -/// * `Err(_)` — `buf` carries `family` but at a version outside -/// `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. The error -/// names the version and the supported range so an operator can -/// diagnose sender/receiver version drift instead of seeing an -/// opaque cloudpickle tuple-unpack failure. -fn strip_wire_header<'a>(buf: &'a [u8], family: &[u8], kind: &str) -> Result> { +/// build accepts and a Python `(major, minor)` matching +/// `expected_py`; `payload` is the cloudpickle blob. +/// * `Err(_)` — `buf` carries `family` but the wire-format version +/// is outside `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`, +/// or the stamped Python `(major, minor)` does not match +/// `expected_py`. The error names the offending values so an +/// operator can diagnose sender/receiver drift instead of seeing +/// an opaque cloudpickle tuple-unpack or `marshal` failure. +fn strip_wire_header<'a>( + buf: &'a [u8], + family: &[u8], + kind: &str, + expected_py: (u8, u8), +) -> Result> { if !buf.starts_with(family) { return Ok(None); } @@ -161,7 +179,28 @@ fn strip_wire_header<'a>(buf: &'a [u8], family: &[u8], kind: &str) -> Result) }; Python::attach(|py| -> Result { + let py_version = current_python_version(py) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; let bytes = encode_python_scalar_udf(py, py_udf) .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; - write_wire_header(buf, PY_SCALAR_UDF_FAMILY); + write_wire_header(buf, PY_SCALAR_UDF_FAMILY, py_version); buf.extend_from_slice(&bytes); Ok(true) }) @@ -399,11 +440,13 @@ pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec) /// the caller to delegate to its `inner` codec (and eventually the /// `FunctionRegistry`). pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result>> { - let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF")? else { - return Ok(None); - }; - Python::attach(|py| -> Result>> { + let py_version = current_python_version(py) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", py_version)? + else { + return Ok(None); + }; let udf = decode_python_scalar_udf(py, payload) .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf)))) @@ -567,6 +610,20 @@ fn volatility_wire_str(v: Volatility) -> &'static str { } } +/// Read the interpreter's `sys.version_info` as `(major, minor)`. +/// +/// Used by encoder/decoder to stamp and verify the Python version a +/// cloudpickle payload was produced on. cloudpickle is not portable +/// across Python minor versions; the wire header carries these bytes +/// so a mismatch surfaces an actionable error instead of an opaque +/// `marshal` failure at `cloudpickle.loads` time. +fn current_python_version(py: Python<'_>) -> PyResult<(u8, u8)> { + let version_info = py.import("sys")?.getattr("version_info")?; + let major: u8 = version_info.getattr("major")?.extract()?; + let minor: u8 = version_info.getattr("minor")?.extract()?; + Ok((major, minor)) +} + /// Cached handle to the `cloudpickle` module. /// /// The encode/decode helpers above would otherwise re-resolve the @@ -589,11 +646,13 @@ fn cloudpickle<'py>(py: Python<'py>) -> PyResult> { mod wire_header_tests { use super::*; + const TEST_PY: (u8, u8) = (3, 12); + #[test] fn strip_returns_none_when_family_absent() { let buf = b"OTHER_PAYLOAD"; assert!(matches!( - strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF"), + strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY), Ok(None) )); } @@ -601,7 +660,7 @@ mod wire_header_tests { #[test] fn strip_errors_on_truncated_version_byte() { let buf = PY_SCALAR_UDF_FAMILY; - let err = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + let err = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).unwrap_err(); assert!(format!("{err}").contains("missing wire-format version byte")); } @@ -609,8 +668,10 @@ mod wire_header_tests { fn strip_errors_on_too_new_version() { let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); buf.push(WIRE_VERSION_CURRENT.saturating_add(1)); + buf.push(TEST_PY.0); + buf.push(TEST_PY.1); buf.extend_from_slice(b"payload"); - let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).unwrap_err(); let msg = format!("{err}"); assert!(msg.contains("wire-format version v")); assert!(msg.contains("supports")); @@ -624,17 +685,56 @@ mod wire_header_tests { } let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); buf.push(WIRE_VERSION_MIN_SUPPORTED - 1); + buf.push(TEST_PY.0); + buf.push(TEST_PY.1); + buf.extend_from_slice(b"payload"); + assert!(strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).is_err()); + } + + #[test] + fn strip_errors_on_truncated_py_major() { + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_CURRENT); + let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).unwrap_err(); + assert!(format!("{err}").contains("missing Python major version byte")); + } + + #[test] + fn strip_errors_on_truncated_py_minor() { + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_CURRENT); + buf.push(TEST_PY.0); + let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY).unwrap_err(); + assert!(format!("{err}").contains("missing Python minor version byte")); + } + + #[test] + fn strip_errors_on_py_minor_mismatch() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, (3, 11)); + buf.extend_from_slice(b"payload"); + let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", (3, 12)).unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("Python 3.11")); + assert!(msg.contains("Python 3.12")); + assert!(msg.contains("not portable across Python minor versions")); + } + + #[test] + fn strip_errors_on_py_major_mismatch() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, (3, 12)); buf.extend_from_slice(b"payload"); - assert!(strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").is_err()); + assert!(strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", (4, 0)).is_err()); } #[test] fn write_then_strip_round_trips_payload() { let mut buf = Vec::new(); - write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY); + write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, TEST_PY); buf.extend_from_slice(b"scalar-payload"); - let payload = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF") + let payload = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", TEST_PY) .unwrap() .unwrap(); assert_eq!(payload, b"scalar-payload"); diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index ff8b7a391..645bd9c18 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -454,13 +454,59 @@ def to_bytes(self, ctx: SessionContext | None = None) -> bytes: UDFs imported via the FFI capsule protocol travel by name only and must be registered on the worker. - .. warning:: + .. warning:: Security Bytes returned here may embed a cloudpickled Python callable (when the expression carries a Python scalar UDF). Reconstructing them via :meth:`from_bytes` or :func:`pickle.loads` executes arbitrary Python on the receiver. Only accept payloads from trusted sources. + .. warning:: Portability + cloudpickle serializes Python bytecode, which is **not + stable across Python minor versions**. A payload produced + on Python 3.11 will fail to load on Python 3.12. The + wire format stamps the sender's ``(major, minor)``; + :meth:`from_bytes` raises a :class:`ValueError` naming + both versions on mismatch. + + cloudpickle captures the UDF callable **by value** — + bytecode and closure cells inlined — but names the + callable resolves via ``import`` are captured **by + reference** (module path only) and must be importable on + the receiver. + + **Self-contained — works anywhere:** + + .. code-block:: python + + # Lambda: bytecode captured inline + udf(lambda x: x * 2, [pa.int64()], pa.int64(), + volatility="immutable") + + # Locally-defined function: bytecode captured inline + def double(x): + return x * 2 + udf(double, [pa.int64()], pa.int64(), volatility="immutable") + + # Closure over a local variable: value captured inline + factor = 3 + udf(lambda x: x * factor, [pa.int64()], pa.int64(), + volatility="immutable") + + **Requires matching environment on receiver:** + + .. code-block:: python + + # Top-level import: `foo` must be installed on receiver + from foo import double + udf(double, [pa.int64()], pa.int64(), volatility="immutable") + + # Bound method of an imported class: same caveat + from mylib import Transformer + t = Transformer() + udf(t.transform, [pa.int64()], pa.int64(), + volatility="immutable") + Examples: >>> from datafusion import col, lit >>> blob = (col("a") + lit(1)).to_bytes() @@ -483,12 +529,21 @@ def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr: (sufficient for built-ins and Python scalar UDFs, plus any UDFs registered on the global context). - .. warning:: + .. warning:: Security Decoding may invoke ``cloudpickle.loads`` on bytes embedded in the payload, which executes arbitrary Python code. Treat ``buf`` as code, not data — only decode bytes you produced yourself or received from a trusted sender. + .. warning:: Portability + cloudpickle payloads are **not portable across Python + minor versions**. The wire format stamps the sender's + ``(major, minor)``; if it does not match the current + interpreter, this method raises :class:`ValueError` + naming both versions. Modules the UDF imports must also + be importable on the receiver — see :meth:`to_bytes` for + by-value vs. by-reference details. + Examples: >>> from datafusion import Expr, col, lit >>> blob = (col("a") + lit(1)).to_bytes() @@ -512,12 +567,18 @@ def __reduce__(self) -> tuple[Callable[[bytes], Expr], tuple[bytes]]: back to the global :class:`SessionContext` if none has been installed on the worker. - .. warning:: + .. warning:: Security :func:`pickle.loads` on the returned tuple executes arbitrary Python on the receiver, including any cloudpickled UDF callable embedded in the payload. Only unpickle expressions from trusted sources. + .. warning:: Portability + Sender and receiver must run the same Python + ``(major, minor)`` version; cloudpickle bytecode is not + portable across minor versions. See :meth:`to_bytes` for + details on what travels by value vs. by reference. + Examples: >>> import pickle >>> from datafusion import col, lit diff --git a/python/datafusion/ipc.py b/python/datafusion/ipc.py index 16e68c4d0..78b6873f7 100644 --- a/python/datafusion/ipc.py +++ b/python/datafusion/ipc.py @@ -37,6 +37,21 @@ def init_worker(): Built-in functions and Python scalar UDFs travel inside the shipped expression itself and do not need pre-registration on the worker. + +.. note:: Serialization model + + Expressions containing Python scalar UDFs are serialized using + :mod:`cloudpickle`. The callable itself travels **by value** + (bytecode and closure cells inlined), but any names the callable + resolves via ``import`` are captured **by reference** and must be + importable on the receiving worker. + + The serialized payload is stamped with the sender's Python + ``(major, minor)`` version. Loading on a different minor version + raises :class:`ValueError` with an actionable message — cloudpickle + payloads are not portable across Python minor versions. See + :meth:`datafusion.Expr.to_bytes` for examples of what travels by + value vs. by reference. """ from __future__ import annotations diff --git a/python/tests/test_pickle_expr.py b/python/tests/test_pickle_expr.py index 3e30ca14a..5d8d9285f 100644 --- a/python/tests/test_pickle_expr.py +++ b/python/tests/test_pickle_expr.py @@ -155,3 +155,30 @@ def test_from_bytes_rejects_garbage(self): def test_from_bytes_rejects_empty(self): with pytest.raises(Exception): # noqa: B017 Expr.from_bytes(b"") + + def test_cross_version_error_message(self): + """Decoding a payload stamped with a different Python minor + version raises a clear, actionable error rather than an opaque + marshal/unpickle failure. + + The wire frame inside the protobuf is: + ``DFPYUDF (7) | version (1) | py_major (1) | py_minor (1) | cloudpickle``. + We locate the frame inside the outer protobuf and patch the + minor byte at offset 9. + """ + import sys + + e = _double_udf()(col("a")) + blob = e.to_bytes() + + idx = blob.find(b"DFPYUDF") + assert idx >= 0, "DFPYUDF frame not found in payload" + + different_minor = (sys.version_info.minor + 1) % 256 + tampered = bytearray(blob) + tampered[idx + 9] = different_minor + + with pytest.raises( + Exception, match="not portable across Python minor versions" + ): + Expr.from_bytes(bytes(tampered))