diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 93a13a3b..82e66a28 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -110,6 +110,11 @@ jobs: RUST_LOG: DEBUG RUST_BACKTRACE: full + - name: Install lumina native library + run: | + pip install lumina-data + echo "LUMINA_LIB_PATH=$(python3 -c 'import lumina_data; print(lumina_data.__path__[0])')/lib/liblumina_py.so" >> $GITHUB_ENV + - name: DataFusion Integration Test run: cargo test -p paimon-datafusion --all-targets env: diff --git a/crates/integrations/datafusion/src/full_text_search.rs b/crates/integrations/datafusion/src/full_text_search.rs index f1b559b8..20ff38ff 100644 --- a/crates/integrations/datafusion/src/full_text_search.rs +++ b/crates/integrations/datafusion/src/full_text_search.rs @@ -37,11 +37,16 @@ use datafusion::error::Result as DFResult; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; -use paimon::catalog::{Catalog, Identifier}; +use paimon::catalog::Catalog; use crate::error::to_datafusion_error; use crate::runtime::{await_with_runtime, block_on_with_runtime}; use crate::table::{PaimonScanBuilder, PaimonTableProvider}; +use crate::table_function_args::{ + extract_int_literal, extract_string_literal, parse_table_identifier, +}; + +const FUNCTION_NAME: &str = "full_text_search"; /// Register the `full_text_search` table-valued function on a [`SessionContext`]. pub fn register_full_text_search( @@ -88,10 +93,10 @@ impl TableFunctionImpl for FullTextSearchFunction { )); } - let table_name = extract_string_literal(&args[0], "table_name")?; - let column_name = extract_string_literal(&args[1], "column_name")?; - let query_text = extract_string_literal(&args[2], "query_text")?; - let limit = extract_int_literal(&args[3], "limit")?; + let table_name = extract_string_literal(FUNCTION_NAME, &args[0], "table_name")?; + let column_name = extract_string_literal(FUNCTION_NAME, &args[1], "column_name")?; + let query_text = extract_string_literal(FUNCTION_NAME, &args[2], "query_text")?; + let limit = extract_int_literal(FUNCTION_NAME, &args[3], "limit")?; if limit <= 0 { return Err(datafusion::error::DataFusionError::Plan( @@ -99,7 +104,8 @@ impl TableFunctionImpl for FullTextSearchFunction { )); } - let identifier = parse_table_identifier(&table_name, &self.default_database)?; + let identifier = + parse_table_identifier(FUNCTION_NAME, &table_name, &self.default_database)?; let catalog = Arc::clone(&self.catalog); let table = block_on_with_runtime( @@ -201,58 +207,3 @@ impl TableProvider for FullTextSearchTableProvider { ]) } } - -fn extract_string_literal(expr: &Expr, name: &str) -> DFResult { - match expr { - Expr::Literal(scalar, _) => { - let s = scalar.try_as_str().flatten().ok_or_else(|| { - datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} must be a string literal, got: {expr}" - )) - })?; - Ok(s.to_string()) - } - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} must be a literal, got: {expr}" - ))), - } -} - -fn extract_int_literal(expr: &Expr, name: &str) -> DFResult { - use datafusion::common::ScalarValue; - match expr { - Expr::Literal(scalar, _) => match scalar { - ScalarValue::Int8(Some(v)) => Ok(*v as i64), - ScalarValue::Int16(Some(v)) => Ok(*v as i64), - ScalarValue::Int32(Some(v)) => Ok(*v as i64), - ScalarValue::Int64(Some(v)) => Ok(*v), - ScalarValue::UInt8(Some(v)) => Ok(*v as i64), - ScalarValue::UInt16(Some(v)) => Ok(*v as i64), - ScalarValue::UInt32(Some(v)) => Ok(*v as i64), - ScalarValue::UInt64(Some(v)) => i64::try_from(*v).map_err(|_| { - datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} value {v} exceeds i64 range" - )) - }), - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} must be an integer literal, got: {expr}" - ))), - }, - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} must be a literal, got: {expr}" - ))), - } -} - -fn parse_table_identifier(name: &str, default_database: &str) -> DFResult { - let parts: Vec<&str> = name.split('.').collect(); - match parts.len() { - 1 => Ok(Identifier::new(default_database, parts[0])), - 2 => Ok(Identifier::new(parts[0], parts[1])), - // 3-part name: catalog.database.table — ignore catalog prefix - 3 => Ok(Identifier::new(parts[1], parts[2])), - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "full_text_search: invalid table name '{name}', expected 'table', 'database.table', or 'catalog.database.table'" - ))), - } -} diff --git a/crates/integrations/datafusion/src/lib.rs b/crates/integrations/datafusion/src/lib.rs index 7dfa9b21..fa094147 100644 --- a/crates/integrations/datafusion/src/lib.rs +++ b/crates/integrations/datafusion/src/lib.rs @@ -48,7 +48,9 @@ pub mod runtime; mod sql_handler; mod system_tables; mod table; +mod table_function_args; mod update; +mod vector_search; pub use catalog::{PaimonCatalogProvider, PaimonSchemaProvider}; pub use error::to_datafusion_error; @@ -58,3 +60,4 @@ pub use physical_plan::PaimonTableScan; pub use relation_planner::PaimonRelationPlanner; pub use sql_handler::PaimonSqlHandler; pub use table::PaimonTableProvider; +pub use vector_search::{register_vector_search, VectorSearchFunction}; diff --git a/crates/integrations/datafusion/src/table_function_args.rs b/crates/integrations/datafusion/src/table_function_args.rs new file mode 100644 index 00000000..c0a6c833 --- /dev/null +++ b/crates/integrations/datafusion/src/table_function_args.rs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::ScalarValue; +use datafusion::error::{DataFusionError, Result as DFResult}; +use datafusion::logical_expr::Expr; +use paimon::catalog::Identifier; + +pub(crate) fn extract_string_literal( + function_name: &str, + expr: &Expr, + name: &str, +) -> DFResult { + match expr { + Expr::Literal(scalar, _) => { + let s = scalar.try_as_str().flatten().ok_or_else(|| { + DataFusionError::Plan(format!( + "{function_name}: {name} must be a string literal, got: {expr}" + )) + })?; + Ok(s.to_string()) + } + _ => Err(DataFusionError::Plan(format!( + "{function_name}: {name} must be a literal, got: {expr}" + ))), + } +} + +pub(crate) fn extract_int_literal(function_name: &str, expr: &Expr, name: &str) -> DFResult { + match expr { + Expr::Literal(scalar, _) => match scalar { + ScalarValue::Int8(Some(v)) => Ok(*v as i64), + ScalarValue::Int16(Some(v)) => Ok(*v as i64), + ScalarValue::Int32(Some(v)) => Ok(*v as i64), + ScalarValue::Int64(Some(v)) => Ok(*v), + ScalarValue::UInt8(Some(v)) => Ok(*v as i64), + ScalarValue::UInt16(Some(v)) => Ok(*v as i64), + ScalarValue::UInt32(Some(v)) => Ok(*v as i64), + ScalarValue::UInt64(Some(v)) => i64::try_from(*v).map_err(|_| { + DataFusionError::Plan(format!( + "{function_name}: {name} value {v} exceeds i64 range" + )) + }), + _ => Err(DataFusionError::Plan(format!( + "{function_name}: {name} must be an integer literal, got: {expr}" + ))), + }, + _ => Err(DataFusionError::Plan(format!( + "{function_name}: {name} must be a literal, got: {expr}" + ))), + } +} + +pub(crate) fn parse_table_identifier( + function_name: &str, + name: &str, + default_database: &str, +) -> DFResult { + let parts: Vec<&str> = name.split('.').collect(); + match parts.len() { + 1 => Ok(Identifier::new(default_database, parts[0])), + 2 => Ok(Identifier::new(parts[0], parts[1])), + 3 => Ok(Identifier::new(parts[1], parts[2])), + _ => Err(DataFusionError::Plan(format!( + "{function_name}: invalid table name '{name}', expected 'table', 'database.table', or 'catalog.database.table'" + ))), + } +} diff --git a/crates/integrations/datafusion/src/vector_search.rs b/crates/integrations/datafusion/src/vector_search.rs new file mode 100644 index 00000000..34daadc8 --- /dev/null +++ b/crates/integrations/datafusion/src/vector_search.rs @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; +use datafusion::catalog::Session; +use datafusion::catalog::TableFunctionImpl; +use datafusion::common::project_schema; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result as DFResult; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; +use datafusion::physical_plan::empty::EmptyExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use paimon::catalog::Catalog; + +use crate::error::to_datafusion_error; +use crate::runtime::{await_with_runtime, block_on_with_runtime}; +use crate::table::{PaimonScanBuilder, PaimonTableProvider}; +use crate::table_function_args::{ + extract_int_literal, extract_string_literal, parse_table_identifier, +}; + +const FUNCTION_NAME: &str = "vector_search"; + +pub fn register_vector_search( + ctx: &SessionContext, + catalog: Arc, + default_database: &str, +) { + ctx.register_udtf( + "vector_search", + Arc::new(VectorSearchFunction::new(catalog, default_database)), + ); +} + +pub struct VectorSearchFunction { + catalog: Arc, + default_database: String, +} + +impl Debug for VectorSearchFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VectorSearchFunction") + .field("default_database", &self.default_database) + .finish() + } +} + +impl VectorSearchFunction { + pub fn new(catalog: Arc, default_database: &str) -> Self { + Self { + catalog, + default_database: default_database.to_string(), + } + } +} + +impl TableFunctionImpl for VectorSearchFunction { + fn call(&self, args: &[Expr]) -> DFResult> { + if args.len() != 4 { + return Err(datafusion::error::DataFusionError::Plan( + "vector_search requires 4 arguments: (table_name, column_name, query_vector_json, limit)".to_string(), + )); + } + + let table_name = extract_string_literal(FUNCTION_NAME, &args[0], "table_name")?; + let column_name = extract_string_literal(FUNCTION_NAME, &args[1], "column_name")?; + let query_vector_json = + extract_string_literal(FUNCTION_NAME, &args[2], "query_vector_json")?; + let limit = extract_int_literal(FUNCTION_NAME, &args[3], "limit")?; + + if limit <= 0 { + return Err(datafusion::error::DataFusionError::Plan( + "vector_search: limit must be positive".to_string(), + )); + } + + let query_vector: Vec = serde_json::from_str(&query_vector_json).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "vector_search: query_vector_json must be a JSON array of floats, got '{}': {}", + query_vector_json, e + )) + })?; + + if query_vector.is_empty() { + return Err(datafusion::error::DataFusionError::Plan( + "vector_search: query vector cannot be empty".to_string(), + )); + } + + let identifier = + parse_table_identifier(FUNCTION_NAME, &table_name, &self.default_database)?; + + let catalog = Arc::clone(&self.catalog); + let table = block_on_with_runtime( + async move { catalog.get_table(&identifier).await }, + "vector_search: catalog access thread panicked", + ) + .map_err(to_datafusion_error)?; + + let inner = PaimonTableProvider::try_new(table)?; + + Ok(Arc::new(VectorSearchTableProvider { + inner, + column_name, + query_vector, + limit: limit as usize, + })) + } +} + +#[derive(Debug)] +struct VectorSearchTableProvider { + inner: PaimonTableProvider, + column_name: String, + query_vector: Vec, + limit: usize, +} + +#[async_trait] +impl TableProvider for VectorSearchTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> ArrowSchemaRef { + self.inner.schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + limit: Option, + ) -> DFResult> { + let table = self.inner.table(); + + let row_ranges = await_with_runtime(async { + let mut builder = table.new_vector_search_builder(); + builder + .with_vector_column(&self.column_name) + .with_query_vector(self.query_vector.clone()) + .with_limit(self.limit); + builder.execute().await.map_err(to_datafusion_error) + }) + .await?; + + if row_ranges.is_empty() { + let schema = project_schema(&self.schema(), projection)?; + return Ok(Arc::new(EmptyExec::new(schema))); + } + + let mut read_builder = table.new_read_builder(); + if let Some(limit) = limit { + read_builder.with_limit(limit); + } + let scan = read_builder.new_scan().with_row_ranges(row_ranges); + let plan = await_with_runtime(scan.plan()) + .await + .map_err(to_datafusion_error)?; + + let target = state.config_options().execution.target_partitions; + PaimonScanBuilder { + table, + schema: &self.schema(), + plan: &plan, + projection, + pushed_predicate: None, + limit, + target_partitions: target, + filter_exact: false, + } + .build() + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> DFResult> { + Ok(vec![ + TableProviderFilterPushDown::Unsupported; + filters.len() + ]) + } +} diff --git a/crates/integrations/datafusion/testdata/test_lumina_vector.tar.gz b/crates/integrations/datafusion/testdata/test_lumina_vector.tar.gz new file mode 100644 index 00000000..3909f95c Binary files /dev/null and b/crates/integrations/datafusion/testdata/test_lumina_vector.tar.gz differ diff --git a/crates/integrations/datafusion/tests/read_tables.rs b/crates/integrations/datafusion/tests/read_tables.rs index daaee458..9318a9a0 100644 --- a/crates/integrations/datafusion/tests/read_tables.rs +++ b/crates/integrations/datafusion/tests/read_tables.rs @@ -1046,3 +1046,111 @@ mod fulltext_tests { assert!(ids.contains(&3), "Searching 'search' should match row 3"); } } + +// ======================= Vector Search Tests ======================= + +mod vector_search_tests { + use std::sync::Arc; + + use datafusion::arrow::array::Int32Array; + use datafusion::prelude::SessionContext; + use paimon::{Catalog, CatalogOptions, FileSystemCatalog, Options}; + use paimon_datafusion::{register_vector_search, PaimonCatalogProvider}; + + fn extract_test_warehouse() -> (tempfile::TempDir, String) { + let archive_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("testdata/test_lumina_vector.tar.gz"); + let file = std::fs::File::open(&archive_path) + .unwrap_or_else(|e| panic!("Failed to open {}: {e}", archive_path.display())); + let decoder = flate2::read::GzDecoder::new(file); + let mut archive = tar::Archive::new(decoder); + + let tmp = tempfile::tempdir().expect("Failed to create temp dir"); + let db_dir = tmp.path().join("default.db"); + std::fs::create_dir_all(&db_dir).unwrap(); + archive.unpack(&db_dir).unwrap(); + + let warehouse = format!("file://{}", tmp.path().display()); + (tmp, warehouse) + } + + async fn create_vector_search_context() -> (SessionContext, tempfile::TempDir) { + let (tmp, warehouse) = extract_test_warehouse(); + let mut options = Options::new(); + options.set(CatalogOptions::WAREHOUSE, warehouse); + let catalog = FileSystemCatalog::new(options).expect("Failed to create catalog"); + let catalog: Arc = Arc::new(catalog); + + let ctx = SessionContext::new(); + ctx.register_catalog( + "paimon", + Arc::new(PaimonCatalogProvider::new(Arc::clone(&catalog))), + ); + register_vector_search(&ctx, catalog, "default"); + (ctx, tmp) + } + + fn extract_ids(batches: &[datafusion::arrow::record_batch::RecordBatch]) -> Vec { + let mut ids = Vec::new(); + for batch in batches { + let id_array = batch + .column_by_name("id") + .and_then(|c| c.as_any().downcast_ref::()) + .expect("Expected Int32Array for id"); + for i in 0..batch.num_rows() { + ids.push(id_array.value(i)); + } + } + ids.sort(); + ids + } + + #[tokio::test] + async fn test_vector_search_top3() { + let (ctx, _tmp) = create_vector_search_context().await; + let batches = ctx + .sql("SELECT id FROM vector_search('paimon.default.test_lumina_vector', 'embedding', '[1.0, 0.0, 0.0, 0.0]', 3)") + .await + .expect("SQL should parse") + .collect() + .await + .expect("query should execute"); + + let ids = extract_ids(&batches); + assert_eq!(ids.len(), 3); + assert!(ids.contains(&0), "exact match [1,0,0,0] should be in top 3"); + } + + #[tokio::test] + async fn test_vector_search_top6_returns_all() { + let (ctx, _tmp) = create_vector_search_context().await; + let batches = ctx + .sql("SELECT id FROM vector_search('paimon.default.test_lumina_vector', 'embedding', '[1.0, 0.0, 0.0, 0.0]', 6)") + .await + .expect("SQL should parse") + .collect() + .await + .expect("query should execute"); + + let ids = extract_ids(&batches); + assert_eq!(ids, vec![0, 1, 2, 3, 4, 5]); + } + + #[tokio::test] + async fn test_vector_search_without_matching_index_returns_empty() { + let (ctx, _tmp) = create_vector_search_context().await; + let batches = ctx + .sql("SELECT id FROM vector_search('paimon.default.test_lumina_vector', 'missing_embedding', '[1.0]', 10)") + .await + .expect("SQL should parse") + .collect() + .await + .expect("query should execute"); + + let total_rows: usize = batches.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!( + total_rows, 0, + "vector_search without a matching Lumina index should not fall back to a full table scan" + ); + } +} diff --git a/crates/paimon/Cargo.toml b/crates/paimon/Cargo.toml index 4bd16ad7..d1c428f9 100644 --- a/crates/paimon/Cargo.toml +++ b/crates/paimon/Cargo.toml @@ -88,6 +88,7 @@ tantivy = { version = "0.22", optional = true } tempfile = { version = "3", optional = true } vortex = { version = "0.68", features = ["tokio"], optional = true } kanal = { version = "0.1.1", optional = true } +libloading = "0.8" [dev-dependencies] axum = { version = "0.7", features = ["macros", "tokio", "http1", "http2"] } diff --git a/crates/paimon/src/lib.rs b/crates/paimon/src/lib.rs index 5aabe254..612477a1 100644 --- a/crates/paimon/src/lib.rs +++ b/crates/paimon/src/lib.rs @@ -31,6 +31,7 @@ pub mod catalog; mod deletion_vector; pub mod file_index; pub mod io; +pub mod lumina; mod predicate_stats; pub mod spec; pub mod table; diff --git a/crates/paimon/src/lumina/ffi.rs b/crates/paimon/src/lumina/ffi.rs new file mode 100644 index 00000000..75b89acb --- /dev/null +++ b/crates/paimon/src/lumina/ffi.rs @@ -0,0 +1,623 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use libloading::{Library, Symbol}; +use std::collections::HashMap; +use std::ffi::{c_char, c_float, c_int, c_void, CStr, CString}; +use std::io::{Read, Seek, SeekFrom}; +use std::sync::{Mutex, OnceLock}; + +const ERR_BUF_SIZE: usize = 4096; + +static LIBRARY: OnceLock = OnceLock::new(); +static LIBRARY_LOAD_LOCK: Mutex<()> = Mutex::new(()); + +fn load_library() -> crate::Result<&'static Library> { + if let Some(lib) = LIBRARY.get() { + return Ok(lib); + } + + let _guard = LIBRARY_LOAD_LOCK + .lock() + .map_err(|_| crate::Error::UnexpectedError { + message: "Lumina library load lock poisoned".to_string(), + source: None, + })?; + if let Some(lib) = LIBRARY.get() { + return Ok(lib); + } + + let lib_path = std::env::var("LUMINA_LIB_PATH").unwrap_or_else(|_| { + if cfg!(target_os = "macos") { + "liblumina_py.dylib".to_string() + } else if cfg!(target_os = "windows") { + "lumina_py.dll".to_string() + } else { + "liblumina_py.so".to_string() + } + }); + let lib = unsafe { + Library::new(&lib_path).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to load lumina library from '{}': {}", lib_path, e), + source: None, + })? + }; + LIBRARY + .set(lib) + .map_err(|_| crate::Error::UnexpectedError { + message: "Lumina library was initialized unexpectedly while holding load lock" + .to_string(), + source: None, + })?; + Ok(LIBRARY + .get() + .expect("Lumina library should be initialized after successful set")) +} + +fn check_error(ret: c_int, err_buf: &[u8; ERR_BUF_SIZE]) -> crate::Result<()> { + if ret != 0 { + let c_str = unsafe { CStr::from_ptr(err_buf.as_ptr() as *const c_char) }; + let msg = c_str.to_string_lossy().to_string(); + return Err(crate::Error::DataInvalid { + message: format!("Lumina error: {}", msg), + source: None, + }); + } + Ok(()) +} + +fn options_to_json(options: &HashMap) -> crate::Result { + let json = serde_json::to_string(options).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to serialize options: {}", e), + source: None, + })?; + CString::new(json).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to create CString: {}", e), + source: None, + }) +} + +pub struct LuminaSearcher { + handle: *mut c_void, + /// Keeps the stream context alive while C-side holds a raw pointer to it. + stream_ctx_keepalive: Option>, +} + +// SAFETY: Each LuminaSearcher owns its handle exclusively and is not Sync. +// Send allows moving the searcher to another thread. +unsafe impl Send for LuminaSearcher {} + +impl LuminaSearcher { + pub fn create(options: &HashMap) -> crate::Result { + let lib = load_library()?; + let opts_json = options_to_json(options)?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let handle: *mut c_void = unsafe { + let func: Symbol< + unsafe extern "C" fn(*const c_char, *mut c_char, c_int) -> *mut c_void, + > = lib + .get(b"lumina_searcher_create") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_create not found: {}", e), + source: None, + })?; + func( + opts_json.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + if handle.is_null() { + let c_str = unsafe { CStr::from_ptr(err_buf.as_ptr() as *const c_char) }; + let msg = c_str.to_string_lossy().to_string(); + return Err(crate::Error::DataInvalid { + message: format!("Failed to create Lumina searcher: {}", msg), + source: None, + }); + } + + Ok(Self { + handle, + stream_ctx_keepalive: None, + }) + } + + #[allow(clippy::type_complexity)] + pub fn open_stream(&mut self, stream: S) -> crate::Result<()> { + if self.stream_ctx_keepalive.is_some() { + return Err(crate::Error::DataInvalid { + message: "A stream is already open; close the searcher before opening a new stream" + .to_string(), + source: None, + }); + } + + let lib = load_library()?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ctx = Box::new(StreamContext::new(stream)); + let ctx_ptr = &*ctx as *const StreamContext as *mut c_void; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *mut c_void, + unsafe extern "C" fn(*mut c_void, *mut c_char, u64) -> c_int, + unsafe extern "C" fn(*mut c_void, u64) -> c_int, + unsafe extern "C" fn(*mut c_void) -> u64, + unsafe extern "C" fn(*mut c_void) -> u64, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_searcher_open_stream") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_open_stream not found: {}", e), + source: None, + })?; + func( + self.handle, + ctx_ptr, + stream_read_cb, + stream_seek_cb, + stream_tell_cb, + stream_length_cb, + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf)?; + self.stream_ctx_keepalive = Some(ctx); + Ok(()) + } + + pub fn open_file(&mut self, path: &str) -> crate::Result<()> { + if self.stream_ctx_keepalive.is_some() { + return Err(crate::Error::DataInvalid { + message: "A stream is already open; close the searcher before opening a file" + .to_string(), + source: None, + }); + } + + let lib = load_library()?; + let c_path = CString::new(path).map_err(|e| crate::Error::DataInvalid { + message: format!("Invalid path: {}", e), + source: None, + })?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn(*mut c_void, *const c_char, *mut c_char, c_int) -> c_int, + > = lib + .get(b"lumina_searcher_open") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_open not found: {}", e), + source: None, + })?; + func( + self.handle, + c_path.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + + pub fn search( + &self, + query: &[f32], + n: i32, + k: i32, + distances: &mut [f32], + labels: &mut [u64], + options: &HashMap, + ) -> crate::Result<()> { + let lib = load_library()?; + let opts_json = options_to_json(options)?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *const c_float, + c_int, + c_int, + *mut c_float, + *mut u64, + *const c_char, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_searcher_search") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_search not found: {}", e), + source: None, + })?; + func( + self.handle, + query.as_ptr(), + n, + k, + distances.as_mut_ptr(), + labels.as_mut_ptr(), + opts_json.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + + #[allow(clippy::too_many_arguments, clippy::type_complexity)] + pub fn search_with_filter( + &self, + query: &[f32], + n: i32, + k: i32, + distances: &mut [f32], + labels: &mut [u64], + filter_ids: &[u64], + options: &HashMap, + ) -> crate::Result<()> { + let lib = load_library()?; + let opts_json = options_to_json(options)?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *const c_float, + c_int, + c_int, + *mut c_float, + *mut u64, + *const u64, + u64, + *const c_char, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_searcher_search_with_filter") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_search_with_filter not found: {}", e), + source: None, + })?; + func( + self.handle, + query.as_ptr(), + n, + k, + distances.as_mut_ptr(), + labels.as_mut_ptr(), + filter_ids.as_ptr(), + filter_ids.len() as u64, + opts_json.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + + pub fn get_count(&self) -> crate::Result { + let lib = load_library()?; + unsafe { + let func: Symbol u64> = lib + .get(b"lumina_searcher_get_count") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_get_count not found: {}", e), + source: None, + })?; + Ok(func(self.handle)) + } + } + + pub fn get_dimension(&self) -> crate::Result { + let lib = load_library()?; + unsafe { + let func: Symbol u32> = lib + .get(b"lumina_searcher_get_dimension") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_get_dimension not found: {}", e), + source: None, + })?; + Ok(func(self.handle)) + } + } +} + +impl Drop for LuminaSearcher { + fn drop(&mut self) { + if !self.handle.is_null() { + if let Ok(lib) = load_library() { + unsafe { + if let Ok(func) = + lib.get::(b"lumina_searcher_destroy") + { + func(self.handle); + } + } + } + self.handle = std::ptr::null_mut(); + } + } +} + +pub struct LuminaBuilder { + handle: *mut c_void, +} + +// SAFETY: Same as LuminaSearcher — exclusively owned, not Sync. +unsafe impl Send for LuminaBuilder {} + +impl LuminaBuilder { + pub fn create(options: &HashMap) -> crate::Result { + let lib = load_library()?; + let opts_json = options_to_json(options)?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let handle: *mut c_void = unsafe { + let func: Symbol< + unsafe extern "C" fn(*const c_char, *mut c_char, c_int) -> *mut c_void, + > = lib + .get(b"lumina_builder_create") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_builder_create not found: {}", e), + source: None, + })?; + func( + opts_json.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + if handle.is_null() { + let c_str = unsafe { CStr::from_ptr(err_buf.as_ptr() as *const c_char) }; + let msg = c_str.to_string_lossy().to_string(); + return Err(crate::Error::DataInvalid { + message: format!("Failed to create Lumina builder: {}", msg), + source: None, + }); + } + + Ok(Self { handle }) + } + + pub fn pretrain(&self, vectors: &[f32], n: i32, dim: i32) -> crate::Result<()> { + let lib = load_library()?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *const c_float, + c_int, + c_int, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_builder_pretrain") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_builder_pretrain not found: {}", e), + source: None, + })?; + func( + self.handle, + vectors.as_ptr(), + n, + dim, + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + + pub fn insert(&self, vectors: &[f32], ids: &[u64], n: i32, dim: i32) -> crate::Result<()> { + let lib = load_library()?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *const c_float, + *const u64, + c_int, + c_int, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_builder_insert") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_builder_insert not found: {}", e), + source: None, + })?; + func( + self.handle, + vectors.as_ptr(), + ids.as_ptr(), + n, + dim, + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + + pub fn dump(&self, path: &str) -> crate::Result<()> { + let lib = load_library()?; + let c_path = CString::new(path).map_err(|e| crate::Error::DataInvalid { + message: format!("Invalid path: {}", e), + source: None, + })?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn(*mut c_void, *const c_char, *mut c_char, c_int) -> c_int, + > = lib + .get(b"lumina_builder_dump") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_builder_dump not found: {}", e), + source: None, + })?; + func( + self.handle, + c_path.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } +} + +impl Drop for LuminaBuilder { + fn drop(&mut self) { + if !self.handle.is_null() { + if let Ok(lib) = load_library() { + unsafe { + if let Ok(func) = + lib.get::(b"lumina_builder_destroy") + { + func(self.handle); + } + } + } + self.handle = std::ptr::null_mut(); + } + } +} + +struct StreamContext { + inner: std::sync::Mutex>, +} + +trait ReadSeekLen: Read + Seek { + fn length(&self) -> u64; +} + +struct ReadSeekLenImpl { + stream: S, + len: u64, +} + +impl ReadSeekLenImpl { + fn new(mut stream: S) -> Self { + let len = stream.seek(SeekFrom::End(0)).unwrap_or(0); + let _ = stream.seek(SeekFrom::Start(0)); + Self { stream, len } + } +} + +impl Read for ReadSeekLenImpl { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.stream.read(buf) + } +} + +impl Seek for ReadSeekLenImpl { + fn seek(&mut self, pos: SeekFrom) -> std::io::Result { + self.stream.seek(pos) + } +} + +impl ReadSeekLen for ReadSeekLenImpl { + fn length(&self) -> u64 { + self.len + } +} + +impl StreamContext { + fn new(stream: S) -> Self { + Self { + inner: std::sync::Mutex::new(Box::new(ReadSeekLenImpl::new(stream))), + } + } +} + +unsafe extern "C" fn stream_read_cb(ctx: *mut c_void, buf: *mut c_char, size: u64) -> c_int { + let ctx = &*(ctx as *const StreamContext); + let mut guard = match ctx.inner.lock() { + Ok(g) => g, + Err(_) => return -1, + }; + let clamped_size = std::cmp::min(size, c_int::MAX as u64) as usize; + let slice = std::slice::from_raw_parts_mut(buf as *mut u8, clamped_size); + let mut total_read = 0usize; + while total_read < clamped_size { + match guard.read(&mut slice[total_read..]) { + Ok(0) => break, + Ok(n) => total_read += n, + Err(_) => return -1, + } + } + std::cmp::min(total_read, c_int::MAX as usize) as c_int +} + +unsafe extern "C" fn stream_seek_cb(ctx: *mut c_void, position: u64) -> c_int { + let ctx = &*(ctx as *const StreamContext); + let mut guard = match ctx.inner.lock() { + Ok(g) => g, + Err(_) => return -1, + }; + match guard.seek(SeekFrom::Start(position)) { + Ok(_) => 0, + Err(_) => -1, + } +} + +unsafe extern "C" fn stream_tell_cb(ctx: *mut c_void) -> u64 { + let ctx = &*(ctx as *const StreamContext); + let mut guard = match ctx.inner.lock() { + Ok(g) => g, + Err(_) => return 0, + }; + guard.seek(SeekFrom::Current(0)).unwrap_or(0) +} + +unsafe extern "C" fn stream_length_cb(ctx: *mut c_void) -> u64 { + let ctx = &*(ctx as *const StreamContext); + let guard = match ctx.inner.lock() { + Ok(g) => g, + Err(_) => return 0, + }; + guard.length() +} diff --git a/crates/paimon/src/lumina/mod.rs b/crates/paimon/src/lumina/mod.rs new file mode 100644 index 00000000..04b4f87e --- /dev/null +++ b/crates/paimon/src/lumina/mod.rs @@ -0,0 +1,711 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod ffi; +pub mod reader; + +use std::collections::HashMap; + +pub const LUMINA_VECTOR_ANN_IDENTIFIER: &str = "lumina-vector-ann"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LuminaVectorMetric { + L2, + Cosine, + InnerProduct, +} + +impl LuminaVectorMetric { + pub fn lumina_name(&self) -> &str { + match self { + LuminaVectorMetric::L2 => "l2", + LuminaVectorMetric::Cosine => "cosine", + LuminaVectorMetric::InnerProduct => "inner_product", + } + } + + pub fn from_string(name: &str) -> crate::Result { + match name.to_uppercase().as_str() { + "L2" => Ok(LuminaVectorMetric::L2), + "COSINE" => Ok(LuminaVectorMetric::Cosine), + "INNER_PRODUCT" => Ok(LuminaVectorMetric::InnerProduct), + _ => Err(crate::Error::DataInvalid { + message: format!("Unknown metric name: {}", name), + source: None, + }), + } + } + + pub fn from_lumina_name(lumina_name: &str) -> crate::Result { + match lumina_name { + "l2" => Ok(LuminaVectorMetric::L2), + "cosine" => Ok(LuminaVectorMetric::Cosine), + "inner_product" => Ok(LuminaVectorMetric::InnerProduct), + _ => Err(crate::Error::DataInvalid { + message: format!("Unknown lumina metric name: {}", lumina_name), + source: None, + }), + } + } +} + +const LUMINA_PREFIX: &str = "lumina."; + +const ALL_OPTIONS_DEFAULTS: &[(&str, &str)] = &[ + ("lumina.index.dimension", "128"), + ("lumina.index.type", "diskann"), + ("lumina.distance.metric", "inner_product"), + ("lumina.encoding.type", "pq"), + ("lumina.pretrain.sample_ratio", "0.2"), + ("lumina.diskann.build.ef_construction", "1024"), + ("lumina.diskann.build.neighbor_count", "64"), + ("lumina.diskann.build.thread_count", "32"), + ("lumina.diskann.search.beam_width", "4"), + ("lumina.encoding.pq.m", "64"), + ("lumina.search.parallel_number", "5"), +]; + +pub struct LuminaVectorIndexOptions { + pub dimension: i32, + pub metric: LuminaVectorMetric, + pub index_type: String, + lumina_options: HashMap, +} + +impl LuminaVectorIndexOptions { + pub fn new(paimon_options: &HashMap) -> crate::Result { + let dimension_str = paimon_options + .get("lumina.index.dimension") + .map(|s| s.as_str()) + .unwrap_or("128"); + let dimension: i32 = dimension_str + .parse() + .map_err(|_| crate::Error::DataInvalid { + message: format!("Invalid dimension: {}", dimension_str), + source: None, + })?; + if dimension <= 0 { + return Err(crate::Error::DataInvalid { + message: format!( + "Invalid value for 'lumina.index.dimension': {}. Must be a positive integer.", + dimension + ), + source: None, + }); + } + + let metric_str = paimon_options + .get("lumina.distance.metric") + .map(|s| s.as_str()) + .unwrap_or("inner_product"); + let metric = LuminaVectorMetric::from_lumina_name(metric_str) + .or_else(|_| LuminaVectorMetric::from_string(metric_str))?; + + let encoding = paimon_options + .get("lumina.encoding.type") + .map(|s| s.as_str()) + .unwrap_or("pq"); + validate_encoding_metric(encoding, metric)?; + + let index_type = paimon_options + .get("lumina.index.type") + .cloned() + .unwrap_or_else(|| "diskann".to_string()); + + let lumina_options = build_lumina_options(paimon_options, dimension)?; + + Ok(Self { + dimension, + metric, + index_type, + lumina_options, + }) + } + + pub fn to_lumina_options(&self) -> HashMap { + self.lumina_options.clone() + } +} + +fn validate_encoding_metric(encoding: &str, metric: LuminaVectorMetric) -> crate::Result<()> { + if encoding.eq_ignore_ascii_case("pq") && metric == LuminaVectorMetric::Cosine { + return Err(crate::Error::DataInvalid { + message: + "Lumina does not support PQ encoding with cosine metric. \ + Please use 'rawf32' or 'sq8' encoding, or switch to 'l2' or 'inner_product' metric." + .to_string(), + source: None, + }); + } + Ok(()) +} + +fn validate_and_cap_pq_m(opts: &mut HashMap, dimension: i32) -> crate::Result<()> { + let encoding = opts.get("encoding.type").map(|s| s.as_str()).unwrap_or(""); + if !encoding.eq_ignore_ascii_case("pq") { + return Ok(()); + } + if let Some(pq_m_str) = opts.get("encoding.pq.m") { + let pq_m: i32 = pq_m_str.parse().map_err(|_| crate::Error::DataInvalid { + message: format!("encoding.pq.m must be an integer, got: {}", pq_m_str), + source: None, + })?; + if pq_m <= 0 { + return Err(crate::Error::DataInvalid { + message: format!("encoding.pq.m must be positive, got: {}", pq_m), + source: None, + }); + } + if pq_m > dimension { + opts.insert("encoding.pq.m".to_string(), dimension.to_string()); + } + } + Ok(()) +} + +fn build_lumina_options( + paimon_options: &HashMap, + dimension: i32, +) -> crate::Result> { + let mut result = HashMap::new(); + + for &(paimon_key, default_value) in ALL_OPTIONS_DEFAULTS { + let native_key = &paimon_key[LUMINA_PREFIX.len()..]; + let value = paimon_options + .get(paimon_key) + .map(|s| s.as_str()) + .unwrap_or(default_value); + result.insert(native_key.to_string(), value.to_string()); + } + + for (key, value) in paimon_options { + if let Some(native_key) = key.strip_prefix(LUMINA_PREFIX) { + result + .entry(native_key.to_string()) + .or_insert_with(|| value.to_string()); + } + } + + validate_and_cap_pq_m(&mut result, dimension)?; + Ok(result) +} + +pub fn strip_lumina_options(paimon_options: &HashMap) -> HashMap { + let mut result = HashMap::new(); + for (key, value) in paimon_options { + if let Some(native_key) = key.strip_prefix(LUMINA_PREFIX) { + result.insert(native_key.to_string(), value.to_string()); + } + } + result +} + +#[derive(Clone)] +pub struct VectorSearch { + pub vector: Vec, + pub limit: usize, + pub field_name: String, + pub include_row_ids: Option, +} + +impl VectorSearch { + pub fn new(vector: Vec, limit: usize, field_name: String) -> crate::Result { + if vector.is_empty() { + return Err(crate::Error::DataInvalid { + message: "Search vector cannot be empty".to_string(), + source: None, + }); + } + if limit == 0 || limit > i32::MAX as usize { + return Err(crate::Error::DataInvalid { + message: format!("Limit must be between 1 and {}, got: {}", i32::MAX, limit), + source: None, + }); + } + if field_name.is_empty() { + return Err(crate::Error::DataInvalid { + message: "Field name cannot be null or empty".to_string(), + source: None, + }); + } + Ok(Self { + vector, + limit, + field_name, + include_row_ids: None, + }) + } + + pub fn with_include_row_ids(mut self, include_row_ids: roaring::RoaringTreemap) -> Self { + self.include_row_ids = Some(include_row_ids); + self + } +} + +impl std::fmt::Display for VectorSearch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "VectorSearch(field_name={}, limit={})", + self.field_name, self.limit + ) + } +} + +pub struct GlobalIndexIOMeta { + pub file_path: String, + pub file_size: u64, + pub metadata: Vec, +} + +impl GlobalIndexIOMeta { + pub fn new(file_path: String, file_size: u64, metadata: Vec) -> Self { + Self { + file_path, + file_size, + metadata, + } + } +} + +pub const KEY_DIMENSION: &str = "index.dimension"; +pub const KEY_DISTANCE_METRIC: &str = "distance.metric"; +pub const KEY_INDEX_TYPE: &str = "index.type"; + +pub struct LuminaIndexMeta { + options: HashMap, +} + +impl LuminaIndexMeta { + pub fn new(options: HashMap) -> Self { + Self { options } + } + + pub fn options(&self) -> &HashMap { + &self.options + } + + pub fn dim(&self) -> crate::Result { + let val = self + .options + .get(KEY_DIMENSION) + .ok_or_else(|| crate::Error::DataInvalid { + message: format!("Missing required key: {}", KEY_DIMENSION), + source: None, + })?; + val.parse::().map_err(|_| crate::Error::DataInvalid { + message: format!("Invalid dimension value: {}", val), + source: None, + }) + } + + pub fn distance_metric(&self) -> &str { + self.options + .get(KEY_DISTANCE_METRIC) + .map(String::as_str) + .unwrap_or("") + } + + pub fn metric(&self) -> crate::Result { + LuminaVectorMetric::from_lumina_name(self.distance_metric()) + } + + pub fn index_type(&self) -> &str { + self.options + .get(KEY_INDEX_TYPE) + .map(String::as_str) + .unwrap_or("diskann") + } + + pub fn serialize(&self) -> crate::Result> { + serde_json::to_vec(&self.options).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to serialize LuminaIndexMeta: {}", e), + source: None, + }) + } + + pub fn deserialize(data: &[u8]) -> crate::Result { + let options: HashMap = + serde_json::from_slice(data).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to deserialize LuminaIndexMeta: {}", e), + source: None, + })?; + if !options.contains_key(KEY_DIMENSION) { + return Err(crate::Error::DataInvalid { + message: format!( + "Missing required key in Lumina index metadata: {}", + KEY_DIMENSION + ), + source: None, + }); + } + if !options.contains_key(KEY_DISTANCE_METRIC) { + return Err(crate::Error::DataInvalid { + message: format!( + "Missing required key in Lumina index metadata: {}", + KEY_DISTANCE_METRIC + ), + source: None, + }); + } + Ok(Self { options }) + } +} + +#[derive(Debug, Clone)] +pub struct SearchResult { + pub row_ids: Vec, + pub scores: Vec, +} + +impl SearchResult { + pub fn new(row_ids: Vec, scores: Vec) -> Self { + assert_eq!(row_ids.len(), scores.len()); + Self { row_ids, scores } + } + + pub fn empty() -> Self { + Self { + row_ids: Vec::new(), + scores: Vec::new(), + } + } + + pub fn from_scored_map(map: HashMap) -> Self { + let mut row_ids = Vec::with_capacity(map.len()); + let mut scores = Vec::with_capacity(map.len()); + for (id, score) in map { + row_ids.push(id); + scores.push(score); + } + Self { row_ids, scores } + } + + pub fn len(&self) -> usize { + self.row_ids.len() + } + + pub fn is_empty(&self) -> bool { + self.row_ids.is_empty() + } + + pub fn offset(&self, offset: i64) -> Self { + if offset == 0 { + return self.clone(); + } + let row_ids = self + .row_ids + .iter() + .map(|&id| { + if offset >= 0 { + id.saturating_add(offset as u64) + } else { + id.saturating_sub(offset.unsigned_abs()) + } + }) + .collect(); + Self { + row_ids, + scores: self.scores.clone(), + } + } + + pub fn or(&self, other: &SearchResult) -> Self { + let mut row_ids = self.row_ids.clone(); + let mut scores = self.scores.clone(); + row_ids.extend_from_slice(&other.row_ids); + scores.extend_from_slice(&other.scores); + Self { row_ids, scores } + } + + pub fn top_k(&self, k: usize) -> Self { + if self.row_ids.len() <= k { + return self.clone(); + } + let mut indices: Vec = (0..self.row_ids.len()).collect(); + indices.sort_by(|&a, &b| { + self.scores[b] + .partial_cmp(&self.scores[a]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + indices.truncate(k); + let row_ids = indices.iter().map(|&i| self.row_ids[i]).collect(); + let scores = indices.iter().map(|&i| self.scores[i]).collect(); + Self { row_ids, scores } + } + + pub fn to_row_ranges(&self) -> crate::Result> { + if self.row_ids.is_empty() { + return Ok(Vec::new()); + } + + let mut sorted = self + .row_ids + .iter() + .copied() + .map(|id| { + i64::try_from(id).map_err(|_| crate::Error::DataInvalid { + message: format!( + "Lumina search row id {id} exceeds i64::MAX and cannot be converted to RowRange" + ), + source: None, + }) + }) + .collect::>>()?; + + sorted.sort_unstable(); + sorted.dedup(); + let mut ranges = Vec::new(); + let mut start = sorted[0]; + let mut end = start; + for &id in &sorted[1..] { + if end.checked_add(1) == Some(id) { + end = id; + } else { + ranges.push(crate::table::RowRange::new(start, end)); + start = id; + end = id; + } + } + ranges.push(crate::table::RowRange::new(start, end)); + Ok(ranges) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metric_roundtrip() { + for metric in [ + LuminaVectorMetric::L2, + LuminaVectorMetric::Cosine, + LuminaVectorMetric::InnerProduct, + ] { + let name = metric.lumina_name(); + assert_eq!(LuminaVectorMetric::from_lumina_name(name).unwrap(), metric); + assert_eq!( + LuminaVectorMetric::from_string(&name.to_uppercase()).unwrap(), + metric + ); + } + assert!(LuminaVectorMetric::from_string("hamming").is_err()); + } + + #[test] + fn test_index_meta_serialize_deserialize() { + let mut options = HashMap::new(); + options.insert(KEY_DIMENSION.to_string(), "128".to_string()); + options.insert(KEY_DISTANCE_METRIC.to_string(), "l2".to_string()); + options.insert(KEY_INDEX_TYPE.to_string(), "diskann".to_string()); + let meta = LuminaIndexMeta::new(options); + + let bytes = meta.serialize().unwrap(); + let meta2 = LuminaIndexMeta::deserialize(&bytes).unwrap(); + assert_eq!(meta2.dim().unwrap(), 128); + assert_eq!(meta2.distance_metric(), "l2"); + assert_eq!(meta2.index_type(), "diskann"); + } + + #[test] + fn test_index_meta_deserialize_missing_fields() { + // missing dimension + let mut opts = HashMap::new(); + opts.insert(KEY_DISTANCE_METRIC.to_string(), "l2".to_string()); + assert!(LuminaIndexMeta::deserialize(&serde_json::to_vec(&opts).unwrap()).is_err()); + + // missing metric + let mut opts = HashMap::new(); + opts.insert(KEY_DIMENSION.to_string(), "128".to_string()); + assert!(LuminaIndexMeta::deserialize(&serde_json::to_vec(&opts).unwrap()).is_err()); + + // invalid json + assert!(LuminaIndexMeta::deserialize(b"not json").is_err()); + } + + #[test] + fn test_dim_error_on_invalid() { + let mut opts = HashMap::new(); + opts.insert(KEY_DIMENSION.to_string(), "abc".to_string()); + opts.insert(KEY_DISTANCE_METRIC.to_string(), "l2".to_string()); + assert!(LuminaIndexMeta::new(opts).dim().is_err()); + } + + #[test] + fn test_index_options_invalid_dimension() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "-1".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_err()); + } + + #[test] + fn test_strip_lumina_options() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert( + "lumina.diskann.search.beam_width".to_string(), + "8".to_string(), + ); + opts.insert("non_lumina_key".to_string(), "ignored".to_string()); + let result = strip_lumina_options(&opts); + assert_eq!(result.get("index.dimension").unwrap(), "128"); + assert_eq!(result.get("diskann.search.beam_width").unwrap(), "8"); + assert!(!result.contains_key("non_lumina_key")); + } + + #[test] + fn test_pq_cosine_rejected() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert("lumina.distance.metric".to_string(), "cosine".to_string()); + opts.insert("lumina.encoding.type".to_string(), "pq".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_err()); + } + + #[test] + fn test_pq_l2_accepted() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert("lumina.distance.metric".to_string(), "l2".to_string()); + opts.insert("lumina.encoding.type".to_string(), "pq".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_ok()); + } + + #[test] + fn test_pq_m_zero_rejected() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert("lumina.encoding.pq.m".to_string(), "0".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_err()); + } + + #[test] + fn test_pq_m_non_numeric_rejected() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert("lumina.encoding.pq.m".to_string(), "abc".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_err()); + } + + #[test] + fn test_cap_pq_m() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "32".to_string()); + opts.insert("lumina.encoding.pq.m".to_string(), "64".to_string()); + let index_opts = LuminaVectorIndexOptions::new(&opts).unwrap(); + let lumina_opts = index_opts.to_lumina_options(); + assert_eq!(lumina_opts.get("encoding.pq.m").unwrap(), "32"); + } + + #[test] + fn test_build_lumina_options_defaults() { + let opts = HashMap::new(); + let index_opts = LuminaVectorIndexOptions::new(&opts).unwrap(); + let lumina_opts = index_opts.to_lumina_options(); + assert_eq!(lumina_opts.get("index.dimension").unwrap(), "128"); + assert_eq!(lumina_opts.get("distance.metric").unwrap(), "inner_product"); + assert_eq!(lumina_opts.get("encoding.type").unwrap(), "pq"); + assert_eq!(lumina_opts.get("pretrain.sample_ratio").unwrap(), "0.2"); + assert_eq!( + lumina_opts.get("diskann.build.ef_construction").unwrap(), + "1024" + ); + assert_eq!( + lumina_opts.get("diskann.build.neighbor_count").unwrap(), + "64" + ); + assert_eq!(lumina_opts.get("diskann.build.thread_count").unwrap(), "32"); + assert_eq!(lumina_opts.get("diskann.search.beam_width").unwrap(), "4"); + assert_eq!(lumina_opts.get("encoding.pq.m").unwrap(), "64"); + assert_eq!(lumina_opts.get("search.parallel_number").unwrap(), "5"); + } + + #[test] + fn test_vector_search_clone_preserves_include_row_ids() { + let mut include_row_ids = roaring::RoaringTreemap::new(); + include_row_ids.insert(1); + include_row_ids.insert(3); + + let vector_search = VectorSearch::new(vec![1.0, 2.0], 10, "embedding".to_string()) + .unwrap() + .with_include_row_ids(include_row_ids.clone()); + + let cloned = vector_search.clone(); + assert_eq!(cloned.vector, vector_search.vector); + assert_eq!(cloned.limit, vector_search.limit); + assert_eq!(cloned.field_name, vector_search.field_name); + assert_eq!(cloned.include_row_ids.as_ref(), Some(&include_row_ids)); + } + + #[test] + fn test_search_result_from_scored_map() { + let mut map = HashMap::new(); + map.insert(1u64, 0.9f32); + map.insert(2, 0.5); + let result = SearchResult::from_scored_map(map); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_search_result_top_k() { + let result = SearchResult::new(vec![1, 2, 3, 4, 5], vec![0.1, 0.9, 0.5, 0.8, 0.3]); + let top = result.top_k(2); + assert_eq!(top.len(), 2); + assert!(top.row_ids.contains(&2)); + assert!(top.row_ids.contains(&4)); + } + + #[test] + fn test_search_result_offset() { + let result = SearchResult::new(vec![0, 1], vec![0.5, 0.6]); + let offset = result.offset(100); + assert_eq!(offset.row_ids, vec![100, 101]); + assert_eq!(offset.scores, vec![0.5, 0.6]); + } + + #[test] + fn test_search_result_or() { + let a = SearchResult::new(vec![1, 2], vec![0.5, 0.6]); + let b = SearchResult::new(vec![3], vec![0.7]); + let merged = a.or(&b); + assert_eq!(merged.len(), 3); + } + + #[test] + fn test_search_result_to_row_ranges() { + let result = SearchResult::new(vec![5, 1, 2, 3, 10], vec![0.1; 5]); + let ranges = result.to_row_ranges().unwrap(); + assert_eq!(ranges.len(), 3); + assert_eq!(ranges[0].from(), 1); + assert_eq!(ranges[0].to(), 3); + assert_eq!(ranges[1].from(), 5); + assert_eq!(ranges[1].to(), 5); + assert_eq!(ranges[2].from(), 10); + assert_eq!(ranges[2].to(), 10); + } + + #[test] + fn test_search_result_to_row_ranges_rejects_i64_overflow() { + let result = SearchResult::new(vec![i64::MAX as u64 + 1], vec![0.1]); + let err = result.to_row_ranges().unwrap_err(); + assert!( + err.to_string().contains("exceeds i64::MAX"), + "unexpected error: {err}" + ); + } +} diff --git a/crates/paimon/src/lumina/reader.rs b/crates/paimon/src/lumina/reader.rs new file mode 100644 index 00000000..09acff88 --- /dev/null +++ b/crates/paimon/src/lumina/reader.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::lumina::ffi::LuminaSearcher; +use crate::lumina::{ + strip_lumina_options, GlobalIndexIOMeta, LuminaIndexMeta, LuminaVectorMetric, VectorSearch, +}; +use std::collections::BinaryHeap; +use std::collections::HashMap; +use std::io::{Read, Seek, SeekFrom}; +use std::path::PathBuf; + +const MIN_SEARCH_LIST_SIZE: usize = 16; +// C ABI returns int64_t -1 for invalid results, which casts to u64::MAX in Rust. +const SENTINEL: u64 = u64::MAX; + +fn ensure_search_list_size(search_options: &mut HashMap, top_k: usize) { + if !search_options.contains_key("diskann.search.list_size") { + let list_size = std::cmp::max((top_k as f64 * 1.5) as usize, MIN_SEARCH_LIST_SIZE); + search_options.insert( + "diskann.search.list_size".to_string(), + list_size.to_string(), + ); + } +} + +fn convert_distance_to_score(distance: f32, metric: LuminaVectorMetric) -> f32 { + match metric { + LuminaVectorMetric::L2 => 1.0 / (1.0 + distance), + LuminaVectorMetric::Cosine => 1.0 - distance, + LuminaVectorMetric::InnerProduct => distance, + } +} + +/// Post-filter search results to top_k. +fn collect_results( + labels: &[u64], + distances: &[f32], + top_k: usize, + metric: LuminaVectorMetric, +) -> HashMap { + #[derive(PartialEq)] + struct ScoredRow { + row_id: u64, + score: f32, + } + impl Eq for ScoredRow {} + impl PartialOrd for ScoredRow { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + impl Ord for ScoredRow { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + other.score.total_cmp(&self.score) + } + } + + let mut min_heap: BinaryHeap = BinaryHeap::with_capacity(top_k + 1); + for (&row_id, &distance) in labels.iter().zip(distances.iter()) { + if row_id == SENTINEL { + continue; + } + let score = convert_distance_to_score(distance, metric); + if min_heap.len() < top_k { + min_heap.push(ScoredRow { row_id, score }); + } else if let Some(peek) = min_heap.peek() { + if score > peek.score { + min_heap.pop(); + min_heap.push(ScoredRow { row_id, score }); + } + } + } + + let mut result = HashMap::with_capacity(min_heap.len()); + for entry in min_heap { + result.insert(entry.row_id, entry.score); + } + result +} + +pub struct LuminaVectorGlobalIndexReader { + io_meta: GlobalIndexIOMeta, + options: HashMap, + searcher: Option, + index_meta: Option, + search_options: Option>, + local_index_file: Option, +} + +impl LuminaVectorGlobalIndexReader { + pub fn new(io_meta: GlobalIndexIOMeta, options: HashMap) -> Self { + Self { + io_meta, + options, + searcher: None, + index_meta: None, + search_options: None, + local_index_file: None, + } + } + + pub fn visit_vector_search( + &mut self, + vector_search: &VectorSearch, + stream_fn: impl FnOnce(&str) -> crate::Result, + ) -> crate::Result>> { + self.ensure_loaded(stream_fn)?; + self.search(vector_search) + } + + fn search(&self, vector_search: &VectorSearch) -> crate::Result>> { + let index_meta = self + .index_meta + .as_ref() + .ok_or_else(|| crate::Error::DataInvalid { + message: "index_meta not initialized".to_string(), + source: None, + })?; + let searcher = self + .searcher + .as_ref() + .ok_or_else(|| crate::Error::DataInvalid { + message: "searcher not initialized".to_string(), + source: None, + })?; + let search_options_base = + self.search_options + .as_ref() + .ok_or_else(|| crate::Error::DataInvalid { + message: "search_options not initialized".to_string(), + source: None, + })?; + + let expected_dim = index_meta.dim()? as usize; + if vector_search.vector.len() != expected_dim { + return Err(crate::Error::DataInvalid { + message: format!( + "Query vector dimension mismatch: index expects {}, but got {}", + expected_dim, + vector_search.vector.len() + ), + source: None, + }); + } + + let limit = vector_search.limit; + let index_metric = index_meta.metric()?; + let count = searcher.get_count()? as usize; + let effective_k = std::cmp::min(limit, count); + if effective_k == 0 { + return Ok(None); + } + + let include_row_ids = &vector_search.include_row_ids; + + let (distances, labels) = if let Some(ref include_ids) = include_row_ids { + let filter_id_list: Vec = include_ids.iter().collect(); + if filter_id_list.is_empty() { + return Ok(None); + } + let ek = std::cmp::min(effective_k, filter_id_list.len()); + let mut distances = vec![0.0f32; ek]; + let mut labels = vec![0u64; ek]; + let mut search_opts: HashMap = search_options_base.clone(); + search_opts.insert("search.thread_safe_filter".to_string(), "true".to_string()); + ensure_search_list_size(&mut search_opts, ek); + searcher.search_with_filter( + &vector_search.vector, + 1, + ek as i32, + &mut distances, + &mut labels, + &filter_id_list, + &search_opts, + )?; + (distances, labels) + } else { + let mut distances = vec![0.0f32; effective_k]; + let mut labels = vec![0u64; effective_k]; + let mut search_opts: HashMap = search_options_base.clone(); + ensure_search_list_size(&mut search_opts, effective_k); + searcher.search( + &vector_search.vector, + 1, + effective_k as i32, + &mut distances, + &mut labels, + &search_opts, + )?; + (distances, labels) + }; + + let id_to_scores = collect_results(&labels, &distances, effective_k, index_metric); + if id_to_scores.is_empty() { + return Ok(None); + } + + Ok(Some(id_to_scores)) + } + + fn ensure_loaded( + &mut self, + stream_fn: impl FnOnce(&str) -> crate::Result, + ) -> crate::Result<()> { + if self.searcher.is_some() { + return Ok(()); + } + + let index_meta = LuminaIndexMeta::deserialize(&self.io_meta.metadata)?; + + let mut searcher_options = index_meta.options().clone(); + for (k, v) in strip_lumina_options(&self.options) { + searcher_options.insert(k, v); + } + + let mut searcher = LuminaSearcher::create(&searcher_options)?; + + let mut stream = stream_fn(&self.io_meta.file_path)?; + let local_index_file = write_temp_index_file(&mut stream)?; + let local_index_path = + local_index_file + .to_str() + .ok_or_else(|| crate::Error::DataInvalid { + message: format!( + "Temporary Lumina index path is not valid UTF-8: {}", + local_index_file.display() + ), + source: None, + })?; + if let Err(err) = searcher.open_file(local_index_path) { + let _ = std::fs::remove_file(&local_index_file); + return Err(err); + } + + self.search_options = Some(searcher_options); + self.index_meta = Some(index_meta); + self.searcher = Some(searcher); + self.local_index_file = Some(local_index_file); + Ok(()) + } + + pub fn close(&mut self) { + self.searcher = None; + self.index_meta = None; + self.search_options = None; + if let Some(path) = self.local_index_file.take() { + let _ = std::fs::remove_file(path); + } + } +} + +fn write_temp_index_file(stream: &mut S) -> crate::Result { + stream + .seek(SeekFrom::Start(0)) + .map_err(|e| crate::Error::UnexpectedError { + message: format!("Failed to seek Lumina index stream to start: {}", e), + source: Some(Box::new(e)), + })?; + + let path = std::env::temp_dir().join(format!( + "paimon-lumina-index-{}.index", + uuid::Uuid::new_v4() + )); + let mut file = std::fs::File::create(&path).map_err(|e| crate::Error::UnexpectedError { + message: format!( + "Failed to create temporary Lumina index file '{}': {}", + path.display(), + e + ), + source: Some(Box::new(e)), + })?; + std::io::copy(stream, &mut file).map_err(|e| crate::Error::UnexpectedError { + message: format!( + "Failed to write temporary Lumina index file '{}': {}", + path.display(), + e + ), + source: Some(Box::new(e)), + })?; + file.sync_all().map_err(|e| crate::Error::UnexpectedError { + message: format!( + "Failed to sync temporary Lumina index file '{}': {}", + path.display(), + e + ), + source: Some(Box::new(e)), + })?; + Ok(path) +} + +impl Drop for LuminaVectorGlobalIndexReader { + fn drop(&mut self) { + self.close(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lumina::GlobalIndexIOMeta; + use std::io::Cursor; + + #[test] + fn test_convert_distance_to_score() { + assert_eq!(convert_distance_to_score(0.0, LuminaVectorMetric::L2), 1.0); + assert_eq!(convert_distance_to_score(1.0, LuminaVectorMetric::L2), 0.5); + assert_eq!( + convert_distance_to_score(0.0, LuminaVectorMetric::Cosine), + 1.0 + ); + assert_eq!( + convert_distance_to_score(1.0, LuminaVectorMetric::Cosine), + 0.0 + ); + assert_eq!( + convert_distance_to_score(0.75, LuminaVectorMetric::InnerProduct), + 0.75 + ); + } + + #[test] + fn test_ensure_search_list_size() { + let mut opts = HashMap::new(); + ensure_search_list_size(&mut opts, 10); + assert_eq!(opts.get("diskann.search.list_size").unwrap(), "16"); // max(15, 16) + + let mut opts = HashMap::new(); + ensure_search_list_size(&mut opts, 100); + assert_eq!(opts.get("diskann.search.list_size").unwrap(), "150"); // 100*1.5 + + // does not override existing + let mut opts = HashMap::new(); + opts.insert("diskann.search.list_size".to_string(), "999".to_string()); + ensure_search_list_size(&mut opts, 100); + assert_eq!(opts.get("diskann.search.list_size").unwrap(), "999"); + } + + #[test] + fn test_collect_results() { + let labels = vec![0, 1, 2, SENTINEL, 3]; + let distances = vec![0.5, 0.3, 0.1, 0.0, 0.9]; + let result = collect_results(&labels, &distances, 2, LuminaVectorMetric::InnerProduct); + assert_eq!(result.len(), 2); + // top 2 by score: row 3 (0.9) and row 0 (0.5) + assert!(result.contains_key(&3)); + assert!(result.contains_key(&0)); + assert!(!result.contains_key(&2)); // 0.1 is lowest + } + + #[test] + fn test_reader_new() { + let m = GlobalIndexIOMeta::new("a".into(), 100, vec![]); + let reader = LuminaVectorGlobalIndexReader::new(m, HashMap::new()); + assert!(reader.searcher.is_none()); + } + + #[test] + fn test_write_temp_index_file_copies_stream() { + let bytes = b"lumina-index-bytes".to_vec(); + let mut stream = Cursor::new(bytes.clone()); + stream.seek(SeekFrom::End(0)).unwrap(); + + let path = write_temp_index_file(&mut stream).unwrap(); + let actual = std::fs::read(&path).unwrap(); + std::fs::remove_file(&path).unwrap(); + + assert_eq!(actual, bytes); + } +} diff --git a/crates/paimon/src/table/full_text_search_builder.rs b/crates/paimon/src/table/full_text_search_builder.rs index 783bd572..41297b7a 100644 --- a/crates/paimon/src/table/full_text_search_builder.rs +++ b/crates/paimon/src/table/full_text_search_builder.rs @@ -21,7 +21,7 @@ use crate::spec::{DataField, FileKind, IndexManifest}; use crate::table::snapshot_manager::SnapshotManager; -use crate::table::{RowRange, Table}; +use crate::table::{find_field_id_by_name, RowRange, Table}; use crate::tantivy::full_text_search::{FullTextSearch, SearchResult}; use crate::tantivy::reader::TantivyFullTextReader; @@ -201,7 +201,3 @@ async fn evaluate_full_text_search( Ok(merged.top_k(search.limit).to_row_ranges()) } - -fn find_field_id_by_name(fields: &[DataField], name: &str) -> Option { - fields.iter().find(|f| f.name() == name).map(|f| f.id()) -} diff --git a/crates/paimon/src/table/global_index_scanner.rs b/crates/paimon/src/table/global_index_scanner.rs index 06aea9c3..cdbccaca 100644 --- a/crates/paimon/src/table/global_index_scanner.rs +++ b/crates/paimon/src/table/global_index_scanner.rs @@ -369,12 +369,10 @@ impl GlobalIndexScanner { } fn find_field_id_by_name(&self, column: &str) -> Result> { - for field in &self.schema_fields { - if field.name() == column { - return Ok(Some(field.id())); - } - } - Ok(None) + Ok(crate::table::find_field_id_by_name( + &self.schema_fields, + column, + )) } fn entries_for_field(&self, field_id: i32) -> Option<&[GlobalIndexEntry]> { diff --git a/crates/paimon/src/table/mod.rs b/crates/paimon/src/table/mod.rs index 147239f8..c8ef5ede 100644 --- a/crates/paimon/src/table/mod.rs +++ b/crates/paimon/src/table/mod.rs @@ -49,6 +49,7 @@ mod table_read; mod table_scan; pub(crate) mod table_write; mod tag_manager; +mod vector_search_builder; mod write_builder; use crate::Result; @@ -71,11 +72,12 @@ pub use table_read::TableRead; pub use table_scan::TableScan; pub use table_write::TableWrite; pub use tag_manager::TagManager; +pub use vector_search_builder::VectorSearchBuilder; pub use write_builder::WriteBuilder; use crate::catalog::Identifier; use crate::io::FileIO; -use crate::spec::TableSchema; +use crate::spec::{DataField, TableSchema}; use std::collections::HashMap; /// Table represents a table in the catalog. @@ -149,6 +151,10 @@ impl Table { FullTextSearchBuilder::new(self) } + pub fn new_vector_search_builder(&self) -> VectorSearchBuilder<'_> { + VectorSearchBuilder::new(self) + } + /// Create a write builder for write/commit. /// /// Reference: [pypaimon FileStoreTable.new_write_builder](https://github.com/apache/paimon/blob/master/paimon-python/pypaimon/table/file_store_table.py). @@ -171,3 +177,7 @@ impl Table { /// A stream of arrow [`RecordBatch`]es. pub type ArrowRecordBatchStream = BoxStream<'static, Result>; + +pub(crate) fn find_field_id_by_name(fields: &[DataField], name: &str) -> Option { + fields.iter().find(|f| f.name() == name).map(|f| f.id()) +} diff --git a/crates/paimon/src/table/vector_search_builder.rs b/crates/paimon/src/table/vector_search_builder.rs new file mode 100644 index 00000000..1f79e540 --- /dev/null +++ b/crates/paimon/src/table/vector_search_builder.rs @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::lumina::reader::LuminaVectorGlobalIndexReader; +use crate::lumina::{GlobalIndexIOMeta, SearchResult, VectorSearch, LUMINA_VECTOR_ANN_IDENTIFIER}; +use crate::spec::{DataField, FileKind, IndexManifest}; +use crate::table::snapshot_manager::SnapshotManager; +use crate::table::{find_field_id_by_name, RowRange, Table}; +use std::collections::HashMap; +use std::io::Cursor; + +const INDEX_DIR: &str = "index"; + +pub struct VectorSearchBuilder<'a> { + table: &'a Table, + vector_column: Option, + query_vector: Option>, + limit: Option, +} + +impl<'a> VectorSearchBuilder<'a> { + pub(crate) fn new(table: &'a Table) -> Self { + Self { + table, + vector_column: None, + query_vector: None, + limit: None, + } + } + + pub fn with_vector_column(&mut self, name: &str) -> &mut Self { + self.vector_column = Some(name.to_string()); + self + } + + pub fn with_query_vector(&mut self, vector: Vec) -> &mut Self { + self.query_vector = Some(vector); + self + } + + pub fn with_limit(&mut self, limit: usize) -> &mut Self { + self.limit = Some(limit); + self + } + + pub async fn execute(&self) -> crate::Result> { + let vector_column = + self.vector_column + .as_deref() + .ok_or_else(|| crate::Error::ConfigInvalid { + message: "Vector column must be set via with_vector_column()".to_string(), + })?; + let query_vector = + self.query_vector + .as_ref() + .ok_or_else(|| crate::Error::ConfigInvalid { + message: "Query vector must be set via with_query_vector()".to_string(), + })?; + let limit = self.limit.ok_or_else(|| crate::Error::ConfigInvalid { + message: "Limit must be set via with_limit()".to_string(), + })?; + + let vector_search = + VectorSearch::new(query_vector.clone(), limit, vector_column.to_string())?; + + let snapshot_manager = SnapshotManager::new( + self.table.file_io().clone(), + self.table.location().to_string(), + ); + + let snapshot = match snapshot_manager.get_latest_snapshot().await? { + Some(s) => s, + None => return Ok(Vec::new()), + }; + + let index_manifest_name = match snapshot.index_manifest() { + Some(name) => name.to_string(), + None => return Ok(Vec::new()), + }; + + let manifest_path = format!( + "{}/manifest/{}", + self.table.location().trim_end_matches('/'), + index_manifest_name + ); + let index_entries = IndexManifest::read(self.table.file_io(), &manifest_path).await?; + + evaluate_vector_search( + self.table.file_io(), + self.table.location(), + self.table.schema().options(), + &index_entries, + &vector_search, + self.table.schema().fields(), + ) + .await + } +} + +async fn evaluate_vector_search( + file_io: &crate::io::FileIO, + table_path: &str, + table_options: &HashMap, + index_entries: &[crate::spec::IndexManifestEntry], + vector_search: &VectorSearch, + schema_fields: &[DataField], +) -> crate::Result> { + let table_path = table_path.trim_end_matches('/'); + + let field_id = match find_field_id_by_name(schema_fields, &vector_search.field_name) { + Some(id) => id, + None => return Ok(Vec::new()), + }; + + let lumina_entries: Vec<_> = index_entries + .iter() + .filter(|e| { + e.kind == FileKind::Add + && e.index_file.index_type == LUMINA_VECTOR_ANN_IDENTIFIER + && e.index_file + .global_index_meta + .as_ref() + .is_some_and(|m| m.index_field_id == field_id) + }) + .collect(); + + if lumina_entries.is_empty() { + return Ok(Vec::new()); + } + + let futures: Vec<_> = lumina_entries + .into_iter() + .map(|entry| { + let global_meta = entry.index_file.global_index_meta.as_ref().unwrap(); + let path = format!("{table_path}/{INDEX_DIR}/{}", entry.index_file.file_name); + let file_name = entry.index_file.file_name.clone(); + let file_size = entry.index_file.file_size as u64; + let index_meta_bytes = global_meta.index_meta.clone().unwrap_or_default(); + let row_range_start = global_meta.row_range_start; + let vector_search_clone = vector_search.clone(); + let options = table_options.clone(); + let input = file_io.new_input(&path); + async move { + let input = input?; + let bytes = input.read().await.map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to read Lumina index file '{}': {}", file_name, e), + source: None, + })?; + + let io_meta = + GlobalIndexIOMeta::new(file_name.clone(), file_size, index_meta_bytes); + let mut reader = LuminaVectorGlobalIndexReader::new(io_meta, options); + let data = bytes.to_vec(); + let result = + reader.visit_vector_search(&vector_search_clone, |_| Ok(Cursor::new(data)))?; + + match result { + Some(scored_map) => Ok::<_, crate::Error>( + SearchResult::from_scored_map(scored_map).offset(row_range_start), + ), + None => Ok(SearchResult::empty()), + } + } + }) + .collect(); + + let results = futures::future::try_join_all(futures).await?; + let mut merged = SearchResult::empty(); + for r in &results { + merged = merged.or(r); + } + + merged.top_k(vector_search.limit).to_row_ranges() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::spec::{DataType, GlobalIndexMeta, IndexFileMeta, IndexManifestEntry, IntType}; + + fn make_field(id: i32, name: &str) -> DataField { + DataField::new(id, name.to_string(), DataType::Int(IntType::default())) + } + + #[test] + fn test_find_field_id_by_name() { + let fields = vec![make_field(1, "id"), make_field(2, "embedding")]; + assert_eq!(find_field_id_by_name(&fields, "embedding"), Some(2)); + assert_eq!(find_field_id_by_name(&fields, "nonexistent"), None); + } + + #[tokio::test] + async fn test_evaluate_no_matching_entries() { + let file_io = crate::io::FileIOBuilder::new("memory").build().unwrap(); + let fields = vec![make_field(1, "id"), make_field(2, "embedding")]; + let vs = VectorSearch::new(vec![1.0, 2.0], 10, "embedding".to_string()).unwrap(); + + let entry = IndexManifestEntry { + kind: FileKind::Add, + partition: vec![], + bucket: 0, + index_file: IndexFileMeta { + index_type: "btree".to_string(), + file_name: "test.idx".to_string(), + file_size: 100, + row_count: 10, + deletion_vectors_ranges: None, + global_index_meta: None, + }, + version: 1, + }; + + let result = evaluate_vector_search( + &file_io, + "memory:///test_table", + &HashMap::new(), + &[entry], + &vs, + &fields, + ) + .await + .unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_evaluate_no_matching_field() { + let file_io = crate::io::FileIOBuilder::new("memory").build().unwrap(); + let fields = vec![make_field(1, "id")]; + let vs = VectorSearch::new(vec![1.0], 10, "embedding".to_string()).unwrap(); + + let entry = IndexManifestEntry { + kind: FileKind::Add, + partition: vec![], + bucket: 0, + index_file: IndexFileMeta { + index_type: LUMINA_VECTOR_ANN_IDENTIFIER.to_string(), + file_name: "test.idx".to_string(), + file_size: 100, + row_count: 10, + deletion_vectors_ranges: None, + global_index_meta: Some(GlobalIndexMeta { + row_range_start: 0, + row_range_end: 9, + index_field_id: 99, + extra_field_ids: None, + index_meta: None, + }), + }, + version: 1, + }; + + let result = evaluate_vector_search( + &file_io, + "memory:///test_table", + &HashMap::new(), + &[entry], + &vs, + &fields, + ) + .await + .unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_evaluate_skips_delete_entries() { + let file_io = crate::io::FileIOBuilder::new("memory").build().unwrap(); + let fields = vec![make_field(2, "embedding")]; + let vs = VectorSearch::new(vec![1.0], 10, "embedding".to_string()).unwrap(); + + let entry = IndexManifestEntry { + kind: FileKind::Delete, + partition: vec![], + bucket: 0, + index_file: IndexFileMeta { + index_type: LUMINA_VECTOR_ANN_IDENTIFIER.to_string(), + file_name: "test.idx".to_string(), + file_size: 100, + row_count: 10, + deletion_vectors_ranges: None, + global_index_meta: Some(GlobalIndexMeta { + row_range_start: 0, + row_range_end: 9, + index_field_id: 2, + extra_field_ids: None, + index_meta: None, + }), + }, + version: 1, + }; + + let result = evaluate_vector_search( + &file_io, + "memory:///test_table", + &HashMap::new(), + &[entry], + &vs, + &fields, + ) + .await + .unwrap(); + assert!(result.is_empty()); + } +}