diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 7300bb399e6..d07b61afa73 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -78,11 +78,9 @@ pub const vortex_tensor::encodings::turboquant::MAX_CENTROIDS: usize pub const vortex_tensor::encodings::turboquant::MIN_DIMENSION: u32 -pub fn vortex_tensor::encodings::turboquant::tq_validate_vector_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult - pub fn vortex_tensor::encodings::turboquant::turboquant_encode(input: vortex_array::array::erased::ArrayRef, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub unsafe fn vortex_tensor::encodings::turboquant::turboquant_encode_unchecked(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::turboquant_encode_normalized(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub mod vortex_tensor::fixed_shape @@ -218,12 +216,16 @@ pub enum vortex_tensor::matcher::TensorMatch<'a> pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>) +pub vortex_tensor::matcher::TensorMatch::NormalizedVector(vortex_tensor::vector::VectorMatcherMetadata) + pub vortex_tensor::matcher::TensorMatch::Vector(vortex_tensor::vector::VectorMatcherMetadata) impl vortex_tensor::matcher::TensorMatch<'_> pub fn vortex_tensor::matcher::TensorMatch<'_>::element_ptype(self) -> vortex_array::dtype::ptype::PType +pub fn vortex_tensor::matcher::TensorMatch<'_>::is_normalized(self) -> bool + pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> u32 impl<'a> core::clone::Clone for vortex_tensor::matcher::TensorMatch<'a> @@ -252,6 +254,66 @@ pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher:: pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option +pub mod vortex_tensor::normalized_vector + +pub struct vortex_tensor::normalized_vector::AnyNormalizedVector + +impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::normalized_vector::AnyNormalizedVector + +pub type vortex_tensor::normalized_vector::AnyNormalizedVector::Match<'a> = vortex_tensor::vector::VectorMatcherMetadata + +pub fn vortex_tensor::normalized_vector::AnyNormalizedVector::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option + +pub struct vortex_tensor::normalized_vector::NormalizedVector + +impl vortex_tensor::normalized_vector::NormalizedVector + +pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::new_unchecked(fsl: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_tensor::normalized_vector::NormalizedVector::try_new(fsl: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::wrap_vector_unchecked(vector: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::clone(&self) -> vortex_tensor::normalized_vector::NormalizedVector + +impl core::cmp::Eq for vortex_tensor::normalized_vector::NormalizedVector + +impl core::cmp::PartialEq for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::eq(&self, other: &vortex_tensor::normalized_vector::NormalizedVector) -> bool + +impl core::default::Default for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::default() -> vortex_tensor::normalized_vector::NormalizedVector + +impl core::fmt::Debug for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_tensor::normalized_vector::NormalizedVector + +pub fn vortex_tensor::normalized_vector::NormalizedVector::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_tensor::normalized_vector::NormalizedVector + +impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::normalized_vector::NormalizedVector + +pub type vortex_tensor::normalized_vector::NormalizedVector::Metadata = vortex_array::extension::EmptyMetadata + +pub type vortex_tensor::normalized_vector::NormalizedVector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue + +pub fn vortex_tensor::normalized_vector::NormalizedVector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult + +pub fn vortex_tensor::normalized_vector::NormalizedVector::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_tensor::normalized_vector::NormalizedVector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_tensor::normalized_vector::NormalizedVector::unpack_native<'a>(_ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_tensor::normalized_vector::NormalizedVector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType) -> vortex_error::VortexResult<()> + pub mod vortex_tensor::scalar_fns pub mod vortex_tensor::scalar_fns::cosine_similarity @@ -384,8 +446,6 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options: pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()> - pub mod vortex_tensor::scalar_fns::l2_norm pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm @@ -490,7 +550,7 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::clone(&self) -> impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform -pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult> +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult>> @@ -578,7 +638,9 @@ pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32 pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType -pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult +pub fn vortex_tensor::vector::VectorMatcherMetadata::is_normalized(self) -> bool + +pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32, is_normalized: bool) -> vortex_error::VortexResult impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata diff --git a/vortex-tensor/src/encodings/l2_denorm.rs b/vortex-tensor/src/encodings/l2_denorm.rs index 6cb4fcb0626..64b876d4afa 100644 --- a/vortex-tensor/src/encodings/l2_denorm.rs +++ b/vortex-tensor/src/encodings/l2_denorm.rs @@ -14,8 +14,9 @@ use vortex_compressor::scheme::Scheme; use vortex_compressor::stats::ArrayAndStats; use vortex_error::VortexResult; -use crate::matcher::AnyTensor; +use crate::normalized_vector::AnyNormalizedVector; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; +use crate::types::vector::AnyVector; #[derive(Debug)] pub struct L2DenormScheme; @@ -26,10 +27,14 @@ impl Scheme for L2DenormScheme { } fn matches(&self, canonical: &Canonical) -> bool { - matches!( - canonical, - Canonical::Extension(ext) if ext.ext_dtype().is::() - ) + let Canonical::Extension(ext) = canonical else { + return false; + }; + + // `AnyVector` matches any vector-shaped extension; we explicitly exclude `NormalizedVector` + // here because a normalized input already carries an authoritative unit-norm representation + // and does not need re-normalization. + ext.ext_dtype().is::() && !ext.ext_dtype().is::() } fn expected_compression_ratio( @@ -38,6 +43,7 @@ impl Scheme for L2DenormScheme { _compress_ctx: CompressorContext, _exec_ctx: &mut ExecutionCtx, ) -> CompressionEstimate { + // We almost always want to pre-normalize our data if the vector is not already normalized. CompressionEstimate::Verdict(EstimateVerdict::AlwaysUse) } @@ -52,3 +58,62 @@ impl Scheme for L2DenormScheme { Ok(l2_denorm.into_array()) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use vortex_array::Canonical; + use vortex_array::IntoArray; + use vortex_array::arrays::ExtensionArray; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::extension::EmptyMetadata; + use vortex_array::validity::Validity; + use vortex_compressor::scheme::Scheme; + use vortex_error::VortexResult; + + use super::L2DenormScheme; + use crate::types::fixed_shape::FixedShapeTensor; + use crate::types::fixed_shape::FixedShapeTensorMetadata; + use crate::types::vector::Vector; + + fn fsl_storage(elements: &[f32], list_size: u32) -> VortexResult { + let len = elements.len() / list_size as usize; + let elements = PrimitiveArray::from_iter(elements.iter().copied()).into_array(); + FixedSizeListArray::try_new(elements, list_size, Validity::NonNullable, len) + } + + #[test] + fn matches_vector() -> VortexResult<()> { + let fsl = fsl_storage(&[1.0, 0.0], 2)?; + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array())); + + assert!(L2DenormScheme.matches(&canonical)); + Ok(()) + } + + #[test] + fn rejects_fixed_shape_tensor() -> VortexResult<()> { + let fsl = fsl_storage(&[1.0, 0.0, 0.0, 1.0], 4)?; + let storage_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + 4, + Nullability::NonNullable, + ); + let ext_dtype = ExtDType::::try_new( + FixedShapeTensorMetadata::new(vec![2, 2]), + storage_dtype, + )? + .erased(); + let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array())); + + assert!(!L2DenormScheme.matches(&canonical)); + Ok(()) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index ca32faa6ec9..7e4f37d59a2 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -3,10 +3,12 @@ //! TurboQuant encoding (quantization) logic. //! -//! The input to [`turboquant_encode`] must be a non-nullable [`Vector`](crate::vector::Vector) -//! extension array whose rows are already L2-normalized (unit norm). Normalization is handled -//! externally by [`normalize_as_l2_denorm`](crate::scalar_fns::l2_denorm::normalize_as_l2_denorm), -//! which the [`TurboQuantScheme`] calls before invoking this function. +//! The input to [`turboquant_encode`] must be a non-nullable [`Vector`] extension array whose rows +//! are already L2-normalized (unit norm). Normalization is handled externally by +//! [`normalize_as_l2_denorm`], which the [`TurboQuantScheme`] calls before invoking this function. +//! +//! If you already have a [`NormalizedVector`] array, then use the [`turboquant_encode_normalized`] +//! function instead. //! //! [`TurboQuantScheme`]: crate::encodings::turboquant::TurboQuantScheme @@ -15,6 +17,7 @@ use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Extension; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::DictArray; @@ -22,6 +25,7 @@ use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::Nullability; +use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; @@ -34,12 +38,14 @@ use crate::encodings::turboquant::MIN_DIMENSION; use crate::encodings::turboquant::centroids::compute_centroid_boundaries; use crate::encodings::turboquant::centroids::compute_or_get_centroids; use crate::encodings::turboquant::centroids::find_nearest_centroid; +use crate::normalized_vector::AnyNormalizedVector; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; use crate::scalar_fns::sorf_transform::SorfTransform; -use crate::types::vector::AnyVector; +use crate::types::normalized_vector::NormalizedVector; +#[expect(unused, reason = "docs")] use crate::types::vector::Vector; use crate::utils::cast_to_f32; @@ -54,6 +60,7 @@ pub struct TurboQuantConfig { pub num_rounds: u8, } +// TODO(connor): We should be able to modify this more easily from the `TurboQuantScheme`! impl Default for TurboQuantConfig { fn default() -> Self { Self { @@ -64,24 +71,24 @@ impl Default for TurboQuantConfig { } } -/// Apply the full TurboQuant compression pipeline to a [`Vector`](crate::vector::Vector) -/// extension array: normalize the rows via [`normalize_as_l2_denorm`], quantize the normalized -/// child via [`turboquant_encode_unchecked`], and reattach the stored norms as the outer -/// [`L2Denorm`] wrapper. +/// Apply the full TurboQuant compression pipeline to a [`Vector`] extension array: normalize the +/// rows via [`normalize_as_l2_denorm`], quantize the normalized child via +/// [`turboquant_encode_normalized`], and reattach the stored norms as the outer [`L2Denorm`] +/// wrapper. /// /// The returned array has the canonical TurboQuant shape: /// /// ```text /// ScalarFnArray(L2Denorm, [ -/// ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), +/// NormalizedVector(ScalarFnArray(SorfTransform, [NormalizedVector(Vector(FSL(Dict)))])), /// norms, /// ]) /// ``` /// /// # Errors /// -/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or -/// if [`turboquant_encode_unchecked`] rejects the input shape. +/// Returns an error if `input` is not a vector-family extension array, if normalization fails, or if +/// [`turboquant_encode_normalized`] rejects the input shape. pub fn turboquant_encode( input: ArrayRef, config: &TurboQuantConfig, @@ -89,7 +96,16 @@ pub fn turboquant_encode( ) -> VortexResult { // We must normalize the array before we can encode it with TurboQuant. let l2_denorm = normalize_as_l2_denorm(input, ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + vortex_ensure!( + normalized + .dtype() + .as_extension_opt() + .is_some_and(|ext| ext.is::()), + "TurboQuant requires a Vector or NormalizedVector input, got normalized child {}", + normalized.dtype(), + ); let norms = l2_denorm.child_at(1).clone(); let num_rows = l2_denorm.len(); @@ -97,136 +113,184 @@ pub fn turboquant_encode( .as_opt::() .vortex_expect("normalize_as_l2_denorm always produces an Extension array child"); - // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero for null rows). - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?; + let tq = turboquant_encode_normalized(normalized_ext, config, ctx)?; // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally - // bypass the strict normalized-row validation when reattaching the stored norms. + // bypass the strict normalized-row and zero-row validation when reattaching the stored norms. Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) } -/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a -/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the unit-norm -/// precondition. -/// -/// # Safety -/// -/// The caller must ensure: -/// -/// - The input dtype is non-nullable. -/// - Every row is L2-normalized (unit norm) or is a zero vector. -/// -/// Passing non-unit-norm vectors will not cause memory unsafety, but will produce silently -/// incorrect quantization results. -pub unsafe fn turboquant_encode_unchecked( +/// Encode a non-nullable [`NormalizedVector`] extension array into a lossy +/// `NormalizedVector(ScalarFnArray(SorfTransform, [NormalizedVector(Vector(FSL(Dict)))]))`, +/// without validating the decoded unit-norm precondition. +pub fn turboquant_encode_normalized( ext: ArrayView, config: &TurboQuantConfig, ctx: &mut ExecutionCtx, ) -> VortexResult { let ext_dtype = ext.dtype().clone(); - let storage = ext.storage_array(); - let fsl = storage.clone().execute::(ctx)?; + + let vector_metadata = ext_dtype.as_extension().metadata::(); + let element_ptype = vector_metadata.element_ptype(); + let dimensions = vector_metadata.dimensions(); + + // `NormalizedVector` storage is `Extension(Vector(FSL))`, so drill past the inner `Vector` to + // reach the underlying `FixedSizeList`. + let inner_vector: ExtensionArray = ext.storage_array().clone().execute(ctx)?; + let fsl: FixedSizeListArray = inner_vector.storage_array().clone().execute(ctx)?; vortex_ensure!( config.bit_width >= 1 && config.bit_width <= MAX_BIT_WIDTH, "bit_width must be 1-{MAX_BIT_WIDTH}, got {}", config.bit_width ); - let dimension = fsl.list_size(); vortex_ensure!( - dimension >= MIN_DIMENSION, - "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimension}", + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", ); - let vector_metadata = ext_dtype.as_extension().metadata::(); - let element_ptype = vector_metadata.element_ptype(); - - let seed = config.seed; let num_rows = fsl.len(); - if fsl.is_empty() { - let padded_dim = dimension.next_power_of_two(); - let empty_codes = PrimitiveArray::empty::(Nullability::NonNullable); - let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); - let empty_dict = - DictArray::try_new(empty_codes.into_array(), empty_centroids.into_array())?; - let empty_fsl = FixedSizeListArray::try_new( - empty_dict.into_array(), - padded_dim, - Validity::NonNullable, - 0, - )?; - let empty_padded_vector = Vector::try_new_vector_array(empty_fsl.into_array())?; - - let sorf_options = SorfOptions { - seed, - num_rounds: config.num_rounds, - dimensions: dimension, - element_ptype, - }; - return Ok( - SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(), - ); - } + // No data to quantize: short-circuit by returning an empty `NormalizedVector` directly at + // the final output shape `(dimensions, element_ptype)`. The non-empty path only goes + // through `SorfTransform` because the inverse rotation reshapes + // `(padded_dim, f32) → (dimensions, element_ptype)`; with zero rows there is no rotation + // to apply and we can construct an FSL with the destination dtype straight away. + if num_rows == 0 { + return match_each_float_ptype!(element_ptype, |T| { + let elements = PrimitiveArray::empty::(Nullability::NonNullable); + let empty_fsl = FixedSizeListArray::try_new( + elements.into_array(), + dimensions, + Validity::NonNullable, + 0, + )?; - let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?; - let quantized_fsl = - build_quantized_fsl(num_rows, core.all_indices, core.centroids, core.padded_dim)?; - let padded_vector = Vector::try_new_vector_array(quantized_fsl)?; + // SAFETY: An empty FSL contains no rows, so the unit-norm-or-zero invariant holds + // vacuously. + unsafe { NormalizedVector::new_unchecked(empty_fsl.into_array()) } + }); + } let sorf_options = SorfOptions { - seed, + seed: config.seed, num_rounds: config.num_rounds, - dimensions: dimension, + dimensions, element_ptype, }; - Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) -} -/// Shared intermediate results from the quantization loop. -struct QuantizationResult { - centroids: Buffer, - all_indices: Buffer, - padded_dim: usize, + let quantized_fsl = turboquant_quantize_fsl(&fsl, config.bit_width, &sorf_options, ctx)?; + + // SAFETY: TurboQuant is a lossy approximation of the already-unit-norm input. + let padded_vector = unsafe { NormalizedVector::new_unchecked(quantized_fsl) }?; + + let sorf = SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); + // SAFETY: Inverse SORF followed by truncation can lose energy, and quantization is already + // lossy, so this is a semantic assertion made by TurboQuant rather than an exact validation. + // Downstream vector operators treat the compressed unit-vector claim as authoritative. + unsafe { NormalizedVector::wrap_vector_unchecked(sorf) } } -/// Core quantization: rotate and quantize already-normalized rows. +/// Rotate and quantize already-normalized vector rows into a dict-encoded `FixedSizeList`. +/// +/// The input `fsl` must contain non-nullable, unit-norm vectors of float values (already +/// L2-normalized). Null vectors are not supported and must be zeroed out before reaching this +/// function. The rotation and centroid lookup happen in f32. +/// +/// The returned array is `FSL(DictArray(codes, centroids), padded_dim)`. The `FixedSizeList` has +/// Dict-encoded elements, where each row of `padded_dim` u8 codes indexes into the centroid +/// codebook. /// -/// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null -/// vectors are not supported and must be zeroed out before reaching this function. The rotation -/// and centroid lookup happen in f32. -fn turboquant_quantize_core( +/// This allows the FSL (via the Dict-encodede elements) to be independently sliced, taken, or +/// executed (dequantized) without knowledge of the rotation. +/// +/// Internally, this function will: +/// +/// 1. Builds a [`SorfMatrix`] structured rotation from the seed/rounds in `sorf_options`. +/// 2. For each row, zero-pads to the next power of 2, applies the rotation, and maps each rotated +/// coordinate to its nearest centroid index via binary search on precomputed boundaries. +/// 3. Packs the per-row centroid indices and the shared centroid codebook into a `DictArray`-backed +/// `FixedSizeListArray`. +fn turboquant_quantize_fsl( fsl: &FixedSizeListArray, - seed: u64, bit_width: u8, - num_rounds: u8, + sorf_options: &SorfOptions, ctx: &mut ExecutionCtx, -) -> VortexResult { - let dimension = fsl.list_size() as usize; +) -> VortexResult { + vortex_ensure!(!fsl.dtype().is_nullable()); + + let dimensions = fsl.list_size() as usize; let num_rows = fsl.len(); - let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?; + let rotation = SorfMatrix::try_new( + sorf_options.seed, + dimensions, + sorf_options.num_rounds as usize, + )?; let padded_dim = rotation.padded_dim(); let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); + // Compute the centroids for the given (dimension, bit_width) combination (or retrieve it from a + // previous computation) + let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?; + + // Extract out the elements of the FSL and cast to f32. In the f64 case, we intentionally lose + // information here because we are already going to be quantizing to a smaller set of centroids, + // so we are fine with this loss. let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; let f32_elements = cast_to_f32(elements_prim)?; - let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?; - let boundaries = compute_centroid_boundaries(¢roids); + // Take the float values and quantize by finding the closest centroid in the codebook to each + // and recording the index of that centroid. + let all_indices = rotate_and_quantize( + f32_elements.as_slice(), + num_rows, + dimensions, + &rotation, + ¢roids, + ); + + // Build the Dict-encoded FSL from the centroid indices and codebook. Everything is non-null + // since our input in non-null. + let codes = PrimitiveArray::new::(all_indices, Validity::NonNullable); + let values = PrimitiveArray::new::(centroids, Validity::NonNullable); + let dict = DictArray::try_new(codes.into_array(), values.into_array())?; + + Ok(FixedSizeListArray::try_new( + dict.into_array(), + padded_dim_u32, + Validity::NonNullable, + num_rows, + )? + .into_array()) +} + +/// Rotate each row via the structured rotation and quantize every rotated coordinate to its nearest +/// centroid index via binary search on precomputed boundaries. +/// +/// Returns a flat [`Buffer`] of length `num_rows * padded_dim` containing the per-coordinate +/// centroid indices. +fn rotate_and_quantize( + f32_slice: &[f32], + num_rows: usize, + dimensions: usize, + rotation: &SorfMatrix, + centroids: &[f32], +) -> Buffer { + let padded_dim = rotation.padded_dim(); + let boundaries = compute_centroid_boundaries(centroids); let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); let mut padded = vec![0.0f32; padded_dim]; let mut rotated = vec![0.0f32; padded_dim]; - let f32_slice = f32_elements.as_slice(); for row in 0..num_rows { - let x = &f32_slice[row * dimension..(row + 1) * dimension]; + let x = &f32_slice[row * dimensions..][..dimensions]; // Zero-pad to the next power of 2. - padded[..dimension].copy_from_slice(x); - padded[dimension..].fill(0.0); + padded[..dimensions].copy_from_slice(x); + padded[dimensions..].fill(0.0); rotation.rotate(&padded, &mut rotated); @@ -235,36 +299,5 @@ fn turboquant_quantize_core( } } - Ok(QuantizationResult { - centroids, - all_indices: all_indices.freeze(), - padded_dim, - }) -} - -/// Build a quantized representation: `FSL(DictArray(codes, centroids), padded_dim)`. -/// -/// This is a Dict-encoded FixedSizeList where each row of `padded_dim` u8 codes indexes into the -/// centroid codebook. The Dict can be independently sliced, taken, or executed (dequantized) -/// without knowledge of the rotation. -fn build_quantized_fsl( - num_rows: usize, - all_indices: Buffer, - centroids: Buffer, - padded_dim: usize, -) -> VortexResult { - let codes = PrimitiveArray::new::(all_indices, Validity::NonNullable); - let centroids_array = PrimitiveArray::new::(centroids, Validity::NonNullable); - - let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?; - - let padded_dim_u32 = - u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); - Ok(FixedSizeListArray::try_new( - dict.into_array(), - padded_dim_u32, - Validity::NonNullable, - num_rows, - )? - .into_array()) + all_indices.freeze() } diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 50cef7b721e..71eb873810a 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -29,13 +29,15 @@ //! //! ```text //! ScalarFnArray(L2Denorm, [ -//! ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), +//! NormalizedVector(ScalarFnArray(SorfTransform, [NormalizedVector(Vector(FSL(Dict)))])), //! norms //! ]) //! ``` //! //! When executed, the tree automatically decompresses: Dict dequantizes codes → SorfTransform -//! inverse-rotates → L2Denorm re-applies norms → original vectors (approximately). +//! inverse-rotates → L2Denorm re-applies norms → original vectors (approximately). The +//! `NormalizedVector` wrappers mark the unit-vector contract that the lossy encoding treats as +//! authoritative. //! //! [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm //! [`SorfTransform`]: crate::scalar_fns::sorf_transform::SorfTransform @@ -134,7 +136,7 @@ pub(crate) mod compress; mod scheme; pub use compress::TurboQuantConfig; pub use compress::turboquant_encode; -pub use compress::turboquant_encode_unchecked; +pub use compress::turboquant_encode_normalized; pub use scheme::TurboQuantScheme; /// Minimum vector dimension for TurboQuant encoding. @@ -149,34 +151,5 @@ pub const MAX_BIT_WIDTH: u8 = 8; /// Maximum supported number of centroids in the scalar quantizer codebook. pub const MAX_CENTROIDS: usize = 1usize << (MAX_BIT_WIDTH as usize); -use vortex_array::dtype::DType; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; - -use crate::types::vector::AnyVector; -use crate::types::vector::VectorMatcherMetadata; - -/// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with -/// dimension >= [`MIN_DIMENSION`]. -/// -/// Returns the validated vector metadata on success. -pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult { - let vector_metadata = dtype - .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) - .ok_or_else(|| { - vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") - })?; - - let dimensions = vector_metadata.dimensions(); - vortex_ensure!( - dimensions >= MIN_DIMENSION, - "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", - ); - - Ok(vector_metadata) -} - #[cfg(test)] mod tests; diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index d4362096bd2..af782c4abb2 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -3,25 +3,32 @@ //! TurboQuant compression scheme. //! -//! The scheme is a thin [`Scheme`] adapter over [`turboquant_encode`], which produces: +//! Plain [`Vector`](crate::vector::Vector) inputs are normalized and encoded via +//! [`turboquant_encode`], which produces: //! //! ```text //! ScalarFnArray(L2Denorm, [ -//! ScalarFnArray( -//! SorfTransform, -//! FSL(Dict(codes, centroids)) -//! ), +//! NormalizedVector(ScalarFnArray(SorfTransform, [ +//! NormalizedVector(Vector(FSL(Dict(codes, centroids)))) +//! ])), //! norms //! ]) //! ``` //! +//! Non-nullable [`NormalizedVector`](crate::normalized_vector::NormalizedVector) inputs skip the +//! outer [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm) wrapper and are encoded directly via +//! [`turboquant_encode_normalized`]. +//! //! Decompression is automatic: executing the outer array walks the ScalarFn tree. //! //! [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode +//! [`turboquant_encode_normalized`]: crate::encodings::turboquant::turboquant_encode_normalized use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::ExecutionCtx; +use vortex_array::arrays::Extension; +use vortex_array::dtype::DType; use vortex_compressor::CascadingCompressor; use vortex_compressor::ctx::CompressorContext; use vortex_compressor::estimate::CompressionEstimate; @@ -30,11 +37,16 @@ use vortex_compressor::scheme::Scheme; use vortex_compressor::stats::ArrayAndStats; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; use crate::encodings::turboquant::MAX_CENTROIDS; +use crate::encodings::turboquant::MIN_DIMENSION; use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::tq_validate_vector_dtype; use crate::encodings::turboquant::turboquant_encode; +use crate::encodings::turboquant::turboquant_encode_normalized; +use crate::vector::AnyVector; +use crate::vector::VectorMatcherMetadata; /// TurboQuant compression scheme for [`Vector`] extension types. /// @@ -101,7 +113,29 @@ impl Scheme for TurboQuantScheme { _compress_ctx: CompressorContext, exec_ctx: &mut ExecutionCtx, ) -> VortexResult { - turboquant_encode(data.array().clone(), &TurboQuantConfig::default(), exec_ctx) + // TODO(connor): If we ever add scheme vtables with metadata, we would need to pass in the + // config as a parameter here. + let config = TurboQuantConfig::default(); + turboquant_encode_for_scheme(data.array().clone(), &config, exec_ctx) + } +} + +fn turboquant_encode_for_scheme( + input: ArrayRef, + config: &TurboQuantConfig, + exec_ctx: &mut ExecutionCtx, +) -> VortexResult { + let vector_metadata = tq_validate_vector_dtype(input.dtype())?; + if vector_metadata.is_normalized() { + let ext = input.as_opt::().ok_or_else(|| { + vortex_err!( + "TurboQuant normalized input must be an Extension array, got {}", + input.encoding_id() + ) + })?; + turboquant_encode_normalized(ext, config, exec_ctx) + } else { + turboquant_encode(input, config, exec_ctx) } } @@ -133,11 +167,49 @@ fn estimate_compression_ratio(element_bit_width: u8, dimensions: u32, num_vector uncompressed_size_bits as f64 / compressed_size_bits as f64 } +/// Validates that `dtype` is a plain [`Vector`](crate::vector::Vector) or non-nullable +/// [`NormalizedVector`](crate::normalized_vector::NormalizedVector) extension type with dimension +/// >= [`MIN_DIMENSION`]. +/// +/// Returns the validated vector metadata on success. +pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult { + let vector_metadata = dtype + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + .ok_or_else(|| { + vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") + })?; + + let dimensions = vector_metadata.dimensions(); + vortex_ensure!( + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", + ); + vortex_ensure!( + !vector_metadata.is_normalized() || !dtype.is_nullable(), + "TurboQuant cannot encode nullable NormalizedVector inputs because normalized encode has \ + no norms child to carry validity", + ); + + Ok(vector_metadata) +} + #[cfg(test)] mod tests { use rstest::rstest; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::arrays::ScalarFn; + use vortex_array::dtype::Nullability; + use vortex_array::validity::Validity; + use vortex_buffer::BufferMut; use super::*; + use crate::tests::SESSION; + use crate::types::normalized_vector::NormalizedVector; + use crate::utils::test_helpers::normalized_vector_array; /// Verify compression ratio for typical embedding dimensions. /// @@ -207,6 +279,48 @@ mod tests { ); } + #[test] + fn scheme_routes_normalized_vector_without_l2_denorm_wrapper() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let mut values = vec![0.0f32; 2 * 128]; + values[0] = 1.0; + values[128 + 1] = 1.0; + let input = normalized_vector_array(128, &values, &mut ctx)?; + + let encoded = turboquant_encode_for_scheme(input, &TurboQuantConfig::default(), &mut ctx)?; + + assert!(encoded.dtype().as_extension().is::()); + assert!( + encoded.as_opt::().is_none(), + "NormalizedVector scheme path should not add an outer L2Denorm ScalarFnArray", + ); + Ok(()) + } + + #[test] + fn validate_rejects_nullable_normalized_vector() -> VortexResult<()> { + let dim = 128u32; + let mut values = BufferMut::::with_capacity(2 * dim as usize); + for row in 0..2 { + for col in 0..dim { + values.push(if col == row { 1.0 } else { 0.0 }); + } + } + let elements = PrimitiveArray::new::(values.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim, + Validity::from_iter([true, false]), + 2, + )?; + let mut ctx = SESSION.create_execution_ctx(); + let normalized = NormalizedVector::try_new(fsl.into_array(), &mut ctx)?; + + assert_eq!(normalized.dtype().nullability(), Nullability::Nullable); + assert!(tq_validate_vector_dtype(normalized.dtype()).is_err()); + Ok(()) + } + /// Power-of-2 dimensions should have better ratios than their non-power-of-2 /// predecessors due to no padding waste. #[test] diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs index ec4182dcc3d..c9e7b8410fc 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -16,6 +16,7 @@ use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::Dict; +use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; @@ -71,7 +72,7 @@ fn make_vector_ext(fsl: &FixedSizeListArray) -> ArrayRef { .vortex_expect("test FSL satisfies Vector storage constraints") } -/// Unwrap an L2Denorm ScalarFnArray into (sorf_child, norms_child). +/// Unwrap an L2Denorm ScalarFnArray into (normalized_sorf_child, norms_child). fn unwrap_l2denorm(encoded: &ArrayRef) -> (ArrayRef, ArrayRef) { let sfn = encoded .as_opt::() @@ -84,16 +85,21 @@ fn unwrap_codes_centroids_norms( encoded: &ArrayRef, ctx: &mut vortex_array::ExecutionCtx, ) -> VortexResult<(PrimitiveArray, PrimitiveArray, PrimitiveArray)> { - let (sorf_child, norms_child) = unwrap_l2denorm(encoded); + let (normalized_sorf_child, norms_child) = unwrap_l2denorm(encoded); + let normalized_sorf = normalized_sorf_child + .as_opt::() + .expect("expected NormalizedVector wrapping SorfTransform"); + let sorf_child = normalized_sorf.storage_array(); let padded_vector_child = sorf_child .as_opt::() .expect("expected SorfTransform ScalarFnArray") .child_at(0) .clone(); - // Vector wrapping FSL(Dict(codes, centroids)) - let padded_vector: ExtensionArray = padded_vector_child.execute(ctx)?; - let fsl: FixedSizeListArray = padded_vector.storage_array().clone().execute(ctx)?; + // NormalizedVector wrapping Vector wrapping FSL(Dict(codes, centroids)). + let normalized_vector: ExtensionArray = padded_vector_child.execute(ctx)?; + let inner_vector: ExtensionArray = normalized_vector.storage_array().clone().execute(ctx)?; + let fsl: FixedSizeListArray = inner_vector.storage_array().clone().execute(ctx)?; let dict = fsl .elements() .as_opt::() diff --git a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs index d82be3cf714..3a0d5a061ee 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs @@ -11,10 +11,9 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexResult; -use vortex_error::vortex_err; use super::*; -use crate::encodings::turboquant::turboquant_encode_unchecked; +use crate::encodings::turboquant::turboquant_encode_normalized; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; #[rstest] @@ -185,7 +184,7 @@ fn rejects_invalid_bit_width(#[case] bit_width: u8) { let normalized_ext = normalized .as_opt::() .expect("normalized child should be Extension"); - assert!(unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx) }.is_err()); + assert!(turboquant_encode_normalized(normalized_ext, &config, &mut ctx).is_err()); } #[test] @@ -196,7 +195,8 @@ fn all_zero_vectors_roundtrip() -> VortexResult<()> { let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new( elements.into_array(), - dim.try_into().map_err(|e| vortex_err!("{e}"))?, + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), Validity::NonNullable, num_rows, )?; @@ -245,7 +245,7 @@ fn f64_input_encodes_successfully() -> VortexResult<()> { let num_rows = 10; let dim = 128; let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f64, 1.0).map_err(|e| vortex_err!("{e}"))?; + let normal = Normal::new(0.0f64, 1.0).unwrap(); let mut buf = BufferMut::::with_capacity(num_rows * dim); for _ in 0..(num_rows * dim) { @@ -254,7 +254,7 @@ fn f64_input_encodes_successfully() -> VortexResult<()> { let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new( elements.into_array(), - dim.try_into().map_err(|e| vortex_err!("{e}"))?, + dim.try_into().unwrap(), Validity::NonNullable, num_rows, )?; @@ -278,7 +278,7 @@ fn f16_input_encodes_successfully() -> VortexResult<()> { let num_rows = 10; let dim = 128; let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f32, 1.0).map_err(|e| vortex_err!("{e}"))?; + let normal = Normal::new(0.0f32, 1.0).unwrap(); let mut buf = BufferMut::::with_capacity(num_rows * dim); for _ in 0..(num_rows * dim) { @@ -287,7 +287,7 @@ fn f16_input_encodes_successfully() -> VortexResult<()> { let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new( elements.into_array(), - dim.try_into().map_err(|e| vortex_err!("{e}"))?, + dim.try_into().unwrap(), Validity::NonNullable, num_rows, )?; diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index 3913cf3d8fe..4b153596ec4 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -4,16 +4,21 @@ //! Tests that verify the internal structure of the encoded tree. use vortex_array::VortexSessionExecute; +use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFn; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_error::VortexResult; use super::*; +use crate::encodings::turboquant::centroids::compute_or_get_centroids; +use crate::types::normalized_vector::NormalizedVector; -/// Verify that the centroids stored in the DictArray match what `compute_or_get_centroids()` computes. +/// Verify that the centroids stored in the DictArray match what `compute_or_get_centroids()` +/// computes. #[test] fn stored_centroids_match_computed() -> VortexResult<()> { let fsl = make_fsl(10, 128, 42); @@ -30,7 +35,7 @@ fn stored_centroids_match_computed() -> VortexResult<()> { let stored = centroids.as_slice::(); // padded_dim for dim=128 is 128. - let computed = crate::encodings::turboquant::centroids::compute_or_get_centroids(128, 3)?; + let computed = compute_or_get_centroids(128, 3)?; assert_eq!(stored.len(), computed.len()); for i in 0..stored.len() { @@ -108,6 +113,42 @@ fn encoded_dtype_is_vector_extension() -> VortexResult<()> { Ok(()) } +/// Verify the L2Denorm child keeps the normalized-vector marker even though SorfTransform itself +/// returns a plain Vector. +#[test] +fn encoded_l2_denorm_child_is_normalized_sorf_transform() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: 123, + num_rounds: 2, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(ext, &config, &mut ctx)?; + + let (normalized_child, _norms) = unwrap_l2denorm(&encoded); + assert!( + normalized_child + .dtype() + .as_extension() + .is::(), + "L2Denorm child should carry NormalizedVector dtype" + ); + + let normalized_ext = normalized_child + .as_opt::() + .expect("normalized child should be an Extension array"); + assert!( + normalized_ext + .storage_array() + .as_opt::() + .is_some(), + "NormalizedVector storage should be the SorfTransform ScalarFnArray" + ); + Ok(()) +} + /// Verify approximate cosine similarity in the quantized domain. #[test] fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 7beadc02e93..72fd137ecaf 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -22,6 +22,7 @@ use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_norm::L2Norm; use crate::scalar_fns::sorf_transform::SorfTransform; use crate::types::fixed_shape::FixedShapeTensor; +use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; pub mod matcher; @@ -30,6 +31,7 @@ pub mod scalar_fns; mod types; pub use types::fixed_shape; +pub use types::normalized_vector; pub use types::vector; pub mod encodings; @@ -48,6 +50,7 @@ pub const SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str = "VX_SCALAR_FN_ARRAY_TENSOR_P /// Initialize the Vortex tensor library with a Vortex session. pub fn initialize(session: &VortexSession) { session.dtypes().register(Vector); + session.dtypes().register(NormalizedVector); session.dtypes().register(FixedShapeTensor); let session_fns = session.scalar_fns(); diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs index 4566dcb3a38..786f78036d9 100644 --- a/vortex-tensor/src/matcher.rs +++ b/vortex-tensor/src/matcher.rs @@ -9,6 +9,7 @@ use vortex_array::dtype::extension::Matcher; use crate::types::fixed_shape::AnyFixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMatcherMetadata; +use crate::types::normalized_vector::AnyNormalizedVector; use crate::types::vector::AnyVector; use crate::types::vector::VectorMatcherMetadata; @@ -18,6 +19,7 @@ use crate::types::vector::VectorMatcherMetadata; /// /// - `FixedShapeTensor` /// - `Vector` +/// - `NormalizedVector` pub struct AnyTensor; /// The matched variant of a tensor-like extension type. @@ -30,6 +32,10 @@ pub enum TensorMatch<'a> { /// /// Note that we store an owned type here wrapping (copyable) data from the dtype. Vector(VectorMatcherMetadata), + + /// A [`NormalizedVector`](crate::normalized_vector::NormalizedVector) extension over + /// [`Vector`](crate::vector::Vector) storage. + NormalizedVector(VectorMatcherMetadata), } impl TensorMatch<'_> { @@ -37,7 +43,7 @@ impl TensorMatch<'_> { pub fn element_ptype(self) -> PType { match self { Self::FixedShapeTensor(metadata) => metadata.element_ptype(), - Self::Vector(metadata) => metadata.element_ptype(), + Self::Vector(metadata) | Self::NormalizedVector(metadata) => metadata.element_ptype(), } } @@ -45,9 +51,15 @@ impl TensorMatch<'_> { pub fn list_size(self) -> u32 { match self { Self::FixedShapeTensor(metadata) => metadata.flat_list_size(), - Self::Vector(metadata) => metadata.dimensions(), + Self::Vector(metadata) | Self::NormalizedVector(metadata) => metadata.dimensions(), } } + + /// Returns `true` when the dtype is a + /// [`NormalizedVector`](crate::normalized_vector::NormalizedVector). + pub fn is_normalized(self) -> bool { + matches!(self, Self::NormalizedVector(_)) + } } impl Matcher for AnyTensor { @@ -58,7 +70,13 @@ impl Matcher for AnyTensor { return Some(TensorMatch::FixedShapeTensor(metadata)); } - // Special logic for vectors to get convenience metadata (instead of `EmptyMetadata`). + // Check `AnyNormalizedVector` first because `AnyVector` is inclusive: it would otherwise + // match `NormalizedVector` and we'd lose the normalized variant in the returned + // `TensorMatch`. + if let Some(metadata) = ext_dtype.metadata_opt::() { + return Some(TensorMatch::NormalizedVector(metadata)); + } + if let Some(metadata) = ext_dtype.metadata_opt::() { return Some(TensorMatch::Vector(metadata)); } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 7819e0b46f0..dd7fc58e455 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -34,11 +34,10 @@ use vortex_error::VortexResult; use vortex_session::VortexSession; use crate::scalar_fns::inner_product::InnerProduct; -use crate::scalar_fns::l2_denorm::DenormOrientation; -use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm; +use crate::scalar_fns::l2_denorm::NormalForm; +use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm_from_constant; use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::BinaryTensorOpMetadata; -use crate::utils::extract_l2_denorm_children; use crate::utils::validate_binary_tensor_float_inputs; /// Cosine similarity between two columns. @@ -131,30 +130,41 @@ impl ScalarFnVTable for CosineSimilarity { let mut rhs_ref = args.get(1)?; let len = args.row_count(); - // If either side is a constant tensor-like extension array, eagerly normalize the single + // If either side is a constant vector extension array, eagerly normalize the single // stored row and re-wrap it as an `L2Denorm` whose children are both `ConstantArray`s. // The L2Denorm fast path below then picks it up. - if let Some(sfn) = try_build_constant_l2_denorm(&lhs_ref, len, ctx)? { + if let Some(sfn) = try_build_constant_l2_denorm_from_constant(&lhs_ref, len, ctx)? { lhs_ref = sfn.into_array(); } - if let Some(sfn) = try_build_constant_l2_denorm(&rhs_ref, len, ctx)? { + if let Some(sfn) = try_build_constant_l2_denorm_from_constant(&rhs_ref, len, ctx)? { rhs_ref = sfn.into_array(); } - // Take any L2Denorm-wrapped fast path that applies. - match DenormOrientation::classify(&lhs_ref, &rhs_ref) { - DenormOrientation::Both { lhs, rhs } => { - return self.execute_both_denorm(lhs, rhs, len); + // The combined validity always comes from the original operands. Compute it once up front + // so the unit-form helpers below can take it directly without re-deriving from an + // `L2Denorm` wrapper they no longer hold. + let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; + + // Classify each operand by its normal form. + let lhs_form = NormalForm::classify(&lhs_ref); + let rhs_form = NormalForm::classify(&rhs_ref); + match (lhs_form.normalized_array(), rhs_form.normalized_array()) { + (Some(unit_lhs), Some(unit_rhs)) => { + // When both operands carry a known unit-norm representation, cosine similarity + // collapses to the dot product of the unit vectors. + return self.execute_both_unit(unit_lhs, unit_rhs, validity, len); } - DenormOrientation::One { denorm, plain } => { - return self.execute_one_denorm(denorm, plain, len, ctx); + // When one operand carries a unit-norm representation, then we can skip one of the + // division steps. + (Some(unit_lhs), None) => { + return self.execute_one_unit(unit_lhs, &rhs_ref, validity, len, ctx); } - DenormOrientation::Neither => {} + (None, Some(unit_rhs)) => { + return self.execute_one_unit(unit_rhs, &lhs_ref, validity, len, ctx); + } + (None, None) => {} } - // Compute combined validity. - let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - // Compute inner product and norms as columnar operations, and propagate the options. let norm_lhs_arr = L2Norm::try_new_array(lhs_ref.clone(), len)?; let norm_rhs_arr = L2Norm::try_new_array(rhs_ref.clone(), len)?; @@ -167,7 +177,7 @@ impl ScalarFnVTable for CosineSimilarity { let norm_r: PrimitiveArray = norm_rhs_arr.into_array().execute(ctx)?; // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation. - // TODO(connor): This can be written in a more SIMD-friendly manner. + // TODO(connor): This can probably be written in a more SIMD-friendly manner. match_each_float_ptype!(dot.ptype(), |T| { let dots = dot.as_slice::(); let norms_l = norm_l.as_slice::(); @@ -229,6 +239,7 @@ impl ScalarFnArrayVTable for CosineSimilarity { ) -> VortexResult> { let reconstructed = BinaryTensorOpMetadata::decode_children(metadata, len, children, session)?; + Ok(ScalarFnArrayParts { options: EmptyOptions, children: reconstructed, @@ -237,22 +248,17 @@ impl ScalarFnArrayVTable for CosineSimilarity { } impl CosineSimilarity { - /// Both sides are `L2Denorm`: treat the normalized children as authoritative, so - /// `cosine_similarity = dot(n_l, n_r)`. - fn execute_both_denorm( + /// Both sides carry a known unit-norm representation: cosine similarity collapses to the + /// dot product of the unit children. + fn execute_both_unit( &self, - lhs_ref: &ArrayRef, - rhs_ref: &ArrayRef, + unit_lhs: &ArrayRef, + unit_rhs: &ArrayRef, + validity: Validity, len: usize, ) -> VortexResult { - let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - - let (normalized_l, _) = extract_l2_denorm_children(lhs_ref); - let (normalized_r, _) = extract_l2_denorm_children(rhs_ref); - - // `L2Denorm` makes the normalized children authoritative, so their dot product is the - // cosine similarity even for lossy storage wrappers. - let dot = InnerProduct::try_new_array(normalized_l, normalized_r, len)?.into_array(); + let dot = + InnerProduct::try_new_array(unit_lhs.clone(), unit_rhs.clone(), len)?.into_array(); if !matches!(validity, Validity::NonNullable) { // Masking always changes the nullability to nullable. @@ -262,25 +268,20 @@ impl CosineSimilarity { } } - /// One side is `L2Denorm`: treat the normalized child as authoritative, so - /// `cosine_similarity = dot(n, b) / ||b||`. - /// - /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`. - fn execute_one_denorm( + /// Exactly one side carries a unit-norm representation: cosine similarity reduces to + /// `dot(unit, plain) / ||plain||`. + fn execute_one_unit( &self, - denorm_ref: &ArrayRef, - plain_ref: &ArrayRef, + unit: &ArrayRef, + plain: &ArrayRef, + validity: Validity, len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?; - - let (normalized, _) = extract_l2_denorm_children(denorm_ref); - - let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)?; + let dot_arr = InnerProduct::try_new_array(unit.clone(), plain.clone(), len)?; let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?; - let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?; + let norm_arr = L2Norm::try_new_array(plain.clone(), len)?; let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?; // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation. @@ -326,6 +327,7 @@ mod tests { use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; use crate::utils::test_helpers::l2_denorm_array; + use crate::utils::test_helpers::normalized_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -514,13 +516,25 @@ mod tests { Ok(()) } + /// Naked [`NormalizedVector`](crate::normalized_vector::NormalizedVector) operands take the + /// fast path: cosine similarity collapses to the dot product without computing norms. + #[test] + fn naked_normalized_vector_cosine() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let rhs = normalized_vector_array(2, &[0.6, 0.8, 0.0, 1.0], &mut ctx)?; + // Row 0: identical -> 1.0, Row 1: orthogonal -> 0.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + Ok(()) + } + #[test] fn both_denorm_self_similarity() -> VortexResult<()> { // [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8]. // [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0]. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Self-similarity should always be 1.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]); @@ -532,8 +546,8 @@ mod tests { // [3.0, 0.0] normalized [1.0, 0.0], norm 3.0. // [0.0, 4.0] normalized [0.0, 1.0], norm 4.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[1.0, 0.0], &[3.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.0, 1.0], &[4.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); Ok(()) @@ -543,8 +557,8 @@ mod tests { fn both_denorm_zero_norm() -> VortexResult<()> { // Zero-norm row: normalized is [0.0, 0.0], norm is 0.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); @@ -557,8 +571,8 @@ mod tests { // RHS is plain [3.0, 4.0]. // cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - let rhs = tensor_array(&[2], &[3.0, 4.0])?; + let lhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; + let rhs = vector_array(2, &[3.0, 4.0])?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); Ok(()) @@ -569,8 +583,8 @@ mod tests { // LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6. let mut ctx = SESSION.create_execution_ctx(); - let lhs = tensor_array(&[2], &[1.0, 0.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; + let lhs = vector_array(2, &[1.0, 0.0])?; + let rhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]); Ok(()) @@ -580,9 +594,9 @@ mod tests { fn both_denorm_null_norms() -> VortexResult<()> { // Row 0: valid, row 1: null (via nullable norms on rhs). let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; - let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; + let normalized_r = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); let rhs = L2Denorm::try_new_array(normalized_r, norms_r, 2, &mut ctx)?.into_array(); @@ -698,9 +712,45 @@ mod tests { Ok(()) } + #[test] + fn serde_round_trip_mixed_vector_and_normalized_vector() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let rhs = vector_array(2, &[3.0, 4.0, 0.0, 1.0])?; + let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone(), 2)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(CosineSimilarity); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("CosineSimilarity serialize must produce metadata"); + + let children = vec![lhs, rhs]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + Ok(()) + } + #[rstest] - #[case::vector(cosine_vector_lhs(), cosine_vector_rhs(), 2)] - #[case::fixed_shape_tensor(cosine_tensor_lhs(), cosine_tensor_rhs(), 2)] + #[case::vector( + vector_array(3, &[1.0, 0.0, 0.0, 3.0, 4.0, 0.0]).unwrap(), + vector_array(3, &[0.0, 1.0, 0.0, 3.0, 4.0, 0.0]).unwrap(), + 2, + )] + #[case::fixed_shape_tensor( + tensor_array(&[2], &[1.0, 0.0, 3.0, 4.0]).unwrap(), + tensor_array(&[2], &[0.0, 1.0, 3.0, 4.0]).unwrap(), + 2, + )] fn serde_round_trip( #[case] lhs: ArrayRef, #[case] rhs: ArrayRef, @@ -728,20 +778,4 @@ mod tests { assert_eq!(recovered.encoding_id(), original.encoding_id()); Ok(()) } - - fn cosine_vector_lhs() -> ArrayRef { - vector_array(3, &[1.0, 0.0, 0.0, 3.0, 4.0, 0.0]).expect("valid vector array") - } - - fn cosine_vector_rhs() -> ArrayRef { - vector_array(3, &[0.0, 1.0, 0.0, 3.0, 4.0, 0.0]).expect("valid vector array") - } - - fn cosine_tensor_lhs() -> ArrayRef { - tensor_array(&[2], &[1.0, 0.0, 3.0, 4.0]).expect("valid tensor array") - } - - fn cosine_tensor_rhs() -> ArrayRef { - tensor_array(&[2], &[0.0, 1.0, 3.0, 4.0]).expect("valid tensor array") - } } diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index d60938dfbd8..d0215b7e25e 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -46,14 +46,14 @@ use vortex_error::VortexResult; use vortex_session::VortexSession; use crate::matcher::AnyTensor; -use crate::scalar_fns::l2_denorm::DenormOrientation; +use crate::scalar_fns::l2_denorm::NormalForm; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfTransform; use crate::types::vector::Vector; use crate::utils::BinaryTensorOpMetadata; use crate::utils::extract_constant_flat_row; use crate::utils::extract_flat_elements; -use crate::utils::extract_l2_denorm_children; +use crate::utils::inner_vector_array; use crate::utils::validate_binary_tensor_float_inputs; /// Inner product (dot product) between two columns. @@ -124,7 +124,8 @@ impl ScalarFnVTable for InnerProduct { let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; - // TODO(connor): relax the float-only gate once integer tensors are supported. + // TODO(connor): Relax the float-only gate once integer tensors are supported, since inner + // product is defined for integer tensors. let tensor_match = validate_binary_tensor_float_inputs(lhs, rhs)?; let ptype = tensor_match.element_ptype(); let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); @@ -141,28 +142,35 @@ impl ScalarFnVTable for InnerProduct { let rhs_ref = args.get(1)?; let len = args.row_count(); - // Take any L2Denorm-wrapped fast path that applies. - match DenormOrientation::classify(&lhs_ref, &rhs_ref) { - DenormOrientation::Both { lhs, rhs } => { - return self.execute_both_denorm(lhs, rhs, len, ctx); - } - DenormOrientation::One { denorm, plain } => { - return self.execute_one_denorm(denorm, plain, len, ctx); - } - DenormOrientation::Neither => {} + // Take the factored fast path only when at least one operand wraps stored norms (the + // `Denormalized` form). Routing through this lets us extract the stored norms instead of + // canonicalizing the `L2Denorm` ScalarFnArray, which would materialize `unit · norms` + // row-by-row before the dot, an avoidable `O(N·D)` pass. + let lhs_form = NormalForm::classify(&lhs_ref); + let rhs_form = NormalForm::classify(&rhs_ref); + if matches!(lhs_form, NormalForm::Denormalized { .. }) + || matches!(rhs_form, NormalForm::Denormalized { .. }) + { + return self.execute_factored_dot(&lhs_form, &rhs_form, &lhs_ref, &rhs_ref, len, ctx); } + // Peel any `NormalizedVector` wrapper before checking reduction cases. TurboQuant marks + // the decoded SORF output as normalized, but the optimization patterns still live on the + // inner vector-shaped storage. + let lhs_inner = inner_vector_array(&lhs_ref, ctx)?; + let rhs_inner = inner_vector_array(&rhs_ref, ctx)?; + // Reduction case 1: `InnerProduct(SorfTransform(x), const)` rewrites to // `InnerProduct(x, forward_rotate(zero_pad(const)))`. Re-executes recursively so // case 2 can fire on the rewritten tree. - if let Some(rewritten) = self.try_execute_sorf_constant(&lhs_ref, &rhs_ref, len, ctx)? { + if let Some(rewritten) = self.try_execute_sorf_constant(&lhs_inner, &rhs_inner, len, ctx)? { return Ok(rewritten); } // Reduction case 2: `InnerProduct(Vector[FSL(Dict(u8, f32))], const)` is computed by // gather-summing `q[j] * values[codes[j] as usize]` per row, reading the codebook // directly instead of decoding the column into dense vectors. - if let Some(result) = self.try_execute_dict_constant(&lhs_ref, &rhs_ref, len, ctx)? { + if let Some(result) = self.try_execute_dict_constant(&lhs_inner, &rhs_inner, len, ctx)? { return Ok(result); } @@ -170,8 +178,8 @@ impl ScalarFnVTable for InnerProduct { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; // Canonicalize so we can perform the math directly. - let lhs: ExtensionArray = lhs_ref.execute(ctx)?; - let rhs: ExtensionArray = rhs_ref.execute(ctx)?; + let lhs: ExtensionArray = lhs_inner.execute(ctx)?; + let rhs: ExtensionArray = rhs_inner.execute(ctx)?; // We validated that both inputs have the same type. let ext = lhs.dtype().as_extension(); @@ -239,6 +247,7 @@ impl ScalarFnArrayVTable for InnerProduct { ) -> VortexResult> { let reconstructed = BinaryTensorOpMetadata::decode_children(metadata, len, children, session)?; + Ok(ScalarFnArrayParts { options: EmptyOptions, children: reconstructed, @@ -247,9 +256,18 @@ impl ScalarFnArrayVTable for InnerProduct { } impl InnerProduct { - /// Both sides are `L2Denorm`: `inner_product = s_l * s_r * dot(n_l, n_r)`. - fn execute_both_denorm( + /// Compute `` after factoring each operand into a `(vector, optional_scale)` pair + /// via [`factor_operand`]. The math is ` = scale_l · scale_r · `, where + /// a `None` scale acts as `1.0` (skipping the per-row multiply). + /// + /// This is **not** restricted to unit-norm operands. `Plain` factors as `(operand, None)` with + /// `scale = 1`, and the formula still holds: `` = + /// `scale_r · `. The win over the standard path is avoiding canonicalizing the + /// `L2Denorm` ScalarFnArray (which would materialize `unit · norms` per row before the dot). + fn execute_factored_dot( &self, + lhs_form: &NormalForm<'_>, + rhs_form: &NormalForm<'_>, lhs_ref: &ArrayRef, rhs_ref: &ArrayRef, len: usize, @@ -257,50 +275,34 @@ impl InnerProduct { ) -> VortexResult { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; - let (normalized_l, norms_l) = extract_l2_denorm_children(lhs_ref); - let (normalized_r, norms_r) = extract_l2_denorm_children(rhs_ref); - - let norms_l: PrimitiveArray = norms_l.execute(ctx)?; - let norms_r: PrimitiveArray = norms_r.execute(ctx)?; - - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r, len)? - .into_array() - .execute(ctx)?; - - match_each_float_ptype!(dot.ptype(), |T| { - let dots = dot.as_slice::(); - let nl = norms_l.as_slice::(); - let nr = norms_r.as_slice::(); - let buffer: Buffer = (0..len).map(|i| nl[i] * nr[i] * dots[i]).collect(); - - // SAFETY: The buffer length equals `len`, which matches the source validity length. - Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) - }) - } - - /// One side is `L2Denorm`: `inner_product = s * dot(n, other)`. - /// - /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`. - fn execute_one_denorm( - &self, - denorm_ref: &ArrayRef, - plain_ref: &ArrayRef, - len: usize, - ctx: &mut ExecutionCtx, - ) -> VortexResult { - let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?; - - let (normalized, norms) = extract_l2_denorm_children(denorm_ref); - let denorm_norms: PrimitiveArray = norms.execute(ctx)?; + let (vec_lhs, lhs_scale) = factor_operand(lhs_form, lhs_ref, ctx)?; + let (vec_rhs, rhs_scale) = factor_operand(rhs_form, rhs_ref, ctx)?; - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)? + // NB: The call into `dot(vec_l, vec_r)` here dispatches back through `InnerProduct`, which + // lets the SORF and Dict reductions fire on TurboQuant's `SorfTransform` child. + let dot: PrimitiveArray = InnerProduct::try_new_array(vec_lhs, vec_rhs, len)? .into_array() .execute(ctx)?; + // TODO(connor): This should use the binary `Mul` expressions. match_each_float_ptype!(dot.ptype(), |T| { let dots = dot.as_slice::(); - let ns = denorm_norms.as_slice::(); - let buffer: Buffer = (0..len).map(|i| ns[i] * dots[i]).collect(); + let buffer: Buffer = match (lhs_scale.as_ref(), rhs_scale.as_ref()) { + (Some(nl), Some(nr)) => { + let nl = nl.as_slice::(); + let nr = nr.as_slice::(); + (0..len).map(|i| nl[i] * nr[i] * dots[i]).collect() + } + (Some(nl), None) => { + let nl = nl.as_slice::(); + (0..len).map(|i| nl[i] * dots[i]).collect() + } + (None, Some(nr)) => { + let nr = nr.as_slice::(); + (0..len).map(|i| nr[i] * dots[i]).collect() + } + (None, None) => dots.iter().copied().collect(), + }; // SAFETY: The buffer length equals `len`, which matches the source validity length. Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) @@ -498,6 +500,37 @@ impl InnerProduct { } } +/// Factor an operand classified by [`NormalForm`] into the `(vector, optional_scale)` pair consumed +/// by [`InnerProduct::execute_factored_dot`]. The factorization satisfies +/// `original = scale · vector` (with `scale = 1` when the returned scale is `None`), so the inner +/// product distributes as ` = scale_l · scale_r · `. +/// +/// The "vector" component is **not** required to be unit-norm: for `Plain` operands the entire +/// operand is returned as the "vector" with an implicit scale of `1`. The point of the +/// factorization is to surface the stored norms of `Denormalized` operands so they can be applied +/// after the dot, not to assert anything about the geometry of the vector component. +/// +/// - `Plain`: `(original, None)`. Implicit `scale = 1`; the operand passes through to the dot +/// unchanged. +/// - `Normalized`: `(inner_vector, None)`. Implicit `scale = 1`; the unit-norm child passes +/// through to the dot unchanged, with the wrapper peeled so SORF/dict reductions can still fire. +/// - `Denormalized`: `(inner_vector, Some(stored_norms))`. The dot is computed over the unit +/// child and the caller multiplies row-wise by the materialized stored norms afterward. +fn factor_operand( + form: &NormalForm<'_>, + original: &ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult<(ArrayRef, Option)> { + match form { + NormalForm::Plain => Ok((original.clone(), None)), + NormalForm::Normalized { array } => Ok((inner_vector_array(array, ctx)?, None)), + NormalForm::Denormalized { normalized, norms } => Ok(( + inner_vector_array(normalized, ctx)?, + Some(norms.clone().execute(ctx)?), + )), + } +} + /// Return the storage constant for a canonical tensor-like constant query. fn constant_tensor_storage(array: &ArrayRef) -> Option { let constant = array.as_opt::()?; @@ -581,6 +614,7 @@ mod tests { use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::l2_denorm_array; + use crate::utils::test_helpers::normalized_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -701,8 +735,8 @@ mod tests { // RHS: [1.0, 0.0] = L2Denorm([1.0, 0.0], 1.0). // dot([3.0, 4.0], [1.0, 0.0]) = 3.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[1.0, 0.0], &[1.0], &mut ctx)?; // Expected: 5.0 * 1.0 * dot([0.6, 0.8], [1.0, 0.0]) = 5.0 * 0.6 = 3.0. assert_close(&eval_inner_product(lhs, rhs, 1)?, &[3.0]); @@ -714,8 +748,8 @@ mod tests { // Row 0: [3.0, 4.0] dot [3.0, 4.0] = 25.0. // Row 1: [1.0, 0.0] dot [0.0, 1.0] = 0.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0], &mut ctx)?; + let lhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); Ok(()) @@ -727,8 +761,8 @@ mod tests { // RHS: plain [1.0, 2.0]. // dot([3.0, 4.0], [1.0, 2.0]) = 3.0 + 8.0 = 11.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - let rhs = tensor_array(&[2], &[1.0, 2.0])?; + let lhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; + let rhs = vector_array(2, &[1.0, 2.0])?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); Ok(()) @@ -740,8 +774,8 @@ mod tests { // RHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // dot([1.0, 2.0], [3.0, 4.0]) = 3.0 + 8.0 = 11.0. let mut ctx = SESSION.create_execution_ctx(); - let lhs = tensor_array(&[2], &[1.0, 2.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; + let lhs = vector_array(2, &[1.0, 2.0])?; + let rhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); Ok(()) @@ -750,12 +784,16 @@ mod tests { #[test] fn both_denorm_null_norms() -> VortexResult<()> { // Row 0: valid, row 1: null (via nullable norms on lhs). - let normalized_l = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; + let normalized_l = normalized_vector_array( + 2, + &[0.6, 0.8, 1.0, 0.0], + &mut SESSION.create_execution_ctx(), + )?; let norms_l = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); let mut ctx = SESSION.create_execution_ctx(); let lhs = L2Denorm::try_new_array(normalized_l, norms_l, 2, &mut ctx)?.into_array(); - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let scalar_fn = InnerProduct::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; @@ -768,9 +806,58 @@ mod tests { Ok(()) } + /// Naked [`NormalizedVector`](crate::normalized_vector::NormalizedVector) operands fall + /// through to the regular dot path (no extra scaling). The result is just `dot(lhs, rhs)`. + #[test] + fn naked_normalized_vector_dot() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let rhs = normalized_vector_array(2, &[0.6, 0.8, 0.0, 1.0], &mut ctx)?; + + // Row 0: dot([0.6,0.8],[0.6,0.8]) = 1.0, Row 1: dot([1.0,0.0],[0.0,1.0]) = 0.0. + assert_close(&eval_inner_product(lhs, rhs, 2)?, &[1.0, 0.0]); + Ok(()) + } + + #[test] + fn serde_round_trip_mixed_vector_and_normalized_vector() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let rhs = vector_array(2, &[3.0, 4.0, 0.0, 1.0])?; + let original = InnerProduct::try_new_array(lhs.clone(), rhs.clone(), 2)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(InnerProduct); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("InnerProduct serialize must produce metadata"); + + let children = vec![lhs, rhs]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + Ok(()) + } + #[rstest] - #[case::vector(inner_product_vector_lhs(), inner_product_vector_rhs(), 2)] - #[case::fixed_shape_tensor(inner_product_tensor_lhs(), inner_product_tensor_rhs(), 2)] + #[case::vector( + vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), + vector_array(3, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap(), + 2, + )] + #[case::fixed_shape_tensor( + tensor_array(&[2], &[1.0, 2.0, 3.0, 4.0]).unwrap(), + tensor_array(&[2], &[5.0, 6.0, 7.0, 8.0]).unwrap(), + 2, + )] fn serde_round_trip( #[case] lhs: ArrayRef, #[case] rhs: ArrayRef, @@ -799,22 +886,6 @@ mod tests { Ok(()) } - fn inner_product_vector_lhs() -> ArrayRef { - vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid vector array") - } - - fn inner_product_vector_rhs() -> ArrayRef { - vector_array(3, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("valid vector array") - } - - fn inner_product_tensor_lhs() -> ArrayRef { - tensor_array(&[2], &[1.0, 2.0, 3.0, 4.0]).expect("valid tensor array") - } - - fn inner_product_tensor_rhs() -> ArrayRef { - tensor_array(&[2], &[5.0, 6.0, 7.0, 8.0]).expect("valid tensor array") - } - // ---- Tests for the `SorfTransform + constant` and `Dict + constant` fast paths ---- #[allow( diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 1bdd81833d9..c3841cbb600 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! L2 denormalization expression for tensor-like types. +//! L2 denormalization expression for normalized vectors. use std::fmt::Formatter; @@ -31,9 +31,11 @@ use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; +use vortex_array::dtype::extension::ExtDType; use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::expr::and; +use vortex_array::extension::EmptyMetadata; use vortex_array::match_each_float_ptype; use vortex_array::scalar::Scalar; use vortex_array::scalar::ScalarValue; @@ -57,34 +59,36 @@ use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; use vortex_session::VortexSession; -use crate::matcher::AnyTensor; use crate::scalar_fns::l2_norm::L2Norm; +use crate::types::normalized_vector::AnyNormalizedVector; +use crate::types::normalized_vector::NormalizedVector; +use crate::types::vector::AnyVector; +use crate::types::vector::Vector; use crate::utils::extract_constant_flat_row; use crate::utils::extract_flat_elements; +use crate::utils::extract_l2_denorm_children; +use crate::utils::inner_vector_array; use crate::utils::unit_norm_tolerance; -use crate::utils::validate_tensor_float_input; +use crate::utils::vector_fsl_storage_dtype; -/// Re-applies authoritative L2 norms to a normalized tensor column. +/// Re-applies authoritative L2 norms to a normalized vector column. /// -/// Computes `normalized * norm` on each row over the flat backing buffer of each tensor-like type. +/// Computes `normalized * norm` on each row over the flat backing buffer of the vector child. /// -/// The normalized input must be a tensor-like extension array with a float element type and each -/// non-null row is semantically required to already be L2-normalized. +/// The first child must be a [`NormalizedVector`]. Exact callers should use +/// [`try_new_array`](Self::try_new_array), which verifies that stored norms are non-negative and +/// that a zero stored norm is paired with an all-zero normalized row. Lossy encodings may use +/// [`new_array_unchecked`](Self::new_array_unchecked) when the decoded child is only an +/// approximation but the stored norms are still authoritative. /// -/// The norms input must be a primitive float column with the same element type as the normalized -/// tensor elements. -/// -/// [`L2Denorm`] is the norm-splitting wrapper used throughout the tensor crate. Callers that build -/// it through [`try_new_array`](Self::try_new_array) get an exact unit-norm invariant on the -/// `normalized` child. -/// -/// Advanced callers can also use [`new_array_unchecked`](Self::new_array_unchecked) to attach -/// authoritative stored norms to a lossy approximation of that child, such as quantized normalized -/// vectors. +/// The norms input must be a primitive float column with the same element type as the +/// normalized vector elements. /// /// Downstream readthrough rules intentionally treat the stored norms and normalized child as the /// encoding contract, even when that differs slightly from recomputing over fully decoded /// coordinates. +/// +/// [`NormalizedVector`]: crate::normalized_vector::NormalizedVector #[derive(Clone)] pub struct L2Denorm; @@ -99,45 +103,44 @@ impl L2Denorm { /// Constructs a validated [`ScalarFnArray`] that lazily re-applies `norms` to `normalized`. /// - /// This is the correct constructor for [`L2Denorm`] arrays. In addition to the structural - /// checks performed by [`ScalarFnArray::try_new`], it validates that every valid row of the - /// `normalized` child has L2 norm `1.0` (or `0.0` for zero rows), within the tolerance implied - /// by the child element precision. It also validates that stored norms are non-negative, and - /// that any row with stored norm `0.0` has an all-zero normalized row. + /// In addition to the structural checks performed by [`ScalarFnArray::try_new`], this + /// constructor verifies that the first child is a [`NormalizedVector`], that stored norms are + /// non-negative, and that any row with stored norm `0.0` has an all-zero normalized row. /// /// # Errors /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype - /// mismatches) or if the `normalized` child is not row-wise L2-normalized. + /// mismatches), if a stored norm is negative, or if a zero-norm row is paired with a + /// non-zero normalized row. pub fn try_new_array( normalized: ArrayRef, norms: ArrayRef, len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - validate_l2_normalized_rows_against_norms(&normalized, Some(&norms), ctx)?; + validate_norms_against_normalized(&normalized, &norms, ctx)?; - // SAFETY: We just validated that it is normalized. + // SAFETY: The validation above established the exact L2Denorm invariants. unsafe { Self::new_array_unchecked(normalized, norms, len) } } - /// Constructs an [`L2Denorm`] array without validating that the `normalized` child is actually - /// row-wise L2-normalized. + /// Constructs an [`L2Denorm`] array without validating row values against `norms`. /// - /// This escape hatch is intended for advanced callers that already established, or - /// intentionally relax, the normalized-child invariant. Structural validation still runs via - /// [`ScalarFnArray::try_new`]. + /// Structural validation still runs via [`ScalarFnArray::try_new`], so the first child must be + /// a [`NormalizedVector`]. Use this when the normalized child is a lossy approximation whose + /// rows may not be exactly unit-norm or may not preserve exact zero-ness. /// /// # Safety /// - /// The caller must ensure the `normalized` child is semantically suitable for L2 - /// denormalization. For exact wrappers, that means every valid row is unit-norm or zero. + /// The caller must ensure the first child is semantically suitable for L2 denormalization and + /// is wrapped as a [`NormalizedVector`]. For exact wrappers, every valid row must be unit-norm + /// or zero and stored norms must be non-negative. Lossy encodings may deliberately relax the + /// exact row invariant while still treating the stored norms as authoritative. /// - /// Lossy encodings may deliberately relax that invariant while still treating the stored norms - /// as authoritative. + /// # Errors /// - /// Violating the intended contract will not cause memory unsafety, but may produce incorrect - /// results. + /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype + /// mismatches). pub unsafe fn new_array_unchecked( normalized: ArrayRef, norms: ArrayRef, @@ -183,20 +186,7 @@ impl ScalarFnVTable for L2Denorm { let normalized = &arg_dtypes[0]; let norms = &arg_dtypes[1]; - let tensor_match = validate_tensor_float_input(normalized)?; - let element_ptype = tensor_match.element_ptype(); - - let DType::Primitive(norms_ptype, _) = norms else { - vortex_bail!("L2Denorm norms must be a primitive float array, got {norms}"); - }; - vortex_ensure_eq!( - *norms_ptype, - element_ptype, - "L2Denorm norms dtype must match normalized element dtype ({element_ptype}), \ - got {norms_ptype}", - ); - - Ok(normalized.union_nullability(norms.nullability())) + l2_denorm_output_dtype(normalized, norms) } fn execute( @@ -207,9 +197,7 @@ impl ScalarFnVTable for L2Denorm { ) -> VortexResult { let normalized_ref = args.get(0)?; let norms_ref = args.get(1)?; - let output_dtype = normalized_ref - .dtype() - .union_nullability(norms_ref.dtype().nullability()); + let output_dtype = l2_denorm_output_dtype(normalized_ref.dtype(), norms_ref.dtype())?; let validity = normalized_ref.validity()?.and(norms_ref.validity()?)?; if let Some(const_norms) = norms_ref.as_opt::() { @@ -232,16 +220,19 @@ impl ScalarFnVTable for L2Denorm { } } - let normalized: ExtensionArray = normalized_ref.execute(ctx)?; + // Drill past any `NormalizedVector` wrapper so we always work with the underlying + // `Vector` extension array. + let vector_ref = inner_vector_array(&normalized_ref, ctx)?; + let normalized: ExtensionArray = vector_ref.execute(ctx)?; let norms: PrimitiveArray = norms_ref.execute(ctx)?; let row_count = args.row_count(); - let tensor_match = normalized + let vector_metadata = normalized .dtype() .as_extension() - .metadata_opt::() + .metadata_opt::() .vortex_expect("we already validated this in `return_dtype`"); - let tensor_flat_size = tensor_match.list_size() as usize; + let tensor_flat_size = vector_metadata.dimensions() as usize; let flat = extract_flat_elements(normalized.storage_array(), tensor_flat_size, ctx)?; @@ -286,6 +277,40 @@ impl ScalarFnVTable for L2Denorm { } } +/// Returns the denormalized output dtype for a normalized vector child and matching norms. +fn l2_denorm_output_dtype(normalized: &DType, norms: &DType) -> VortexResult { + let normalized_ext = normalized.as_extension_opt().ok_or_else(|| { + vortex_err!("L2Denorm normalized child must be a NormalizedVector, got {normalized}") + })?; + let normalized_metadata = normalized_ext + .metadata_opt::() + .ok_or_else(|| { + vortex_err!("L2Denorm normalized child must be a NormalizedVector, got {normalized}") + })?; + let element_ptype = normalized_metadata.element_ptype(); + + let DType::Primitive(norms_ptype, _) = norms else { + vortex_bail!("L2Denorm norms must be a primitive float array, got {norms}"); + }; + vortex_ensure!( + norms_ptype.is_float(), + "L2Denorm norms must be a primitive float array, got {norms}", + ); + vortex_ensure_eq!( + *norms_ptype, + element_ptype, + "L2Denorm norms dtype must match normalized element dtype ({element_ptype}), \ + got {norms_ptype}", + ); + + let fsl_dtype = vector_fsl_storage_dtype(normalized_ext).ok_or_else(|| { + vortex_err!("L2Denorm normalized child must be a NormalizedVector, got {normalized}") + })?; + let output = DType::Extension(ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased()); + + Ok(output.union_nullability(norms.nullability())) +} + /// Metadata for a serialized [`L2Denorm`] array: both children's full [`DType`]s. The parent's /// dtype is `normalized.union_nullability(norms.nullability())`, which loses both children's /// individual nullabilities, so we persist them directly. @@ -360,28 +385,49 @@ fn execute_l2_denorm_constant_norms( .vortex_expect("we know that this is a float, so it must fit in f64") - 1.0f64; - let tensor_match = normalized_ref + let normalized_metadata = normalized_ref .dtype() .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) + .and_then(|ext| ext.metadata_opt::()) .ok_or_else(|| { vortex_err!( - "L2Denorm normalized child must be a tensor-like extension, got {}", + "L2Denorm normalized child must be a NormalizedVector, got {}", normalized_ref.dtype(), ) })?; let tolerance = unit_norm_tolerance( norm_scalar.dtype().as_ptype(), - tensor_match.list_size() as usize, + normalized_metadata.dimensions() as usize, ); + + // Drill past any outer `NormalizedVector` wrapper so we always work with the inner plain + // `Vector` extension array (and its `FixedSizeList` storage). + let vector_ref = inner_vector_array(&normalized_ref, ctx)?; + if err.abs() < tolerance { - return Ok(normalized_ref); + // The output dtype is the sibling plain `Vector`. Rebuild the FSL wrapper with the + // combined validity so the executed array's storage nullability matches `output_dtype`. + let normalized: ExtensionArray = vector_ref.execute(ctx)?; + + let storage_fsl: FixedSizeListArray = normalized.storage_array().clone().execute(ctx)?; + let new_fsl = FixedSizeListArray::try_new( + storage_fsl.elements().clone(), + storage_fsl.list_size(), + new_validity, + storage_fsl.len(), + )?; + + return Ok(ExtensionArray::try_new( + output_dtype.as_extension().clone(), + new_fsl.into_array(), + )? + .into_array()); } // Even if the norms are not all 1, if they are all the same then we can multiply // the entire elements array by the same number. - let normalized: ExtensionArray = normalized_ref.execute(ctx)?; + let normalized: ExtensionArray = vector_ref.execute(ctx)?; let storage_fsl: FixedSizeListArray = normalized.storage_array().clone().execute(ctx)?; // Replace the elements array with an array that multiplies it by the constant @@ -407,40 +453,75 @@ fn execute_l2_denorm_constant_norms( Ok(ExtensionArray::new(output_dtype.as_extension().clone(), new_fsl.into_array()).into_array()) } -/// Builds an unexecuted [`L2Denorm`] expression by normalizing `input` and reattaching the exact -/// norms as the norms child. +/// Builds an unexecuted [`L2Denorm`] expression by normalizing a vector input and reattaching the +/// exact norms as the `norms` child. /// /// The returned array is a lazy `L2Denorm(normalized, norms)` scalar function array. /// /// # Normalized child /// -/// The normalized child is always **non-nullable** with [`Validity::NonNullable`]. Every non-null -/// row with a positive L2 norm is divided by its norm to produce a unit-norm vector. +/// For plain [`Vector`] input, every non-null row with a positive L2 norm is divided by its norm +/// to produce a unit-norm vector, and the normalized child is promoted to [`NormalizedVector`]. +/// The normalized child is forced **non-nullable** with [`Validity::NonNullable`] so optimized +/// kernels only have to reason about unit-norm vs. zero rows, not nulls. Rows that are null in the +/// original input are **zeroed out** in the normalized output to avoid leaking undefined physical +/// storage values into downstream encodings. /// -/// Rows that are null in the original input are **zeroed out** in the normalized output. This is -/// necessary because null rows may have undefined (garbage) physical storage values, and we do not -/// want to let those propagate into downstream encodings (like TurboQuant). +/// For [`NormalizedVector`] input, the function takes a fast path that returns the input +/// unchanged as the normalized child and asks [`L2Norm`] for the per-row norms. The fast path +/// preserves the input's outer nullability rather than rewriting null rows to zero, since the +/// caller has already committed to a [`NormalizedVector`] shape and we do not want to +/// re-allocate storage just to coerce nullability. /// /// # Nullability /// -/// Nullability is tracked entirely by the norms child. Null input rows produce null norms via -/// [`L2Norm`]'s validity propagation. When the [`L2Denorm`] wrapper is executed, its validity is -/// `and(normalized_validity, norms_validity)`, which correctly identifies originally-null rows -/// since the normalized child is all-valid and the norms child carries the original nulls. +/// Nullability is tracked entirely by the `norms` child. Null input rows produce null `norms` via +/// [`L2Norm`]'s validity propagation. When the [`L2Denorm`] wrapper is executed, the output +/// validity is `and(normalized_validity, norms_validity)`, which correctly identifies +/// originally-null rows. +/// +/// Because this helper computes exact `norms` and (on the slow path) divides by them, the +/// returned `normalized` child satisfies the unit-norm invariant required by [`L2Denorm`]. /// -/// Because this helper computes exact norms first and then divides by those norms, the returned -/// `normalized` child satisfies the strict unit-norm invariant required by [`L2Denorm`]. +/// [`NormalizedVector`]: crate::normalized_vector::NormalizedVector pub fn normalize_as_l2_denorm( input: ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult { let row_count = input.len(); - let tensor_match = validate_tensor_float_input(input.dtype())?; - let tensor_flat_size = tensor_match.list_size() as usize; + let input_dtype = input.dtype().clone(); + let vector_metadata = input_dtype + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + .ok_or_else(|| { + vortex_err!( + "normalize_as_l2_denorm requires a Vector or NormalizedVector extension input, \ + got {input_dtype}", + ) + })?; + let tensor_flat_size = vector_metadata.dimensions() as usize; + + // Fast path: input is already a `NormalizedVector`. The slow path below would compute exact + // norms and divide every row by its norm, but for a `NormalizedVector` the divisor is always + // 1.0 (or 0.0 for zero rows). Skip the divide entirely and reattach `L2Norm`'s + // short-circuited per-row 0.0 / 1.0 norms. Crucially, this preserves the invariant required + // by [`L2Denorm::try_new_array`] that a zero-norm row is paired with an all-zero normalized + // row, because [`L2Norm`]'s `NormalizedVector` short-circuit emits 0.0 exactly when the row + // is all zero. + // This also has the added benefit of correcting any lossy-encoded `NormalizedVector` arrays. + if vector_metadata.is_normalized() { + let norms_sfn = L2Norm::try_new_array(input.clone(), row_count)?; + let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; + + // SAFETY: `input` is a `NormalizedVector`, so every valid row is unit-norm or zero by + // type. `norms_array` was produced by `L2Norm`, so stored norms are non-negative and a + // zero norm is always paired with an all-zero row. + return unsafe { L2Denorm::new_array_unchecked(input, norms_array, row_count) }; + } // Constant fast path: if the input is a constant-backed extension, normalize the single // stored row once and return an `L2Denorm` whose children are both `ConstantArray`s. - if let Some(wrapped) = try_build_constant_l2_denorm(&input, row_count, ctx)? { + if let Some(wrapped) = try_build_constant_l2_denorm_from_constant(&input, row_count, ctx)? { return Ok(wrapped); } @@ -451,11 +532,10 @@ pub fn normalize_as_l2_denorm( let norms_validity = primitive_norms.validity()?; let input: ExtensionArray = input.execute(ctx)?; - let normalized_dtype = input.dtype().as_nonnullable(); let flat = extract_flat_elements(input.storage_array(), tensor_flat_size, ctx)?; // Normalize all of the vectors. - let normalized = match_each_float_ptype!(flat.ptype(), |T| { + let normalized_storage = match_each_float_ptype!(flat.ptype(), |T| { let norm_values = primitive_norms.as_slice::(); let total_elements = row_count * tensor_flat_size; @@ -478,39 +558,35 @@ pub fn normalize_as_l2_denorm( } // Since L2Denorm's validity is the `and` of its child validities, we can make the - // normalized array non-nullable. - build_tensor_array( - normalized_dtype, - tensor_flat_size, - row_count, - Validity::NonNullable, - elements.freeze(), - ) + // normalized child non-nullable. + build_normalized_storage(tensor_flat_size, row_count, elements.freeze()) })?; // SAFETY: // - `norms_array` was produced by `L2Norm(input)`, so every stored norm is non-negative and // null rows already carry null validity through that child. // - For every valid row, we either emit all zeros when the norm is zero or divide every - // element by the exact stored norm, so the normalized child is unit-norm (or zero) by + // element by the exact stored norm, so the normalized storage is unit-norm (or zero) by // construction. - // - Null rows are zeroed out above to avoid propagating arbitrary physical storage values into - // downstream lossy encodings. + // - Null rows are zeroed out above to avoid propagating arbitrary physical storage values + // into downstream lossy encodings. + let normalized = unsafe { NormalizedVector::new_unchecked(normalized_storage) }?; unsafe { L2Denorm::new_array_unchecked(normalized, norms_array, row_count) } } /// Attempts to build an [`L2Denorm`] whose two children are both [`ConstantArray`]s by eagerly /// normalizing `input`'s single stored row. /// -/// Returns `Ok(None)` when `input` is not a tensor-like extension array whose storage is a -/// [`ConstantArray`] with a non-null fixed-size-list scalar. +/// Returns `Ok(None)` when `input` is not a plain vector extension array whose storage is a +/// [`ConstantArray`] with a non-null fixed-size-list scalar, or when it is already a +/// [`NormalizedVector`]. /// /// When `input` matches, the returned [`ScalarFnArray`] is equivalent to [`normalize_as_l2_denorm`] /// but runs in `O(list_size)` time instead of `O(row_count * list_size)`. /// /// This is helpful in some of the reduction steps for cosine similarity execution into inner /// product execution. -pub(crate) fn try_build_constant_l2_denorm( +pub(crate) fn try_build_constant_l2_denorm_from_constant( input: &ArrayRef, len: usize, ctx: &mut ExecutionCtx, @@ -526,16 +602,19 @@ pub(crate) fn try_build_constant_l2_denorm( return Ok(None); } - // The caller is expected to have already validated that `input` is an `AnyTensor` - // extension dtype. - let tensor_match = input + let Some(vector_metadata) = input .dtype() - .as_extension() - .metadata_opt::() - .vortex_expect("caller validated input has AnyTensor metadata"); - let list_size = tensor_match.list_size() as usize; + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + else { + return Ok(None); + }; + if vector_metadata.is_normalized() { + return Ok(None); + } + + let list_size = vector_metadata.dimensions() as usize; let original_nullability = input.dtype().nullability(); - let ext_dtype = input.dtype().as_extension().clone(); let storage_fsl_nullability = storage.dtype().nullability(); // Materialize just the single stored row; this does not expand the constant to the full @@ -551,8 +630,8 @@ pub(crate) fn try_build_constant_l2_denorm( } let norm_t: T = sum_sq.sqrt(); - // Zero-norm rows must be stored as all-zeros so [`L2Denorm`]'s unit-norm-or-zero - // invariant holds. This mirrors the per-row logic in `normalize_as_l2_denorm`. + // Zero-norm rows must be stored as all-zeros so the `NormalizedVector` invariant holds. + // This mirrors the per-row logic in `normalize_as_l2_denorm`. let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); let children: Vec = if norm_t == T::zero() { (0..list_size) @@ -564,26 +643,25 @@ pub(crate) fn try_build_constant_l2_denorm( .collect() }; - // The rebuilt FSL scalar preserves the original storage FSL's nullability so the - // resulting `ExtensionArray::new` call accepts the same extension dtype. let fsl_scalar = Scalar::fixed_size_list(element_dtype, children, storage_fsl_nullability); let norms_scalar = Scalar::primitive(norm_t, original_nullability); (fsl_scalar, norms_scalar) }); let normalized_storage = ConstantArray::new(normalized_fsl_scalar, len).into_array(); - let normalized_ext = ExtensionArray::new(ext_dtype, normalized_storage).into_array(); + // SAFETY: The single stored row is either `v / ||v||` (unit norm within floating-point + // tolerance) or all zeros when `||v|| == 0`. + let normalized = unsafe { NormalizedVector::new_unchecked(normalized_storage) }?; let norms_array = ConstantArray::new(norms_scalar, len).into_array(); - // SAFETY: Each row of `normalized_ext` is either `v / ||v||` (unit norm within floating - // point tolerance) or all zeros when `||v|| == 0`. Stored norms are non-negative by - // construction (`sqrt`). These are exactly the invariants required by - // [`L2Denorm::new_array_unchecked`]. - let wrapped = unsafe { L2Denorm::new_array_unchecked(normalized_ext, norms_array, len)? }; - Ok(Some(wrapped)) + // SAFETY: The single stored row is exactly normalized above (or all zeros), and the norm was + // computed with `sqrt`, so it is non-negative. + Ok(Some(unsafe { + L2Denorm::new_array_unchecked(normalized, norms_array, len)? + })) } -/// Rebuilds a tensor-like extension array from flat primitive elements. +/// Rebuilds a vector extension array from flat primitive elements. /// /// # Errors /// @@ -596,97 +674,115 @@ fn build_tensor_array( validity: Validity, elements: Buffer, ) -> VortexResult { + let storage = build_fsl_storage(tensor_flat_size, row_count, validity, elements)?.into_array(); + Ok(ExtensionArray::new(dtype.as_extension().clone(), storage).into_array()) +} + +/// Build a non-nullable [`FixedSizeListArray`] suitable for wrapping as a +/// [`NormalizedVector`] storage. +fn build_normalized_storage( + tensor_flat_size: usize, + row_count: usize, + elements: Buffer, +) -> VortexResult { + Ok( + build_fsl_storage(tensor_flat_size, row_count, Validity::NonNullable, elements)? + .into_array(), + ) +} + +/// Build a [`FixedSizeListArray`] from a flat element buffer and a per-row validity. +fn build_fsl_storage( + tensor_flat_size: usize, + row_count: usize, + validity: Validity, + elements: Buffer, +) -> VortexResult { let list_size = u32::try_from(tensor_flat_size).vortex_expect("tensor flat size must fit into `u32`"); - - // SAFETY: Validity has no length (because tensor elements are always non-nullable). + // SAFETY: Validity has no length (because vector elements are always non-nullable). let elements = unsafe { PrimitiveArray::new_unchecked(elements, Validity::NonNullable) }; - - let storage = - FixedSizeListArray::try_new(elements.into_array(), list_size, validity, row_count)?; - - Ok(ExtensionArray::new(dtype.as_extension().clone(), storage.into_array()).into_array()) + FixedSizeListArray::try_new(elements.into_array(), list_size, validity, row_count) } -/// Validates that `normalized` and (when supplied) the matching `norms` jointly satisfy the -/// [`L2Denorm`] invariants: -/// -/// - Every valid row of `normalized` has L2 norm `1.0` or `0.0` (within element-precision -/// tolerance). -/// - When `norms` is supplied, every stored norm is non-negative and any row whose stored norm is -/// `0.0` is exactly the zero vector in `normalized`. -pub fn validate_l2_normalized_rows_against_norms( +/// Cross-check that `normalized` and `norms` agree on per-row zero-ness, and that stored norms +/// are non-negative. Unit-norm enforcement on the rows lives on the +/// [`NormalizedVector`](crate::normalized_vector::NormalizedVector) dtype itself. +fn validate_norms_against_normalized( normalized: &ArrayRef, - norms: Option<&ArrayRef>, + norms: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { + let vector_metadata = normalized + .dtype() + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + .ok_or_else(|| { + vortex_err!( + "L2Denorm normalized child must be a NormalizedVector, got {}", + normalized.dtype(), + ) + })?; let row_count = normalized.len(); - if row_count == 0 { - return Ok(()); - } + let element_ptype = vector_metadata.element_ptype(); + let tensor_flat_size = vector_metadata.dimensions() as usize; - let tensor_match = validate_tensor_float_input(normalized.dtype())?; - let element_ptype = tensor_match.element_ptype(); - let tensor_flat_size = tensor_match.list_size() as usize; - let tolerance = unit_norm_tolerance(element_ptype, tensor_flat_size); + vortex_ensure_eq!( + norms.len(), + row_count, + "L2Denorm normalized and norms children must have the same length" + ); - if let Some(norms) = norms { - vortex_ensure_eq!( - norms.dtype().as_ptype(), - element_ptype, - "L2Denorm norms ptype must match normalized element ptype" + let DType::Primitive(norms_ptype, _) = norms.dtype() else { + vortex_bail!( + "L2Denorm norms must be a primitive float array, got {}", + norms.dtype() ); - } + }; + vortex_ensure_eq!( + *norms_ptype, + element_ptype, + "L2Denorm norms ptype must match normalized element ptype" + ); - let normalized: ExtensionArray = normalized.clone().execute(ctx)?; - let normalized_validity = normalized.as_ref().validity()?; + if row_count == 0 { + return Ok(()); + } - let flat = extract_flat_elements(normalized.storage_array(), tensor_flat_size, ctx)?; - let norms = norms - .map(|norms| norms.clone().execute::(ctx)) - .transpose()?; + // Drill past the outer `NormalizedVector` wrapper so we always iterate the FSL of the inner + // plain `Vector`. + let vector_ref = inner_vector_array(normalized, ctx)?; + let vector_ext: ExtensionArray = vector_ref.execute(ctx)?; + let normalized_validity = normalized.validity()?; - let combined_validity = match &norms { - Some(norms) => normalized_validity.and(norms.validity()?)?, - None => normalized_validity, - }; + let flat = extract_flat_elements(vector_ext.storage_array(), tensor_flat_size, ctx)?; + let norms_prim: PrimitiveArray = norms.clone().execute(ctx)?; + let combined_validity = normalized_validity.and(norms_prim.validity()?)?; match_each_float_ptype!(element_ptype, |T| { - let stored_norms = norms.as_ref().map(|norms| norms.as_slice::()); + let stored_norms = norms_prim.as_slice::(); for i in 0..row_count { if !combined_validity.is_valid(i)? { continue; } - let (row_norm_sq, is_zero_row) = - flat.row::(i) - .iter() - .fold((0.0f64, true), |(sum_sq, is_zero), x| { - let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN); - (sum_sq + value * value, is_zero && value.abs() <= tolerance) - }); - let row_norm = row_norm_sq.sqrt(); - + let stored_norm_f64 = ToPrimitive::to_f64(&stored_norms[i]).unwrap_or(f64::NAN); vortex_ensure!( - row_norm == 0.0 || (row_norm - 1.0).abs() <= tolerance, - "L2Denorm normalized child must have L2 norm 1.0 or 0.0, but row {i} has \ - {row_norm:.6}", + stored_norm_f64 >= 0.0, + "L2Denorm norms must be non-negative, but row {i} has {stored_norm_f64:.6}", ); - if let Some(stored_norms) = stored_norms { - let stored_norm_f64 = ToPrimitive::to_f64(&stored_norms[i]).unwrap_or(f64::NAN); + let is_zero_row = flat.row::(i).iter().all(|x| { + let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN); + value == 0.0 + }); + + if stored_norm_f64 == 0.0 { vortex_ensure!( - stored_norm_f64 >= 0.0, - "L2Denorm norms must be non-negative, but row {i} has {stored_norm_f64:.6}", + is_zero_row, + "L2Denorm normalized child must be all zeros when norms row {i} is 0.0", ); - - if stored_norm_f64 == 0.0 { - vortex_ensure!( - is_zero_row, - "L2Denorm normalized child must be all zeros when norms row {i} is 0.0", - ); - } } } }); @@ -694,47 +790,61 @@ pub fn validate_l2_normalized_rows_against_norms( Ok(()) } -/// Classification of a binary operand pair by which side (if any) is wrapped in [`L2Denorm`]. +/// Per-operand classification of a tensor argument by whether it carries an authoritative unit-norm +/// representation, and whether stored norms accompany it. /// -/// Symmetric binary tensor operators (e.g. [`CosineSimilarity`], [`InnerProduct`]) have identical -/// fast paths for "only the lhs is denormalized" and "only the rhs is denormalized", and a separate -/// fast path for "both are denormalized". Rather than hand-rolling the commutative swap at every -/// call site, callers classify their operands with [`Self::classify`] and pattern-match on the -/// returned variant. +/// Symmetric binary tensor operators ([`CosineSimilarity`], [`InnerProduct`]) and unary ones +/// ([`L2Norm`]) take a fast path whenever an operand carries a unit-norm representation. Callers +/// classify each operand individually via [`Self::classify`] and pattern-match on the resulting +/// variant. /// /// [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity /// [`InnerProduct`]: crate::scalar_fns::inner_product::InnerProduct -pub(crate) enum DenormOrientation<'a> { - /// Both operands are [`ExactScalarFn`] arrays. - Both { - lhs: &'a ArrayRef, - rhs: &'a ArrayRef, - }, - /// Exactly one operand is an [`ExactScalarFn`]; the other is plain. - One { - denorm: &'a ArrayRef, - plain: &'a ArrayRef, +pub(crate) enum NormalForm<'a> { + /// A plain `Vector`. + Plain, + + /// An already-normalized `NormalizedVector`, which has implicit norms of `1.0`. + Normalized { array: &'a ArrayRef }, + + /// Decomposed `L2Denorm(normalized, norms)`. + /// + /// The normalized child is a `NormalizedVector` by structural contract. It is usually + /// non-null, with validity stored in `norms`, except when callers use + /// [`L2Denorm::new_array_unchecked`] directly. + Denormalized { + normalized: ArrayRef, + norms: ArrayRef, }, - /// Neither operand is an [`ExactScalarFn`]. - Neither, } -impl<'a> DenormOrientation<'a> { - /// Classify `(lhs, rhs)` by which side (if any) is wrapped in [`L2Denorm`]. - pub(crate) fn classify(lhs: &'a ArrayRef, rhs: &'a ArrayRef) -> Self { - let lhs_denorm = lhs.is::>(); - let rhs_denorm = rhs.is::>(); - match (lhs_denorm, rhs_denorm) { - (true, true) => Self::Both { lhs, rhs }, - (true, false) => Self::One { - denorm: lhs, - plain: rhs, - }, - (false, true) => Self::One { - denorm: rhs, - plain: lhs, - }, - (false, false) => Self::Neither, +impl<'a> NormalForm<'a> { + /// Classify `array` by its tensor extension type and (if present) any wrapping `L2Denorm`. + pub(crate) fn classify(array: &'a ArrayRef) -> Self { + if array.is::>() { + let (normalized, norms) = extract_l2_denorm_children(array); + return Self::Denormalized { normalized, norms }; + } + + let is_normalized = array + .dtype() + .as_extension_opt() + .is_some_and(|ext| ext.is::()); + + if is_normalized { + Self::Normalized { array } + } else { + Self::Plain + } + } + + /// Returns the unit-norm "shape" of the operand suitable for inner-product fast paths, if + /// one exists. For [`Self::Plain`] this returns `None`. + pub(crate) fn normalized_array(&self) -> Option<&ArrayRef> { + match self { + Self::Plain => None, + Self::Normalized { array } => Some(array), + Self::Denormalized { normalized, .. } => Some(normalized), } } } @@ -769,17 +879,26 @@ mod tests { use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; - use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms; use crate::tests::SESSION; + use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; use crate::utils::test_helpers::assert_close; - use crate::utils::test_helpers::constant_tensor_array; + use crate::utils::test_helpers::normalized_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; - /// Evaluates L2 denorm on a tensor/vector array and returns the executed array. - fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef, len: usize) -> VortexResult { + /// Evaluates L2 denorm on a [`Vector`] (rewrapped as a [`NormalizedVector`]) and the matching + /// norms, returning the executed array. Convenience wrapper for tests that already hold a + /// pre-normalized [`Vector`] (possibly wrapped in another encoding such as `MaskedArray`). + fn eval_l2_denorm( + vector_input: ArrayRef, + norms: ArrayRef, + len: usize, + ) -> VortexResult { let mut ctx = SESSION.create_execution_ctx(); + let canonical: ExtensionArray = vector_input.execute(&mut ctx)?; + let storage = canonical.storage_array().clone(); + let normalized = NormalizedVector::try_new(storage, &mut ctx)?; let result = L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?; result.into_array().execute(&mut ctx) } @@ -827,17 +946,6 @@ mod tests { Ok(()) } - #[test] - fn l2_denorm_fixed_shape_tensors() -> VortexResult<()> { - let lhs = tensor_array(&[2, 2], &[0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 0.0])?; - let rhs = PrimitiveArray::from_iter([4.0f64, 2.0]).into_array(); - let actual = eval_l2_denorm(lhs, rhs, 2)?; - let expected = tensor_array(&[2, 2], &[2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0])?; - - assert_tensor_arrays_eq(actual, expected)?; - Ok(()) - } - #[test] fn l2_denorm_null_propagation() -> VortexResult<()> { let lhs = vector_array(2, &[0.6, 0.8, 1.0, 0.0, 0.0, 0.0])?; @@ -878,8 +986,8 @@ mod tests { } #[test] - fn l2_denorm_rejects_integer_tensor_lhs() -> VortexResult<()> { - let lhs = tensor_array(&[2], &[1i32, 2, 3, 4])?; + fn l2_denorm_rejects_plain_unit_vector_lhs() -> VortexResult<()> { + let lhs = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); @@ -888,51 +996,74 @@ mod tests { Ok(()) } + #[test] + fn l2_denorm_rejects_unnormalized_plain_vector_lhs() -> VortexResult<()> { + let lhs = vector_array(2, &[3.0, 4.0, 0.0, 1.0])?; + let rhs = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); + + let mut ctx = SESSION.create_execution_ctx(); + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + assert!(result.is_err()); + Ok(()) + } + #[test] fn l2_denorm_rejects_mismatched_rhs_ptype() -> VortexResult<()> { - let lhs = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; let rhs = PrimitiveArray::from_iter([1.0f32, 1.0]).into_array(); - let mut ctx = SESSION.create_execution_ctx(); let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); assert!(result.is_err()); Ok(()) } #[test] - fn validate_l2_normalized_rows_accepts_normalized_f16_input() -> VortexResult<()> { - let input = vector_array(2, &[3.0f32, 4.0, 0.0, 0.0].map(half::f16::from_f32))?; + fn l2_denorm_rejects_non_primitive_rhs_without_panic() -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; - validate_l2_normalized_rows_against_norms(&roundtrip.child_at(0).clone(), None, &mut ctx)?; + let lhs = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; + let rhs = vector_array(2, &[1.0f64, 0.0, 0.0, 1.0])?; + + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + assert!(result.is_err()); Ok(()) } #[test] - fn validate_l2_normalized_rows_rejects_unnormalized_input() -> VortexResult<()> { - let input = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; + fn l2_denorm_rejects_length_mismatch_without_panic() -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); - let result = validate_l2_normalized_rows_against_norms(&input, None, &mut ctx); + let lhs = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; + let rhs = PrimitiveArray::from_iter([1.0f64]).into_array(); + + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); assert!(result.is_err()); Ok(()) } #[test] - fn l2_denorm_try_new_array_rejects_unnormalized_child() -> VortexResult<()> { - let normalized = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; - let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); + fn normalized_vector_try_new_accepts_normalized_f16_input() -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); + let elements = [3.0f32, 4.0, 0.0, 0.0].map(half::f16::from_f32); + let roundtrip = normalize_as_l2_denorm(vector_array(2, &elements)?, &mut ctx)?; + // The first child is already a `NormalizedVector` by construction. + let normalized = roundtrip.child_at(0).clone(); + assert!(normalized.dtype().as_extension().is::(),); + Ok(()) + } - let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); + #[test] + fn normalized_vector_try_new_rejects_unnormalized_input() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let result = normalized_vector_array(2, &[3.0, 4.0, 1.0, 0.0], &mut ctx); assert!(result.is_err()); Ok(()) } #[test] fn l2_denorm_try_new_array_rejects_nonzero_row_with_zero_norm() -> VortexResult<()> { - let normalized = vector_array(2, &[1.0, 0.0, 0.0, 0.0])?; - let norms = PrimitiveArray::from_iter([0.0f64, 0.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(2, &[1.0, 0.0, 0.0, 0.0], &mut ctx)?; + let norms = PrimitiveArray::from_iter([0.0f64, 0.0]).into_array(); let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); assert!(result.is_err()); @@ -941,9 +1072,9 @@ mod tests { #[test] fn l2_denorm_try_new_array_rejects_negative_norms() -> VortexResult<()> { - let normalized = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; - let norms = PrimitiveArray::from_iter([1.0f64, -1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; + let norms = PrimitiveArray::from_iter([1.0f64, -1.0]).into_array(); let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); assert!(result.is_err()); @@ -951,29 +1082,34 @@ mod tests { } #[test] - fn l2_denorm_new_array_unchecked_accepts_unnormalized_child() -> VortexResult<()> { - let normalized = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; - let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); + fn l2_denorm_new_array_unchecked_skips_zero_row_cross_check() -> VortexResult<()> { + // `L2Denorm::new_array_unchecked` accepts a NormalizedVector + norms without the zero-norm + // cross-check; useful for lossy encodings (e.g. TurboQuant). + let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(2, &[1.0, 0.0, 0.0, 1.0], &mut ctx)?; + let norms = PrimitiveArray::from_iter([0.0f64, 1.0]).into_array(); + // SAFETY: This test intentionally exercises the lossy escape hatch. let result = unsafe { L2Denorm::new_array_unchecked(normalized, norms, 2) }; assert!(result.is_ok()); Ok(()) } #[test] - fn normalize_as_l2_denorm_roundtrips_vectors() -> VortexResult<()> { - let input = vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0])?; - let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; - let actual = roundtrip.into_array().execute(&mut ctx)?; + fn l2_denorm_new_array_unchecked_rejects_plain_vector_lhs() -> VortexResult<()> { + let vector = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; + let norms = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); - assert_tensor_arrays_eq(actual, input)?; + // SAFETY: This deliberately checks that structural validation still rejects a plain + // `Vector` child. + let result = unsafe { L2Denorm::new_array_unchecked(vector, norms, 2) }; + assert!(result.is_err()); Ok(()) } #[test] - fn normalize_as_l2_denorm_roundtrips_fixed_shape_tensors() -> VortexResult<()> { - let input = tensor_array(&[2, 2], &[1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0])?; + fn normalize_as_l2_denorm_roundtrips_vectors() -> VortexResult<()> { + let input = vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0])?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; @@ -983,13 +1119,11 @@ mod tests { } #[test] - fn normalize_as_l2_denorm_supports_constant_tensors() -> VortexResult<()> { - let input = constant_tensor_array(&[2], &[3.0, 4.0], 3)?; + fn normalize_as_l2_denorm_rejects_fixed_shape_tensor() -> VortexResult<()> { + let input = tensor_array(&[2, 2], &[3.0, 4.0, 0.0, 0.0])?; let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; - let actual = roundtrip.into_array().execute(&mut ctx)?; - assert_tensor_arrays_eq(actual, input)?; + assert!(normalize_as_l2_denorm(input, &mut ctx).is_err()); Ok(()) } @@ -1013,16 +1147,18 @@ mod tests { let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; - // The normalized child must be an extension array whose storage is still constant. + // The normalized child is a `NormalizedVector(Vector(Constant))`. Drill past both + // extension layers and confirm the innermost FSL storage is still constant-backed. let normalized = roundtrip.child_at(0).clone(); let normalized_ext = normalized .as_opt::() .expect("normalized child should be an Extension array"); + let inner_vector = normalized_ext + .storage_array() + .as_opt::() + .expect("NormalizedVector storage should be a Vector extension array"); assert!( - normalized_ext - .storage_array() - .as_opt::() - .is_some(), + inner_vector.storage_array().as_opt::().is_some(), "normalized storage should stay constant after the fast path" ); @@ -1047,8 +1183,11 @@ mod tests { let input = vector_array(2, &[0.0, 0.0, 3.0, 4.0])?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; + // Normalized child is a `NormalizedVector` wrapping a `Vector` wrapping the FSL; drill + // past the outer `NormalizedVector` to reach the underlying `Vector`. let normalized: ExtensionArray = roundtrip.child_at(0).clone().execute(&mut ctx)?; - let storage: FixedSizeListArray = normalized.storage_array().clone().execute(&mut ctx)?; + let vector: ExtensionArray = normalized.storage_array().clone().execute(&mut ctx)?; + let storage: FixedSizeListArray = vector.storage_array().clone().execute(&mut ctx)?; let elements: PrimitiveArray = storage.elements().clone().execute(&mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; @@ -1057,6 +1196,56 @@ mod tests { Ok(()) } + /// `NormalizedVector` input takes the fast path: re-applying norms must reconstruct the + /// original element values bit-for-bit (since the divisor in the slow path would be 1.0 + /// for unit rows and 0.0 for zero rows). + #[test] + fn normalize_as_l2_denorm_normalized_vector_round_trips_values() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let elements = [0.6, 0.8, 0.0, 0.0, 1.0, 0.0]; + let input = normalized_vector_array(2, &elements, &mut ctx)?; + + let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; + let executed: ExtensionArray = roundtrip.into_array().execute(&mut ctx)?; + let storage: FixedSizeListArray = executed.storage_array().clone().execute(&mut ctx)?; + let executed_elements: PrimitiveArray = storage.elements().clone().execute(&mut ctx)?; + + assert_close(executed_elements.as_slice::(), &elements); + Ok(()) + } + + /// The `NormalizedVector` fast path borrows `L2Norm`'s short-circuit, which emits `1.0` for + /// unit rows and `0.0` for zero rows. Tag the zero row with norm `0.0` here (not `1.0`) so a + /// downstream `L2Norm` over the resulting `L2Denorm` continues to read the right value. + #[test] + fn normalize_as_l2_denorm_normalized_vector_emits_unit_or_zero_norms() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let input = normalized_vector_array(2, &[0.6, 0.8, 0.0, 0.0, 1.0, 0.0], &mut ctx)?; + + let l2_denorm = normalize_as_l2_denorm(input, &mut ctx)?; + let norms: PrimitiveArray = l2_denorm.child_at(1).clone().execute(&mut ctx)?; + + assert_close(norms.as_slice::(), &[1.0, 0.0, 1.0]); + Ok(()) + } + + /// The `NormalizedVector` fast path returns the input unchanged as the `normalized` child + /// rather than re-allocating storage to satisfy the slow path's "always non-nullable" + /// invariant. Verify that the child dtype is still a `NormalizedVector` extension after the + /// fast path. + #[test] + fn normalize_as_l2_denorm_normalized_vector_preserves_normalized_child_dtype() + -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let input = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + + let l2_denorm = normalize_as_l2_denorm(input, &mut ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + + assert!(normalized.dtype().as_extension().is::()); + Ok(()) + } + /// Builds a non-nullable constant f64 norms array of length `len`. fn constant_f64_norms(value: f64, len: usize) -> ArrayRef { ConstantArray::new(Scalar::primitive(value, Nullability::NonNullable), len).into_array() @@ -1099,16 +1288,30 @@ mod tests { Ok(()) } + /// Regression: a non-nullable [`NormalizedVector`] child paired with a nullable-dtype + /// constant norms array (whose value happens to be non-null `1.0`) used to fail in the + /// constant-unit fast path because the extension's declared storage nullability no longer + /// matched the storage array's own nullability. #[test] - fn l2_denorm_constant_nonunit_norms_scales_fixed_shape_tensors() -> VortexResult<()> { - // The same constant-scaling fast path must also cover multi-dimensional fixed-shape - // tensors, where the backing elements buffer spans more than one slot per row. - let normalized = tensor_array(&[2, 2], &[0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 0.0])?; - let norms = constant_f64_norms(4.0, 2); + fn l2_denorm_constant_unit_norms_nullable_scalar_nonnullable_normalized() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(3, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &mut ctx)?; + let nullable_unit_norms = + ConstantArray::new(Scalar::primitive(1.0f64, Nullability::Nullable), 2).into_array(); - let actual = eval_l2_denorm(normalized, norms, 2)?; - let expected = tensor_array(&[2, 2], &[2.0, 2.0, 2.0, 2.0, 4.0, 0.0, 0.0, 0.0])?; - assert_tensor_arrays_eq(actual, expected)?; + let result = L2Denorm::try_new_array(normalized, nullable_unit_norms, 2, &mut ctx)?; + let actual: ArrayRef = result.into_array().execute(&mut ctx)?; + + // The output surfaces as a plain `Vector` whose outer nullability is the union of the + // two children (nullable here, since the norms child was nullable). + assert!(actual.dtype().as_extension().is::()); + assert!(actual.dtype().is_nullable()); + + // The element values round-trip: scaling unit vectors by `1.0` is a no-op. + let ext: ExtensionArray = actual.execute(&mut ctx)?; + let storage: FixedSizeListArray = ext.storage_array().clone().execute(&mut ctx)?; + let elements: PrimitiveArray = storage.elements().clone().execute(&mut ctx)?; + assert_close(elements.as_slice::(), &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0]); Ok(()) } @@ -1117,8 +1320,7 @@ mod tests { /// inherits the input's nullability, giving us two different per-child nullabilities to /// round-trip. #[rstest] - #[case::vector(l2_denorm_vector_input())] - #[case::fixed_shape_tensor(l2_denorm_tensor_input())] + #[case::vector(vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0]).unwrap())] fn serde_round_trip(#[case] input: ArrayRef) -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); let original = normalize_as_l2_denorm(input, &mut ctx)?.into_array(); @@ -1145,13 +1347,4 @@ mod tests { assert_eq!(recovered.encoding_id(), original.encoding_id()); Ok(()) } - - fn l2_denorm_vector_input() -> ArrayRef { - vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0]).expect("valid vector array") - } - - fn l2_denorm_tensor_input() -> ArrayRef { - tensor_array(&[2, 2], &[1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0]) - .expect("valid tensor array") - } } diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 5d741eef55e..bcabe823466 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -6,6 +6,8 @@ use std::fmt::Formatter; use num_traits::Float; +use num_traits::One; +use num_traits::Zero; use prost::Message; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; @@ -17,7 +19,6 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFn as ScalarFnArrayEncoding; use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; @@ -25,6 +26,7 @@ use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::match_each_float_ptype; @@ -40,14 +42,13 @@ use vortex_array::serde::ArrayChildren; use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; use vortex_session::VortexSession; use crate::matcher::AnyTensor; -use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::NormalForm; use crate::utils::extract_flat_elements; -use crate::utils::extract_l2_denorm_children; +use crate::utils::inner_vector_array; use crate::utils::validate_tensor_float_input; /// L2 norm (Euclidean norm) of a tensor or vector column. @@ -57,10 +58,11 @@ use crate::utils::validate_tensor_float_input; /// The input must be a tensor-like extension array with a float element type. The output is a float /// column of the same float type. /// -/// When the input is wrapped in [`L2Denorm`], this operator treats the stored norms as -/// authoritative. For lossy encodings such as TurboQuant, that means `L2Norm` may intentionally -/// read the stored norms instead of re-deriving them from fully decoded coordinates. That behavior -/// is part of the lossy storage contract, not a separate lossy-compute mode. +/// When the input is wrapped in [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm), this operator +/// treats the stored norms as authoritative. For lossy encodings such as TurboQuant, that means +/// `L2Norm` may intentionally read the stored norms instead of re-deriving them from fully decoded +/// coordinates. That behavior is part of the lossy storage contract, not a separate lossy-compute +/// mode. #[derive(Clone)] pub struct L2Norm; @@ -115,6 +117,7 @@ impl ScalarFnVTable for L2Norm { let tensor_match = validate_tensor_float_input(input_dtype)?; let ptype = tensor_match.element_ptype(); + // Inherit the nullability from the vectors themselves. let nullability = Nullability::from(input_dtype.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } @@ -137,13 +140,23 @@ impl ScalarFnVTable for L2Norm { let norm_dtype = DType::Primitive(element_ptype, ext.nullability()); - // L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored - // norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics - // instead of forcing a decode-and-recompute path here. - if input_ref.is::>() { - let (_, norms) = extract_l2_denorm_children(&input_ref); - vortex_ensure_eq!(norms.dtype(), &norm_dtype); - return Ok(norms); + // Short-circuit when the input carries a unit-norm representation already. + match NormalForm::classify(&input_ref) { + NormalForm::Denormalized { norms, .. } => { + return Ok(norms); + } + NormalForm::Normalized { .. } => { + // A naked `NormalizedVector` row is either unit norm or the zero vector by type. + // We still have to distinguish those two cases and preserve row validity. + return execute_normalized_vector_norms( + &input_ref, + element_ptype, + tensor_flat_size, + row_count, + ctx, + ); + } + NormalForm::Plain => {} } // Optimize for the constant array case. @@ -172,6 +185,9 @@ impl ScalarFnVTable for L2Norm { return Ok(norms); } + // Drill past any `NormalizedVector` wrapper so we always work with the underlying + // `Vector` extension array. + let input_ref = inner_vector_array(&input_ref, ctx)?; let input: ExtensionArray = input_ref.execute(ctx)?; let validity = input.as_ref().validity()?; @@ -261,6 +277,38 @@ fn l2_norm_row(v: &[T]) -> T { sum_sq.sqrt() } +/// Computes L2 norms for a [`NormalizedVector`](crate::normalized_vector::NormalizedVector) +/// without taking square roots: valid rows are either all-zero (`0.0`) or unit-norm (`1.0`). +fn execute_normalized_vector_norms( + input_ref: &ArrayRef, + element_ptype: PType, + tensor_flat_size: usize, + row_count: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { + // `NormalizedVector` storage is `Extension(Vector(FSL))`; drill to the inner `Vector` to + // reach the underlying FSL. + let vector_ref = inner_vector_array(input_ref, ctx)?; + let input: ExtensionArray = vector_ref.execute(ctx)?; + let validity = input.as_ref().validity()?; + let flat = extract_flat_elements(input.storage_array(), tensor_flat_size, ctx)?; + + match_each_float_ptype!(element_ptype, |T| { + let buffer: Buffer = (0..row_count) + .map(|i| { + if flat.row::(i).iter().all(|&x| x == T::zero()) { + T::zero() + } else { + T::one() + } + }) + .collect(); + + // SAFETY: The buffer length equals `row_count`, which matches the source validity length. + Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) + }) +} + #[cfg(test)] mod tests { @@ -289,6 +337,7 @@ mod tests { use crate::types::vector::Vector; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::literal_vector_array; + use crate::utils::test_helpers::normalized_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -418,8 +467,8 @@ mod tests { } #[rstest] - #[case::fixed_shape_tensor(l2_norm_tensor_child(), 2)] - #[case::vector(l2_norm_vector_child(), 2)] + #[case::fixed_shape_tensor(tensor_array(&[3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)] + #[case::vector(vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)] fn serde_round_trip(#[case] child: ArrayRef, #[case] len: usize) -> VortexResult<()> { let original = L2Norm::try_new_array(child.clone(), len)?.into_array(); @@ -444,11 +493,69 @@ mod tests { Ok(()) } - fn l2_norm_tensor_child() -> ArrayRef { - tensor_array(&[3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid tensor array") + /// A naked [`NormalizedVector`](crate::normalized_vector::NormalizedVector) input must + /// short-circuit to `1.0` for unit rows and `0.0` for zero rows without taking square roots. + #[test] + fn naked_normalized_vector_returns_unit_norms() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let input = normalized_vector_array(2, &[1.0, 0.0, 0.6, 0.8, 0.0, 0.0], &mut ctx)?; + assert_close(&eval_l2_norm(input, 3)?, &[1.0, 1.0, 0.0]); + Ok(()) + } + + #[test] + fn naked_normalized_vector_preserves_nulls() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let input = normalized_vector_array(2, &[1.0, 0.0, 0.0, 0.0], &mut ctx)?; + let input = MaskedArray::try_new(input, Validity::from_iter([true, false]))?.into_array(); + + let result = L2Norm::try_new_array(input, 2)?; + let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; + + assert!(prim.is_valid(0, &mut ctx)?); + assert!(!prim.is_valid(1, &mut ctx)?); + assert_close(&[prim.as_slice::()[0]], &[1.0]); + Ok(()) } - fn l2_norm_vector_child() -> ArrayRef { - vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid vector array") + /// `L2Norm(L2Denorm(normalized, norms))` must return the stored norms verbatim — that is the + /// `NormalForm::Denormalized` short-circuit's whole purpose. We use a deliberately oddball + /// norm value (`7.0`) that no row could plausibly produce from a unit-norm child, so a + /// regression that fell through to the recompute path would round-trip a different number. + #[test] + fn denormalized_input_returns_stored_norms() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let l2_denorm = crate::utils::test_helpers::l2_denorm_array( + 2, + &[1.0, 0.0, 0.6, 0.8], + &[7.0, 5.0], + &mut ctx, + )?; + + let result = L2Norm::try_new_array(l2_denorm, 2)?; + let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; + + assert_close(prim.as_slice::(), &[7.0, 5.0]); + Ok(()) + } + + /// The `Denormalized` short-circuit must propagate null rows in the stored norms child, + /// since validity on a `L2Denorm` lives entirely in its norms. + #[test] + fn denormalized_input_preserves_norm_nulls() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?; + let norms = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); + let l2_denorm = + crate::scalar_fns::l2_denorm::L2Denorm::try_new_array(normalized, norms, 2, &mut ctx)? + .into_array(); + + let result = L2Norm::try_new_array(l2_denorm, 2)?; + let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; + + assert!(prim.is_valid(0, &mut ctx)?); + assert!(!prim.is_valid(1, &mut ctx)?); + assert_close(&[prim.as_slice::()[0]], &[5.0]); + Ok(()) } } diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs index 26d38e87a1e..7666e58cdac 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs @@ -8,10 +8,10 @@ //! Walsh-Hadamard transform to achieve O(d log d) matrix-vector products instead of the O(d^2) cost //! of a dense orthogonal matrix. //! -//! This module wraps a [`Vector`] extension array whose dimension is the padded SORF dimension -//! (e.g. a `Vector` wrapping `FSL(Dict(codes, centroids))`) and applies the inverse SORF transform -//! at execution time, producing a [`Vector`] extension array with the original (pre-padding) -//! dimensionality. +//! This module wraps a [`Vector`] or [`NormalizedVector`] extension array whose dimension is the +//! padded SORF dimension (e.g. a `Vector` wrapping `FSL(Dict(codes, centroids))`) and applies the +//! inverse SORF transform at execution time, producing a plain [`Vector`] extension array with the +//! original (pre-padding) dimensionality. //! //! The transform parameters are stored as a deterministic seed in [`SorfOptions`], so the //! [`SorfMatrix`] is reconstructed cheaply at decode time. Sign diagonals are defined by Vortex's @@ -19,9 +19,9 @@ //! //! # Input element type: `f32` only (TODO(connor): for now...) //! -//! The child [`Vector`] **must** have `f32` storage elements. This is a hard constraint that is -//! enforced by `SorfTransform`'s `return_dtype` check. Callers with `f16` or `f64` source data need -//! to cast to `f32` before wrapping in a [`Vector`] and handing it to SorfTransform. +//! The child vector extension **must** have `f32` storage elements. This is a hard constraint that +//! is enforced by `SorfTransform`'s `return_dtype` check. Callers with `f16` or `f64` source data +//! need to cast to `f32` before wrapping in a vector extension and handing it to SorfTransform. //! //! The reason for this constraint is that TurboQuant (the only production caller today) stores its //! dictionary centroids as `f32`, and the SORF transform itself operates internally in `f32`. @@ -34,12 +34,14 @@ //! //! The output [`Vector`]'s element type is whatever [`SorfOptions::element_ptype`] is set to. It //! does **not** have to match the child's `f32` storage: we apply an explicit `f32 -> T` cast -//! while materializing the output. This lets SorfTransform hand its result directly to a -//! downstream consumer (e.g. [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)) whose -//! element-type expectation may differ from the `f32` the transform operated on internally. +//! while materializing the output. Callers that intentionally treat the decoded output as +//! normalized (for example TurboQuant) must wrap the result as a +//! [`NormalizedVector`](crate::normalized_vector::NormalizedVector) before handing it to consumers +//! such as [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm). //! //! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf //! [`Vector`]: crate::vector::Vector +//! [`NormalizedVector`]: crate::normalized_vector::NormalizedVector use std::fmt; use std::fmt::Formatter; @@ -59,10 +61,10 @@ mod vtable; /// Inverse SORF orthogonal transform scalar function. /// -/// Takes a [`Vector`](crate::vector::Vector) extension child at the padded dimension with `f32` -/// storage, applies the inverse structured Walsh-Hadamard orthogonal transform, truncates to the -/// original (pre-padding) dimension, casts element-wise to [`SorfOptions::element_ptype`], and -/// wraps the result in a new [`Vector`](crate::vector::Vector) extension array. +/// Takes a vector extension child at the padded dimension with `f32` storage, applies the inverse +/// structured Walsh-Hadamard orthogonal transform, truncates to the original (pre-padding) +/// dimension, casts element-wise to [`SorfOptions::element_ptype`], and wraps the result in a new +/// plain [`Vector`](crate::vector::Vector) extension array. /// /// See the [module-level docs](crate::scalar_fns::sorf_transform) for the rationale behind the /// `f32`-only input constraint. @@ -77,11 +79,14 @@ pub struct SorfTransform; pub struct SorfOptions { /// Seed used to generate the structured sign diagonals via Vortex's frozen SplitMix64 stream. pub seed: u64, + /// Number of sign-diagonal + WHT rounds in the structured orthogonal transform. pub num_rounds: u8, + /// Original vector dimension (before power-of-2 padding). The output /// [`Vector`](crate::vector::Vector) has this dimension. pub dimensions: u32, + /// Element type of the output [`Vector`](crate::vector::Vector). The child input must always /// be `f32`, but the output can be any float type (`F16`, `F32`, `F64`); the final /// `f32 -> element_ptype` cast happens while building the output. @@ -89,23 +94,26 @@ pub struct SorfOptions { } impl SorfTransform { - /// Creates a new [`TypedScalarFnInstance`] wrapping the SORF inverse transform with the given options. + /// Creates a new [`TypedScalarFnInstance`] wrapping the SORF inverse transform with the given + /// options. pub fn new(options: &SorfOptions) -> TypedScalarFnInstance { TypedScalarFnInstance::new(SorfTransform, options.clone()) } /// Constructs a validated [`ScalarFnArray`] that lazily applies the inverse SORF transform. /// - /// The `child` must be a [`Vector`] extension array (or an array that executes to one) with: + /// The `child` must be a [`Vector`] or [`NormalizedVector`] extension array (or an array that + /// executes to one) with: /// /// - dimension equal to `padded_dim` (i.e. `options.dimension.next_power_of_two()`), and /// - `f32` storage elements. This is a hard requirement today; see the /// [module-level docs](crate::scalar_fns::sorf_transform) for the rationale. /// - /// The output [`Vector`] has dimension `options.dimension` and element type + /// The output plain [`Vector`] has dimension `options.dimension` and element type /// `options.element_ptype`. /// /// [`Vector`]: crate::vector::Vector + /// [`NormalizedVector`]: crate::normalized_vector::NormalizedVector pub fn try_new_array( options: &SorfOptions, child: ArrayRef, @@ -118,7 +126,7 @@ impl SorfTransform { } /// Checks that the SORF configuration is valid. -pub(crate) fn validate_sorf_options(options: &SorfOptions) -> VortexResult<()> { +pub(super) fn validate_sorf_options(options: &SorfOptions) -> VortexResult<()> { vortex_ensure!( options.num_rounds >= 1, "SorfTransform num_rounds must be >= 1, got {}", diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs index ff8aebd0f11..ea6e776ecdb 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs @@ -61,17 +61,19 @@ pub struct SorfMatrix { /// Indexed as `round * padded_dim + i`. `0x00000000` = multiply by +1 (no-op), `0x80000000` = /// multiply by -1 (flip sign bit). sign_masks: Vec, + /// The number of sign-diagonal + WHT rounds. num_rounds: usize, /// The padded dimension (next power of 2 >= dimension). padded_dim: usize, - /// Normalization factor: `padded_dim^(-num_rounds/2)`, applied once at the end. + /// Normalization factor: `padded_dim^(-num_rounds/2)`, applied once at the end. This is stored + /// for convenience. norm_factor: f32, } impl SorfMatrix { - /// Create a new structured Walsh-Hadamard-based orthogonal transform from a deterministic - /// seed. + // TODO(connor): Should this just only allow power-of-2 dimensions? Require the caller to do it? + /// Create a new structured Walsh-Hadamard-based orthogonal transform from a deterministic seed. /// /// The seed is expanded using Vortex's frozen local SplitMix64 stream. Signs are generated in /// round-major, block-major order, with each `u64` contributing 64 sign bits in diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs index 46abc66db71..d65bf5f15c0 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -14,9 +14,11 @@ use vortex_array::VortexSessionExecute; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFn as ScalarFnArrayEncoding; use vortex_array::arrays::dict::DictArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; @@ -36,6 +38,7 @@ use crate::encodings::turboquant::centroids::compute_centroid_boundaries; use crate::encodings::turboquant::centroids::compute_or_get_centroids; use crate::encodings::turboquant::centroids::find_nearest_centroid; use crate::tests::SESSION; +use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; /// Build a unit-normalized input vector array and forward-transform + quantize it, returning @@ -361,6 +364,46 @@ fn rejects_non_vector_extension_child_at_construction() { assert!(err.to_string().contains("Vector extension")); } +#[test] +fn accepts_normalized_vector_child_but_returns_plain_vector() -> VortexResult<()> { + let options = default_options(128, 42); + let mut values = vec![0.0f32; 128]; + values[0] = 1.0; + let elements = PrimitiveArray::from_iter(values).into_array(); + let fsl = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1)?; + let mut ctx = SESSION.create_execution_ctx(); + let child = NormalizedVector::try_new(fsl.into_array(), &mut ctx)?; + + // A `NormalizedVector` child is accepted, but the output is a plain `Vector`: inverse SORF is + // followed by truncation, which cannot generally preserve the unit-norm invariant. + let sorf = SorfTransform::try_new_array(&options, child, 1)?.into_array(); + assert!(sorf.dtype().as_extension().is::()); + assert!(!sorf.dtype().as_extension().is::()); + + let result: ExtensionArray = sorf.execute(&mut ctx)?; + assert!(result.dtype().as_extension().is::()); + assert!(!result.dtype().as_extension().is::()); + Ok(()) +} + +#[test] +fn accepts_plain_vector_child_and_returns_plain_vector() -> VortexResult<()> { + let options = default_options(128, 42); + let elements = PrimitiveArray::from_iter([0.0f32; 128]).into_array(); + let fsl = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1)?; + let child = wrap_as_vector(fsl, Validity::NonNullable)?; + + let sorf = SorfTransform::try_new_array(&options, child.into_array(), 1)?.into_array(); + assert!(sorf.dtype().as_extension().is::()); + assert!(!sorf.dtype().as_extension().is::()); + + let mut ctx = SESSION.create_execution_ctx(); + let result: ExtensionArray = sorf.execute(&mut ctx)?; + assert!(result.dtype().as_extension().is::()); + assert!(!result.dtype().as_extension().is::()); + Ok(()) +} + #[test] fn rejects_wrong_padded_dimension_at_construction() { // Options say dimension=128 so padded_dim should be 128. Pass a Vector<256> instead. @@ -455,6 +498,20 @@ fn trivial_padded_vector(padded_dim: u32, num_rows: usize, validity: Validity) - ExtensionArray::new(ext_dtype, fsl.into_array()).into_array() } +fn trivial_padded_normalized_vector( + padded_dim: u32, + num_rows: usize, + validity: Validity, +) -> VortexResult { + let elements = PrimitiveArray::new( + Buffer::::zeroed(num_rows * padded_dim as usize), + Validity::NonNullable, + ); + let fsl = FixedSizeListArray::try_new(elements.into_array(), padded_dim, validity, num_rows)?; + let mut ctx = SESSION.create_execution_ctx(); + NormalizedVector::try_new(fsl.into_array(), &mut ctx) +} + #[rstest::rstest] // Non-power-of-two dimension to exercise `padded_dim = dim.next_power_of_two()`. #[case::power_of_two_dim(128, Validity::NonNullable)] @@ -491,5 +548,54 @@ fn serde_round_trip(#[case] dimensions: u32, #[case] validity: Validity) -> Vort assert_eq!(recovered.dtype(), original.dtype()); assert_eq!(recovered.len(), original.len()); assert_eq!(recovered.encoding_id(), original.encoding_id()); + let recovered_scalar_fn = recovered.as_::(); + assert!( + recovered_scalar_fn + .child_at(0) + .dtype() + .as_extension() + .is::() + ); + Ok(()) +} + +#[test] +fn serde_round_trip_preserves_normalized_vector_child_dtype() -> VortexResult<()> { + let dimension = 128; + let num_rows = 4; + let options = default_options(dimension, 42); + let child = trivial_padded_normalized_vector( + dimension.next_power_of_two(), + num_rows, + Validity::NonNullable, + )?; + let original = SorfTransform::try_new_array(&options, child.clone(), num_rows)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(SorfTransform); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("SorfTransform serialize must produce metadata"); + + let children = vec![child]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + let recovered_scalar_fn = recovered.as_::(); + assert!( + recovered_scalar_fn + .child_at(0) + .dtype() + .as_extension() + .is::() + ); Ok(()) } diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 827f8e6a796..9551a942c0a 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -52,6 +52,7 @@ use super::rotation::SorfMatrix; use super::validate_sorf_options; use crate::types::vector::AnyVector; use crate::types::vector::Vector; +use crate::utils::inner_vector_array; impl ScalarFnVTable for SorfTransform { type Options = SorfOptions; @@ -90,7 +91,10 @@ impl ScalarFnVTable for SorfTransform { .as_extension_opt() .and_then(|ext| ext.metadata_opt::()) .ok_or_else(|| { - vortex_err!("SorfTransform child must be a Vector extension, got {child_dtype}") + vortex_err!( + "SorfTransform child must be a Vector or NormalizedVector extension, \ + got {child_dtype}" + ) })?; let expected_padded = options.dimensions.next_power_of_two(); @@ -102,8 +106,7 @@ impl ScalarFnVTable for SorfTransform { options.dimensions, ); - // For now, the child Vector storage must be f32. TurboQuant stores its centroids as f32, - // and the SORF transform itself operates in f32, so any other input type would require an + // For now, the child Vector storage must be f32, so any other input type would require an // implicit cast that we do not yet support. The output element type is independently // specified via `options.element_ptype` and is built below. vortex_ensure_eq!( @@ -114,14 +117,16 @@ impl ScalarFnVTable for SorfTransform { ); let output_elem_dtype = DType::Primitive(options.element_ptype, Nullability::NonNullable); - let storage_dtype = DType::FixedSizeList( + let fsl_dtype = DType::FixedSizeList( Arc::new(output_elem_dtype), options.dimensions, child_dtype.nullability(), ); - let _ = vector_metadata; - let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage_dtype)?.erased(); + // The inverse SORF rotation is orthogonal over the padded dimension, but this scalar + // function then truncates back to the original dimension. Truncation can drop energy, so + // even a `NormalizedVector` child cannot generally produce a `NormalizedVector` parent. + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl_dtype)?.erased(); Ok(DType::Extension(ext_dtype)) } @@ -135,48 +140,46 @@ impl ScalarFnVTable for SorfTransform { let dim = options.dimensions as usize; let num_rows = args.row_count(); - if num_rows == 0 { - let child_dtype = args.get(0)?.dtype().clone(); - let validity = Validity::from(child_dtype.nullability()); - - return match_each_float_ptype!(options.element_ptype, |T| { - let elements = PrimitiveArray::empty::(Nullability::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - options.dimensions, - validity, - 0, - )?; - Vector::try_new_vector_array(fsl.into_array()) - }); - } + let child_arg = args.get(0)?; + + let fsl_array = if num_rows == 0 { + let validity = Validity::from(child_arg.dtype().nullability()); + let elements = match_each_float_ptype!(options.element_ptype, |T| { + PrimitiveArray::empty::(Nullability::NonNullable) + }) + .into_array(); + + FixedSizeListArray::try_new(elements, options.dimensions, validity, 0)?.into_array() + } else { + // Execute the child, since we cannot apply rotations over compressed data. + let child_ref = inner_vector_array(&child_arg, ctx)?; + let child_ext: ExtensionArray = child_ref.execute(ctx)?; + let child_validity = child_ext.as_ref().validity()?; + let child_fsl: FixedSizeListArray = child_ext.storage_array().clone().execute(ctx)?; + + let elements_prim: PrimitiveArray = child_fsl.elements().clone().execute(ctx)?; + let f32_elements = elements_prim.into_buffer::(); + + let padded_dim = + usize::try_from(child_fsl.list_size()).vortex_expect("list_size fits usize"); + + // Reconstruct the orthogonal transform matrix from the seed. + let rotation = SorfMatrix::try_new(options.seed, dim, options.num_rounds as usize)?; + + // Inverse transform each row, truncate to original dimension, cast to target type. + match_each_float_ptype!(options.element_ptype, |T| { + inverse_rotate_typed::( + &f32_elements, + &rotation, + dim, + padded_dim, + num_rows, + child_validity, + ) + })? + }; - // Execute the child to get the Vector extension wrapping an FSL of f32 coordinates. The - // `return_dtype` check guarantees the child is a `Vector`, so the - // materialized FSL elements are always f32. - let child_ext: ExtensionArray = args.get(0)?.execute(ctx)?; - let child_validity = child_ext.as_ref().validity()?; - let child_fsl: FixedSizeListArray = child_ext.storage_array().clone().execute(ctx)?; - let padded_dim = - usize::try_from(child_fsl.list_size()).vortex_expect("list_size fits usize"); - - let elements_prim: PrimitiveArray = child_fsl.elements().clone().execute(ctx)?; - let f32_elements = elements_prim.into_buffer::(); - - // Reconstruct the orthogonal transform matrix from the seed. - let rotation = SorfMatrix::try_new(options.seed, dim, options.num_rounds as usize)?; - - // Inverse transform each row, truncate to original dimension, cast to target type. - match_each_float_ptype!(options.element_ptype, |T| { - inverse_rotate_typed::( - &f32_elements, - &rotation, - dim, - padded_dim, - num_rows, - child_validity, - ) - }) + Vector::try_new_vector_array(fsl_array) } fn validity( @@ -237,31 +240,33 @@ impl ScalarFnArrayVTable for SorfTransform { len: usize, metadata: &[u8], children: &dyn ArrayChildren, - session: &VortexSession, + _session: &VortexSession, ) -> VortexResult> { let metadata = SorfTransformMetadata::decode(metadata) .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))?; let options = metadata.to_options()?; - // `return_dtype` sets the output FSL's nullability to the child's nullability (see - // `return_dtype` above), so we read the child nullability back from the parent dtype. - let child_nullability = dtype + let parent_ext = dtype .as_extension_opt() + .filter(|ext| ext.is::()) .ok_or_else(|| { - vortex_err!("SorfTransform parent dtype must be a Vector extension, got {dtype}") - })? - .storage_dtype() - .nullability(); + vortex_err!("SorfTransform parent dtype must be a `Vector` extension, got {dtype}",) + })?; + + // The nullability of the parent extension type is the same as the storage type. + let fsl_nullability = parent_ext.nullability(); + let padded_dim = options.dimensions.next_power_of_two(); - let child_storage = DType::FixedSizeList( + let child_fsl_dtype = DType::FixedSizeList( Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), padded_dim, - child_nullability, + fsl_nullability, ); let child_dtype = match metadata.child_dtype.as_ref() { - Some(dtype) => DType::from_proto(dtype, session)?, + Some(dtype) => DType::from_proto(dtype, _session)?, None => { - let child_ext = ExtDType::::try_new(EmptyMetadata, child_storage)?.erased(); + let child_ext = + ExtDType::::try_new(EmptyMetadata, child_fsl_dtype)?.erased(); DType::Extension(child_ext) } }; @@ -283,7 +288,8 @@ fn float_from_f32(v: f32) -> T { } /// Apply the inverse SORF transform on f32 data, truncate to the original dimension, cast each -/// element to `T`, and build a plain [`Vector`](crate::vector::Vector) extension array. +/// element to `T`, and return the resulting `FixedSizeList` storage array. The caller wraps the +/// FSL as a plain [`Vector`](crate::vector::Vector) extension array. fn inverse_rotate_typed( f32_elements: &[f32], rotation: &SorfMatrix, @@ -297,7 +303,7 @@ fn inverse_rotate_typed( let mut unrotated = vec![0.0f32; padded_dim]; for row in 0..num_rows { - let row_data = &f32_elements[row * padded_dim..(row + 1) * padded_dim]; + let row_data = &f32_elements[row * padded_dim..][..padded_dim]; rotation.inverse_rotate(row_data, &mut unrotated); @@ -309,7 +315,7 @@ fn inverse_rotate_typed( let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new(elements.into_array(), dim_u32, validity, num_rows)?; - Vector::try_new_vector_array(fsl.into_array()) + Ok(fsl.into_array()) } impl From<&SorfOptions> for SorfTransformMetadata { diff --git a/vortex-tensor/src/types/mod.rs b/vortex-tensor/src/types/mod.rs index 3ecd2826743..97aa932f9d6 100644 --- a/vortex-tensor/src/types/mod.rs +++ b/vortex-tensor/src/types/mod.rs @@ -4,4 +4,5 @@ //! Internal homes for tensor extension types. pub mod fixed_shape; +pub mod normalized_vector; pub mod vector; diff --git a/vortex-tensor/src/types/normalized_vector/matcher.rs b/vortex-tensor/src/types/normalized_vector/matcher.rs new file mode 100644 index 00000000000..b1974e13b52 --- /dev/null +++ b/vortex-tensor/src/types/normalized_vector/matcher.rs @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDTypeRef; +use vortex_array::dtype::extension::Matcher; +use vortex_error::VortexExpect; +use vortex_error::vortex_panic; + +use crate::types::normalized_vector::NormalizedVector; +use crate::types::vector::Vector; +use crate::types::vector::VectorMatcherMetadata; + +/// Matcher that accepts only the [`NormalizedVector`] extension type. +/// +/// Use this when a consumer requires the unit-norm guarantee. Callers that accept any +/// vector-shaped extension should use [`AnyTensor`](crate::matcher::AnyTensor). +pub struct AnyNormalizedVector; + +impl Matcher for AnyNormalizedVector { + type Match<'a> = VectorMatcherMetadata; + + fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option> { + if !ext_dtype.is::() { + return None; + } + + // `NormalizedVector` stores a `Vector(FixedSizeList)`. Drill into the inner + // `Vector` to recover the dimension and element dtype. + let DType::Extension(inner_ext) = ext_dtype.storage_dtype() else { + vortex_panic!( + "`NormalizedVector` storage must be `DType::Extension(Vector)`, \ + got {}", + ext_dtype.storage_dtype(), + ) + }; + if !inner_ext.is::() { + vortex_panic!( + "`NormalizedVector` inner extension must be `Vector`, got {}", + inner_ext.id(), + ) + } + let DType::FixedSizeList(element_dtype, list_size, _) = inner_ext.storage_dtype() else { + vortex_panic!( + "inner `Vector` storage must be `FixedSizeList`, got {}", + inner_ext.storage_dtype(), + ) + }; + assert!(element_dtype.is_float(), "element dtype must be float"); + assert!( + !element_dtype.is_nullable(), + "element dtype must be non-nullable" + ); + + let metadata = VectorMatcherMetadata::try_new(element_dtype.as_ptype(), *list_size, true) + .vortex_expect("`NormalizedVector` inner Vector did not have float elements"); + + Some(metadata) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::extension::EmptyMetadata; + use vortex_error::VortexResult; + + use super::*; + use crate::types::vector::AnyVector; + use crate::types::vector::Vector; + + fn fsl_storage(element_ptype: PType, dimensions: u32) -> DType { + DType::FixedSizeList( + Arc::new(DType::Primitive(element_ptype, Nullability::NonNullable)), + dimensions, + Nullability::NonNullable, + ) + } + + fn nv_storage(element_ptype: PType, dimensions: u32) -> VortexResult { + let vector = + ExtDType::::try_new(EmptyMetadata, fsl_storage(element_ptype, dimensions))? + .erased(); + Ok(DType::Extension(vector)) + } + + #[test] + fn matches_normalized_vector_dtype() -> VortexResult<()> { + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, nv_storage(PType::F32, 128)?)? + .erased(); + + let metadata = ext_dtype.metadata::(); + assert_eq!(metadata.element_ptype(), PType::F32); + assert_eq!(metadata.dimensions(), 128); + Ok(()) + } + + #[test] + fn rejects_plain_vector() -> VortexResult<()> { + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl_storage(PType::F32, 128))?.erased(); + + assert!(ext_dtype.metadata_opt::().is_none()); + Ok(()) + } + + #[test] + fn any_vector_matches_normalized_vector() -> VortexResult<()> { + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, nv_storage(PType::F32, 128)?)? + .erased(); + + // `AnyVector` is the inclusive matcher: it matches both `Vector` and `NormalizedVector`. + // Callers that need to distinguish the two should pair it with an + // [`AnyNormalizedVector`] check, or use [`AnyTensor`](crate::matcher::AnyTensor) to also + // accept `FixedShapeTensor`. + let metadata = ext_dtype.metadata::(); + assert_eq!(metadata.element_ptype(), PType::F32); + assert_eq!(metadata.dimensions(), 128); + Ok(()) + } +} diff --git a/vortex-tensor/src/types/normalized_vector/mod.rs b/vortex-tensor/src/types/normalized_vector/mod.rs new file mode 100644 index 00000000000..7242bab9868 --- /dev/null +++ b/vortex-tensor/src/types/normalized_vector/mod.rs @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Normalized vector extension type over [`Vector`](crate::vector::Vector) storage whose +//! rows are guaranteed (or asserted, for lossy encodings) to have unit L2 norm. + +use num_traits::ToPrimitive; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::extension::EmptyMetadata; +use vortex_array::match_each_float_ptype; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use crate::types::vector::AnyVector; +use crate::types::vector::Vector; +use crate::utils::extract_flat_elements; +use crate::utils::unit_norm_tolerance; + +/// Extension type over [`Vector`](crate::vector::Vector) storage that asserts every valid row is +/// L2-normalized (unit-norm) or the zero vector. +/// +/// The storage dtype is `DType::Extension(Vector(FixedSizeList))`, i.e. a +/// [`Vector`](crate::vector::Vector) extension array. Downstream operators such as +/// [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm), +/// [`L2Norm`](crate::scalar_fns::l2_norm::L2Norm), +/// [`InnerProduct`](crate::scalar_fns::inner_product::InnerProduct), and +/// [`CosineSimilarity`](crate::scalar_fns::cosine_similarity::CosineSimilarity) short-circuit +/// arithmetic when they see this type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct NormalizedVector; + +impl NormalizedVector { + /// Wraps a [`FixedSizeList`](vortex_array::arrays::FixedSizeListArray) of float elements + /// as a [`NormalizedVector`] extension array, wrapping the FSL in a + /// [`Vector`](crate::vector::Vector) first. + /// + /// Every valid row is checked to be unit-norm or the zero vector before returning. + /// + /// # Errors + /// + /// Returns an error if `fsl` is not a `FixedSizeList` of non-nullable float elements, or if + /// any valid row's L2 norm is not `1.0` (or `0.0`) within the tolerance implied by the + /// element precision. + pub fn try_new(fsl: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let vector = Vector::try_new_vector_array(fsl)?; + // Validate before wrapping so we iterate the inner `Vector` storage directly. The + // `ExtensionArray::try_new_from_vtable` call below runs `validate_dtype` (which only + // checks the storage dtype shape), but the unit-norm check is a bulk row operation we + // run explicitly here. + validate_unit_norm_rows(&vector, ctx)?; + Ok( + ExtensionArray::try_new_from_vtable(NormalizedVector, EmptyMetadata, vector)? + .into_array(), + ) + } + + /// Wraps `fsl` as a [`NormalizedVector`] extension array **without** validating that rows + /// are unit-norm. The FSL is still wrapped in a [`Vector`](crate::vector::Vector) first. + /// + /// # Safety + /// + /// Every valid row must be unit-norm or the zero vector. Lossy approximations (e.g. + /// TurboQuant) deliberately relax this, but still treat the claim as authoritative + /// downstream. Violating this does not cause memory unsafety but will produce silently + /// incorrect results. + /// + /// # Errors + /// + /// Returns an error if `fsl` is not a `FixedSizeList` of non-nullable float elements. + pub unsafe fn new_unchecked(fsl: ArrayRef) -> VortexResult { + let vector = Vector::try_new_vector_array(fsl)?; + Ok( + ExtensionArray::try_new_from_vtable(NormalizedVector, EmptyMetadata, vector)? + .into_array(), + ) + } + + /// Wraps an already-constructed [`Vector`](crate::vector::Vector) extension array as a + /// [`NormalizedVector`] **without** validating that rows are unit-norm. + /// + /// # Safety + /// + /// Every valid row of `vector` must be unit-norm or the zero vector. + /// + /// # Errors + /// + /// Returns an error if `vector.dtype()` is not a `Vector` extension dtype. + pub unsafe fn wrap_vector_unchecked(vector: ArrayRef) -> VortexResult { + Ok( + ExtensionArray::try_new_from_vtable(NormalizedVector, EmptyMetadata, vector)? + .into_array(), + ) + } +} + +/// Validates that every valid row of a [`Vector`](crate::vector::Vector) extension array has L2 +/// norm `1.0` or `0.0` within the element-precision tolerance. +/// +/// The input is expected to be a `Vector` extension array (not a raw `FixedSizeList`), matching +/// the storage of a `NormalizedVector`. +pub(crate) fn validate_unit_norm_rows( + vector_array: &ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let row_count = vector_array.len(); + if row_count == 0 { + return Ok(()); + } + + let vector_metadata = vector_array.dtype().as_extension().metadata::(); + let element_ptype = vector_metadata.element_ptype(); + let dim = vector_metadata.dimensions() as usize; + let tolerance = unit_norm_tolerance(element_ptype, dim); + + let ext: ExtensionArray = vector_array.clone().execute(ctx)?; + let validity = ext.as_ref().validity()?; + let flat = extract_flat_elements(ext.storage_array(), dim, ctx)?; + + match_each_float_ptype!(element_ptype, |T| { + for i in 0..row_count { + if !validity.is_valid(i)? { + continue; + } + + let row_norm_sq = flat.row::(i).iter().fold(0.0f64, |sum_sq, x| { + let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN); + sum_sq + value * value + }); + let row_norm = row_norm_sq.sqrt(); + + vortex_ensure!( + row_norm == 0.0 || (row_norm - 1.0).abs() <= tolerance, + "NormalizedVector row {i} has L2 norm {row_norm:.6}, expected 1.0 or 0.0", + ); + } + }); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use half::f16; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::dtype::PType; + use vortex_array::validity::Validity; + use vortex_error::VortexResult; + + use super::NormalizedVector; + use crate::tests::SESSION; + use crate::utils::unit_norm_tolerance; + + #[test] + fn f16_unit_norm_tolerance_is_capped() { + assert!(unit_norm_tolerance(PType::F16, 768) <= 1e-3); + } + + #[test] + fn try_new_rejects_f16_row_outside_capped_tolerance() -> VortexResult<()> { + let dim = 768u32; + let dim_usize = usize::try_from(dim).expect("dim fits usize"); + let mut values = vec![f16::from_f32(0.0); dim_usize]; + values[0] = f16::from_f32(0.99); + + let elements = PrimitiveArray::from_iter(values).into_array(); + let fsl = FixedSizeListArray::try_new(elements, dim, Validity::NonNullable, 1)?; + let mut ctx = SESSION.create_execution_ctx(); + + assert!(NormalizedVector::try_new(fsl.into_array(), &mut ctx).is_err()); + Ok(()) + } +} + +mod matcher; +mod vtable; + +pub use matcher::AnyNormalizedVector; diff --git a/vortex-tensor/src/types/normalized_vector/vtable.rs b/vortex-tensor/src/types/normalized_vector/vtable.rs new file mode 100644 index 00000000000..f50a0ac98fd --- /dev/null +++ b/vortex-tensor/src/types/normalized_vector/vtable.rs @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtId; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::extension::EmptyMetadata; +use vortex_array::scalar::ScalarValue; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; + +use crate::types::normalized_vector::NormalizedVector; +use crate::types::vector::Vector; + +impl ExtVTable for NormalizedVector { + type Metadata = EmptyMetadata; + type NativeValue<'a> = &'a ScalarValue; + + fn id(&self) -> ExtId { + ExtId::new("vortex.tensor.normalized_vector") + } + + fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { + // Storage must be an extension-wrapped `Vector`. The inner `Vector` vtable's + // `validate_dtype` already ran when the inner `ExtDType` was constructed, so we + // only need to confirm the storage is in fact a `Vector` extension. + let DType::Extension(inner) = ext_dtype.storage_dtype() else { + vortex_bail!( + "`NormalizedVector` storage must be an extension type, got {}", + ext_dtype.storage_dtype(), + ); + }; + vortex_ensure!( + inner.is::(), + "`NormalizedVector` storage must be a `Vector` extension, got {}", + inner.id(), + ); + Ok(()) + } + + fn unpack_native<'a>( + _ext_dtype: &'a ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + // Per-scalar validation is a no-op: unit-norm is enforced in bulk by + // `validate_unit_norm_rows` at array construction, matching how `L2Denorm` + // validates up front rather than on each scalar access. + Ok(storage_value) + } + + fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { + Ok(Vec::new()) + } + + fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyMetadata) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rstest::rstest; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::dtype::extension::ExtVTable; + use vortex_array::extension::EmptyMetadata; + use vortex_error::VortexResult; + + use crate::types::normalized_vector::NormalizedVector; + use crate::types::vector::Vector; + + /// The NormalizedVector storage dtype is `DType::Extension(Vector(FSL))`. + fn nv_storage_dtype(ptype: PType, size: u32, nullability: Nullability) -> VortexResult { + let fsl = DType::FixedSizeList( + Arc::new(DType::Primitive(ptype, Nullability::NonNullable)), + size, + nullability, + ); + let vector = ExtDType::::try_new(EmptyMetadata, fsl)?.erased(); + Ok(DType::Extension(vector)) + } + + #[rstest] + #[case::f16(PType::F16)] + #[case::f32(PType::F32)] + #[case::f64(PType::F64)] + fn validate_accepts_float_types(#[case] ptype: PType) -> VortexResult<()> { + let storage = nv_storage_dtype(ptype, 64, Nullability::NonNullable)?; + ExtDType::::try_new(EmptyMetadata, storage)?; + Ok(()) + } + + #[rstest] + #[case::nullable(Nullability::Nullable)] + #[case::non_nullable(Nullability::NonNullable)] + fn validate_accepts_any_outer_nullability( + #[case] nullability: Nullability, + ) -> VortexResult<()> { + let storage = nv_storage_dtype(PType::F32, 64, nullability)?; + ExtDType::::try_new(EmptyMetadata, storage)?; + Ok(()) + } + + #[test] + fn validate_rejects_non_extension_storage() { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + 64, + Nullability::NonNullable, + ); + assert!(ExtDType::::try_new(EmptyMetadata, storage).is_err()); + } + + #[test] + fn roundtrip_metadata() -> VortexResult<()> { + let vtable = NormalizedVector; + let bytes = vtable.serialize_metadata(&EmptyMetadata)?; + assert!(bytes.is_empty()); + let deserialized = vtable.deserialize_metadata(&bytes)?; + assert_eq!(deserialized, EmptyMetadata); + Ok(()) + } +} diff --git a/vortex-tensor/src/types/vector/matcher.rs b/vortex-tensor/src/types/vector/matcher.rs index 7ac75f097db..e7a4b5c39e9 100644 --- a/vortex-tensor/src/types/vector/matcher.rs +++ b/vortex-tensor/src/types/vector/matcher.rs @@ -10,8 +10,16 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_panic; +use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; +/// Matcher that accepts any vector-shaped extension type (both plain [`Vector`] and +/// [`NormalizedVector`]). +/// +/// To match a plain [`Vector`] only (excluding [`NormalizedVector`]), pair this matcher with a +/// negated `is::()` check; to match a `NormalizedVector` only, use +/// [`AnyNormalizedVector`](crate::normalized_vector::AnyNormalizedVector) directly. Use +/// [`AnyTensor`](crate::matcher::AnyTensor) when `FixedShapeTensor` should also match. pub struct AnyVector; /// Convenience metadata for vectors. @@ -33,17 +41,41 @@ pub struct VectorMatcherMetadata { /// The number of dimensions of the vector. This is always fixed. dimensions: u32, + + ///`true` when the dtype is a [`NormalizedVector`]. + is_normalized: bool, } impl Matcher for AnyVector { type Match<'a> = VectorMatcherMetadata; fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option> { - if !ext_dtype.is::() { + // Walk to the inner `FixedSizeList` for whichever vector-shaped wrapper this is. Plain + // `Vector` stores the FSL directly; `NormalizedVector` wraps a `Vector` extension which + // in turn stores the FSL. + let (fsl_dtype, is_normalized) = if ext_dtype.is::() { + let DType::Extension(inner) = ext_dtype.storage_dtype() else { + vortex_panic!( + "`NormalizedVector` storage must be `DType::Extension(Vector)`, got {}", + ext_dtype.storage_dtype(), + ) + }; + + if !inner.is::() { + vortex_panic!( + "`NormalizedVector` inner extension must be `Vector`, got {}", + inner.id(), + ) + } + + (inner.storage_dtype(), true) + } else if ext_dtype.is::() { + (ext_dtype.storage_dtype(), false) + } else { return None; - } + }; - let DType::FixedSizeList(element_dtype, list_size, _) = ext_dtype.storage_dtype() else { + let DType::FixedSizeList(element_dtype, list_size, _) = fsl_dtype else { vortex_panic!("`Vector` type somehow did not have a `FixedSizeList` storage type") }; @@ -56,8 +88,9 @@ impl Matcher for AnyVector { ); let element_ptype = element_dtype.as_ptype(); - let vector_metadata = VectorMatcherMetadata::try_new(element_ptype, dimensions) - .vortex_expect("`Vector` type somehow did not have float elements"); + let vector_metadata = + VectorMatcherMetadata::try_new(element_ptype, dimensions, is_normalized) + .vortex_expect("`Vector` type somehow did not have float elements"); Some(vector_metadata) } @@ -69,12 +102,17 @@ impl VectorMatcherMetadata { /// # Errors /// /// Returns an error if the element type is not a float. - pub fn try_new(element_ptype: PType, dimensions: u32) -> VortexResult { + pub fn try_new( + element_ptype: PType, + dimensions: u32, + is_normalized: bool, + ) -> VortexResult { vortex_ensure!(element_ptype.is_float()); Ok(Self { element_ptype, dimensions, + is_normalized, }) } @@ -87,6 +125,12 @@ impl VectorMatcherMetadata { pub fn dimensions(&self) -> u32 { self.dimensions } + + /// Returns `true` when the dtype is a + /// [`NormalizedVector`](crate::normalized_vector::NormalizedVector). + pub fn is_normalized(self) -> bool { + self.is_normalized + } } #[cfg(test)] @@ -112,6 +156,18 @@ mod tests { ) } + fn normalized_vector_storage_dtype( + element_ptype: PType, + dimensions: u32, + ) -> VortexResult { + let inner = ExtDType::::try_new( + EmptyMetadata, + vector_storage_dtype(element_ptype, dimensions), + )? + .erased(); + Ok(DType::Extension(inner)) + } + #[test] fn matches_vector_dtype_metadata() -> VortexResult<()> { let ext_dtype = @@ -124,6 +180,22 @@ mod tests { Ok(()) } + #[test] + fn matches_normalized_vector_dtype_metadata() -> VortexResult<()> { + let ext_dtype = ExtDType::::try_new( + EmptyMetadata, + normalized_vector_storage_dtype(PType::F32, 256)?, + )? + .erased(); + + // `AnyVector` is the inclusive matcher: it matches `NormalizedVector` too and surfaces + // the inner `Vector`'s element ptype and dimensionality. + let metadata = ext_dtype.metadata::(); + assert_eq!(metadata.element_ptype(), PType::F32); + assert_eq!(metadata.dimensions(), 256); + Ok(()) + } + #[test] fn does_not_match_fixed_shape_tensor() -> VortexResult<()> { let ext_dtype = ExtDType::::try_new( diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 9dc097e11e0..bab89ce2b6c 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -8,9 +8,11 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Constant; use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFn; +use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::primitive::PrimitiveArrayExt; use vortex_array::arrays::scalar_fn::ExactScalarFn; @@ -19,6 +21,7 @@ use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::PType; +use vortex_array::dtype::extension::ExtDTypeRef; use vortex_array::dtype::proto::dtype as pb; use vortex_array::scalar_fn::ScalarFnVTable; use vortex_buffer::Buffer; @@ -31,12 +34,22 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::matcher::TensorMatch; +use crate::normalized_vector::NormalizedVector; use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::types::vector::VectorMatcherMetadata; +use crate::vector::AnyVector; +use crate::vector::Vector; /// Safety factor for unit-norm tolerance. Applied as a constant multiplier on the probabilistic /// `√d · ε` bound so that legitimate round-off noise clears the check with headroom. pub(crate) const SAFETY_FACTOR: usize = 10; +/// Upper bound for unit-norm validation drift. +/// +/// This keeps low-precision element types (especially f16) from accepting vectors whose norms are +/// materially different from 1.0 at common embedding dimensions. +pub(crate) const MAX_UNIT_NORM_TOLERANCE: f64 = 1e-3; + /// Returns the acceptable unit-norm drift for the given element precision and dimension count. /// /// Uses the `c · √d · ε` bound where ε is machine epsilon and d is the vector dimension. Under @@ -57,7 +70,7 @@ pub fn unit_norm_tolerance(element_ptype: PType, dimensions: usize) -> f64 { let dimensions_root = (dimensions as f64).sqrt(); - SAFETY_FACTOR as f64 * machine_epsilon * dimensions_root + (SAFETY_FACTOR as f64 * machine_epsilon * dimensions_root).min(MAX_UNIT_NORM_TOLERANCE) } /// Extracts the `(normalized, norms)` children from an [`L2Denorm`] scalar function array. @@ -94,18 +107,77 @@ pub fn validate_tensor_float_input(input_dtype: &DType) -> VortexResult( lhs: &'a DType, rhs: &DType, ) -> VortexResult> { + let dtypes_match = lhs.eq_ignore_nullability(rhs) || vector_shapes_match(lhs, rhs); vortex_ensure!( - lhs.eq_ignore_nullability(rhs), + dtypes_match, "binary tensor expression expects inputs to have the same dtype, got {lhs} and {rhs}" ); validate_tensor_float_input(lhs) } +/// Returns `true` when `lhs` and `rhs` are both within the vector extension family (plain +/// `Vector` or `NormalizedVector`) and share the same float ptype and dimension. +fn vector_shapes_match(lhs: &DType, rhs: &DType) -> bool { + fn vector_family_match(dtype: &DType) -> Option { + dtype.as_extension_opt()?.metadata_opt::() + } + + matches!( + (vector_family_match(lhs), vector_family_match(rhs)), + (Some(l), Some(r)) + if l.element_ptype() == r.element_ptype() && l.dimensions() == r.dimensions() + ) +} + +/// Returns the underlying `FixedSizeList` storage dtype for a vector-shaped extension dtype. +/// +/// For a plain [`Vector`], this is the direct storage dtype. For a [`NormalizedVector`] +/// it drills through one extra extension layer. +pub fn vector_fsl_storage_dtype(ext: &ExtDTypeRef) -> Option { + use vortex_array::dtype::DType; + if ext.is::() { + Some(ext.storage_dtype().clone()) + } else if ext.is::() { + let DType::Extension(inner) = ext.storage_dtype() else { + return None; + }; + if !inner.is::() { + return None; + } + Some(inner.storage_dtype().clone()) + } else { + None + } +} + +/// Returns the underlying `Vector` extension array inside a vector-shaped extension array. +/// +/// For a [`NormalizedVector`] array, this executes the outer extension and returns its +/// `Vector` storage child. For a plain [`Vector`] array, it returns the array itself (after +/// canonicalizing to an `ExtensionArray`). +pub fn inner_vector_array(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let is_normalized = array + .dtype() + .as_extension_opt() + .is_some_and(|ext| ext.is::()); + if is_normalized { + let ext: ExtensionArray = array.clone().execute(ctx)?; + Ok(ext.storage_array().clone()) + } else { + Ok(array.clone()) + } +} + /// Cast a float [`PrimitiveArray`] to a `Buffer`. /// /// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively @@ -334,6 +406,7 @@ pub mod test_helpers { use crate::scalar_fns::l2_denorm::L2Denorm; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; + use crate::types::normalized_vector::NormalizedVector; use crate::types::vector::Vector; /// Builds a `FixedSizeList` storage array from flat `elements`. The row count is @@ -372,6 +445,16 @@ pub mod test_helpers { Vector::try_new_vector_array(flat_fsl(elements, dim)) } + /// Builds a [`NormalizedVector`] extension array from pre-normalized `elements` and a vector + /// dimension size. The caller must ensure each row is unit-norm or the zero vector. + pub fn normalized_vector_array( + dim: u32, + elements: &[T], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + NormalizedVector::try_new(flat_fsl(elements, dim), ctx) + } + /// Builds a [`FixedShapeTensor`] extension array whose storage is a [`ConstantArray`], /// representing a single query tensor broadcast to `len` rows. pub fn constant_tensor_array>( @@ -399,17 +482,21 @@ pub mod test_helpers { ConstantArray::new(ext_scalar, len).into_array() } - /// Creates an [`L2Denorm`] scalar function array from pre-normalized tensor elements and + /// Creates an [`L2Denorm`] scalar function array from pre-normalized vector elements and /// matching norms. The caller must ensure every row of `normalized_elements` is unit-norm or /// zero. + /// + /// `dim` is the vector dimension (the inner `FixedSizeList` width). The number of rows is + /// inferred from `normalized_elements.len() / dim`. pub fn l2_denorm_array( - shape: &[usize], + dim: u32, normalized_elements: &[T], norms: &[T], ctx: &mut ExecutionCtx, ) -> VortexResult { let len = norms.len(); - let normalized = tensor_array(shape, normalized_elements)?; + let storage = flat_fsl(normalized_elements, dim); + let normalized = NormalizedVector::try_new(storage, ctx)?; let norms = PrimitiveArray::new(Buffer::copy_from(norms), Validity::NonNullable).into_array(); Ok(L2Denorm::try_new_array(normalized, norms, len, ctx)?.into_array()) diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index be253187956..10547ba011d 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -504,7 +504,7 @@ mod turboquant_benches { use vortex_array::VortexSessionExecute; use vortex_buffer::BufferMut; use vortex_tensor::encodings::turboquant::TurboQuantConfig; - use vortex_tensor::encodings::turboquant::turboquant_encode_unchecked; + use vortex_tensor::encodings::turboquant::turboquant_encode_normalized; use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm; use vortex_tensor::vector::Vector; @@ -573,10 +573,9 @@ mod turboquant_benches { .as_ref() .as_opt::() .expect("normalized benchmark input should be an Extension array"); - // SAFETY: Benchmark inputs are normalized once up front so the timed - // region measures only TurboQuant encoding. - unsafe { turboquant_encode_unchecked(normalized, &config, ctx) } - .unwrap() + // Benchmark inputs are normalized once up front so the timed region + // measures only TurboQuant encoding. + turboquant_encode_normalized(normalized, &config, ctx).unwrap() }); } } @@ -588,10 +587,9 @@ mod turboquant_benches { let normalized_ext = setup_normalized_vector_ext($dim); let config = turboquant_config($bits); let mut ctx = SESSION.create_execution_ctx(); - let compressed = unsafe { - turboquant_encode_unchecked(normalized_ext.as_view(), &config, &mut ctx) - } - .unwrap(); + let compressed = + turboquant_encode_normalized(normalized_ext.as_view(), &config, &mut ctx) + .unwrap(); with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) .with_inputs(|| (&compressed, SESSION.create_execution_ctx())) .bench_refs(|(a, ctx)| {