Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13586,8 +13586,12 @@ impl vortex_array::optimizer::kernels::ArrayKernels

pub fn vortex_array::optimizer::kernels::ArrayKernels::empty() -> Self

pub fn vortex_array::optimizer::kernels::ArrayKernels::find_execute_parent(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id) -> core::option::Option<alloc::sync::Arc<[vortex_array::optimizer::kernels::ExecuteParentFn]>>

pub fn vortex_array::optimizer::kernels::ArrayKernels::find_reduce_parent(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id) -> core::option::Option<alloc::sync::Arc<[vortex_array::optimizer::kernels::ReduceParentFn]>>

pub fn vortex_array::optimizer::kernels::ArrayKernels::register_execute_parent<I: core::iter::traits::collect::IntoIterator<Item = vortex_array::optimizer::kernels::ExecuteParentFn>>(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id, fns: I)

pub fn vortex_array::optimizer::kernels::ArrayKernels::register_reduce_parent<I: core::iter::traits::collect::IntoIterator<Item = vortex_array::optimizer::kernels::ReduceParentFn>>(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id, fns: I)

impl core::default::Default for vortex_array::optimizer::kernels::ArrayKernels
Expand All @@ -13612,6 +13616,8 @@ impl<S: vortex_session::SessionExt> vortex_array::optimizer::kernels::ArrayKerne

pub fn S::kernels(&self) -> vortex_session::Ref<'_, vortex_array::optimizer::kernels::ArrayKernels>

