diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 550f342c9c7..605989a8a5a 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -13586,8 +13586,12 @@ impl vortex_array::optimizer::kernels::ArrayKernels pub fn vortex_array::optimizer::kernels::ArrayKernels::empty() -> Self +pub fn vortex_array::optimizer::kernels::ArrayKernels::find_execute_parent(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id) -> core::option::Option> + pub fn vortex_array::optimizer::kernels::ArrayKernels::find_reduce_parent(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id) -> core::option::Option> +pub fn vortex_array::optimizer::kernels::ArrayKernels::register_execute_parent>(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id, fns: I) + pub fn vortex_array::optimizer::kernels::ArrayKernels::register_reduce_parent>(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id, fns: I) impl core::default::Default for vortex_array::optimizer::kernels::ArrayKernels @@ -13612,6 +13616,8 @@ impl vortex_array::optimizer::kernels::ArrayKerne pub fn S::kernels(&self) -> vortex_session::Ref<'_, vortex_array::optimizer::kernels::ArrayKernels> +pub type vortex_array::optimizer::kernels::ExecuteParentFn = fn(child: &vortex_array::ArrayRef, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub type vortex_array::optimizer::kernels::ReduceParentFn = fn(child: &vortex_array::ArrayRef, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> pub mod vortex_array::optimizer::rules diff --git a/vortex-array/src/arrays/struct_/compute/cast.rs b/vortex-array/src/arrays/struct_/compute/cast.rs index ddb3cbeadf2..51f670546f9 100644 --- a/vortex-array/src/arrays/struct_/compute/cast.rs +++ b/vortex-array/src/arrays/struct_/compute/cast.rs @@ -6,97 +6,107 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use crate::ArrayRef; +use crate::ArrayView; use crate::ExecutionCtx; use crate::IntoArray; -use crate::array::ArrayView; use crate::arrays::ConstantArray; use crate::arrays::Struct; use crate::arrays::StructArray; +use crate::arrays::scalar_fn::ExactScalarFn; use crate::arrays::struct_::StructArrayExt; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; +use crate::matcher::Matcher; use crate::scalar::Scalar; -use crate::scalar_fn::fns::cast::CastKernel; - -impl CastKernel for Struct { - fn cast( - array: ArrayView<'_, Struct>, - dtype: &DType, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - let Some(target_sdtype) = dtype.as_struct_fields_opt() else { - return Ok(None); - }; - - let source_sdtype = array.struct_fields(); +use crate::scalar_fn::fns::cast::Cast; + +pub(crate) fn struct_cast_execute_parent( + child: &ArrayRef, + parent: &ArrayRef, + _child_idx: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + let Some(array) = child.as_opt::() else { + return Ok(None); + }; + let Some(parent) = ExactScalarFn::::try_match(parent) else { + return Ok(None); + }; + + let dtype = parent.options; + if array.dtype() == parent.options { + return Ok(Some(array.array().clone())); + } - let fields_match_order = target_sdtype.nfields() == source_sdtype.nfields() - && target_sdtype - .names() - .iter() - .zip(source_sdtype.names().iter()) - .all(|(f1, f2)| f1 == f2); + struct_cast(array, dtype, ctx) +} - let mut cast_fields = Vec::with_capacity(target_sdtype.nfields()); - if fields_match_order { - for (field, target_type) in array.iter_unmasked_fields().zip_eq(target_sdtype.fields()) - { - let cast_field = field.cast(target_type)?; - cast_fields.push(cast_field); +pub(crate) fn struct_cast( + array: ArrayView, + dtype: &DType, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + let Some(target_sdtype) = dtype.as_struct_fields_opt() else { + return Ok(None); + }; + + let source_sdtype = array.struct_fields(); + + let mut cast_fields = Vec::with_capacity(target_sdtype.nfields()); + // Re-order, handle fields by value instead. + for (target_name, target_type) in target_sdtype.names().iter().zip_eq(target_sdtype.fields()) { + match source_sdtype.find(target_name) { + None => { + // No source field with this name => evolve the schema compatibly. + // If the field is nullable, we add a new ConstantArray field with the type. + vortex_ensure!( + target_type.is_nullable(), + "CAST for struct only supports added nullable fields" + ); + + cast_fields + .push(ConstantArray::new(Scalar::null(target_type), array.len()).into_array()); } - } else { - // Re-order, handle fields by value instead. - for (target_name, target_type) in - target_sdtype.names().iter().zip_eq(target_sdtype.fields()) - { - match source_sdtype.find(target_name) { - None => { - // No source field with this name => evolve the schema compatibly. - // If the field is nullable, we add a new ConstantArray field with the type. - vortex_ensure!( - target_type.is_nullable(), - "CAST for struct only supports added nullable fields" - ); - - cast_fields.push( - ConstantArray::new(Scalar::null(target_type), array.len()).into_array(), - ); - } - Some(src_field_idx) => { - // Field exists in source field. Cast it to the target type. - let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?; - cast_fields.push(cast_field); - } - } + Some(src_field_idx) => { + // Field exists in source field. Cast it to the target type. + let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?; + cast_fields.push(cast_field); } } + } - let validity = array - .validity()? - .cast_nullability(dtype.nullability(), array.len(), ctx)?; + let validity = array + .validity()? + .cast_nullability(dtype.nullability(), array.len(), ctx)?; - StructArray::try_new( - target_sdtype.names().clone(), - cast_fields, - array.len(), - validity, - ) - .map(|a| Some(a.into_array())) - } + Ok(Some( + unsafe { + StructArray::new_unchecked(cast_fields, target_sdtype.clone(), array.len(), validity) + } + .into_array(), + )) } #[cfg(test)] mod tests { + use std::sync::LazyLock; + use rstest::rstest; use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + use crate::ArrayRef; + use crate::ExecutionCtx; use crate::IntoArray; - #[expect(deprecated)] - use crate::ToCanonical as _; + use crate::VortexSessionExecute; + use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; use crate::arrays::VarBinArray; + use crate::arrays::scalar_fn::ScalarFnFactoryExt; use crate::arrays::struct_::StructArrayExt; + use crate::assert_arrays_eq; use crate::builtins::ArrayBuiltins; use crate::compute::conformance::cast::test_cast_conformance; use crate::dtype::DType; @@ -104,8 +114,41 @@ mod tests { use crate::dtype::FieldNames; use crate::dtype::Nullability; use crate::dtype::PType; + use crate::dtype::StructFields; + use crate::optimizer::kernels::ArrayKernels; + use crate::optimizer::kernels::ArrayKernelsExt; + use crate::optimizer::kernels::ExecuteParentFn; + use crate::scalar::Scalar; + use crate::scalar_fn::fns::cast::Cast; + use crate::session::ArraySession; use crate::validity::Validity; + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + fn null_struct_cast_execute_parent( + child: &ArrayRef, + parent: &ArrayRef, + _child_idx: usize, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let Some(target_fields) = parent.dtype().as_struct_fields_opt() else { + return Ok(None); + }; + let fields: Vec = target_fields + .fields() + .map(|dtype| ConstantArray::new(Scalar::null(dtype), child.len()).into_array()) + .collect(); + + StructArray::try_new( + target_fields.names().clone(), + fields, + child.len(), + Validity::from(parent.dtype().nullability()), + ) + .map(|array| Some(array.into_array())) + } + #[rstest] #[case(create_test_struct(false))] #[case(create_test_struct(true))] @@ -115,6 +158,64 @@ mod tests { test_cast_conformance(&array.into_array()); } + #[test] + fn struct_cast_execute_parent_is_not_static_kernel() { + let source = create_simple_struct().into_array(); + let target = DType::struct_( + [( + "value", + DType::Primitive(PType::I64, Nullability::NonNullable), + )], + Nullability::NonNullable, + ); + + let cast = Cast + .try_new_array(source.len(), target, [source.clone()]) + .unwrap(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + assert!(source.execute_parent(&cast, 0, &mut ctx).unwrap().is_none()); + } + + #[test] + fn struct_cast_execute_parent_uses_session_plugin() { + let source = StructArray::try_new( + FieldNames::from(["a"]), + vec![VarBinArray::from_vec(vec!["A"], DType::Utf8(Nullability::Nullable)).into_array()], + 1, + Validity::NonNullable, + ) + .unwrap() + .into_array(); + let child_id = source.encoding_id(); + + let utf8_null = DType::Utf8(Nullability::Nullable); + let target = DType::Struct( + StructFields::new(FieldNames::from(["b"]), vec![utf8_null.clone()]), + Nullability::NonNullable, + ); + + let cast = Cast + .try_new_array(source.len(), target.clone(), [source]) + .unwrap(); + let parent_id = cast.encoding_id(); + let session = VortexSession::empty().with::(); + session.kernels().register_execute_parent( + parent_id, + child_id, + &[null_struct_cast_execute_parent as ExecuteParentFn], + ); + let mut ctx = session.create_execution_ctx(); + + let result = cast.execute::(&mut ctx).unwrap(); + + assert_eq!(result.dtype(), &target); + assert_arrays_eq!( + result.unmasked_field_by_name("b").unwrap(), + ConstantArray::new(Scalar::null(utf8_null), 1) + ); + } + fn create_test_struct(nullable: bool) -> StructArray { let names = FieldNames::from(["a", "b"]); @@ -204,14 +305,17 @@ mod tests { let target_dtype = struct_array.dtype().as_nullable(); - let result = struct_array + let cast = struct_array .into_array() .cast(target_dtype.clone()) .unwrap(); - assert_eq!(result.dtype(), &target_dtype); - assert_eq!(result.len(), 3); - #[expect(deprecated)] - let nfields = result.to_struct().struct_fields().nfields(); + assert_eq!(cast.dtype(), &target_dtype); + assert_eq!(cast.len(), 3); + let nfields = cast + .execute::(&mut SESSION.create_execution_ctx()) + .unwrap() + .struct_fields() + .nfields(); assert_eq!(nfields, 2); } @@ -241,8 +345,11 @@ mod tests { .unwrap(); assert_eq!(result.dtype(), &target_dtype); assert_eq!(result.len(), 3); - #[expect(deprecated)] - let nfields = result.to_struct().struct_fields().nfields(); + let nfields = result + .execute::(&mut SESSION.create_execution_ctx()) + .unwrap() + .struct_fields() + .nfields(); assert_eq!(nfields, 3); } } diff --git a/vortex-array/src/arrays/struct_/compute/mod.rs b/vortex-array/src/arrays/struct_/compute/mod.rs index 9ddee9496aa..c30ef532008 100644 --- a/vortex-array/src/arrays/struct_/compute/mod.rs +++ b/vortex-array/src/arrays/struct_/compute/mod.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod cast; +pub(crate) mod cast; mod mask; pub(crate) mod rules; mod slice; diff --git a/vortex-array/src/arrays/struct_/compute/rules.rs b/vortex-array/src/arrays/struct_/compute/rules.rs index a39b8b37a38..204458771e0 100644 --- a/vortex-array/src/arrays/struct_/compute/rules.rs +++ b/vortex-array/src/arrays/struct_/compute/rules.rs @@ -14,85 +14,91 @@ use crate::arrays::StructArray; use crate::arrays::dict::TakeReduceAdaptor; use crate::arrays::scalar_fn::ExactScalarFn; use crate::arrays::scalar_fn::ScalarFnArrayView; -use crate::arrays::scalar_fn::ScalarFnFactoryExt; use crate::arrays::slice::SliceReduceAdaptor; use crate::arrays::struct_::StructArrayExt; use crate::builtins::ArrayBuiltins; -use crate::dtype::DType; +use crate::matcher::Matcher; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; -use crate::scalar_fn::EmptyOptions; -use crate::scalar_fn::fns::cast::CastReduce; -use crate::scalar_fn::fns::cast::CastReduceAdaptor; +use crate::scalar::Scalar; +use crate::scalar_fn::fns::cast::Cast; use crate::scalar_fn::fns::get_item::GetItem; -use crate::scalar_fn::fns::mask::Mask; use crate::scalar_fn::fns::mask::MaskReduceAdaptor; use crate::validity::Validity; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ - ParentRuleSet::lift(&CastReduceAdaptor(Struct)), ParentRuleSet::lift(&StructGetItemRule), ParentRuleSet::lift(&MaskReduceAdaptor(Struct)), ParentRuleSet::lift(&SliceReduceAdaptor(Struct)), ParentRuleSet::lift(&TakeReduceAdaptor(Struct)), ]); -/// Push the cast into struct fields without execution. -/// -/// Supports schema evolution by allowing new nullable fields to be added during the cast, -/// filled with null values. For nullability changes, only handles the cheap path -/// (`try_cast_nullability`); when statistics computation is required to determine whether -/// the array contains invalid values, returns `Ok(None)` so [`CastKernel`] can run instead. -/// -/// [`CastKernel`]: crate::scalar_fn::fns::cast::CastKernel -impl CastReduce for Struct { - fn cast(array: ArrayView<'_, Struct>, dtype: &DType) -> VortexResult> { - let Some(target_fields) = dtype.as_struct_fields_opt() else { - return Ok(None); - }; - - let Some(validity) = array - .validity()? - .try_cast_nullability(dtype.nullability(), array.len())? - else { - return Ok(None); - }; - - let mut new_fields = Vec::with_capacity(target_fields.nfields()); - - for (target_name, target_dtype) in target_fields.names().iter().zip(target_fields.fields()) - { - match array.unmasked_field_by_name(target_name).ok() { - Some(field) => { - new_fields.push(field.cast(target_dtype)?); - } - None => { - // Not found - create NULL array (schema evolution) - vortex_ensure!( - target_dtype.is_nullable(), - "Cannot add non-nullable field '{}' during struct cast", - target_name - ); - new_fields.push( - ConstantArray::new(crate::scalar::Scalar::null(target_dtype), array.len()) - .into_array(), - ); - } - } - } +pub(crate) fn struct_cast_reduce_parent( + child: &ArrayRef, + parent: &ArrayRef, + _child_idx: usize, +) -> VortexResult> { + let Some(array) = child.as_opt::() else { + return Ok(None); + }; + let Some(parent) = ExactScalarFn::::try_match(parent) else { + return Ok(None); + }; + + if array.dtype() == parent.options { + return Ok(Some(array.array().clone())); + } + + reduce_struct_cast(array, parent) +} + +fn reduce_struct_cast( + array: ArrayView<'_, Struct>, + parent: ScalarFnArrayView<'_, Cast>, +) -> VortexResult> { + let Some(target_fields) = parent.options.as_struct_fields_opt() else { + return Ok(None); + }; + + let Some(validity) = array + .validity()? + .try_cast_nullability(parent.options.nullability(), array.len())? + else { + return Ok(None); + }; - Ok(Some( - unsafe { - StructArray::new_unchecked(new_fields, target_fields.clone(), array.len(), validity) + let mut new_fields = Vec::with_capacity(target_fields.nfields()); + + for (target_name, target_dtype) in target_fields.names().iter().zip(target_fields.fields()) { + match array.unmasked_field_by_name(target_name).ok() { + Some(field) => { + new_fields.push(field.cast(target_dtype)?); + } + None => { + // Not found - create NULL array (schema evolution) + vortex_ensure!( + target_dtype.is_nullable(), + "Cannot add non-nullable field '{}' during struct cast", + target_name + ); + new_fields + .push(ConstantArray::new(Scalar::null(target_dtype), array.len()).into_array()); } - .into_array(), - )) + } } + + Ok(Some( + unsafe { + StructArray::new_unchecked(new_fields, target_fields.clone(), array.len(), validity) + } + .into_array(), + )) } /// Rule to flatten get_item from struct by field name #[derive(Debug)] pub(crate) struct StructGetItemRule; + impl ArrayParentReduceRule for StructGetItemRule { type Parent = ExactScalarFn; @@ -121,17 +127,13 @@ impl ArrayParentReduceRule for StructGetItemRule { Validity::AllInvalid => { // If everything is invalid, the field is also all invalid Ok(Some( - ConstantArray::new( - crate::scalar::Scalar::null(field.dtype().clone()), - field.len(), - ) - .into_array(), + ConstantArray::new(Scalar::null(field.dtype().clone()), field.len()) + .into_array(), )) } Validity::Array(mask) => { // If the validity is an array, we need to combine it with the field's validity - Mask.try_new_array(field.len(), EmptyOptions, [field.clone(), mask]) - .map(Some) + field.clone().mask(mask).map(Some) } } } @@ -142,10 +144,14 @@ mod tests { use std::sync::LazyLock; use vortex_buffer::buffer; + use vortex_error::VortexResult; use vortex_session::VortexSession; + use crate::ArrayRef; use crate::IntoArray; - use crate::VortexSessionExecute; + use crate::array::ArrayPlugin; + use crate::arrays::ScalarFn; + use crate::arrays::Struct; use crate::arrays::StructArray; use crate::arrays::VarBinViewArray; use crate::arrays::struct_::StructArrayExt; @@ -157,12 +163,24 @@ mod tests { use crate::dtype::Nullability; use crate::dtype::PType; use crate::dtype::StructFields; + use crate::optimizer::ArrayOptimizer; + use crate::optimizer::kernels::ArrayKernels; + use crate::optimizer::kernels::ReduceParentFn; use crate::scalar::Scalar; - use crate::session::ArraySession; + use crate::scalar_fn::ScalarFnVTable; + use crate::scalar_fn::fns::cast::Cast; use crate::validity::Validity; static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); + LazyLock::new(|| VortexSession::empty().with::()); + + fn no_struct_cast_plugin( + _child: &ArrayRef, + _parent: &ArrayRef, + _child_idx: usize, + ) -> VortexResult> { + Ok(None) + } #[test] fn test_struct_cast_field_reorder() { @@ -209,8 +227,67 @@ mod tests { ); } + #[test] + fn struct_cast_is_not_static_parent_rule() { + let source = StructArray::try_new( + FieldNames::from(["a", "b"]), + vec![ + VarBinViewArray::from_iter_str(["A"]).into_array(), + VarBinViewArray::from_iter_str(["B"]).into_array(), + ], + 1, + Validity::NonNullable, + ) + .unwrap() + .into_array(); + + let utf8_null = DType::Utf8(Nullability::Nullable); + let target = DType::Struct( + StructFields::new(FieldNames::from(["c", "b", "a"]), vec![utf8_null; 3]), + Nullability::NonNullable, + ); + + let cast = source.cast(target).unwrap(); + let optimized = cast.optimize().unwrap(); + assert!(optimized.is::()); + } + + #[test] + fn struct_cast_plugin_can_be_overridden() { + let source = StructArray::try_new( + FieldNames::from(["a", "b"]), + vec![ + VarBinViewArray::from_iter_str(["A"]).into_array(), + VarBinViewArray::from_iter_str(["B"]).into_array(), + ], + 1, + Validity::NonNullable, + ) + .unwrap() + .into_array(); + + let utf8_null = DType::Utf8(Nullability::Nullable); + let target = DType::Struct( + StructFields::new(FieldNames::from(["c", "b", "a"]), vec![utf8_null; 3]), + Nullability::NonNullable, + ); + + let cast = source.cast(target).unwrap(); + let kernels = ArrayKernels::empty(); + kernels.register_reduce_parent( + Cast.id(), + Struct.id(), + &[no_struct_cast_plugin as ReduceParentFn], + ); + let session = VortexSession::empty().with_some(kernels); + + let optimized = cast.optimize_ctx(&session).unwrap(); + assert!(optimized.is::()); + } + /// Regression test: casting a struct to a non-struct DType must not panic. Previously, - /// `StructCastPushDownRule` called `as_struct_fields()` which panics on non-struct types. + /// the Struct/Cast reduce-parent rewrite called `as_struct_fields()` which panics on + /// non-struct types. #[test] fn cast_struct_to_non_struct_does_not_panic() { let source = StructArray::try_new( @@ -222,10 +299,12 @@ mod tests { .unwrap(); // Casting a struct to a primitive type should not panic. Before the fix, - // `StructCastPushDownRule` would panic via `as_struct_fields()` on the non-struct target. + // the reduce-parent rewrite would panic via `as_struct_fields()` on the non-struct target. let result = source .into_array() - .cast(DType::Primitive(PType::I32, Nullability::NonNullable)); + .cast(DType::Primitive(PType::I32, Nullability::NonNullable)) + .and_then(|arr| arr.optimize_ctx(&SESSION)); + // Whether this errors or succeeds depends on execution, but the key invariant is that the // optimizer rule does not panic. if let Ok(arr) = &result { @@ -336,6 +415,10 @@ mod tests { Nullability::NonNullable, ); - assert!(source.into_array().cast(target).is_err()); + let arr = source + .into_array() + .cast(target) + .and_then(|arr| arr.optimize_ctx(&SESSION)); + assert!(arr.is_err()); } } diff --git a/vortex-array/src/arrays/struct_/vtable/kernel.rs b/vortex-array/src/arrays/struct_/vtable/kernel.rs index 1c5c9f3db3a..eac7158921d 100644 --- a/vortex-array/src/arrays/struct_/vtable/kernel.rs +++ b/vortex-array/src/arrays/struct_/vtable/kernel.rs @@ -3,10 +3,7 @@ use crate::arrays::Struct; use crate::kernel::ParentKernelSet; -use crate::scalar_fn::fns::cast::CastExecuteAdaptor; use crate::scalar_fn::fns::zip::ZipExecuteAdaptor; -pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ - ParentKernelSet::lift(&CastExecuteAdaptor(Struct)), - ParentKernelSet::lift(&ZipExecuteAdaptor(Struct)), -]); +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&ZipExecuteAdaptor(Struct))]); diff --git a/vortex-array/src/executor.rs b/vortex-array/src/executor.rs index 1ca27bc1de5..49315babdf0 100644 --- a/vortex-array/src/executor.rs +++ b/vortex-array/src/executor.rs @@ -21,12 +21,14 @@ use std::env::VarError; use std::fmt; use std::fmt::Display; +use std::sync::Arc; use std::sync::LazyLock; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_panic; +use vortex_session::SessionExt; use vortex_session::VortexSession; use crate::AnyCanonical; @@ -38,6 +40,8 @@ use crate::matcher::Matcher; use crate::memory::HostAllocatorRef; use crate::memory::MemorySessionExt; use crate::optimizer::ArrayOptimizer; +use crate::optimizer::kernels::ArrayKernels; +use crate::optimizer::kernels::ExecuteParentFn; /// Returns the maximum number of iterations to attempt when executing an array before giving up and returning /// an error, can be by the `VORTEX_MAX_ITERATIONS` env variables, otherwise defaults to 128. @@ -360,7 +364,7 @@ impl Executable for ArrayRef { let Some(child) = slot else { continue; }; - if let Some(executed_parent) = child.execute_parent(&array, slot_idx, ctx)? { + if let Some(executed_parent) = execute_parent_for_child(&array, child, slot_idx, ctx)? { ctx.log(format_args!( "execute_parent: slot[{}]({}) rewrote {} -> {}", slot_idx, @@ -402,13 +406,44 @@ fn execute_step(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult Option> { + ctx.session() + .get_opt::() + .and_then(|s| s.find_execute_parent(parent.encoding_id(), child.encoding_id())) +} + +fn execute_parent_for_child( + parent: &ArrayRef, + child: &ArrayRef, + slot_idx: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + if let Some(plugins) = plugin_execute_parent(ctx, parent, child) { + for plugin in plugins.as_ref() { + if let Some(result) = plugin(child, parent, slot_idx, ctx)? { + return Ok(Some(result)); + } + } + } + + child.execute_parent(parent, slot_idx, ctx) +} + /// Try execute_parent on each occupied slot of the array. fn try_execute_parent(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult> { for (slot_idx, slot) in array.slots().iter().enumerate() { let Some(child) = slot else { continue; }; - if let Some(result) = child.execute_parent(array, slot_idx, ctx)? { + if let Some(result) = execute_parent_for_child(array, child, slot_idx, ctx)? { result.statistics().inherit_from(array.statistics()); return Ok(Some(result)); } diff --git a/vortex-array/src/optimizer/kernels.rs b/vortex-array/src/optimizer/kernels.rs index a31192f8016..d38bc9402d1 100644 --- a/vortex-array/src/optimizer/kernels.rs +++ b/vortex-array/src/optimizer/kernels.rs @@ -3,22 +3,30 @@ //! Session-scoped registry for optimizer kernels. //! -//! [`ArrayKernels`] stores function pointers that participate in array optimization without -//! adding rules to an encoding vtable. The optimizer currently consults it for parent-reduce -//! rewrites before the child encoding's static `PARENT_RULES`. A registered function can -//! therefore add a rule for an extension encoding or take precedence over a built-in rule. +//! [`ArrayKernels`] stores function pointers that participate in array optimization and execution +//! without adding rules or kernels to an encoding vtable. The optimizer consults it for +//! parent-reduce rewrites before the child encoding's static `PARENT_RULES`, and the executor +//! consults it for parent execution before the child encoding's static parent kernels. A +//! registered function can therefore add support for an extension encoding or take precedence over +//! a built-in rule or kernel. When several functions are registered for the same key and kind, +//! they are tried in registration order until one applies. //! -//! Kernel entries are addressed by `(outer_id, child_id, kind)`. For parent-reduce kernels, -//! `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is the -//! child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the parent -//! id is the scalar function id. +//! Kernel entries are addressed by `(outer_id, child_id)`. For parent-reduce and execute-parent +//! kernels, `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is +//! the child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the +//! parent id is the scalar function id. //! -//! Sessions created by the top-level `vortex` crate install an empty registry by default. Other -//! sessions can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely -//! on [`ArrayKernelsExt::kernels`] to insert the default value. +//! Because registered functions have different signatures for each kernel kind, the registry +//! maintains one storage map per function type rather than a single type-erased map. +//! +//! Sessions created by the top-level `vortex` crate install the default registry. Other sessions +//! can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely on +//! [`ArrayKernelsExt::kernels`] to insert the default value. use std::any::Any; +use std::borrow::Borrow; use std::hash::BuildHasher; +use std::hash::Hash; use std::sync::Arc; use std::sync::LazyLock; @@ -32,8 +40,15 @@ use vortex_utils::aliases::DefaultHashBuilder; use vortex_utils::aliases::hash_map::HashMap; use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::array::VTable; +use crate::arrays::Struct; +use crate::arrays::struct_::compute::cast::struct_cast_execute_parent; +use crate::arrays::struct_::compute::rules::struct_cast_reduce_parent; +use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::fns::cast::Cast; -/// Shared hasher used to combine `(outer, child, FnKind)` tuples into [`FnRegistry`] keys. +/// Shared hasher used to combine `(outer, child)` tuples into registry keys. static FN_HASHER: LazyLock = LazyLock::new(DefaultHashBuilder::default); /// Function pointer for a plugin-provided parent-reduce rewrite. @@ -47,59 +62,164 @@ static FN_HASHER: LazyLock = LazyLock::new(DefaultHashBuilde pub type ReduceParentFn = fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult>; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[repr(transparent)] +struct ReduceParentFnId(u64); + +impl From for ReduceParentFnId { + fn from(id: u64) -> Self { + Self(id) + } +} + +impl Borrow for ReduceParentFnId { + fn borrow(&self) -> &u64 { + &self.0 + } +} + +/// Function pointer for a plugin-provided parent execution. +/// +/// The executor calls this with the matched `child`, its `parent`, the slot index where the child +/// appears, and the current [`ExecutionCtx`]. Return `Ok(Some(new_parent))` to replace the parent +/// with an executed result, or `Ok(None)` when the kernel does not apply. +/// +/// Implementations must preserve the parent's logical length and dtype, matching the invariant +/// required of static `execute_parent` kernels. +pub type ExecuteParentFn = fn( + child: &ArrayRef, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[repr(transparent)] +struct ExecuteParentFnId(u64); + +impl From for ExecuteParentFnId { + fn from(id: u64) -> Self { + Self(id) + } +} + +impl Borrow for ExecuteParentFnId { + fn borrow(&self) -> &u64 { + &self.0 + } +} + /// Session-scoped registry of optimizer kernel functions. -#[derive(Debug, Default)] +/// +/// Each kernel kind has its own storage map, keyed by `(outer_id, child_id)`. Registering +/// functions for an existing key appends them to that key's ordered list. +#[derive(Debug)] pub struct ArrayKernels { - reduce_parent: ArcSwap>>, + reduce_parent: ArcSwap>>, + execute_parent: ArcSwap>>, +} + +impl Default for ArrayKernels { + fn default() -> ArrayKernels { + let this = Self::empty(); + this.register_builtin_reduce_parent(); + this.register_builtin_execute_parent(); + this + } } impl ArrayKernels { /// Create an empty [`ArrayKernels`] with no kernels registered. pub fn empty() -> Self { - Self::default() + Self { + reduce_parent: ArcSwap::from_pointee(HashMap::default()), + execute_parent: ArcSwap::from_pointee(HashMap::default()), + } + } + + fn register_builtin_reduce_parent(&self) { + self.register_reduce_parent( + Cast.id(), + Struct.id(), + &[struct_cast_reduce_parent as ReduceParentFn], + ); + } + + fn register_builtin_execute_parent(&self) { + self.register_execute_parent( + Cast.id(), + Struct.id(), + &[struct_cast_execute_parent as ExecuteParentFn], + ); } - /// Register a [`ReduceParentFn`] for `(outer, child)`. + /// Register [`ReduceParentFn`]s for `(parent, child)`. /// - /// The optimizer will invoke `f` when it sees a parent with encoding id `outer` holding a - /// child with encoding id `child` during a `reduce_parent` step, before trying the child - /// encoding's static `PARENT_RULES`. `outer` is usually the parent array's encoding id. For - /// `ScalarFnArray`, it is the scalar function id, for example `Cast.id()`. + /// The optimizer invokes these functions in registration order when it sees a parent with + /// encoding id `parent` holding a child with encoding id `child` during a `reduce_parent` + /// step, before trying the child encoding's static `PARENT_RULES`. `parent` is usually the + /// parent array's encoding id. For `ScalarFnArray`, it is the scalar function id, for example + /// `Cast.id()`. /// - /// Replaces any function already registered for the same pair. - pub fn register_reduce_parent>( - &self, - parent: Id, - child: Id, - fns: I, - ) { - let registry = self.reduce_parent.load(); - let id = self.hash_fn_ids(parent, child); - let mut owned_registry = registry.as_ref().clone(); - if let Some(existing) = owned_registry.remove(&id) { - owned_registry.insert(id, existing.as_ref().iter().cloned().chain(fns).collect()); - } else { - owned_registry.insert(id, fns.into_iter().collect()); - } - self.reduce_parent.store(Arc::new(owned_registry)); + /// If functions have already been registered for the same pair, these functions are appended + /// after them. + pub fn register_reduce_parent(&self, parent: Id, child: Id, fns: &[ReduceParentFn]) { + self.reduce_parent.rcu(move |registry| { + update_fns(registry.as_ref().clone(), hash_fn_id(parent, child), fns) + }); } - /// Look up the [`ReduceParentFn`] registered for `(outer, child)`. + /// Look up the [`ReduceParentFn`]s registered for `(parent, child)`. /// /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the - /// function. + /// functions. pub fn find_reduce_parent(&self, parent: Id, child: Id) -> Option> { - let id = self.hash_fn_ids(parent, child); - let map = self.reduce_parent.load(); - let entry = map.get(&id)?; - Some(Arc::clone(entry)) + let id = hash_fn_id(parent, child); + self.reduce_parent.load().get(&id).cloned() + } + + /// Register [`ExecuteParentFn`]s for `(parent, child)`. + /// + /// The executor invokes these functions in registration order when it sees a parent with + /// encoding id `parent` holding a child with encoding id `child` during a parent execution + /// step, before trying the child encoding's static parent kernels. + /// + /// If functions have already been registered for the same pair, these functions are appended + /// after them. + pub fn register_execute_parent(&self, parent: Id, child: Id, fns: &[ExecuteParentFn]) { + self.execute_parent.rcu(move |registry| { + update_fns(registry.as_ref().clone(), hash_fn_id(parent, child), fns) + }); + } + + /// Look up the [`ExecuteParentFn`]s registered for `(parent, child)`. + /// + /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the + /// functions. + pub fn find_execute_parent(&self, parent: Id, child: Id) -> Option> { + let id = hash_fn_id(parent, child); + self.execute_parent.load().get(&id).cloned() } +} + +fn hash_fn_id(parent: Id, child: Id) -> u64 { + FN_HASHER.hash_one((parent, child)) +} - /// Combine a typed kernel id tuple into the `u64` key expected by the underlying - /// [`FnRegistry`]. All typed helpers use this path so registration and lookup agree. - fn hash_fn_ids(&self, parent: Id, child: Id) -> u64 { - FN_HASHER.hash_one((parent, child)) +fn update_fns + Eq + Hash + From>( + mut existing: HashMap>, + id: u64, + fns: &[F], +) -> HashMap> { + if let Some(existing_fns) = existing.remove(&id) { + existing.insert( + id.into(), + existing_fns.as_ref().iter().chain(fns).cloned().collect(), + ); + } else { + existing.insert(id.into(), fns.into()); } + existing } impl SessionVar for ArrayKernels { diff --git a/vortex-array/src/optimizer/mod.rs b/vortex-array/src/optimizer/mod.rs index 70a041bcc18..aab974cd29e 100644 --- a/vortex-array/src/optimizer/mod.rs +++ b/vortex-array/src/optimizer/mod.rs @@ -15,8 +15,6 @@ //! - [`ArrayOptimizer::optimize_recursive`] applies the session-aware optimizer to the root and //! every descendant. -use std::sync::Arc; - use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_session::SessionExt; @@ -24,7 +22,6 @@ use vortex_session::VortexSession; use crate::ArrayRef; use crate::optimizer::kernels::ArrayKernels; -use crate::optimizer::kernels::ReduceParentFn; pub mod kernels; pub mod rules; @@ -65,26 +62,13 @@ impl ArrayOptimizer for ArrayRef { } } -/// Resolve a session-registered [`ReduceParentFn`] for the `(parent, child)` pair. -/// -/// The returned [`Arc`] is owned so the caller can drop the [`ArrayKernels`] borrow before -/// invoking the function. -fn plugin_reduce_parent( - session: &VortexSession, - parent: &ArrayRef, - child: &ArrayRef, -) -> Option> { - session - .get_opt::() - .and_then(|s| s.find_reduce_parent(parent.encoding_id(), child.encoding_id())) -} - fn try_optimize( array: &ArrayRef, session: Option<&VortexSession>, ) -> VortexResult> { let mut current_array = array.clone(); let mut any_optimizations = false; + let array_kernels = session.and_then(|s| s.get_opt::()); // Apply reduction rules to the current array until no more rules apply. let mut loop_counter = 0; @@ -106,8 +90,9 @@ fn try_optimize( let Some(child) = slot else { continue }; // Session kernels take precedence over the child encoding's static PARENT_RULES. - if let Some(session) = session - && let Some(plugins) = plugin_reduce_parent(session, ¤t_array, child) + if let Some(kernels) = &array_kernels + && let Some(plugins) = + kernels.find_reduce_parent(current_array.encoding_id(), child.encoding_id()) { for plugin in plugins.as_ref() { if let Some(new_array) = plugin(child, ¤t_array, slot_idx)? { diff --git a/vortex-array/src/scalar_fn/fns/cast/mod.rs b/vortex-array/src/scalar_fn/fns/cast/mod.rs index cbab9615e97..a007788acb2 100644 --- a/vortex-array/src/scalar_fn/fns/cast/mod.rs +++ b/vortex-array/src/scalar_fn/fns/cast/mod.rs @@ -27,8 +27,8 @@ use crate::arrays::FixedSizeList; use crate::arrays::ListView; use crate::arrays::Null; use crate::arrays::Primitive; -use crate::arrays::Struct; use crate::arrays::VarBinView; +use crate::arrays::struct_::compute::cast::struct_cast; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::expr::StatsCatalog; @@ -218,7 +218,7 @@ fn cast_canonical( CanonicalView::VarBinView(a) => ::cast(a, dtype, ctx), CanonicalView::List(a) => ::cast(a, dtype, ctx), CanonicalView::FixedSizeList(a) => ::cast(a, dtype, ctx), - CanonicalView::Struct(a) => ::cast(a, dtype, ctx), + CanonicalView::Struct(a) => struct_cast(a, dtype, ctx), CanonicalView::Extension(a) => ::cast(a, dtype), CanonicalView::Variant(_) => { vortex_bail!("Variant arrays don't support casting")