diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index 629a7ebe3e2..e81b59dc9a4 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -32,7 +32,7 @@ impl CastReduce for ByteBool { let Some(new_validity) = array .validity()? - .trivial_cast_nullability(dtype.nullability(), array.len())? + .trivially_cast_nullability(dtype.nullability(), array.len())? else { return Ok(None); }; diff --git a/encodings/fastlanes/src/bitpacking/compute/cast.rs b/encodings/fastlanes/src/bitpacking/compute/cast.rs index 86bd701e2ee..3cb810e0442 100644 --- a/encodings/fastlanes/src/bitpacking/compute/cast.rs +++ b/encodings/fastlanes/src/bitpacking/compute/cast.rs @@ -42,7 +42,7 @@ impl CastReduce for BitPacked { } let Some(new_validity) = array .validity()? - .trivial_cast_nullability(dtype.nullability(), array.len())? + .trivially_cast_nullability(dtype.nullability(), array.len())? else { return Ok(None); }; diff --git a/encodings/fsst/src/compute/cast.rs b/encodings/fsst/src/compute/cast.rs index cd8c5e3a9c3..a1c96363ba0 100644 --- a/encodings/fsst/src/compute/cast.rs +++ b/encodings/fsst/src/compute/cast.rs @@ -50,7 +50,7 @@ impl CastReduce for FSST { let codes = array.codes(); let Some(new_codes_validity) = codes .validity()? - .trivial_cast_nullability(dtype.nullability(), codes.len())? + .trivially_cast_nullability(dtype.nullability(), codes.len())? else { return Ok(None); }; diff --git a/encodings/pco/src/compute/cast.rs b/encodings/pco/src/compute/cast.rs index 7cb1a28bb9c..b35526c1016 100644 --- a/encodings/pco/src/compute/cast.rs +++ b/encodings/pco/src/compute/cast.rs @@ -31,7 +31,7 @@ impl CastReduce for Pco { let unsliced_validity = child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability()); let Some(new_validity) = - unsliced_validity.trivial_cast_nullability(dtype.nullability(), array.len())? + unsliced_validity.trivially_cast_nullability(dtype.nullability(), array.len())? else { return Ok(None); }; diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 0a8bbfaa43a..e88a5f6792c 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -4662,14 +4662,6 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::Struct pub fn vortex_array::arrays::Struct::slice(vortex_array::ArrayView<'_, Self>, core::ops::range::Range) -> vortex_error::VortexResult> -impl vortex_array::scalar_fn::fns::cast::CastKernel for vortex_array::arrays::Struct - -pub fn vortex_array::arrays::Struct::cast(vortex_array::ArrayView<'_, vortex_array::arrays::Struct>, &vortex_array::dtype::DType, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> - -impl vortex_array::scalar_fn::fns::cast::CastReduce for vortex_array::arrays::Struct - -pub fn vortex_array::arrays::Struct::cast(vortex_array::ArrayView<'_, vortex_array::arrays::Struct>, &vortex_array::dtype::DType) -> vortex_error::VortexResult> - impl vortex_array::scalar_fn::fns::mask::MaskReduce for vortex_array::arrays::Struct pub fn vortex_array::arrays::Struct::mask(vortex_array::ArrayView<'_, vortex_array::arrays::Struct>, &vortex_array::ArrayRef) -> vortex_error::VortexResult> @@ -6782,14 +6774,6 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::Struct pub fn vortex_array::arrays::Struct::slice(vortex_array::ArrayView<'_, Self>, core::ops::range::Range) -> vortex_error::VortexResult> -impl vortex_array::scalar_fn::fns::cast::CastKernel for vortex_array::arrays::Struct - -pub fn vortex_array::arrays::Struct::cast(vortex_array::ArrayView<'_, vortex_array::arrays::Struct>, &vortex_array::dtype::DType, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> - -impl vortex_array::scalar_fn::fns::cast::CastReduce for vortex_array::arrays::Struct - -pub fn vortex_array::arrays::Struct::cast(vortex_array::ArrayView<'_, vortex_array::arrays::Struct>, &vortex_array::dtype::DType) -> vortex_error::VortexResult> - impl vortex_array::scalar_fn::fns::mask::MaskReduce for vortex_array::arrays::Struct pub fn vortex_array::arrays::Struct::mask(vortex_array::ArrayView<'_, vortex_array::arrays::Struct>, &vortex_array::ArrayRef) -> vortex_error::VortexResult> @@ -13586,9 +13570,13 @@ impl vortex_array::optimizer::kernels::ArrayKernels pub fn vortex_array::optimizer::kernels::ArrayKernels::empty() -> Self +pub fn vortex_array::optimizer::kernels::ArrayKernels::find_execute_parent(&self, vortex_session::registry::Id, vortex_session::registry::Id) -> core::option::Option> + pub fn vortex_array::optimizer::kernels::ArrayKernels::find_reduce_parent(&self, vortex_session::registry::Id, vortex_session::registry::Id) -> core::option::Option> -pub fn vortex_array::optimizer::kernels::ArrayKernels::register_reduce_parent>(&self, vortex_session::registry::Id, vortex_session::registry::Id, I) +pub fn vortex_array::optimizer::kernels::ArrayKernels::register_execute_parent(&self, vortex_session::registry::Id, vortex_session::registry::Id, &[vortex_array::optimizer::kernels::ExecuteParentFn]) + +pub fn vortex_array::optimizer::kernels::ArrayKernels::register_reduce_parent(&self, vortex_session::registry::Id, vortex_session::registry::Id, &[vortex_array::optimizer::kernels::ReduceParentFn]) impl core::default::Default for vortex_array::optimizer::kernels::ArrayKernels @@ -13612,6 +13600,8 @@ impl vortex_array::optimizer::kernels::ArrayKerne pub fn S::kernels(&self) -> vortex_session::Ref<'_, vortex_array::optimizer::kernels::ArrayKernels> +pub type vortex_array::optimizer::kernels::ExecuteParentFn = fn(&vortex_array::ArrayRef, &vortex_array::ArrayRef, usize, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub type vortex_array::optimizer::kernels::ReduceParentFn = fn(&vortex_array::ArrayRef, &vortex_array::ArrayRef, usize) -> vortex_error::VortexResult> pub mod vortex_array::optimizer::rules @@ -16246,10 +16236,6 @@ impl vortex_array::scalar_fn::fns::cast::CastKernel for vortex_array::arrays::Pr pub fn vortex_array::arrays::Primitive::cast(vortex_array::ArrayView<'_, vortex_array::arrays::Primitive>, &vortex_array::dtype::DType, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::scalar_fn::fns::cast::CastKernel for vortex_array::arrays::Struct - -pub fn vortex_array::arrays::Struct::cast(vortex_array::ArrayView<'_, vortex_array::arrays::Struct>, &vortex_array::dtype::DType, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> - impl vortex_array::scalar_fn::fns::cast::CastKernel for vortex_array::arrays::VarBin pub fn vortex_array::arrays::VarBin::cast(vortex_array::ArrayView<'_, vortex_array::arrays::VarBin>, &vortex_array::dtype::DType, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -16298,10 +16284,6 @@ impl vortex_array::scalar_fn::fns::cast::CastReduce for vortex_array::arrays::Pr pub fn vortex_array::arrays::Primitive::cast(vortex_array::ArrayView<'_, vortex_array::arrays::Primitive>, &vortex_array::dtype::DType) -> vortex_error::VortexResult> -impl vortex_array::scalar_fn::fns::cast::CastReduce for vortex_array::arrays::Struct - -pub fn vortex_array::arrays::Struct::cast(vortex_array::ArrayView<'_, vortex_array::arrays::Struct>, &vortex_array::dtype::DType) -> vortex_error::VortexResult> - impl vortex_array::scalar_fn::fns::cast::CastReduce for vortex_array::arrays::VarBin pub fn vortex_array::arrays::VarBin::cast(vortex_array::ArrayView<'_, vortex_array::arrays::VarBin>, &vortex_array::dtype::DType) -> vortex_error::VortexResult> @@ -19582,10 +19564,10 @@ pub fn vortex_array::validity::Validity::to_array(&self, usize) -> vortex_array: pub fn vortex_array::validity::Validity::to_mask(&self, usize, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_array::validity::Validity::trivial_cast_nullability(self, vortex_array::dtype::Nullability, usize) -> vortex_error::VortexResult> - pub fn vortex_array::validity::Validity::trivial_into_non_nullable(self, usize) -> vortex_error::VortexResult> +pub fn vortex_array::validity::Validity::trivially_cast_nullability(self, vortex_array::dtype::Nullability, usize) -> vortex_error::VortexResult> + pub fn vortex_array::validity::Validity::union_nullability(self, vortex_array::dtype::Nullability) -> Self impl vortex_array::validity::Validity diff --git a/vortex-array/src/arrays/bool/compute/cast.rs b/vortex-array/src/arrays/bool/compute/cast.rs index 6bfaa03b9db..0855af21f25 100644 --- a/vortex-array/src/arrays/bool/compute/cast.rs +++ b/vortex-array/src/arrays/bool/compute/cast.rs @@ -22,7 +22,7 @@ impl CastReduce for Bool { let Some(new_validity) = array .validity()? - .trivial_cast_nullability(dtype.nullability(), array.len())? + .trivially_cast_nullability(dtype.nullability(), array.len())? else { return Ok(None); }; diff --git a/vortex-array/src/arrays/decimal/compute/cast.rs b/vortex-array/src/arrays/decimal/compute/cast.rs index 34ccc27d094..432313d3cb6 100644 --- a/vortex-array/src/arrays/decimal/compute/cast.rs +++ b/vortex-array/src/arrays/decimal/compute/cast.rs @@ -40,7 +40,7 @@ impl CastReduce for Decimal { let Some(new_validity) = array .validity()? - .trivial_cast_nullability(*to_nullability, array.len())? + .trivially_cast_nullability(*to_nullability, array.len())? else { return Ok(None); }; diff --git a/vortex-array/src/arrays/fixed_size_list/compute/cast.rs b/vortex-array/src/arrays/fixed_size_list/compute/cast.rs index a54fe5c7c4d..4ac3bd1938f 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/cast.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/cast.rs @@ -39,7 +39,7 @@ impl CastReduce for FixedSizeList { let Some(validity) = array .validity()? - .trivial_cast_nullability(dtype.nullability(), array.len())? + .trivially_cast_nullability(dtype.nullability(), array.len())? else { return Ok(None); }; diff --git a/vortex-array/src/arrays/list/compute/cast.rs b/vortex-array/src/arrays/list/compute/cast.rs index e9b29b9b145..e6d71f6e80c 100644 --- a/vortex-array/src/arrays/list/compute/cast.rs +++ b/vortex-array/src/arrays/list/compute/cast.rs @@ -23,7 +23,7 @@ impl CastReduce for List { let Some(validity) = array .validity()? - .trivial_cast_nullability(dtype.nullability(), array.len())? + .trivially_cast_nullability(dtype.nullability(), array.len())? else { return Ok(None); }; diff --git a/vortex-array/src/arrays/listview/compute/cast.rs b/vortex-array/src/arrays/listview/compute/cast.rs index 473f76b12ee..7f4c972940f 100644 --- a/vortex-array/src/arrays/listview/compute/cast.rs +++ b/vortex-array/src/arrays/listview/compute/cast.rs @@ -42,7 +42,7 @@ impl CastReduce for ListView { }; let Some(validity) = array .validity()? - .trivial_cast_nullability(dtype.nullability(), array.len())? + .trivially_cast_nullability(dtype.nullability(), array.len())? else { return Ok(None); }; diff --git a/vortex-array/src/arrays/primitive/compute/cast.rs b/vortex-array/src/arrays/primitive/compute/cast.rs index 92140bf9b8b..68b3da04cca 100644 --- a/vortex-array/src/arrays/primitive/compute/cast.rs +++ b/vortex-array/src/arrays/primitive/compute/cast.rs @@ -36,7 +36,7 @@ impl CastReduce for Primitive { let Some(new_validity) = array .validity()? - .trivial_cast_nullability(*new_nullability, array.len())? + .trivially_cast_nullability(*new_nullability, array.len())? else { return Ok(None); }; diff --git a/vortex-array/src/arrays/struct_/compute/cast.rs b/vortex-array/src/arrays/struct_/compute/cast.rs index ddb3cbeadf2..51f670546f9 100644 --- a/vortex-array/src/arrays/struct_/compute/cast.rs +++ b/vortex-array/src/arrays/struct_/compute/cast.rs @@ -6,97 +6,107 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use crate::ArrayRef; +use crate::ArrayView; use crate::ExecutionCtx; use crate::IntoArray; -use crate::array::ArrayView; use crate::arrays::ConstantArray; use crate::arrays::Struct; use crate::arrays::StructArray; +use crate::arrays::scalar_fn::ExactScalarFn; use crate::arrays::struct_::StructArrayExt; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; +use crate::matcher::Matcher; use crate::scalar::Scalar; -use crate::scalar_fn::fns::cast::CastKernel; - -impl CastKernel for Struct { - fn cast( - array: ArrayView<'_, Struct>, - dtype: &DType, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - let Some(target_sdtype) = dtype.as_struct_fields_opt() else { - return Ok(None); - }; - - let source_sdtype = array.struct_fields(); +use crate::scalar_fn::fns::cast::Cast; + +pub(crate) fn struct_cast_execute_parent( + child: &ArrayRef, + parent: &ArrayRef, + _child_idx: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + let Some(array) = child.as_opt::() else { + return Ok(None); + }; + let Some(parent) = ExactScalarFn::::try_match(parent) else { + return Ok(None); + }; + + let dtype = parent.options; + if array.dtype() == parent.options { + return Ok(Some(array.array().clone())); + } - let fields_match_order = target_sdtype.nfields() == source_sdtype.nfields() - && target_sdtype - .names() - .iter() - .zip(source_sdtype.names().iter()) - .all(|(f1, f2)| f1 == f2); + struct_cast(array, dtype, ctx) +} - let mut cast_fields = Vec::with_capacity(target_sdtype.nfields()); - if fields_match_order { - for (field, target_type) in array.iter_unmasked_fields().zip_eq(target_sdtype.fields()) - { - let cast_field = field.cast(target_type)?; - cast_fields.push(cast_field); +pub(crate) fn struct_cast( + array: ArrayView, + dtype: &DType, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + let Some(target_sdtype) = dtype.as_struct_fields_opt() else { + return Ok(None); + }; + + let source_sdtype = array.struct_fields(); + + let mut cast_fields = Vec::with_capacity(target_sdtype.nfields()); + // Re-order, handle fields by value instead. + for (target_name, target_type) in target_sdtype.names().iter().zip_eq(target_sdtype.fields()) { + match source_sdtype.find(target_name) { + None => { + // No source field with this name => evolve the schema compatibly. + // If the field is nullable, we add a new ConstantArray field with the type. + vortex_ensure!( + target_type.is_nullable(), + "CAST for struct only supports added nullable fields" + ); + + cast_fields + .push(ConstantArray::new(Scalar::null(target_type), array.len()).into_array()); } - } else { - // Re-order, handle fields by value instead. - for (target_name, target_type) in - target_sdtype.names().iter().zip_eq(target_sdtype.fields()) - { - match source_sdtype.find(target_name) { - None => { - // No source field with this name => evolve the schema compatibly. - // If the field is nullable, we add a new ConstantArray field with the type. - vortex_ensure!( - target_type.is_nullable(), - "CAST for struct only supports added nullable fields" - ); - - cast_fields.push( - ConstantArray::new(Scalar::null(target_type), array.len()).into_array(), - ); - } - Some(src_field_idx) => { - // Field exists in source field. Cast it to the target type. - let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?; - cast_fields.push(cast_field); - } - } + Some(src_field_idx) => { + // Field exists in source field. Cast it to the target type. + let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?; + cast_fields.push(cast_field); } } + } - let validity = array - .validity()? - .cast_nullability(dtype.nullability(), array.len(), ctx)?; + let validity = array + .validity()? + .cast_nullability(dtype.nullability(), array.len(), ctx)?; - StructArray::try_new( - target_sdtype.names().clone(), - cast_fields, - array.len(), - validity, - ) - .map(|a| Some(a.into_array())) - } + Ok(Some( + unsafe { + StructArray::new_unchecked(cast_fields, target_sdtype.clone(), array.len(), validity) + } + .into_array(), + )) } #[cfg(test)] mod tests { + use std::sync::LazyLock; + use rstest::rstest; use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + use crate::ArrayRef; + use crate::ExecutionCtx; use crate::IntoArray; - #[expect(deprecated)] - use crate::ToCanonical as _; + use crate::VortexSessionExecute; + use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; use crate::arrays::VarBinArray; + use crate::arrays::scalar_fn::ScalarFnFactoryExt; use crate::arrays::struct_::StructArrayExt; + use crate::assert_arrays_eq; use crate::builtins::ArrayBuiltins; use crate::compute::conformance::cast::test_cast_conformance; use crate::dtype::DType; @@ -104,8 +114,41 @@ mod tests { use crate::dtype::FieldNames; use crate::dtype::Nullability; use crate::dtype::PType; + use crate::dtype::StructFields; + use crate::optimizer::kernels::ArrayKernels; + use crate::optimizer::kernels::ArrayKernelsExt; + use crate::optimizer::kernels::ExecuteParentFn; + use crate::scalar::Scalar; + use crate::scalar_fn::fns::cast::Cast; + use crate::session::ArraySession; use crate::validity::Validity; + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + fn null_struct_cast_execute_parent( + child: &ArrayRef, + parent: &ArrayRef, + _child_idx: usize, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let Some(target_fields) = parent.dtype().as_struct_fields_opt() else { + return Ok(None); + }; + let fields: Vec = target_fields + .fields() + .map(|dtype| ConstantArray::new(Scalar::null(dtype), child.len()).into_array()) + .collect(); + + StructArray::try_new( + target_fields.names().clone(), + fields, + child.len(), + Validity::from(parent.dtype().nullability()), + ) + .map(|array| Some(array.into_array())) + } + #[rstest] #[case(create_test_struct(false))] #[case(create_test_struct(true))] @@ -115,6 +158,64 @@ mod tests { test_cast_conformance(&array.into_array()); } + #[test] + fn struct_cast_execute_parent_is_not_static_kernel() { + let source = create_simple_struct().into_array(); + let target = DType::struct_( + [( + "value", + DType::Primitive(PType::I64, Nullability::NonNullable), + )], + Nullability::NonNullable, + ); + + let cast = Cast + .try_new_array(source.len(), target, [source.clone()]) + .unwrap(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + assert!(source.execute_parent(&cast, 0, &mut ctx).unwrap().is_none()); + } + + #[test] + fn struct_cast_execute_parent_uses_session_plugin() { + let source = StructArray::try_new( + FieldNames::from(["a"]), + vec![VarBinArray::from_vec(vec!["A"], DType::Utf8(Nullability::Nullable)).into_array()], + 1, + Validity::NonNullable, + ) + .unwrap() + .into_array(); + let child_id = source.encoding_id(); + + let utf8_null = DType::Utf8(Nullability::Nullable); + let target = DType::Struct( + StructFields::new(FieldNames::from(["b"]), vec![utf8_null.clone()]), + Nullability::NonNullable, + ); + + let cast = Cast + .try_new_array(source.len(), target.clone(), [source]) + .unwrap(); + let parent_id = cast.encoding_id(); + let session = VortexSession::empty().with::(); + session.kernels().register_execute_parent( + parent_id, + child_id, + &[null_struct_cast_execute_parent as ExecuteParentFn], + ); + let mut ctx = session.create_execution_ctx(); + + let result = cast.execute::(&mut ctx).unwrap(); + + assert_eq!(result.dtype(), &target); + assert_arrays_eq!( + result.unmasked_field_by_name("b").unwrap(), + ConstantArray::new(Scalar::null(utf8_null), 1) + ); + } + fn create_test_struct(nullable: bool) -> StructArray { let names = FieldNames::from(["a", "b"]); @@ -204,14 +305,17 @@ mod tests { let target_dtype = struct_array.dtype().as_nullable(); - let result = struct_array + let cast = struct_array .into_array() .cast(target_dtype.clone()) .unwrap(); - assert_eq!(result.dtype(), &target_dtype); - assert_eq!(result.len(), 3); - #[expect(deprecated)] - let nfields = result.to_struct().struct_fields().nfields(); + assert_eq!(cast.dtype(), &target_dtype); + assert_eq!(cast.len(), 3); + let nfields = cast + .execute::(&mut SESSION.create_execution_ctx()) + .unwrap() + .struct_fields() + .nfields(); assert_eq!(nfields, 2); } @@ -241,8 +345,11 @@ mod tests { .unwrap(); assert_eq!(result.dtype(), &target_dtype); assert_eq!(result.len(), 3); - #[expect(deprecated)] - let nfields = result.to_struct().struct_fields().nfields(); + let nfields = result + .execute::(&mut SESSION.create_execution_ctx()) + .unwrap() + .struct_fields() + .nfields(); assert_eq!(nfields, 3); } } diff --git a/vortex-array/src/arrays/struct_/compute/mod.rs b/vortex-array/src/arrays/struct_/compute/mod.rs index 9ddee9496aa..c30ef532008 100644 --- a/vortex-array/src/arrays/struct_/compute/mod.rs +++ b/vortex-array/src/arrays/struct_/compute/mod.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod cast; +pub(crate) mod cast; mod mask; pub(crate) mod rules; mod slice; diff --git a/vortex-array/src/arrays/struct_/compute/rules.rs b/vortex-array/src/arrays/struct_/compute/rules.rs index 6e95e3eaff8..304ce5e9b75 100644 --- a/vortex-array/src/arrays/struct_/compute/rules.rs +++ b/vortex-array/src/arrays/struct_/compute/rules.rs @@ -14,85 +14,91 @@ use crate::arrays::StructArray; use crate::arrays::dict::TakeReduceAdaptor; use crate::arrays::scalar_fn::ExactScalarFn; use crate::arrays::scalar_fn::ScalarFnArrayView; -use crate::arrays::scalar_fn::ScalarFnFactoryExt; use crate::arrays::slice::SliceReduceAdaptor; use crate::arrays::struct_::StructArrayExt; use crate::builtins::ArrayBuiltins; -use crate::dtype::DType; +use crate::matcher::Matcher; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; -use crate::scalar_fn::EmptyOptions; -use crate::scalar_fn::fns::cast::CastReduce; -use crate::scalar_fn::fns::cast::CastReduceAdaptor; +use crate::scalar::Scalar; +use crate::scalar_fn::fns::cast::Cast; use crate::scalar_fn::fns::get_item::GetItem; -use crate::scalar_fn::fns::mask::Mask; use crate::scalar_fn::fns::mask::MaskReduceAdaptor; use crate::validity::Validity; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ - ParentRuleSet::lift(&CastReduceAdaptor(Struct)), ParentRuleSet::lift(&StructGetItemRule), ParentRuleSet::lift(&MaskReduceAdaptor(Struct)), ParentRuleSet::lift(&SliceReduceAdaptor(Struct)), ParentRuleSet::lift(&TakeReduceAdaptor(Struct)), ]); -/// Push the cast into struct fields without execution. -/// -/// Supports schema evolution by allowing new nullable fields to be added during the cast, -/// filled with null values. For nullability changes, only handles the cheap path -/// (`try_cast_nullability`); when statistics computation is required to determine whether -/// the array contains invalid values, returns `Ok(None)` so [`CastKernel`] can run instead. -/// -/// [`CastKernel`]: crate::scalar_fn::fns::cast::CastKernel -impl CastReduce for Struct { - fn cast(array: ArrayView<'_, Struct>, dtype: &DType) -> VortexResult> { - let Some(target_fields) = dtype.as_struct_fields_opt() else { - return Ok(None); - }; - - let Some(validity) = array - .validity()? - .trivial_cast_nullability(dtype.nullability(), array.len())? - else { - return Ok(None); - }; - - let mut new_fields = Vec::with_capacity(target_fields.nfields()); - - for (target_name, target_dtype) in target_fields.names().iter().zip(target_fields.fields()) - { - match array.unmasked_field_by_name(target_name).ok() { - Some(field) => { - new_fields.push(field.cast(target_dtype)?); - } - None => { - // Not found - create NULL array (schema evolution) - vortex_ensure!( - target_dtype.is_nullable(), - "Cannot add non-nullable field '{}' during struct cast", - target_name - ); - new_fields.push( - ConstantArray::new(crate::scalar::Scalar::null(target_dtype), array.len()) - .into_array(), - ); - } - } - } +pub(crate) fn struct_cast_reduce_parent( + child: &ArrayRef, + parent: &ArrayRef, + _child_idx: usize, +) -> VortexResult> { + let Some(array) = child.as_opt::() else { + return Ok(None); + }; + let Some(parent) = ExactScalarFn::::try_match(parent) else { + return Ok(None); + }; + + if array.dtype() == parent.options { + return Ok(Some(array.array().clone())); + } + + reduce_struct_cast(array, parent) +} + +fn reduce_struct_cast( + array: ArrayView<'_, Struct>, + parent: ScalarFnArrayView<'_, Cast>, +) -> VortexResult> { + let Some(target_fields) = parent.options.as_struct_fields_opt() else { + return Ok(None); + }; + + let Some(validity) = array + .validity()? + .trivially_cast_nullability(parent.options.nullability(), array.len())? + else { + return Ok(None); + }; - Ok(Some( - unsafe { - StructArray::new_unchecked(new_fields, target_fields.clone(), array.len(), validity) + let mut new_fields = Vec::with_capacity(target_fields.nfields()); + + for (target_name, target_dtype) in target_fields.names().iter().zip(target_fields.fields()) { + match array.unmasked_field_by_name(target_name).ok() { + Some(field) => { + new_fields.push(field.cast(target_dtype)?); + } + None => { + // Not found - create NULL array (schema evolution) + vortex_ensure!( + target_dtype.is_nullable(), + "Cannot add non-nullable field '{}' during struct cast", + target_name + ); + new_fields + .push(ConstantArray::new(Scalar::null(target_dtype), array.len()).into_array()); } - .into_array(), - )) + } } + + Ok(Some( + unsafe { + StructArray::new_unchecked(new_fields, target_fields.clone(), array.len(), validity) + } + .into_array(), + )) } /// Rule to flatten get_item from struct by field name #[derive(Debug)] pub(crate) struct StructGetItemRule; + impl ArrayParentReduceRule for StructGetItemRule { type Parent = ExactScalarFn; @@ -121,17 +127,13 @@ impl ArrayParentReduceRule for StructGetItemRule { Validity::AllInvalid => { // If everything is invalid, the field is also all invalid Ok(Some( - ConstantArray::new( - crate::scalar::Scalar::null(field.dtype().clone()), - field.len(), - ) - .into_array(), + ConstantArray::new(Scalar::null(field.dtype().clone()), field.len()) + .into_array(), )) } Validity::Array(mask) => { // If the validity is an array, we need to combine it with the field's validity - Mask.try_new_array(field.len(), EmptyOptions, [field.clone(), mask]) - .map(Some) + field.clone().mask(mask).map(Some) } } } @@ -142,10 +144,14 @@ mod tests { use std::sync::LazyLock; use vortex_buffer::buffer; + use vortex_error::VortexResult; use vortex_session::VortexSession; + use crate::ArrayRef; use crate::IntoArray; - use crate::VortexSessionExecute; + use crate::array::ArrayPlugin; + use crate::arrays::ScalarFn; + use crate::arrays::Struct; use crate::arrays::StructArray; use crate::arrays::VarBinViewArray; use crate::arrays::struct_::StructArrayExt; @@ -157,12 +163,24 @@ mod tests { use crate::dtype::Nullability; use crate::dtype::PType; use crate::dtype::StructFields; + use crate::optimizer::ArrayOptimizer; + use crate::optimizer::kernels::ArrayKernels; + use crate::optimizer::kernels::ReduceParentFn; use crate::scalar::Scalar; - use crate::session::ArraySession; + use crate::scalar_fn::ScalarFnVTable; + use crate::scalar_fn::fns::cast::Cast; use crate::validity::Validity; static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); + LazyLock::new(|| VortexSession::empty().with::()); + + fn no_struct_cast_plugin( + _child: &ArrayRef, + _parent: &ArrayRef, + _child_idx: usize, + ) -> VortexResult> { + Ok(None) + } #[test] fn test_struct_cast_field_reorder() { @@ -209,8 +227,67 @@ mod tests { ); } + #[test] + fn struct_cast_is_not_static_parent_rule() { + let source = StructArray::try_new( + FieldNames::from(["a", "b"]), + vec![ + VarBinViewArray::from_iter_str(["A"]).into_array(), + VarBinViewArray::from_iter_str(["B"]).into_array(), + ], + 1, + Validity::NonNullable, + ) + .unwrap() + .into_array(); + + let utf8_null = DType::Utf8(Nullability::Nullable); + let target = DType::Struct( + StructFields::new(FieldNames::from(["c", "b", "a"]), vec![utf8_null; 3]), + Nullability::NonNullable, + ); + + let cast = source.cast(target).unwrap(); + let optimized = cast.optimize().unwrap(); + assert!(optimized.is::()); + } + + #[test] + fn struct_cast_plugin_can_be_overridden() { + let source = StructArray::try_new( + FieldNames::from(["a", "b"]), + vec![ + VarBinViewArray::from_iter_str(["A"]).into_array(), + VarBinViewArray::from_iter_str(["B"]).into_array(), + ], + 1, + Validity::NonNullable, + ) + .unwrap() + .into_array(); + + let utf8_null = DType::Utf8(Nullability::Nullable); + let target = DType::Struct( + StructFields::new(FieldNames::from(["c", "b", "a"]), vec![utf8_null; 3]), + Nullability::NonNullable, + ); + + let cast = source.cast(target).unwrap(); + let kernels = ArrayKernels::empty(); + kernels.register_reduce_parent( + Cast.id(), + Struct.id(), + &[no_struct_cast_plugin as ReduceParentFn], + ); + let session = VortexSession::empty().with_some(kernels); + + let optimized = cast.optimize_ctx(&session).unwrap(); + assert!(optimized.is::()); + } + /// Regression test: casting a struct to a non-struct DType must not panic. Previously, - /// `StructCastPushDownRule` called `as_struct_fields()` which panics on non-struct types. + /// the Struct/Cast reduce-parent rewrite called `as_struct_fields()` which panics on + /// non-struct types. #[test] fn cast_struct_to_non_struct_does_not_panic() { let source = StructArray::try_new( @@ -222,10 +299,12 @@ mod tests { .unwrap(); // Casting a struct to a primitive type should not panic. Before the fix, - // `StructCastPushDownRule` would panic via `as_struct_fields()` on the non-struct target. + // the reduce-parent rewrite would panic via `as_struct_fields()` on the non-struct target. let result = source .into_array() - .cast(DType::Primitive(PType::I32, Nullability::NonNullable)); + .cast(DType::Primitive(PType::I32, Nullability::NonNullable)) + .and_then(|arr| arr.optimize_ctx(&SESSION)); + // Whether this errors or succeeds depends on execution, but the key invariant is that the // optimizer rule does not panic. if let Ok(arr) = &result { @@ -336,6 +415,10 @@ mod tests { Nullability::NonNullable, ); - assert!(source.into_array().cast(target).is_err()); + let arr = source + .into_array() + .cast(target) + .and_then(|arr| arr.optimize_ctx(&SESSION)); + assert!(arr.is_err()); } } diff --git a/vortex-array/src/arrays/struct_/vtable/kernel.rs b/vortex-array/src/arrays/struct_/vtable/kernel.rs index 1c5c9f3db3a..eac7158921d 100644 --- a/vortex-array/src/arrays/struct_/vtable/kernel.rs +++ b/vortex-array/src/arrays/struct_/vtable/kernel.rs @@ -3,10 +3,7 @@ use crate::arrays::Struct; use crate::kernel::ParentKernelSet; -use crate::scalar_fn::fns::cast::CastExecuteAdaptor; use crate::scalar_fn::fns::zip::ZipExecuteAdaptor; -pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ - ParentKernelSet::lift(&CastExecuteAdaptor(Struct)), - ParentKernelSet::lift(&ZipExecuteAdaptor(Struct)), -]); +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&ZipExecuteAdaptor(Struct))]); diff --git a/vortex-array/src/arrays/varbin/compute/cast.rs b/vortex-array/src/arrays/varbin/compute/cast.rs index e6403535e2d..ca5dfde56e5 100644 --- a/vortex-array/src/arrays/varbin/compute/cast.rs +++ b/vortex-array/src/arrays/varbin/compute/cast.rs @@ -38,7 +38,7 @@ impl CastReduce for VarBin { let new_nullability = dtype.nullability(); let Some(new_validity) = array .validity()? - .trivial_cast_nullability(new_nullability, array.len())? + .trivially_cast_nullability(new_nullability, array.len())? else { return Ok(None); }; diff --git a/vortex-array/src/arrays/varbinview/compute/cast.rs b/vortex-array/src/arrays/varbinview/compute/cast.rs index 9b3d6b45dc7..449e735d1fc 100644 --- a/vortex-array/src/arrays/varbinview/compute/cast.rs +++ b/vortex-array/src/arrays/varbinview/compute/cast.rs @@ -42,7 +42,7 @@ impl CastReduce for VarBinView { let new_nullability = dtype.nullability(); let Some(new_validity) = array .validity()? - .trivial_cast_nullability(new_nullability, array.len())? + .trivially_cast_nullability(new_nullability, array.len())? else { return Ok(None); }; diff --git a/vortex-array/src/executor.rs b/vortex-array/src/executor.rs index 694fd25c3c6..caa0c1d61e9 100644 --- a/vortex-array/src/executor.rs +++ b/vortex-array/src/executor.rs @@ -22,6 +22,8 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_panic; +use vortex_session::Ref; +use vortex_session::SessionExt; use vortex_session::VortexSession; use crate::AnyCanonical; @@ -36,6 +38,7 @@ use crate::matcher::Matcher; use crate::memory::HostAllocatorRef; use crate::memory::MemorySessionExt; use crate::optimizer::ArrayOptimizer; +use crate::optimizer::kernels::ArrayKernels; use crate::stats::ArrayStats; use crate::stats::StatsSet; @@ -418,9 +421,15 @@ impl Executable for ArrayRef { } } + let tmp_session = ctx.session().clone(); + let kernels = tmp_session.get_opt::(); + + for (slot_idx, slot) in array.slots().iter().enumerate() { let Some(child) = slot else { continue }; - if let Some(executed_parent) = child.execute_parent(&array, slot_idx, ctx)? { + if let Some(executed_parent) = + execute_parent_for_child(&array, child, slot_idx, kernels.as_ref(), ctx)? + { ctx.log(format_args!( "execute_parent: slot[{}]({}) rewrote {} -> {}", slot_idx, @@ -527,15 +536,48 @@ fn finalize_done( Ok((output, None)) } +fn execute_parent_for_child( + parent: &ArrayRef, + child: &ArrayRef, + slot_idx: usize, + kernels: Option<&Ref>, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + if let Some(kernels) = kernels + && let Some(plugins) = + kernels.find_execute_parent(parent.encoding_id(), child.encoding_id()) + { + for plugin in plugins.as_ref() { + if let Some(result) = plugin(child, parent, slot_idx, ctx)? { + return Ok(Some(result)); + } + } + } + + child.execute_parent(parent, slot_idx, ctx) +} + /// Try execute_parent on each occupied slot of the array. fn try_execute_parent(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult> { + let tmp_session = ctx.session().clone(); + let kernels = tmp_session.get_opt::(); + for (slot_idx, slot) in array.slots().iter().enumerate() { - let Some(child) = slot else { - continue; - }; - if let Some(result) = child.execute_parent(array, slot_idx, ctx)? { - result.statistics().inherit_from(array.statistics()); - return Ok(Some(result)); + let Some(child) = slot else { continue }; + if let Some(executed_parent) = + execute_parent_for_child(&array, child, slot_idx, kernels.as_ref(), ctx)? + { + ctx.log(format_args!( + "execute_parent: slot[{}]({}) rewrote {} -> {}", + slot_idx, + child.encoding_id(), + array, + executed_parent + )); + executed_parent + .statistics() + .inherit_from(array.statistics()); + return Ok(Some(executed_parent)); } } Ok(None) diff --git a/vortex-array/src/optimizer/kernels.rs b/vortex-array/src/optimizer/kernels.rs index a31192f8016..d38bc9402d1 100644 --- a/vortex-array/src/optimizer/kernels.rs +++ b/vortex-array/src/optimizer/kernels.rs @@ -3,22 +3,30 @@ //! Session-scoped registry for optimizer kernels. //! -//! [`ArrayKernels`] stores function pointers that participate in array optimization without -//! adding rules to an encoding vtable. The optimizer currently consults it for parent-reduce -//! rewrites before the child encoding's static `PARENT_RULES`. A registered function can -//! therefore add a rule for an extension encoding or take precedence over a built-in rule. +//! [`ArrayKernels`] stores function pointers that participate in array optimization and execution +//! without adding rules or kernels to an encoding vtable. The optimizer consults it for +//! parent-reduce rewrites before the child encoding's static `PARENT_RULES`, and the executor +//! consults it for parent execution before the child encoding's static parent kernels. A +//! registered function can therefore add support for an extension encoding or take precedence over +//! a built-in rule or kernel. When several functions are registered for the same key and kind, +//! they are tried in registration order until one applies. //! -//! Kernel entries are addressed by `(outer_id, child_id, kind)`. For parent-reduce kernels, -//! `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is the -//! child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the parent -//! id is the scalar function id. +//! Kernel entries are addressed by `(outer_id, child_id)`. For parent-reduce and execute-parent +//! kernels, `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is +//! the child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the +//! parent id is the scalar function id. //! -//! Sessions created by the top-level `vortex` crate install an empty registry by default. Other -//! sessions can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely -//! on [`ArrayKernelsExt::kernels`] to insert the default value. +//! Because registered functions have different signatures for each kernel kind, the registry +//! maintains one storage map per function type rather than a single type-erased map. +//! +//! Sessions created by the top-level `vortex` crate install the default registry. Other sessions +//! can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely on +//! [`ArrayKernelsExt::kernels`] to insert the default value. use std::any::Any; +use std::borrow::Borrow; use std::hash::BuildHasher; +use std::hash::Hash; use std::sync::Arc; use std::sync::LazyLock; @@ -32,8 +40,15 @@ use vortex_utils::aliases::DefaultHashBuilder; use vortex_utils::aliases::hash_map::HashMap; use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::array::VTable; +use crate::arrays::Struct; +use crate::arrays::struct_::compute::cast::struct_cast_execute_parent; +use crate::arrays::struct_::compute::rules::struct_cast_reduce_parent; +use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::fns::cast::Cast; -/// Shared hasher used to combine `(outer, child, FnKind)` tuples into [`FnRegistry`] keys. +/// Shared hasher used to combine `(outer, child)` tuples into registry keys. static FN_HASHER: LazyLock = LazyLock::new(DefaultHashBuilder::default); /// Function pointer for a plugin-provided parent-reduce rewrite. @@ -47,59 +62,164 @@ static FN_HASHER: LazyLock = LazyLock::new(DefaultHashBuilde pub type ReduceParentFn = fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult>; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[repr(transparent)] +struct ReduceParentFnId(u64); + +impl From for ReduceParentFnId { + fn from(id: u64) -> Self { + Self(id) + } +} + +impl Borrow for ReduceParentFnId { + fn borrow(&self) -> &u64 { + &self.0 + } +} + +/// Function pointer for a plugin-provided parent execution. +/// +/// The executor calls this with the matched `child`, its `parent`, the slot index where the child +/// appears, and the current [`ExecutionCtx`]. Return `Ok(Some(new_parent))` to replace the parent +/// with an executed result, or `Ok(None)` when the kernel does not apply. +/// +/// Implementations must preserve the parent's logical length and dtype, matching the invariant +/// required of static `execute_parent` kernels. +pub type ExecuteParentFn = fn( + child: &ArrayRef, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[repr(transparent)] +struct ExecuteParentFnId(u64); + +impl From for ExecuteParentFnId { + fn from(id: u64) -> Self { + Self(id) + } +} + +impl Borrow for ExecuteParentFnId { + fn borrow(&self) -> &u64 { + &self.0 + } +} + /// Session-scoped registry of optimizer kernel functions. -#[derive(Debug, Default)] +/// +/// Each kernel kind has its own storage map, keyed by `(outer_id, child_id)`. Registering +/// functions for an existing key appends them to that key's ordered list. +#[derive(Debug)] pub struct ArrayKernels { - reduce_parent: ArcSwap>>, + reduce_parent: ArcSwap>>, + execute_parent: ArcSwap>>, +} + +impl Default for ArrayKernels { + fn default() -> ArrayKernels { + let this = Self::empty(); + this.register_builtin_reduce_parent(); + this.register_builtin_execute_parent(); + this + } } impl ArrayKernels { /// Create an empty [`ArrayKernels`] with no kernels registered. pub fn empty() -> Self { - Self::default() + Self { + reduce_parent: ArcSwap::from_pointee(HashMap::default()), + execute_parent: ArcSwap::from_pointee(HashMap::default()), + } + } + + fn register_builtin_reduce_parent(&self) { + self.register_reduce_parent( + Cast.id(), + Struct.id(), + &[struct_cast_reduce_parent as ReduceParentFn], + ); + } + + fn register_builtin_execute_parent(&self) { + self.register_execute_parent( + Cast.id(), + Struct.id(), + &[struct_cast_execute_parent as ExecuteParentFn], + ); } - /// Register a [`ReduceParentFn`] for `(outer, child)`. + /// Register [`ReduceParentFn`]s for `(parent, child)`. /// - /// The optimizer will invoke `f` when it sees a parent with encoding id `outer` holding a - /// child with encoding id `child` during a `reduce_parent` step, before trying the child - /// encoding's static `PARENT_RULES`. `outer` is usually the parent array's encoding id. For - /// `ScalarFnArray`, it is the scalar function id, for example `Cast.id()`. + /// The optimizer invokes these functions in registration order when it sees a parent with + /// encoding id `parent` holding a child with encoding id `child` during a `reduce_parent` + /// step, before trying the child encoding's static `PARENT_RULES`. `parent` is usually the + /// parent array's encoding id. For `ScalarFnArray`, it is the scalar function id, for example + /// `Cast.id()`. /// - /// Replaces any function already registered for the same pair. - pub fn register_reduce_parent>( - &self, - parent: Id, - child: Id, - fns: I, - ) { - let registry = self.reduce_parent.load(); - let id = self.hash_fn_ids(parent, child); - let mut owned_registry = registry.as_ref().clone(); - if let Some(existing) = owned_registry.remove(&id) { - owned_registry.insert(id, existing.as_ref().iter().cloned().chain(fns).collect()); - } else { - owned_registry.insert(id, fns.into_iter().collect()); - } - self.reduce_parent.store(Arc::new(owned_registry)); + /// If functions have already been registered for the same pair, these functions are appended + /// after them. + pub fn register_reduce_parent(&self, parent: Id, child: Id, fns: &[ReduceParentFn]) { + self.reduce_parent.rcu(move |registry| { + update_fns(registry.as_ref().clone(), hash_fn_id(parent, child), fns) + }); } - /// Look up the [`ReduceParentFn`] registered for `(outer, child)`. + /// Look up the [`ReduceParentFn`]s registered for `(parent, child)`. /// /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the - /// function. + /// functions. pub fn find_reduce_parent(&self, parent: Id, child: Id) -> Option> { - let id = self.hash_fn_ids(parent, child); - let map = self.reduce_parent.load(); - let entry = map.get(&id)?; - Some(Arc::clone(entry)) + let id = hash_fn_id(parent, child); + self.reduce_parent.load().get(&id).cloned() + } + + /// Register [`ExecuteParentFn`]s for `(parent, child)`. + /// + /// The executor invokes these functions in registration order when it sees a parent with + /// encoding id `parent` holding a child with encoding id `child` during a parent execution + /// step, before trying the child encoding's static parent kernels. + /// + /// If functions have already been registered for the same pair, these functions are appended + /// after them. + pub fn register_execute_parent(&self, parent: Id, child: Id, fns: &[ExecuteParentFn]) { + self.execute_parent.rcu(move |registry| { + update_fns(registry.as_ref().clone(), hash_fn_id(parent, child), fns) + }); + } + + /// Look up the [`ExecuteParentFn`]s registered for `(parent, child)`. + /// + /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the + /// functions. + pub fn find_execute_parent(&self, parent: Id, child: Id) -> Option> { + let id = hash_fn_id(parent, child); + self.execute_parent.load().get(&id).cloned() } +} + +fn hash_fn_id(parent: Id, child: Id) -> u64 { + FN_HASHER.hash_one((parent, child)) +} - /// Combine a typed kernel id tuple into the `u64` key expected by the underlying - /// [`FnRegistry`]. All typed helpers use this path so registration and lookup agree. - fn hash_fn_ids(&self, parent: Id, child: Id) -> u64 { - FN_HASHER.hash_one((parent, child)) +fn update_fns + Eq + Hash + From>( + mut existing: HashMap>, + id: u64, + fns: &[F], +) -> HashMap> { + if let Some(existing_fns) = existing.remove(&id) { + existing.insert( + id.into(), + existing_fns.as_ref().iter().chain(fns).cloned().collect(), + ); + } else { + existing.insert(id.into(), fns.into()); } + existing } impl SessionVar for ArrayKernels { diff --git a/vortex-array/src/scalar_fn/fns/cast/mod.rs b/vortex-array/src/scalar_fn/fns/cast/mod.rs index cbab9615e97..a007788acb2 100644 --- a/vortex-array/src/scalar_fn/fns/cast/mod.rs +++ b/vortex-array/src/scalar_fn/fns/cast/mod.rs @@ -27,8 +27,8 @@ use crate::arrays::FixedSizeList; use crate::arrays::ListView; use crate::arrays::Null; use crate::arrays::Primitive; -use crate::arrays::Struct; use crate::arrays::VarBinView; +use crate::arrays::struct_::compute::cast::struct_cast; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::expr::StatsCatalog; @@ -218,7 +218,7 @@ fn cast_canonical( CanonicalView::VarBinView(a) => ::cast(a, dtype, ctx), CanonicalView::List(a) => ::cast(a, dtype, ctx), CanonicalView::FixedSizeList(a) => ::cast(a, dtype, ctx), - CanonicalView::Struct(a) => ::cast(a, dtype, ctx), + CanonicalView::Struct(a) => struct_cast(a, dtype, ctx), CanonicalView::Extension(a) => ::cast(a, dtype), CanonicalView::Variant(_) => { vortex_bail!("Variant arrays don't support casting") diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index b081783b809..204205d1f51 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -385,10 +385,10 @@ impl Validity { /// Convert into a variant compatible with the given nullability. /// /// This is the execution-time half of the nullability-cast pair. It is paired with - /// [`Self::trivial_cast_nullability`], which is used by `CastReduce` rules. The pattern is: + /// [`Self::trivially_cast_nullability`], which is used by `CastReduce` rules. The pattern is: /// /// - **`CastReduce` rules** (metadata-only rewrites in the optimizer) call - /// [`Self::trivial_cast_nullability`]. If it returns `Ok(None)`, the rule returns `Ok(None)` + /// [`Self::trivially_cast_nullability`]. If it returns `Ok(None)`, the rule returns `Ok(None)` /// and the cast is deferred to execution. /// - **`CastKernel` impls** (executed via [`ExecuteParentKernel`]) call this method, which /// may run the underlying validity array to compute statistics. @@ -437,7 +437,7 @@ impl Validity { /// }; /// ``` #[inline] - pub fn trivial_cast_nullability( + pub fn trivially_cast_nullability( self, nullability: Nullability, len: usize,