pub type vortex_array::optimizer::kernels::ExecuteParentFn = fn(child: &vortex_array::ArrayRef, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>

pub type vortex_array::optimizer::kernels::ReduceParentFn = fn(child: &vortex_array::ArrayRef, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>

pub mod vortex_array::optimizer::rules
Expand Down
247 changes: 177 additions & 70 deletions vortex-array/src/arrays/struct_/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,106 +6,149 @@ use vortex_error::VortexResult;
use vortex_error::vortex_ensure;

use crate::ArrayRef;
use crate::ArrayView;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::ConstantArray;
use crate::arrays::Struct;
use crate::arrays::StructArray;
use crate::arrays::scalar_fn::ExactScalarFn;
use crate::arrays::struct_::StructArrayExt;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::matcher::Matcher;
use crate::scalar::Scalar;
use crate::scalar_fn::fns::cast::CastKernel;

impl CastKernel for Struct {
fn cast(
array: ArrayView<'_, Struct>,
dtype: &DType,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
return Ok(None);
};

let source_sdtype = array.struct_fields();
use crate::scalar_fn::fns::cast::Cast;

pub(crate) fn struct_cast_execute_parent(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this an execute if _ctx?

Copy link
Copy Markdown
Contributor Author

@robert3005 robert3005 Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be just a reduce rule apart from the case where you want to convert validity from nullable into non nullable which MIGHT require compute. We need to refactor validity casting

child: &ArrayRef,
parent: &ArrayRef,
_child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(array) = child.as_opt::<Struct>() else {
return Ok(None);
};
let Some(parent) = ExactScalarFn::<Cast>::try_match(parent) else {
return Ok(None);
};

let dtype = parent.options;
if array.dtype() == parent.options {
return Ok(Some(array.array().clone()));
}

let fields_match_order = target_sdtype.nfields() == source_sdtype.nfields()
&& target_sdtype
.names()
.iter()
.zip(source_sdtype.names().iter())
.all(|(f1, f2)| f1 == f2);
struct_cast(array, dtype, ctx)
}

let mut cast_fields = Vec::with_capacity(target_sdtype.nfields());
if fields_match_order {
for (field, target_type) in array.iter_unmasked_fields().zip_eq(target_sdtype.fields())
{
let cast_field = field.cast(target_type)?;
cast_fields.push(cast_field);
pub(crate) fn struct_cast(
array: ArrayView<Struct>,
dtype: &DType,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
return Ok(None);
};

let source_sdtype = array.struct_fields();

let mut cast_fields = Vec::with_capacity(target_sdtype.nfields());
// Re-order, handle fields by value instead.
for (target_name, target_type) in target_sdtype.names().iter().zip_eq(target_sdtype.fields()) {
match source_sdtype.find(target_name) {
None => {
// No source field with this name => evolve the schema compatibly.
// If the field is nullable, we add a new ConstantArray field with the type.
vortex_ensure!(
target_type.is_nullable(),
"CAST for struct only supports added nullable fields"
);

cast_fields
.push(ConstantArray::new(Scalar::null(target_type), array.len()).into_array());
}
} else {
// Re-order, handle fields by value instead.
for (target_name, target_type) in
target_sdtype.names().iter().zip_eq(target_sdtype.fields())
{
match source_sdtype.find(target_name) {
None => {
// No source field with this name => evolve the schema compatibly.
// If the field is nullable, we add a new ConstantArray field with the type.
vortex_ensure!(
target_type.is_nullable(),
"CAST for struct only supports added nullable fields"
);

cast_fields.push(
ConstantArray::new(Scalar::null(target_type), array.len()).into_array(),
);
}
Some(src_field_idx) => {
// Field exists in source field. Cast it to the target type.
let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?;
cast_fields.push(cast_field);
}
}
Some(src_field_idx) => {
// Field exists in source field. Cast it to the target type.
let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?;
cast_fields.push(cast_field);
}
}
}

let validity = array
.validity()?
.cast_nullability(dtype.nullability(), array.len(), ctx)?;
let validity = array
.validity()?
.cast_nullability(dtype.nullability(), array.len(), ctx)?;

StructArray::try_new(
target_sdtype.names().clone(),
cast_fields,
array.len(),
validity,
)
.map(|a| Some(a.into_array()))
}
Ok(Some(
unsafe {
StructArray::new_unchecked(cast_fields, target_sdtype.clone(), array.len(), validity)
}
.into_array(),
))
}

#[cfg(test)]
mod tests {
use std::sync::LazyLock;

use rstest::rstest;
use vortex_buffer::buffer;
use vortex_error::VortexResult;
use vortex_session::VortexSession;

use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
#[expect(deprecated)]
use crate::ToCanonical as _;
use crate::VortexSessionExecute;
use crate::arrays::ConstantArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::StructArray;
use crate::arrays::VarBinArray;
use crate::arrays::scalar_fn::ScalarFnFactoryExt;
use crate::arrays::struct_::StructArrayExt;
use crate::assert_arrays_eq;
use crate::builtins::ArrayBuiltins;
use crate::compute::conformance::cast::test_cast_conformance;
use crate::dtype::DType;
use crate::dtype::DecimalDType;
use crate::dtype::FieldNames;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::dtype::StructFields;
use crate::optimizer::kernels::ArrayKernels;
use crate::optimizer::kernels::ArrayKernelsExt;
use crate::optimizer::kernels::ExecuteParentFn;
use crate::scalar::Scalar;
use crate::scalar_fn::fns::cast::Cast;
use crate::session::ArraySession;
use crate::validity::Validity;

static SESSION: LazyLock<VortexSession> =
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());

fn null_struct_cast_execute_parent(
child: &ArrayRef,
parent: &ArrayRef,
_child_idx: usize,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(target_fields) = parent.dtype().as_struct_fields_opt() else {
return Ok(None);
};
let fields: Vec<ArrayRef> = target_fields
.fields()
.map(|dtype| ConstantArray::new(Scalar::null(dtype), child.len()).into_array())
.collect();

StructArray::try_new(
target_fields.names().clone(),
fields,
child.len(),
Validity::from(parent.dtype().nullability()),
)
.map(|array| Some(array.into_array()))
}

#[rstest]
#[case(create_test_struct(false))]
#[case(create_test_struct(true))]
Expand All @@ -115,6 +158,64 @@ mod tests {
test_cast_conformance(&array.into_array());
}

#[test]
fn struct_cast_execute_parent_is_not_static_kernel() {
let source = create_simple_struct().into_array();
let target = DType::struct_(
[(
"value",
DType::Primitive(PType::I64, Nullability::NonNullable),
)],
Nullability::NonNullable,
);

let cast = Cast
.try_new_array(source.len(), target, [source.clone()])
.unwrap();
let mut ctx = ExecutionCtx::new(VortexSession::empty());

assert!(source.execute_parent(&cast, 0, &mut ctx).unwrap().is_none());
}

#[test]
fn struct_cast_execute_parent_uses_session_plugin() {
let source = StructArray::try_new(
FieldNames::from(["a"]),
vec![VarBinArray::from_vec(vec!["A"], DType::Utf8(Nullability::Nullable)).into_array()],
1,
Validity::NonNullable,
)
.unwrap()
.into_array();
let child_id = source.encoding_id();

let utf8_null = DType::Utf8(Nullability::Nullable);
let target = DType::Struct(
StructFields::new(FieldNames::from(["b"]), vec![utf8_null.clone()]),
Nullability::NonNullable,
);

let cast = Cast
.try_new_array(source.len(), target.clone(), [source])
.unwrap();
let parent_id = cast.encoding_id();
let session = VortexSession::empty().with::<ArrayKernels>();
session.kernels().register_execute_parent(
parent_id,
child_id,
&[null_struct_cast_execute_parent as ExecuteParentFn],
);
let mut ctx = session.create_execution_ctx();

let result = cast.execute::<StructArray>(&mut ctx).unwrap();

assert_eq!(result.dtype(), &target);
assert_arrays_eq!(
result.unmasked_field_by_name("b").unwrap(),
ConstantArray::new(Scalar::null(utf8_null), 1)
);
}

fn create_test_struct(nullable: bool) -> StructArray {
let names = FieldNames::from(["a", "b"]);

Expand Down Expand Up @@ -204,14 +305,17 @@ mod tests {

let target_dtype = struct_array.dtype().as_nullable();

let result = struct_array
let cast = struct_array
.into_array()
.cast(target_dtype.clone())
.unwrap();
assert_eq!(result.dtype(), &target_dtype);
assert_eq!(result.len(), 3);
#[expect(deprecated)]
let nfields = result.to_struct().struct_fields().nfields();
assert_eq!(cast.dtype(), &target_dtype);
assert_eq!(cast.len(), 3);
let nfields = cast
.execute::<StructArray>(&mut SESSION.create_execution_ctx())
.unwrap()
.struct_fields()
.nfields();
assert_eq!(nfields, 2);
}

Expand Down Expand Up @@ -241,8 +345,11 @@ mod tests {
.unwrap();
assert_eq!(result.dtype(), &target_dtype);
assert_eq!(result.len(), 3);
#[expect(deprecated)]
let nfields = result.to_struct().struct_fields().nfields();
let nfields = result
.execute::<StructArray>(&mut SESSION.create_execution_ctx())
.unwrap()
.struct_fields()
.nfields();
assert_eq!(nfields, 3);
}
}
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/struct_/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

mod cast;
pub(crate) mod cast;
mod mask;
pub(crate) mod rules;
mod slice;
Expand Down
Loading
Loading