Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,9 @@ pub const vortex_tensor::encodings::turboquant::MAX_CENTROIDS: usize

pub const vortex_tensor::encodings::turboquant::MIN_DIMENSION: u32

pub fn vortex_tensor::encodings::turboquant::tq_validate_vector_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_tensor::vector::VectorMatcherMetadata>

pub fn vortex_tensor::encodings::turboquant::turboquant_encode(input: vortex_array::array::erased::ArrayRef, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

pub unsafe fn vortex_tensor::encodings::turboquant::turboquant_encode_unchecked(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
pub fn vortex_tensor::encodings::turboquant::turboquant_encode_normalized(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

pub mod vortex_tensor::fixed_shape

Expand Down Expand Up @@ -218,12 +216,16 @@ pub enum vortex_tensor::matcher::TensorMatch<'a>

pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>)

pub vortex_tensor::matcher::TensorMatch::NormalizedVector(vortex_tensor::vector::VectorMatcherMetadata)

pub vortex_tensor::matcher::TensorMatch::Vector(vortex_tensor::vector::VectorMatcherMetadata)

impl vortex_tensor::matcher::TensorMatch<'_>

pub fn vortex_tensor::matcher::TensorMatch<'_>::element_ptype(self) -> vortex_array::dtype::ptype::PType

pub fn vortex_tensor::matcher::TensorMatch<'_>::is_normalized(self) -> bool

pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> u32

impl<'a> core::clone::Clone for vortex_tensor::matcher::TensorMatch<'a>
Expand Down Expand Up @@ -252,6 +254,66 @@ pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::

pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>

pub mod vortex_tensor::normalized_vector

pub struct vortex_tensor::normalized_vector::AnyNormalizedVector

impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::normalized_vector::AnyNormalizedVector

pub type vortex_tensor::normalized_vector::AnyNormalizedVector::Match<'a> = vortex_tensor::vector::VectorMatcherMetadata

pub fn vortex_tensor::normalized_vector::AnyNormalizedVector::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>

pub struct vortex_tensor::normalized_vector::NormalizedVector

impl vortex_tensor::normalized_vector::NormalizedVector

pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::new_unchecked(fsl: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

pub fn vortex_tensor::normalized_vector::NormalizedVector::try_new(fsl: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::wrap_vector_unchecked(vector: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

impl core::clone::Clone for vortex_tensor::normalized_vector::NormalizedVector

pub fn vortex_tensor::normalized_vector::NormalizedVector::clone(&self) -> vortex_tensor::normalized_vector::NormalizedVector

impl core::cmp::Eq for vortex_tensor::normalized_vector::NormalizedVector

impl core::cmp::PartialEq for vortex_tensor::normalized_vector::NormalizedVector

pub fn vortex_tensor::normalized_vector::NormalizedVector::eq(&self, other: &vortex_tensor::normalized_vector::NormalizedVector) -> bool

impl core::default::Default for vortex_tensor::normalized_vector::NormalizedVector

pub fn vortex_tensor::normalized_vector::NormalizedVector::default() -> vortex_tensor::normalized_vector::NormalizedVector

impl core::fmt::Debug for vortex_tensor::normalized_vector::NormalizedVector

pub fn vortex_tensor::normalized_vector::NormalizedVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl core::hash::Hash for vortex_tensor::normalized_vector::NormalizedVector

pub fn vortex_tensor::normalized_vector::NormalizedVector::hash<__H: core::hash::Hasher>(&self, state: &mut __H)

impl core::marker::StructuralPartialEq for vortex_tensor::normalized_vector::NormalizedVector

impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::normalized_vector::NormalizedVector

pub type vortex_tensor::normalized_vector::NormalizedVector::Metadata = vortex_array::extension::EmptyMetadata

pub type vortex_tensor::normalized_vector::NormalizedVector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue

pub fn vortex_tensor::normalized_vector::NormalizedVector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult<Self::Metadata>

pub fn vortex_tensor::normalized_vector::NormalizedVector::id(&self) -> vortex_array::dtype::extension::ExtId

pub fn vortex_tensor::normalized_vector::NormalizedVector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult<alloc::vec::Vec<u8>>

pub fn vortex_tensor::normalized_vector::NormalizedVector::unpack_native<'a>(_ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType<Self>, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult<Self::NativeValue>

pub fn vortex_tensor::normalized_vector::NormalizedVector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>

pub mod vortex_tensor::scalar_fns

pub mod vortex_tensor::scalar_fns::cosine_similarity
Expand Down Expand Up @@ -384,8 +446,6 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options:

pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>

pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()>

pub mod vortex_tensor::scalar_fns::l2_norm

pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm
Expand Down Expand Up @@ -490,7 +550,7 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::clone(&self) ->

impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform

pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>

pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

Expand Down Expand Up @@ -578,7 +638,9 @@ pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32

pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType

pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult<Self>
pub fn vortex_tensor::vector::VectorMatcherMetadata::is_normalized(self) -> bool

pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32, is_normalized: bool) -> vortex_error::VortexResult<Self>

impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata

Expand Down
75 changes: 70 additions & 5 deletions vortex-tensor/src/encodings/l2_denorm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ use vortex_compressor::scheme::Scheme;
use vortex_compressor::stats::ArrayAndStats;
use vortex_error::VortexResult;

use crate::matcher::AnyTensor;
use crate::normalized_vector::AnyNormalizedVector;
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
use crate::types::vector::AnyVector;

#[derive(Debug)]
pub struct L2DenormScheme;
Expand All @@ -26,10 +27,14 @@ impl Scheme for L2DenormScheme {
}

fn matches(&self, canonical: &Canonical) -> bool {
matches!(
canonical,
Canonical::Extension(ext) if ext.ext_dtype().is::<AnyTensor>()
)
let Canonical::Extension(ext) = canonical else {
return false;
};

// `AnyVector` matches any vector-shaped extension; we explicitly exclude `NormalizedVector`
// here because a normalized input already carries an authoritative unit-norm representation
// and does not need re-normalization.
ext.ext_dtype().is::<AnyVector>() && !ext.ext_dtype().is::<AnyNormalizedVector>()
}

fn expected_compression_ratio(
Expand All @@ -38,6 +43,7 @@ impl Scheme for L2DenormScheme {
_compress_ctx: CompressorContext,
_exec_ctx: &mut ExecutionCtx,
) -> CompressionEstimate {
// We almost always want to pre-normalize our data if the vector is not already normalized.
CompressionEstimate::Verdict(EstimateVerdict::AlwaysUse)
}

Expand All @@ -52,3 +58,62 @@ impl Scheme for L2DenormScheme {
Ok(l2_denorm.into_array())
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use vortex_array::Canonical;
use vortex_array::IntoArray;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::extension::EmptyMetadata;
use vortex_array::validity::Validity;
use vortex_compressor::scheme::Scheme;
use vortex_error::VortexResult;

use super::L2DenormScheme;
use crate::types::fixed_shape::FixedShapeTensor;
use crate::types::fixed_shape::FixedShapeTensorMetadata;
use crate::types::vector::Vector;

fn fsl_storage(elements: &[f32], list_size: u32) -> VortexResult<FixedSizeListArray> {
let len = elements.len() / list_size as usize;
let elements = PrimitiveArray::from_iter(elements.iter().copied()).into_array();
FixedSizeListArray::try_new(elements, list_size, Validity::NonNullable, len)
}

#[test]
fn matches_vector() -> VortexResult<()> {
let fsl = fsl_storage(&[1.0, 0.0], 2)?;
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array()));

assert!(L2DenormScheme.matches(&canonical));
Ok(())
}

#[test]
fn rejects_fixed_shape_tensor() -> VortexResult<()> {
let fsl = fsl_storage(&[1.0, 0.0, 0.0, 1.0], 4)?;
let storage_dtype = DType::FixedSizeList(
Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)),
4,
Nullability::NonNullable,
);
let ext_dtype = ExtDType::<FixedShapeTensor>::try_new(
FixedShapeTensorMetadata::new(vec![2, 2]),
storage_dtype,
)?
.erased();
let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array()));

assert!(!L2DenormScheme.matches(&canonical));
Ok(())
}
}
Loading
Loading