From a00225612afce4687b49133909e7b5bb230df245 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Mon, 20 Apr 2026 17:18:47 +0800 Subject: [PATCH 1/9] feat: add Lumina vector index read infrastructure --- crates/paimon/Cargo.toml | 5 + crates/paimon/src/globalindex/mod.rs | 509 ++++++++++++++++++++++++ crates/paimon/src/lib.rs | 3 + crates/paimon/src/lumina/ffi.rs | 567 +++++++++++++++++++++++++++ crates/paimon/src/lumina/mod.rs | 449 +++++++++++++++++++++ crates/paimon/src/lumina/reader.rs | 316 +++++++++++++++ 6 files changed, 1849 insertions(+) create mode 100644 crates/paimon/src/globalindex/mod.rs create mode 100644 crates/paimon/src/lumina/ffi.rs create mode 100644 crates/paimon/src/lumina/mod.rs create mode 100644 crates/paimon/src/lumina/reader.rs diff --git a/crates/paimon/Cargo.toml b/crates/paimon/Cargo.toml index 4bd16ad7..2906b619 100644 --- a/crates/paimon/Cargo.toml +++ b/crates/paimon/Cargo.toml @@ -39,6 +39,7 @@ storage-fs = ["opendal/services-fs"] storage-oss = ["opendal/services-oss"] storage-s3 = ["opendal/services-s3"] storage-hdfs = ["opendal/services-hdfs-native"] +lumina = ["libloading"] [dependencies] url = "2.5.2" @@ -86,8 +87,12 @@ uuid = { version = "1", features = ["v4"] } urlencoding = "2.1" tantivy = { version = "0.22", optional = true } tempfile = { version = "3", optional = true } +<<<<<<< HEAD vortex = { version = "0.68", features = ["tokio"], optional = true } kanal = { version = "0.1.1", optional = true } +======= +libloading = { version = "0.8", optional = true } +>>>>>>> 2c6608c (feat: add Lumina vector index read infrastructure) [dev-dependencies] axum = { version = "0.7", features = ["macros", "tokio", "http1", "http2"] } diff --git a/crates/paimon/src/globalindex/mod.rs b/crates/paimon/src/globalindex/mod.rs new file mode 100644 index 00000000..921127d5 --- /dev/null +++ b/crates/paimon/src/globalindex/mod.rs @@ -0,0 +1,509 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use roaring::RoaringTreemap; +use std::collections::BinaryHeap; +use std::sync::Arc; + +pub type ScoreGetter = Arc f32 + Send + Sync>; + +pub trait GlobalIndexResult: Send + Sync { + fn results(&self) -> &RoaringTreemap; + + fn as_scored(&self) -> Option<&dyn ScoredGlobalIndexResult> { + None + } + + fn offset(&self, start_offset: u64) -> Box { + let bitmap = self.results(); + let offset_bitmap = if start_offset == 0 { + bitmap.clone() + } else { + let mut offset_bitmap = RoaringTreemap::new(); + for row_id in bitmap.iter() { + offset_bitmap.insert(row_id + start_offset); + } + offset_bitmap + }; + + if let Some(scored) = self.as_scored() { + let sg = scored.clone_score_getter(); + let score_getter = if start_offset == 0 { + sg + } else { + Arc::new(move |row_id| sg(row_id - start_offset)) + }; + return Box::new(SimpleScoredGlobalIndexResult::new( + offset_bitmap, + score_getter, + )); + } + + Box::new(SimpleGlobalIndexResult::new_ready(offset_bitmap)) + } + + fn and(&self, other: &dyn GlobalIndexResult) -> crate::Result> { + if self.as_scored().is_some() || other.as_scored().is_some() { + return Err(crate::Error::DataInvalid { + message: "and() is not supported for scored global index results".to_string(), + source: None, + }); + } + let result = self.results() & other.results(); + Ok(Box::new(SimpleGlobalIndexResult::new_ready(result))) + } + + fn or(&self, other: &dyn GlobalIndexResult) -> crate::Result> { + match (self.as_scored(), other.as_scored()) { + (Some(this), Some(that)) => { + let this_row_ids = self.results().clone(); + let result_or = &this_row_ids | other.results(); + let this_sg = this.clone_score_getter(); + let other_sg = that.clone_score_getter(); + // For overlapping IDs, use left-side score + return Ok(Box::new(SimpleScoredGlobalIndexResult::new( + result_or, + Arc::new(move |row_id| { + if this_row_ids.contains(row_id) { + this_sg(row_id) + } else { + other_sg(row_id) + } + }), + ))); + } + (None, None) => {} + _ => { + return Err(crate::Error::DataInvalid { + message: "Cannot union scored and unscored global index results".to_string(), + source: None, + }); + } + } + + let result = self.results() | other.results(); + Ok(Box::new(SimpleGlobalIndexResult::new_ready(result))) + } + + fn is_empty(&self) -> bool { + self.results().is_empty() + } +} + +pub struct SimpleGlobalIndexResult { + bitmap: RoaringTreemap, +} + +impl SimpleGlobalIndexResult { + pub fn new_ready(bitmap: RoaringTreemap) -> Self { + Self { bitmap } + } + + pub fn create_empty() -> Self { + Self::new_ready(RoaringTreemap::new()) + } +} + +impl GlobalIndexResult for SimpleGlobalIndexResult { + fn results(&self) -> &RoaringTreemap { + &self.bitmap + } +} + +pub trait ScoredGlobalIndexResult: GlobalIndexResult { + fn score_getter(&self) -> &ScoreGetter; + + fn scored_offset(&self, offset: u64) -> Box { + if offset == 0 { + let bitmap = self.results().clone(); + let sg = self.clone_score_getter(); + return Box::new(SimpleScoredGlobalIndexResult::new(bitmap, sg)); + } + let bitmap = self.results(); + let mut offset_bitmap = RoaringTreemap::new(); + for row_id in bitmap.iter() { + offset_bitmap.insert(row_id + offset); + } + let sg = self.clone_score_getter(); + Box::new(SimpleScoredGlobalIndexResult::new( + offset_bitmap, + Arc::new(move |row_id| sg(row_id - offset)), + )) + } + + // For overlapping IDs, use left-side score + fn scored_or(&self, other: &dyn ScoredGlobalIndexResult) -> Box { + let this_row_ids = self.results().clone(); + let result_or = &this_row_ids | other.results(); + let this_sg = self.clone_score_getter(); + let other_sg = other.clone_score_getter(); + Box::new(SimpleScoredGlobalIndexResult::new( + result_or, + Arc::new(move |row_id| { + if this_row_ids.contains(row_id) { + this_sg(row_id) + } else { + other_sg(row_id) + } + }), + )) + } + + fn top_k(&self, k: usize) -> Box { + let row_ids = self.results(); + if row_ids.len() <= k as u64 { + let bitmap = row_ids.clone(); + let sg = self.clone_score_getter(); + return Box::new(SimpleScoredGlobalIndexResult::new(bitmap, sg)); + } + + let score_getter_fn = self.score_getter(); + + #[derive(PartialEq)] + struct ScoredEntry { + row_id: u64, + score: f32, + } + impl Eq for ScoredEntry {} + impl PartialOrd for ScoredEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + impl Ord for ScoredEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Min-heap: reverse order so smallest score is at top. + // Use total_cmp to handle NaN deterministically. + other.score.total_cmp(&self.score) + } + } + + let mut min_heap: BinaryHeap = BinaryHeap::with_capacity(k + 1); + for row_id in row_ids.iter() { + let score = score_getter_fn(row_id); + if min_heap.len() < k { + min_heap.push(ScoredEntry { row_id, score }); + } else if let Some(peek) = min_heap.peek() { + if score > peek.score { + min_heap.pop(); + min_heap.push(ScoredEntry { row_id, score }); + } + } + } + + let mut top_k_ids = RoaringTreemap::new(); + for entry in &min_heap { + top_k_ids.insert(entry.row_id); + } + + let sg = self.clone_score_getter(); + Box::new(SimpleScoredGlobalIndexResult::new(top_k_ids, sg)) + } + + fn clone_score_getter(&self) -> ScoreGetter; +} + +pub struct SimpleScoredGlobalIndexResult { + bitmap: RoaringTreemap, + score_getter: ScoreGetter, +} + +impl SimpleScoredGlobalIndexResult { + pub fn new(bitmap: RoaringTreemap, score_getter: ScoreGetter) -> Self { + Self { + bitmap, + score_getter, + } + } + + pub fn create_empty() -> Self { + Self { + bitmap: RoaringTreemap::new(), + score_getter: Arc::new(|_| 0.0), + } + } +} + +impl GlobalIndexResult for SimpleScoredGlobalIndexResult { + fn results(&self) -> &RoaringTreemap { + &self.bitmap + } + + fn as_scored(&self) -> Option<&dyn ScoredGlobalIndexResult> { + Some(self) + } +} + +impl ScoredGlobalIndexResult for SimpleScoredGlobalIndexResult { + fn score_getter(&self) -> &ScoreGetter { + &self.score_getter + } + + fn clone_score_getter(&self) -> ScoreGetter { + self.score_getter.clone() + } +} + +pub struct DictBasedScoredIndexResult { + bitmap: RoaringTreemap, + score_getter_fn: ScoreGetter, +} + +impl DictBasedScoredIndexResult { + pub fn new(id_to_scores: std::collections::HashMap) -> Self { + let mut bitmap = RoaringTreemap::new(); + for &row_id in id_to_scores.keys() { + bitmap.insert(row_id); + } + let map = Arc::new(id_to_scores); + let score_getter_fn: ScoreGetter = + Arc::new(move |row_id| map.get(&row_id).copied().unwrap_or(0.0)); + Self { + bitmap, + score_getter_fn, + } + } +} + +impl GlobalIndexResult for DictBasedScoredIndexResult { + fn results(&self) -> &RoaringTreemap { + &self.bitmap + } + + fn as_scored(&self) -> Option<&dyn ScoredGlobalIndexResult> { + Some(self) + } +} + +impl ScoredGlobalIndexResult for DictBasedScoredIndexResult { + fn score_getter(&self) -> &ScoreGetter { + &self.score_getter_fn + } + + fn clone_score_getter(&self) -> ScoreGetter { + self.score_getter_fn.clone() + } +} + +pub struct VectorSearch { + pub vector: Vec, + pub limit: usize, + pub field_name: String, + pub include_row_ids: Option, +} + +impl VectorSearch { + pub fn new(vector: Vec, limit: usize, field_name: String) -> crate::Result { + if vector.is_empty() { + return Err(crate::Error::DataInvalid { + message: "Search vector cannot be empty".to_string(), + source: None, + }); + } + if limit == 0 || limit > i32::MAX as usize { + return Err(crate::Error::DataInvalid { + message: format!("Limit must be between 1 and {}, got: {}", i32::MAX, limit), + source: None, + }); + } + if field_name.is_empty() { + return Err(crate::Error::DataInvalid { + message: "Field name cannot be null or empty".to_string(), + source: None, + }); + } + Ok(Self { + vector, + limit, + field_name, + include_row_ids: None, + }) + } + + pub fn with_include_row_ids(mut self, include_row_ids: RoaringTreemap) -> Self { + self.include_row_ids = Some(include_row_ids); + self + } + + pub fn offset_range(&self, from: u64, to: u64) -> Self { + if let Some(ref include_row_ids) = self.include_row_ids { + let mut range_bitmap = RoaringTreemap::new(); + if to == u64::MAX { + range_bitmap.insert_range(from..u64::MAX); + range_bitmap.insert(u64::MAX); + } else { + range_bitmap.insert_range(from..to + 1); + } + let and_result = include_row_ids & &range_bitmap; + let mut offset_bitmap = RoaringTreemap::new(); + for row_id in and_result.iter() { + offset_bitmap.insert(row_id - from); + } + VectorSearch { + vector: self.vector.clone(), + limit: self.limit, + field_name: self.field_name.clone(), + include_row_ids: Some(offset_bitmap), + } + } else { + VectorSearch { + vector: self.vector.clone(), + limit: self.limit, + field_name: self.field_name.clone(), + include_row_ids: None, + } + } + } +} + +impl std::fmt::Display for VectorSearch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "VectorSearch(field_name={}, limit={})", + self.field_name, self.limit + ) + } +} + +pub struct GlobalIndexIOMeta { + pub file_path: String, + pub file_size: u64, + pub metadata: Vec, +} + +impl GlobalIndexIOMeta { + pub fn new(file_path: String, file_size: u64, metadata: Vec) -> Self { + Self { + file_path, + file_size, + metadata, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vector_search_offset_range() { + let mut bitmap = RoaringTreemap::new(); + bitmap.insert_range(100..200); + let vs = VectorSearch::new(vec![1.0, 2.0], 10, "vec".to_string()) + .unwrap() + .with_include_row_ids(bitmap); + + let result = vs.offset_range(60, 150); + let ids = result.include_row_ids.unwrap(); + // [100,200) means row ids [100,199]. Inclusive [60,150] keeps [100,150]. + assert_eq!(ids.len(), 51); + assert!(ids.contains(40)); + assert!(ids.contains(90)); + assert!(ids.contains(89)); + assert!(!ids.contains(39)); + assert!(!ids.contains(91)); + } + + #[test] + fn test_invalid_top_k() { + assert!(VectorSearch::new(vec![1.0], 0, "f".to_string()).is_err()); + assert!(VectorSearch::new(vec![1.0], i32::MAX as usize + 1, "f".to_string()).is_err()); + } + + #[test] + fn test_offset_range_no_filter() { + let vs = VectorSearch::new(vec![1.0], 5, "f".to_string()).unwrap(); + let result = vs.offset_range(100, 200); + assert!(result.include_row_ids.is_none()); + } + + fn make_dict(entries: Vec<(u64, f32)>) -> DictBasedScoredIndexResult { + DictBasedScoredIndexResult::new(entries.into_iter().collect()) + } + + #[test] + fn test_top_k_selects_highest() { + let r = make_dict(vec![(1, 0.1), (2, 0.9), (3, 0.5), (4, 0.8), (5, 0.3)]); + let top = r.top_k(2); + assert_eq!(top.results().len(), 2); + assert!(top.results().contains(2)); + assert!(top.results().contains(4)); + } + + #[test] + fn test_scored_offset_preserves_scores() { + let r = make_dict(vec![(1, 0.5), (2, 0.6)]); + let o = r.scored_offset(100); + assert!(o.results().contains(101)); + assert_eq!(o.score_getter()(101), 0.5); + assert_eq!(o.score_getter()(102), 0.6); + } + + #[test] + fn test_base_offset_preserves_scores() { + let r = make_dict(vec![(1, 0.5), (2, 0.6)]); + let o = r.offset(100); + let scored = o + .as_scored() + .expect("offset should preserve scored results"); + assert!(o.results().contains(101)); + assert_eq!(scored.score_getter()(101), 0.5); + assert_eq!(scored.score_getter()(102), 0.6); + } + + #[test] + fn test_base_or_preserves_scores() { + let left = make_dict(vec![(1, 0.5), (2, 0.6)]); + let right = make_dict(vec![(3, 0.7), (4, 0.8)]); + let merged = left.or(&right).unwrap(); + let scored = merged + .as_scored() + .expect("or should preserve scored results"); + assert_eq!(merged.results().len(), 4); + assert_eq!(scored.score_getter()(1), 0.5); + assert_eq!(scored.score_getter()(4), 0.8); + } + + #[test] + fn test_or_overlapping_uses_left_score() { + let left = make_dict(vec![(1, 0.3), (2, 0.9)]); + let right = make_dict(vec![(1, 0.7), (2, 0.4)]); + let merged = left.or(&right).unwrap(); + let scored = merged + .as_scored() + .expect("or should preserve scored results"); + assert_eq!(merged.results().len(), 2); + assert_eq!(scored.score_getter()(1), 0.3); + assert_eq!(scored.score_getter()(2), 0.9); + } + + #[test] + fn test_scored_and_returns_error() { + let left = make_dict(vec![(1, 0.5), (2, 0.6)]); + let right = make_dict(vec![(1, 0.7), (3, 0.8)]); + assert!(left.and(&right).is_err()); + } + + #[test] + fn test_clone_score_getter() { + let r = make_dict(vec![(10, 1.5), (20, 2.5)]); + let cloned = r.clone_score_getter(); + assert_eq!(cloned(10), 1.5); + assert_eq!(cloned(20), 2.5); + } +} diff --git a/crates/paimon/src/lib.rs b/crates/paimon/src/lib.rs index 5aabe254..f2fbfe05 100644 --- a/crates/paimon/src/lib.rs +++ b/crates/paimon/src/lib.rs @@ -30,7 +30,10 @@ pub mod btree; pub mod catalog; mod deletion_vector; pub mod file_index; +pub mod globalindex; pub mod io; +#[cfg(feature = "lumina")] +pub mod lumina; mod predicate_stats; pub mod spec; pub mod table; diff --git a/crates/paimon/src/lumina/ffi.rs b/crates/paimon/src/lumina/ffi.rs new file mode 100644 index 00000000..4891a131 --- /dev/null +++ b/crates/paimon/src/lumina/ffi.rs @@ -0,0 +1,567 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use libloading::{Library, Symbol}; +use std::collections::HashMap; +use std::ffi::{c_char, c_float, c_int, c_void, CStr, CString}; +use std::io::{Read, Seek, SeekFrom}; +use std::sync::OnceLock; + +const ERR_BUF_SIZE: usize = 4096; + +static LIBRARY: OnceLock = OnceLock::new(); + +fn load_library() -> crate::Result<&'static Library> { + if let Some(lib) = LIBRARY.get() { + return Ok(lib); + } + let lib_path = std::env::var("LUMINA_LIB_PATH").unwrap_or_else(|_| { + if cfg!(target_os = "macos") { + "liblumina_py.dylib".to_string() + } else if cfg!(target_os = "windows") { + "lumina_py.dll".to_string() + } else { + "liblumina_py.so".to_string() + } + }); + let lib = unsafe { + Library::new(&lib_path).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to load lumina library from '{}': {}", lib_path, e), + source: None, + })? + }; + let _ = LIBRARY.set(lib); + Ok(LIBRARY.get().unwrap()) +} + +fn check_error(ret: c_int, err_buf: &[u8; ERR_BUF_SIZE]) -> crate::Result<()> { + if ret != 0 { + let c_str = unsafe { CStr::from_ptr(err_buf.as_ptr() as *const c_char) }; + let msg = c_str.to_string_lossy().to_string(); + return Err(crate::Error::DataInvalid { + message: format!("Lumina error: {}", msg), + source: None, + }); + } + Ok(()) +} + +fn options_to_json(options: &HashMap) -> crate::Result { + let json = serde_json::to_string(options).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to serialize options: {}", e), + source: None, + })?; + CString::new(json).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to create CString: {}", e), + source: None, + }) +} + +pub struct LuminaSearcher { + handle: *mut c_void, + /// Keeps the stream context alive while C-side holds a raw pointer to it. + stream_ctx_keepalive: Option>, +} + +// SAFETY: Each LuminaSearcher owns its handle exclusively and is not Sync. +// Send allows moving the searcher to another thread. +unsafe impl Send for LuminaSearcher {} + +impl LuminaSearcher { + pub fn create(options: &HashMap) -> crate::Result { + let lib = load_library()?; + let opts_json = options_to_json(options)?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let handle: *mut c_void = unsafe { + let func: Symbol< + unsafe extern "C" fn(*const c_char, *mut c_char, c_int) -> *mut c_void, + > = lib + .get(b"lumina_searcher_create") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_create not found: {}", e), + source: None, + })?; + func( + opts_json.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + if handle.is_null() { + let c_str = unsafe { CStr::from_ptr(err_buf.as_ptr() as *const c_char) }; + let msg = c_str.to_string_lossy().to_string(); + return Err(crate::Error::DataInvalid { + message: format!("Failed to create Lumina searcher: {}", msg), + source: None, + }); + } + + Ok(Self { + handle, + stream_ctx_keepalive: None, + }) + } + + #[allow(clippy::type_complexity)] + pub fn open_stream(&mut self, stream: S) -> crate::Result<()> { + if self.stream_ctx_keepalive.is_some() { + return Err(crate::Error::DataInvalid { + message: "A stream is already open; close the searcher before opening a new stream" + .to_string(), + source: None, + }); + } + + let lib = load_library()?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ctx = Box::new(StreamContext::new(stream)); + let ctx_ptr = &*ctx as *const StreamContext as *mut c_void; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *mut c_void, + unsafe extern "C" fn(*mut c_void, *mut c_char, u64) -> c_int, + unsafe extern "C" fn(*mut c_void, u64) -> c_int, + unsafe extern "C" fn(*mut c_void) -> u64, + unsafe extern "C" fn(*mut c_void) -> u64, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_searcher_open_stream") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_open_stream not found: {}", e), + source: None, + })?; + func( + self.handle, + ctx_ptr, + stream_read_cb, + stream_seek_cb, + stream_tell_cb, + stream_length_cb, + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf)?; + self.stream_ctx_keepalive = Some(ctx); + Ok(()) + } + + pub fn search( + &self, + query: &[f32], + n: i32, + k: i32, + distances: &mut [f32], + labels: &mut [u64], + options: &HashMap, + ) -> crate::Result<()> { + let lib = load_library()?; + let opts_json = options_to_json(options)?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *const c_float, + c_int, + c_int, + *mut c_float, + *mut u64, + *const c_char, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_searcher_search") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_search not found: {}", e), + source: None, + })?; + func( + self.handle, + query.as_ptr(), + n, + k, + distances.as_mut_ptr(), + labels.as_mut_ptr(), + opts_json.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + + #[allow(clippy::too_many_arguments, clippy::type_complexity)] + pub fn search_with_filter( + &self, + query: &[f32], + n: i32, + k: i32, + distances: &mut [f32], + labels: &mut [u64], + filter_ids: &[u64], + options: &HashMap, + ) -> crate::Result<()> { + let lib = load_library()?; + let opts_json = options_to_json(options)?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *const c_float, + c_int, + c_int, + *mut c_float, + *mut u64, + *const u64, + u64, + *const c_char, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_searcher_search_with_filter") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_search_with_filter not found: {}", e), + source: None, + })?; + func( + self.handle, + query.as_ptr(), + n, + k, + distances.as_mut_ptr(), + labels.as_mut_ptr(), + filter_ids.as_ptr(), + filter_ids.len() as u64, + opts_json.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + + pub fn get_count(&self) -> crate::Result { + let lib = load_library()?; + unsafe { + let func: Symbol u64> = lib + .get(b"lumina_searcher_get_count") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_get_count not found: {}", e), + source: None, + })?; + Ok(func(self.handle)) + } + } + + pub fn get_dimension(&self) -> crate::Result { + let lib = load_library()?; + unsafe { + let func: Symbol u32> = lib + .get(b"lumina_searcher_get_dimension") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_get_dimension not found: {}", e), + source: None, + })?; + Ok(func(self.handle)) + } + } +} + +impl Drop for LuminaSearcher { + fn drop(&mut self) { + if !self.handle.is_null() { + if let Ok(lib) = load_library() { + unsafe { + if let Ok(func) = + lib.get::(b"lumina_searcher_destroy") + { + func(self.handle); + } + } + } + self.handle = std::ptr::null_mut(); + } + } +} + +pub struct LuminaBuilder { + handle: *mut c_void, +} + +// SAFETY: Same as LuminaSearcher — exclusively owned, not Sync. +unsafe impl Send for LuminaBuilder {} + +impl LuminaBuilder { + pub fn create(options: &HashMap) -> crate::Result { + let lib = load_library()?; + let opts_json = options_to_json(options)?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let handle: *mut c_void = unsafe { + let func: Symbol< + unsafe extern "C" fn(*const c_char, *mut c_char, c_int) -> *mut c_void, + > = lib + .get(b"lumina_builder_create") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_builder_create not found: {}", e), + source: None, + })?; + func( + opts_json.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + if handle.is_null() { + let c_str = unsafe { CStr::from_ptr(err_buf.as_ptr() as *const c_char) }; + let msg = c_str.to_string_lossy().to_string(); + return Err(crate::Error::DataInvalid { + message: format!("Failed to create Lumina builder: {}", msg), + source: None, + }); + } + + Ok(Self { handle }) + } + + pub fn pretrain(&self, vectors: &[f32], n: i32, dim: i32) -> crate::Result<()> { + let lib = load_library()?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *const c_float, + c_int, + c_int, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_builder_pretrain") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_builder_pretrain not found: {}", e), + source: None, + })?; + func( + self.handle, + vectors.as_ptr(), + n, + dim, + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + + pub fn insert(&self, vectors: &[f32], ids: &[u64], n: i32, dim: i32) -> crate::Result<()> { + let lib = load_library()?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn( + *mut c_void, + *const c_float, + *const u64, + c_int, + c_int, + *mut c_char, + c_int, + ) -> c_int, + > = lib + .get(b"lumina_builder_insert") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_builder_insert not found: {}", e), + source: None, + })?; + func( + self.handle, + vectors.as_ptr(), + ids.as_ptr(), + n, + dim, + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + + pub fn dump(&self, path: &str) -> crate::Result<()> { + let lib = load_library()?; + let c_path = CString::new(path).map_err(|e| crate::Error::DataInvalid { + message: format!("Invalid path: {}", e), + source: None, + })?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn(*mut c_void, *const c_char, *mut c_char, c_int) -> c_int, + > = lib + .get(b"lumina_builder_dump") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_builder_dump not found: {}", e), + source: None, + })?; + func( + self.handle, + c_path.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } +} + +impl Drop for LuminaBuilder { + fn drop(&mut self) { + if !self.handle.is_null() { + if let Ok(lib) = load_library() { + unsafe { + if let Ok(func) = + lib.get::(b"lumina_builder_destroy") + { + func(self.handle); + } + } + } + self.handle = std::ptr::null_mut(); + } + } +} + +struct StreamContext { + inner: std::sync::Mutex>, +} + +trait ReadSeekLen: Read + Seek { + fn length(&self) -> u64; +} + +struct ReadSeekLenImpl { + stream: S, + len: u64, +} + +impl ReadSeekLenImpl { + fn new(mut stream: S) -> Self { + let len = stream.seek(SeekFrom::End(0)).unwrap_or(0); + let _ = stream.seek(SeekFrom::Start(0)); + Self { stream, len } + } +} + +impl Read for ReadSeekLenImpl { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.stream.read(buf) + } +} + +impl Seek for ReadSeekLenImpl { + fn seek(&mut self, pos: SeekFrom) -> std::io::Result { + self.stream.seek(pos) + } +} + +impl ReadSeekLen for ReadSeekLenImpl { + fn length(&self) -> u64 { + self.len + } +} + +impl StreamContext { + fn new(stream: S) -> Self { + Self { + inner: std::sync::Mutex::new(Box::new(ReadSeekLenImpl::new(stream))), + } + } +} + +unsafe extern "C" fn stream_read_cb(ctx: *mut c_void, buf: *mut c_char, size: u64) -> c_int { + let ctx = &*(ctx as *const StreamContext); + let mut guard = match ctx.inner.lock() { + Ok(g) => g, + Err(_) => return -1, + }; + let clamped_size = std::cmp::min(size, c_int::MAX as u64) as usize; + let slice = std::slice::from_raw_parts_mut(buf as *mut u8, clamped_size); + let mut total_read = 0usize; + while total_read < clamped_size { + match guard.read(&mut slice[total_read..]) { + Ok(0) => break, + Ok(n) => total_read += n, + Err(_) => return -1, + } + } + std::cmp::min(total_read, c_int::MAX as usize) as c_int +} + +unsafe extern "C" fn stream_seek_cb(ctx: *mut c_void, position: u64) -> c_int { + let ctx = &*(ctx as *const StreamContext); + let mut guard = match ctx.inner.lock() { + Ok(g) => g, + Err(_) => return -1, + }; + match guard.seek(SeekFrom::Start(position)) { + Ok(_) => 0, + Err(_) => -1, + } +} + +unsafe extern "C" fn stream_tell_cb(ctx: *mut c_void) -> u64 { + let ctx = &*(ctx as *const StreamContext); + let mut guard = match ctx.inner.lock() { + Ok(g) => g, + Err(_) => return 0, + }; + guard.seek(SeekFrom::Current(0)).unwrap_or(0) +} + +unsafe extern "C" fn stream_length_cb(ctx: *mut c_void) -> u64 { + let ctx = &*(ctx as *const StreamContext); + let guard = match ctx.inner.lock() { + Ok(g) => g, + Err(_) => return 0, + }; + guard.length() +} diff --git a/crates/paimon/src/lumina/mod.rs b/crates/paimon/src/lumina/mod.rs new file mode 100644 index 00000000..d6eb6cdf --- /dev/null +++ b/crates/paimon/src/lumina/mod.rs @@ -0,0 +1,449 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod ffi; +pub mod reader; + +use std::collections::HashMap; + +pub const LUMINA_VECTOR_ANN_IDENTIFIER: &str = "lumina-vector-ann"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LuminaVectorMetric { + L2, + Cosine, + InnerProduct, +} + +impl LuminaVectorMetric { + pub fn lumina_name(&self) -> &str { + match self { + LuminaVectorMetric::L2 => "l2", + LuminaVectorMetric::Cosine => "cosine", + LuminaVectorMetric::InnerProduct => "inner_product", + } + } + + pub fn from_string(name: &str) -> crate::Result { + match name.to_uppercase().as_str() { + "L2" => Ok(LuminaVectorMetric::L2), + "COSINE" => Ok(LuminaVectorMetric::Cosine), + "INNER_PRODUCT" => Ok(LuminaVectorMetric::InnerProduct), + _ => Err(crate::Error::DataInvalid { + message: format!("Unknown metric name: {}", name), + source: None, + }), + } + } + + pub fn from_lumina_name(lumina_name: &str) -> crate::Result { + match lumina_name { + "l2" => Ok(LuminaVectorMetric::L2), + "cosine" => Ok(LuminaVectorMetric::Cosine), + "inner_product" => Ok(LuminaVectorMetric::InnerProduct), + _ => Err(crate::Error::DataInvalid { + message: format!("Unknown lumina metric name: {}", lumina_name), + source: None, + }), + } + } +} + +const LUMINA_PREFIX: &str = "lumina."; + +const ALL_OPTIONS_DEFAULTS: &[(&str, &str)] = &[ + ("lumina.index.dimension", "128"), + ("lumina.index.type", "diskann"), + ("lumina.distance.metric", "inner_product"), + ("lumina.encoding.type", "pq"), + ("lumina.pretrain.sample_ratio", "0.2"), + ("lumina.diskann.build.ef_construction", "1024"), + ("lumina.diskann.build.neighbor_count", "64"), + ("lumina.diskann.build.thread_count", "32"), + ("lumina.diskann.search.beam_width", "4"), + ("lumina.encoding.pq.m", "64"), + ("lumina.search.parallel_number", "5"), +]; + +pub struct LuminaVectorIndexOptions { + pub dimension: i32, + pub metric: LuminaVectorMetric, + pub index_type: String, + lumina_options: HashMap, +} + +impl LuminaVectorIndexOptions { + pub fn new(paimon_options: &HashMap) -> crate::Result { + let dimension_str = paimon_options + .get("lumina.index.dimension") + .map(|s| s.as_str()) + .unwrap_or("128"); + let dimension: i32 = dimension_str + .parse() + .map_err(|_| crate::Error::DataInvalid { + message: format!("Invalid dimension: {}", dimension_str), + source: None, + })?; + if dimension <= 0 { + return Err(crate::Error::DataInvalid { + message: format!( + "Invalid value for 'lumina.index.dimension': {}. Must be a positive integer.", + dimension + ), + source: None, + }); + } + + let metric_str = paimon_options + .get("lumina.distance.metric") + .map(|s| s.as_str()) + .unwrap_or("inner_product"); + let metric = LuminaVectorMetric::from_lumina_name(metric_str) + .or_else(|_| LuminaVectorMetric::from_string(metric_str))?; + + let encoding = paimon_options + .get("lumina.encoding.type") + .map(|s| s.as_str()) + .unwrap_or("pq"); + validate_encoding_metric(encoding, metric)?; + + let index_type = paimon_options + .get("lumina.index.type") + .cloned() + .unwrap_or_else(|| "diskann".to_string()); + + let lumina_options = build_lumina_options(paimon_options, dimension)?; + + Ok(Self { + dimension, + metric, + index_type, + lumina_options, + }) + } + + pub fn to_lumina_options(&self) -> HashMap { + self.lumina_options.clone() + } +} + +fn validate_encoding_metric(encoding: &str, metric: LuminaVectorMetric) -> crate::Result<()> { + if encoding.eq_ignore_ascii_case("pq") && metric == LuminaVectorMetric::Cosine { + return Err(crate::Error::DataInvalid { + message: + "Lumina does not support PQ encoding with cosine metric. \ + Please use 'rawf32' or 'sq8' encoding, or switch to 'l2' or 'inner_product' metric." + .to_string(), + source: None, + }); + } + Ok(()) +} + +fn validate_and_cap_pq_m(opts: &mut HashMap, dimension: i32) -> crate::Result<()> { + let encoding = opts.get("encoding.type").map(|s| s.as_str()).unwrap_or(""); + if !encoding.eq_ignore_ascii_case("pq") { + return Ok(()); + } + if let Some(pq_m_str) = opts.get("encoding.pq.m") { + let pq_m: i32 = pq_m_str.parse().map_err(|_| crate::Error::DataInvalid { + message: format!("encoding.pq.m must be an integer, got: {}", pq_m_str), + source: None, + })?; + if pq_m <= 0 { + return Err(crate::Error::DataInvalid { + message: format!("encoding.pq.m must be positive, got: {}", pq_m), + source: None, + }); + } + if pq_m > dimension { + opts.insert("encoding.pq.m".to_string(), dimension.to_string()); + } + } + Ok(()) +} + +fn build_lumina_options( + paimon_options: &HashMap, + dimension: i32, +) -> crate::Result> { + let mut result = HashMap::new(); + + for &(paimon_key, default_value) in ALL_OPTIONS_DEFAULTS { + let native_key = &paimon_key[LUMINA_PREFIX.len()..]; + let value = paimon_options + .get(paimon_key) + .map(|s| s.as_str()) + .unwrap_or(default_value); + result.insert(native_key.to_string(), value.to_string()); + } + + for (key, value) in paimon_options { + if let Some(native_key) = key.strip_prefix(LUMINA_PREFIX) { + result + .entry(native_key.to_string()) + .or_insert_with(|| value.to_string()); + } + } + + validate_and_cap_pq_m(&mut result, dimension)?; + Ok(result) +} + +pub fn strip_lumina_options(paimon_options: &HashMap) -> HashMap { + let mut result = HashMap::new(); + for (key, value) in paimon_options { + if let Some(native_key) = key.strip_prefix(LUMINA_PREFIX) { + result.insert(native_key.to_string(), value.to_string()); + } + } + result +} + +pub const KEY_DIMENSION: &str = "index.dimension"; +pub const KEY_DISTANCE_METRIC: &str = "distance.metric"; +pub const KEY_INDEX_TYPE: &str = "index.type"; + +pub struct LuminaIndexMeta { + options: HashMap, +} + +impl LuminaIndexMeta { + pub fn new(options: HashMap) -> Self { + Self { options } + } + + pub fn options(&self) -> &HashMap { + &self.options + } + + pub fn dim(&self) -> crate::Result { + let val = self + .options + .get(KEY_DIMENSION) + .ok_or_else(|| crate::Error::DataInvalid { + message: format!("Missing required key: {}", KEY_DIMENSION), + source: None, + })?; + val.parse::().map_err(|_| crate::Error::DataInvalid { + message: format!("Invalid dimension value: {}", val), + source: None, + }) + } + + pub fn distance_metric(&self) -> &str { + self.options + .get(KEY_DISTANCE_METRIC) + .map(String::as_str) + .unwrap_or("") + } + + pub fn metric(&self) -> crate::Result { + LuminaVectorMetric::from_lumina_name(self.distance_metric()) + } + + pub fn index_type(&self) -> &str { + self.options + .get(KEY_INDEX_TYPE) + .map(String::as_str) + .unwrap_or("diskann") + } + + pub fn serialize(&self) -> crate::Result> { + serde_json::to_vec(&self.options).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to serialize LuminaIndexMeta: {}", e), + source: None, + }) + } + + pub fn deserialize(data: &[u8]) -> crate::Result { + let options: HashMap = + serde_json::from_slice(data).map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to deserialize LuminaIndexMeta: {}", e), + source: None, + })?; + if !options.contains_key(KEY_DIMENSION) { + return Err(crate::Error::DataInvalid { + message: format!( + "Missing required key in Lumina index metadata: {}", + KEY_DIMENSION + ), + source: None, + }); + } + if !options.contains_key(KEY_DISTANCE_METRIC) { + return Err(crate::Error::DataInvalid { + message: format!( + "Missing required key in Lumina index metadata: {}", + KEY_DISTANCE_METRIC + ), + source: None, + }); + } + Ok(Self { options }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metric_roundtrip() { + for metric in [ + LuminaVectorMetric::L2, + LuminaVectorMetric::Cosine, + LuminaVectorMetric::InnerProduct, + ] { + let name = metric.lumina_name(); + assert_eq!(LuminaVectorMetric::from_lumina_name(name).unwrap(), metric); + assert_eq!( + LuminaVectorMetric::from_string(&name.to_uppercase()).unwrap(), + metric + ); + } + assert!(LuminaVectorMetric::from_string("hamming").is_err()); + } + + #[test] + fn test_index_meta_serialize_deserialize() { + let mut options = HashMap::new(); + options.insert(KEY_DIMENSION.to_string(), "128".to_string()); + options.insert(KEY_DISTANCE_METRIC.to_string(), "l2".to_string()); + options.insert(KEY_INDEX_TYPE.to_string(), "diskann".to_string()); + let meta = LuminaIndexMeta::new(options); + + let bytes = meta.serialize().unwrap(); + let meta2 = LuminaIndexMeta::deserialize(&bytes).unwrap(); + assert_eq!(meta2.dim().unwrap(), 128); + assert_eq!(meta2.distance_metric(), "l2"); + assert_eq!(meta2.index_type(), "diskann"); + } + + #[test] + fn test_index_meta_deserialize_missing_fields() { + // missing dimension + let mut opts = HashMap::new(); + opts.insert(KEY_DISTANCE_METRIC.to_string(), "l2".to_string()); + assert!(LuminaIndexMeta::deserialize(&serde_json::to_vec(&opts).unwrap()).is_err()); + + // missing metric + let mut opts = HashMap::new(); + opts.insert(KEY_DIMENSION.to_string(), "128".to_string()); + assert!(LuminaIndexMeta::deserialize(&serde_json::to_vec(&opts).unwrap()).is_err()); + + // invalid json + assert!(LuminaIndexMeta::deserialize(b"not json").is_err()); + } + + #[test] + fn test_dim_error_on_invalid() { + let mut opts = HashMap::new(); + opts.insert(KEY_DIMENSION.to_string(), "abc".to_string()); + opts.insert(KEY_DISTANCE_METRIC.to_string(), "l2".to_string()); + assert!(LuminaIndexMeta::new(opts).dim().is_err()); + } + + #[test] + fn test_index_options_invalid_dimension() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "-1".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_err()); + } + + #[test] + fn test_strip_lumina_options() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert( + "lumina.diskann.search.beam_width".to_string(), + "8".to_string(), + ); + opts.insert("non_lumina_key".to_string(), "ignored".to_string()); + let result = strip_lumina_options(&opts); + assert_eq!(result.get("index.dimension").unwrap(), "128"); + assert_eq!(result.get("diskann.search.beam_width").unwrap(), "8"); + assert!(!result.contains_key("non_lumina_key")); + } + + #[test] + fn test_pq_cosine_rejected() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert("lumina.distance.metric".to_string(), "cosine".to_string()); + opts.insert("lumina.encoding.type".to_string(), "pq".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_err()); + } + + #[test] + fn test_pq_l2_accepted() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert("lumina.distance.metric".to_string(), "l2".to_string()); + opts.insert("lumina.encoding.type".to_string(), "pq".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_ok()); + } + + #[test] + fn test_pq_m_zero_rejected() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert("lumina.encoding.pq.m".to_string(), "0".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_err()); + } + + #[test] + fn test_pq_m_non_numeric_rejected() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "128".to_string()); + opts.insert("lumina.encoding.pq.m".to_string(), "abc".to_string()); + assert!(LuminaVectorIndexOptions::new(&opts).is_err()); + } + + #[test] + fn test_cap_pq_m() { + let mut opts = HashMap::new(); + opts.insert("lumina.index.dimension".to_string(), "32".to_string()); + opts.insert("lumina.encoding.pq.m".to_string(), "64".to_string()); + let index_opts = LuminaVectorIndexOptions::new(&opts).unwrap(); + let lumina_opts = index_opts.to_lumina_options(); + assert_eq!(lumina_opts.get("encoding.pq.m").unwrap(), "32"); + } + + #[test] + fn test_build_lumina_options_defaults() { + let opts = HashMap::new(); + let index_opts = LuminaVectorIndexOptions::new(&opts).unwrap(); + let lumina_opts = index_opts.to_lumina_options(); + assert_eq!(lumina_opts.get("index.dimension").unwrap(), "128"); + assert_eq!(lumina_opts.get("distance.metric").unwrap(), "inner_product"); + assert_eq!(lumina_opts.get("encoding.type").unwrap(), "pq"); + assert_eq!(lumina_opts.get("pretrain.sample_ratio").unwrap(), "0.2"); + assert_eq!( + lumina_opts.get("diskann.build.ef_construction").unwrap(), + "1024" + ); + assert_eq!( + lumina_opts.get("diskann.build.neighbor_count").unwrap(), + "64" + ); + assert_eq!(lumina_opts.get("diskann.build.thread_count").unwrap(), "32"); + assert_eq!(lumina_opts.get("diskann.search.beam_width").unwrap(), "4"); + assert_eq!(lumina_opts.get("encoding.pq.m").unwrap(), "64"); + assert_eq!(lumina_opts.get("search.parallel_number").unwrap(), "5"); + } +} diff --git a/crates/paimon/src/lumina/reader.rs b/crates/paimon/src/lumina/reader.rs new file mode 100644 index 00000000..542a7a33 --- /dev/null +++ b/crates/paimon/src/lumina/reader.rs @@ -0,0 +1,316 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::globalindex::{ + DictBasedScoredIndexResult, GlobalIndexIOMeta, ScoredGlobalIndexResult, VectorSearch, +}; +use crate::lumina::ffi::LuminaSearcher; +use crate::lumina::{strip_lumina_options, LuminaIndexMeta, LuminaVectorMetric}; +use std::collections::BinaryHeap; +use std::collections::HashMap; +use std::io::{Read, Seek}; + +const MIN_SEARCH_LIST_SIZE: usize = 16; +// C ABI returns int64_t -1 for invalid results, which casts to u64::MAX in Rust. +const SENTINEL: u64 = u64::MAX; + +fn ensure_search_list_size(search_options: &mut HashMap, top_k: usize) { + if !search_options.contains_key("diskann.search.list_size") { + let list_size = std::cmp::max((top_k as f64 * 1.5) as usize, MIN_SEARCH_LIST_SIZE); + search_options.insert( + "diskann.search.list_size".to_string(), + list_size.to_string(), + ); + } +} + +fn convert_distance_to_score(distance: f32, metric: LuminaVectorMetric) -> f32 { + match metric { + LuminaVectorMetric::L2 => 1.0 / (1.0 + distance), + LuminaVectorMetric::Cosine => 1.0 - distance, + LuminaVectorMetric::InnerProduct => distance, + } +} + +/// Post-filter search results to top_k. +fn collect_results( + labels: &[u64], + distances: &[f32], + top_k: usize, + metric: LuminaVectorMetric, +) -> HashMap { + #[derive(PartialEq)] + struct ScoredRow { + row_id: u64, + score: f32, + } + impl Eq for ScoredRow {} + impl PartialOrd for ScoredRow { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + impl Ord for ScoredRow { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + other.score.total_cmp(&self.score) + } + } + + let mut min_heap: BinaryHeap = BinaryHeap::with_capacity(top_k + 1); + for (&row_id, &distance) in labels.iter().zip(distances.iter()) { + if row_id == SENTINEL { + continue; + } + let score = convert_distance_to_score(distance, metric); + if min_heap.len() < top_k { + min_heap.push(ScoredRow { row_id, score }); + } else if let Some(peek) = min_heap.peek() { + if score > peek.score { + min_heap.pop(); + min_heap.push(ScoredRow { row_id, score }); + } + } + } + + let mut result = HashMap::with_capacity(min_heap.len()); + for entry in min_heap { + result.insert(entry.row_id, entry.score); + } + result +} + +pub struct LuminaVectorGlobalIndexReader { + io_meta: GlobalIndexIOMeta, + options: HashMap, + searcher: Option, + index_meta: Option, + search_options: Option>, +} + +impl LuminaVectorGlobalIndexReader { + pub fn new(io_meta: GlobalIndexIOMeta, options: HashMap) -> Self { + Self { + io_meta, + options, + searcher: None, + index_meta: None, + search_options: None, + } + } + + pub fn visit_vector_search( + &mut self, + vector_search: &VectorSearch, + stream_fn: impl FnOnce(&str) -> crate::Result, + ) -> crate::Result>> { + self.ensure_loaded(stream_fn)?; + self.search(vector_search) + } + + fn search( + &self, + vector_search: &VectorSearch, + ) -> crate::Result>> { + let index_meta = self + .index_meta + .as_ref() + .ok_or_else(|| crate::Error::DataInvalid { + message: "index_meta not initialized".to_string(), + source: None, + })?; + let searcher = self + .searcher + .as_ref() + .ok_or_else(|| crate::Error::DataInvalid { + message: "searcher not initialized".to_string(), + source: None, + })?; + let search_options_base = + self.search_options + .as_ref() + .ok_or_else(|| crate::Error::DataInvalid { + message: "search_options not initialized".to_string(), + source: None, + })?; + + let expected_dim = index_meta.dim()? as usize; + if vector_search.vector.len() != expected_dim { + return Err(crate::Error::DataInvalid { + message: format!( + "Query vector dimension mismatch: index expects {}, but got {}", + expected_dim, + vector_search.vector.len() + ), + source: None, + }); + } + + let limit = vector_search.limit; + let index_metric = index_meta.metric()?; + let count = searcher.get_count()? as usize; + let effective_k = std::cmp::min(limit, count); + if effective_k == 0 { + return Ok(None); + } + + let include_row_ids = &vector_search.include_row_ids; + + let (distances, labels) = if let Some(ref include_ids) = include_row_ids { + let filter_id_list: Vec = include_ids.iter().collect(); + if filter_id_list.is_empty() { + return Ok(None); + } + let ek = std::cmp::min(effective_k, filter_id_list.len()); + let mut distances = vec![0.0f32; ek]; + let mut labels = vec![0u64; ek]; + let mut search_opts: HashMap = search_options_base.clone(); + search_opts.insert("search.thread_safe_filter".to_string(), "true".to_string()); + ensure_search_list_size(&mut search_opts, ek); + searcher.search_with_filter( + &vector_search.vector, + 1, + ek as i32, + &mut distances, + &mut labels, + &filter_id_list, + &search_opts, + )?; + (distances, labels) + } else { + let mut distances = vec![0.0f32; effective_k]; + let mut labels = vec![0u64; effective_k]; + let mut search_opts: HashMap = search_options_base.clone(); + ensure_search_list_size(&mut search_opts, effective_k); + searcher.search( + &vector_search.vector, + 1, + effective_k as i32, + &mut distances, + &mut labels, + &search_opts, + )?; + (distances, labels) + }; + + let id_to_scores = collect_results(&labels, &distances, effective_k, index_metric); + if id_to_scores.is_empty() { + return Ok(None); + } + + Ok(Some(Box::new(DictBasedScoredIndexResult::new( + id_to_scores, + )))) + } + + fn ensure_loaded( + &mut self, + stream_fn: impl FnOnce(&str) -> crate::Result, + ) -> crate::Result<()> { + if self.searcher.is_some() { + return Ok(()); + } + + let index_meta = LuminaIndexMeta::deserialize(&self.io_meta.metadata)?; + + let mut searcher_options = strip_lumina_options(&self.options); + for (k, v) in index_meta.options().iter() { + searcher_options.insert(k.to_string(), v.to_string()); + } + + let mut searcher = LuminaSearcher::create(&searcher_options)?; + + let stream = stream_fn(&self.io_meta.file_path)?; + searcher.open_stream(stream)?; + + self.search_options = Some(searcher_options); + self.index_meta = Some(index_meta); + self.searcher = Some(searcher); + Ok(()) + } + + pub fn close(&mut self) { + self.searcher = None; + self.index_meta = None; + self.search_options = None; + } +} + +impl Drop for LuminaVectorGlobalIndexReader { + fn drop(&mut self) { + self.close(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::globalindex::GlobalIndexIOMeta; + + #[test] + fn test_convert_distance_to_score() { + assert_eq!(convert_distance_to_score(0.0, LuminaVectorMetric::L2), 1.0); + assert_eq!(convert_distance_to_score(1.0, LuminaVectorMetric::L2), 0.5); + assert_eq!( + convert_distance_to_score(0.0, LuminaVectorMetric::Cosine), + 1.0 + ); + assert_eq!( + convert_distance_to_score(1.0, LuminaVectorMetric::Cosine), + 0.0 + ); + assert_eq!( + convert_distance_to_score(0.75, LuminaVectorMetric::InnerProduct), + 0.75 + ); + } + + #[test] + fn test_ensure_search_list_size() { + let mut opts = HashMap::new(); + ensure_search_list_size(&mut opts, 10); + assert_eq!(opts.get("diskann.search.list_size").unwrap(), "16"); // max(15, 16) + + let mut opts = HashMap::new(); + ensure_search_list_size(&mut opts, 100); + assert_eq!(opts.get("diskann.search.list_size").unwrap(), "150"); // 100*1.5 + + // does not override existing + let mut opts = HashMap::new(); + opts.insert("diskann.search.list_size".to_string(), "999".to_string()); + ensure_search_list_size(&mut opts, 100); + assert_eq!(opts.get("diskann.search.list_size").unwrap(), "999"); + } + + #[test] + fn test_collect_results() { + let labels = vec![0, 1, 2, SENTINEL, 3]; + let distances = vec![0.5, 0.3, 0.1, 0.0, 0.9]; + let result = collect_results(&labels, &distances, 2, LuminaVectorMetric::InnerProduct); + assert_eq!(result.len(), 2); + // top 2 by score: row 3 (0.9) and row 0 (0.5) + assert!(result.contains_key(&3)); + assert!(result.contains_key(&0)); + assert!(!result.contains_key(&2)); // 0.1 is lowest + } + + #[test] + fn test_reader_new() { + let m = GlobalIndexIOMeta::new("a".into(), 100, vec![]); + let reader = LuminaVectorGlobalIndexReader::new(m, HashMap::new()); + assert!(reader.searcher.is_none()); + } +} From 9b5caf14e693d8e5c49c91d4e0711b90b78f0360 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Mon, 20 Apr 2026 17:21:00 +0800 Subject: [PATCH 2/9] fix: resolve merge conflict in Cargo.toml --- crates/paimon/Cargo.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/crates/paimon/Cargo.toml b/crates/paimon/Cargo.toml index 2906b619..0ebbf6c3 100644 --- a/crates/paimon/Cargo.toml +++ b/crates/paimon/Cargo.toml @@ -87,12 +87,9 @@ uuid = { version = "1", features = ["v4"] } urlencoding = "2.1" tantivy = { version = "0.22", optional = true } tempfile = { version = "3", optional = true } -<<<<<<< HEAD vortex = { version = "0.68", features = ["tokio"], optional = true } kanal = { version = "0.1.1", optional = true } -======= libloading = { version = "0.8", optional = true } ->>>>>>> 2c6608c (feat: add Lumina vector index read infrastructure) [dev-dependencies] axum = { version = "0.7", features = ["macros", "tokio", "http1", "http2"] } From 130f01824c1a12a55141f511bcc9ecf8d3e4a7ab Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Sun, 26 Apr 2026 11:50:50 +0800 Subject: [PATCH 3/9] fix comments and add data fusion support lumina --- crates/integrations/datafusion/src/lib.rs | 2 + crates/paimon/Cargo.toml | 3 +- crates/paimon/src/globalindex/mod.rs | 509 ---------------------- crates/paimon/src/lib.rs | 2 - crates/paimon/src/lumina/mod.rs | 221 ++++++++++ crates/paimon/src/lumina/reader.rs | 26 +- crates/paimon/src/table/mod.rs | 6 + 7 files changed, 240 insertions(+), 529 deletions(-) delete mode 100644 crates/paimon/src/globalindex/mod.rs diff --git a/crates/integrations/datafusion/src/lib.rs b/crates/integrations/datafusion/src/lib.rs index 7dfa9b21..4d1688d3 100644 --- a/crates/integrations/datafusion/src/lib.rs +++ b/crates/integrations/datafusion/src/lib.rs @@ -49,6 +49,7 @@ mod sql_handler; mod system_tables; mod table; mod update; +mod vector_search; pub use catalog::{PaimonCatalogProvider, PaimonSchemaProvider}; pub use error::to_datafusion_error; @@ -58,3 +59,4 @@ pub use physical_plan::PaimonTableScan; pub use relation_planner::PaimonRelationPlanner; pub use sql_handler::PaimonSqlHandler; pub use table::PaimonTableProvider; +pub use vector_search::{register_vector_search, VectorSearchFunction}; diff --git a/crates/paimon/Cargo.toml b/crates/paimon/Cargo.toml index 0ebbf6c3..d1c428f9 100644 --- a/crates/paimon/Cargo.toml +++ b/crates/paimon/Cargo.toml @@ -39,7 +39,6 @@ storage-fs = ["opendal/services-fs"] storage-oss = ["opendal/services-oss"] storage-s3 = ["opendal/services-s3"] storage-hdfs = ["opendal/services-hdfs-native"] -lumina = ["libloading"] [dependencies] url = "2.5.2" @@ -89,7 +88,7 @@ tantivy = { version = "0.22", optional = true } tempfile = { version = "3", optional = true } vortex = { version = "0.68", features = ["tokio"], optional = true } kanal = { version = "0.1.1", optional = true } -libloading = { version = "0.8", optional = true } +libloading = "0.8" [dev-dependencies] axum = { version = "0.7", features = ["macros", "tokio", "http1", "http2"] } diff --git a/crates/paimon/src/globalindex/mod.rs b/crates/paimon/src/globalindex/mod.rs deleted file mode 100644 index 921127d5..00000000 --- a/crates/paimon/src/globalindex/mod.rs +++ /dev/null @@ -1,509 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use roaring::RoaringTreemap; -use std::collections::BinaryHeap; -use std::sync::Arc; - -pub type ScoreGetter = Arc f32 + Send + Sync>; - -pub trait GlobalIndexResult: Send + Sync { - fn results(&self) -> &RoaringTreemap; - - fn as_scored(&self) -> Option<&dyn ScoredGlobalIndexResult> { - None - } - - fn offset(&self, start_offset: u64) -> Box { - let bitmap = self.results(); - let offset_bitmap = if start_offset == 0 { - bitmap.clone() - } else { - let mut offset_bitmap = RoaringTreemap::new(); - for row_id in bitmap.iter() { - offset_bitmap.insert(row_id + start_offset); - } - offset_bitmap - }; - - if let Some(scored) = self.as_scored() { - let sg = scored.clone_score_getter(); - let score_getter = if start_offset == 0 { - sg - } else { - Arc::new(move |row_id| sg(row_id - start_offset)) - }; - return Box::new(SimpleScoredGlobalIndexResult::new( - offset_bitmap, - score_getter, - )); - } - - Box::new(SimpleGlobalIndexResult::new_ready(offset_bitmap)) - } - - fn and(&self, other: &dyn GlobalIndexResult) -> crate::Result> { - if self.as_scored().is_some() || other.as_scored().is_some() { - return Err(crate::Error::DataInvalid { - message: "and() is not supported for scored global index results".to_string(), - source: None, - }); - } - let result = self.results() & other.results(); - Ok(Box::new(SimpleGlobalIndexResult::new_ready(result))) - } - - fn or(&self, other: &dyn GlobalIndexResult) -> crate::Result> { - match (self.as_scored(), other.as_scored()) { - (Some(this), Some(that)) => { - let this_row_ids = self.results().clone(); - let result_or = &this_row_ids | other.results(); - let this_sg = this.clone_score_getter(); - let other_sg = that.clone_score_getter(); - // For overlapping IDs, use left-side score - return Ok(Box::new(SimpleScoredGlobalIndexResult::new( - result_or, - Arc::new(move |row_id| { - if this_row_ids.contains(row_id) { - this_sg(row_id) - } else { - other_sg(row_id) - } - }), - ))); - } - (None, None) => {} - _ => { - return Err(crate::Error::DataInvalid { - message: "Cannot union scored and unscored global index results".to_string(), - source: None, - }); - } - } - - let result = self.results() | other.results(); - Ok(Box::new(SimpleGlobalIndexResult::new_ready(result))) - } - - fn is_empty(&self) -> bool { - self.results().is_empty() - } -} - -pub struct SimpleGlobalIndexResult { - bitmap: RoaringTreemap, -} - -impl SimpleGlobalIndexResult { - pub fn new_ready(bitmap: RoaringTreemap) -> Self { - Self { bitmap } - } - - pub fn create_empty() -> Self { - Self::new_ready(RoaringTreemap::new()) - } -} - -impl GlobalIndexResult for SimpleGlobalIndexResult { - fn results(&self) -> &RoaringTreemap { - &self.bitmap - } -} - -pub trait ScoredGlobalIndexResult: GlobalIndexResult { - fn score_getter(&self) -> &ScoreGetter; - - fn scored_offset(&self, offset: u64) -> Box { - if offset == 0 { - let bitmap = self.results().clone(); - let sg = self.clone_score_getter(); - return Box::new(SimpleScoredGlobalIndexResult::new(bitmap, sg)); - } - let bitmap = self.results(); - let mut offset_bitmap = RoaringTreemap::new(); - for row_id in bitmap.iter() { - offset_bitmap.insert(row_id + offset); - } - let sg = self.clone_score_getter(); - Box::new(SimpleScoredGlobalIndexResult::new( - offset_bitmap, - Arc::new(move |row_id| sg(row_id - offset)), - )) - } - - // For overlapping IDs, use left-side score - fn scored_or(&self, other: &dyn ScoredGlobalIndexResult) -> Box { - let this_row_ids = self.results().clone(); - let result_or = &this_row_ids | other.results(); - let this_sg = self.clone_score_getter(); - let other_sg = other.clone_score_getter(); - Box::new(SimpleScoredGlobalIndexResult::new( - result_or, - Arc::new(move |row_id| { - if this_row_ids.contains(row_id) { - this_sg(row_id) - } else { - other_sg(row_id) - } - }), - )) - } - - fn top_k(&self, k: usize) -> Box { - let row_ids = self.results(); - if row_ids.len() <= k as u64 { - let bitmap = row_ids.clone(); - let sg = self.clone_score_getter(); - return Box::new(SimpleScoredGlobalIndexResult::new(bitmap, sg)); - } - - let score_getter_fn = self.score_getter(); - - #[derive(PartialEq)] - struct ScoredEntry { - row_id: u64, - score: f32, - } - impl Eq for ScoredEntry {} - impl PartialOrd for ScoredEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } - } - impl Ord for ScoredEntry { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // Min-heap: reverse order so smallest score is at top. - // Use total_cmp to handle NaN deterministically. - other.score.total_cmp(&self.score) - } - } - - let mut min_heap: BinaryHeap = BinaryHeap::with_capacity(k + 1); - for row_id in row_ids.iter() { - let score = score_getter_fn(row_id); - if min_heap.len() < k { - min_heap.push(ScoredEntry { row_id, score }); - } else if let Some(peek) = min_heap.peek() { - if score > peek.score { - min_heap.pop(); - min_heap.push(ScoredEntry { row_id, score }); - } - } - } - - let mut top_k_ids = RoaringTreemap::new(); - for entry in &min_heap { - top_k_ids.insert(entry.row_id); - } - - let sg = self.clone_score_getter(); - Box::new(SimpleScoredGlobalIndexResult::new(top_k_ids, sg)) - } - - fn clone_score_getter(&self) -> ScoreGetter; -} - -pub struct SimpleScoredGlobalIndexResult { - bitmap: RoaringTreemap, - score_getter: ScoreGetter, -} - -impl SimpleScoredGlobalIndexResult { - pub fn new(bitmap: RoaringTreemap, score_getter: ScoreGetter) -> Self { - Self { - bitmap, - score_getter, - } - } - - pub fn create_empty() -> Self { - Self { - bitmap: RoaringTreemap::new(), - score_getter: Arc::new(|_| 0.0), - } - } -} - -impl GlobalIndexResult for SimpleScoredGlobalIndexResult { - fn results(&self) -> &RoaringTreemap { - &self.bitmap - } - - fn as_scored(&self) -> Option<&dyn ScoredGlobalIndexResult> { - Some(self) - } -} - -impl ScoredGlobalIndexResult for SimpleScoredGlobalIndexResult { - fn score_getter(&self) -> &ScoreGetter { - &self.score_getter - } - - fn clone_score_getter(&self) -> ScoreGetter { - self.score_getter.clone() - } -} - -pub struct DictBasedScoredIndexResult { - bitmap: RoaringTreemap, - score_getter_fn: ScoreGetter, -} - -impl DictBasedScoredIndexResult { - pub fn new(id_to_scores: std::collections::HashMap) -> Self { - let mut bitmap = RoaringTreemap::new(); - for &row_id in id_to_scores.keys() { - bitmap.insert(row_id); - } - let map = Arc::new(id_to_scores); - let score_getter_fn: ScoreGetter = - Arc::new(move |row_id| map.get(&row_id).copied().unwrap_or(0.0)); - Self { - bitmap, - score_getter_fn, - } - } -} - -impl GlobalIndexResult for DictBasedScoredIndexResult { - fn results(&self) -> &RoaringTreemap { - &self.bitmap - } - - fn as_scored(&self) -> Option<&dyn ScoredGlobalIndexResult> { - Some(self) - } -} - -impl ScoredGlobalIndexResult for DictBasedScoredIndexResult { - fn score_getter(&self) -> &ScoreGetter { - &self.score_getter_fn - } - - fn clone_score_getter(&self) -> ScoreGetter { - self.score_getter_fn.clone() - } -} - -pub struct VectorSearch { - pub vector: Vec, - pub limit: usize, - pub field_name: String, - pub include_row_ids: Option, -} - -impl VectorSearch { - pub fn new(vector: Vec, limit: usize, field_name: String) -> crate::Result { - if vector.is_empty() { - return Err(crate::Error::DataInvalid { - message: "Search vector cannot be empty".to_string(), - source: None, - }); - } - if limit == 0 || limit > i32::MAX as usize { - return Err(crate::Error::DataInvalid { - message: format!("Limit must be between 1 and {}, got: {}", i32::MAX, limit), - source: None, - }); - } - if field_name.is_empty() { - return Err(crate::Error::DataInvalid { - message: "Field name cannot be null or empty".to_string(), - source: None, - }); - } - Ok(Self { - vector, - limit, - field_name, - include_row_ids: None, - }) - } - - pub fn with_include_row_ids(mut self, include_row_ids: RoaringTreemap) -> Self { - self.include_row_ids = Some(include_row_ids); - self - } - - pub fn offset_range(&self, from: u64, to: u64) -> Self { - if let Some(ref include_row_ids) = self.include_row_ids { - let mut range_bitmap = RoaringTreemap::new(); - if to == u64::MAX { - range_bitmap.insert_range(from..u64::MAX); - range_bitmap.insert(u64::MAX); - } else { - range_bitmap.insert_range(from..to + 1); - } - let and_result = include_row_ids & &range_bitmap; - let mut offset_bitmap = RoaringTreemap::new(); - for row_id in and_result.iter() { - offset_bitmap.insert(row_id - from); - } - VectorSearch { - vector: self.vector.clone(), - limit: self.limit, - field_name: self.field_name.clone(), - include_row_ids: Some(offset_bitmap), - } - } else { - VectorSearch { - vector: self.vector.clone(), - limit: self.limit, - field_name: self.field_name.clone(), - include_row_ids: None, - } - } - } -} - -impl std::fmt::Display for VectorSearch { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "VectorSearch(field_name={}, limit={})", - self.field_name, self.limit - ) - } -} - -pub struct GlobalIndexIOMeta { - pub file_path: String, - pub file_size: u64, - pub metadata: Vec, -} - -impl GlobalIndexIOMeta { - pub fn new(file_path: String, file_size: u64, metadata: Vec) -> Self { - Self { - file_path, - file_size, - metadata, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_vector_search_offset_range() { - let mut bitmap = RoaringTreemap::new(); - bitmap.insert_range(100..200); - let vs = VectorSearch::new(vec![1.0, 2.0], 10, "vec".to_string()) - .unwrap() - .with_include_row_ids(bitmap); - - let result = vs.offset_range(60, 150); - let ids = result.include_row_ids.unwrap(); - // [100,200) means row ids [100,199]. Inclusive [60,150] keeps [100,150]. - assert_eq!(ids.len(), 51); - assert!(ids.contains(40)); - assert!(ids.contains(90)); - assert!(ids.contains(89)); - assert!(!ids.contains(39)); - assert!(!ids.contains(91)); - } - - #[test] - fn test_invalid_top_k() { - assert!(VectorSearch::new(vec![1.0], 0, "f".to_string()).is_err()); - assert!(VectorSearch::new(vec![1.0], i32::MAX as usize + 1, "f".to_string()).is_err()); - } - - #[test] - fn test_offset_range_no_filter() { - let vs = VectorSearch::new(vec![1.0], 5, "f".to_string()).unwrap(); - let result = vs.offset_range(100, 200); - assert!(result.include_row_ids.is_none()); - } - - fn make_dict(entries: Vec<(u64, f32)>) -> DictBasedScoredIndexResult { - DictBasedScoredIndexResult::new(entries.into_iter().collect()) - } - - #[test] - fn test_top_k_selects_highest() { - let r = make_dict(vec![(1, 0.1), (2, 0.9), (3, 0.5), (4, 0.8), (5, 0.3)]); - let top = r.top_k(2); - assert_eq!(top.results().len(), 2); - assert!(top.results().contains(2)); - assert!(top.results().contains(4)); - } - - #[test] - fn test_scored_offset_preserves_scores() { - let r = make_dict(vec![(1, 0.5), (2, 0.6)]); - let o = r.scored_offset(100); - assert!(o.results().contains(101)); - assert_eq!(o.score_getter()(101), 0.5); - assert_eq!(o.score_getter()(102), 0.6); - } - - #[test] - fn test_base_offset_preserves_scores() { - let r = make_dict(vec![(1, 0.5), (2, 0.6)]); - let o = r.offset(100); - let scored = o - .as_scored() - .expect("offset should preserve scored results"); - assert!(o.results().contains(101)); - assert_eq!(scored.score_getter()(101), 0.5); - assert_eq!(scored.score_getter()(102), 0.6); - } - - #[test] - fn test_base_or_preserves_scores() { - let left = make_dict(vec![(1, 0.5), (2, 0.6)]); - let right = make_dict(vec![(3, 0.7), (4, 0.8)]); - let merged = left.or(&right).unwrap(); - let scored = merged - .as_scored() - .expect("or should preserve scored results"); - assert_eq!(merged.results().len(), 4); - assert_eq!(scored.score_getter()(1), 0.5); - assert_eq!(scored.score_getter()(4), 0.8); - } - - #[test] - fn test_or_overlapping_uses_left_score() { - let left = make_dict(vec![(1, 0.3), (2, 0.9)]); - let right = make_dict(vec![(1, 0.7), (2, 0.4)]); - let merged = left.or(&right).unwrap(); - let scored = merged - .as_scored() - .expect("or should preserve scored results"); - assert_eq!(merged.results().len(), 2); - assert_eq!(scored.score_getter()(1), 0.3); - assert_eq!(scored.score_getter()(2), 0.9); - } - - #[test] - fn test_scored_and_returns_error() { - let left = make_dict(vec![(1, 0.5), (2, 0.6)]); - let right = make_dict(vec![(1, 0.7), (3, 0.8)]); - assert!(left.and(&right).is_err()); - } - - #[test] - fn test_clone_score_getter() { - let r = make_dict(vec![(10, 1.5), (20, 2.5)]); - let cloned = r.clone_score_getter(); - assert_eq!(cloned(10), 1.5); - assert_eq!(cloned(20), 2.5); - } -} diff --git a/crates/paimon/src/lib.rs b/crates/paimon/src/lib.rs index f2fbfe05..612477a1 100644 --- a/crates/paimon/src/lib.rs +++ b/crates/paimon/src/lib.rs @@ -30,9 +30,7 @@ pub mod btree; pub mod catalog; mod deletion_vector; pub mod file_index; -pub mod globalindex; pub mod io; -#[cfg(feature = "lumina")] pub mod lumina; mod predicate_stats; pub mod spec; diff --git a/crates/paimon/src/lumina/mod.rs b/crates/paimon/src/lumina/mod.rs index d6eb6cdf..a686eee1 100644 --- a/crates/paimon/src/lumina/mod.rs +++ b/crates/paimon/src/lumina/mod.rs @@ -214,6 +214,73 @@ pub fn strip_lumina_options(paimon_options: &HashMap) -> HashMap result } +pub struct VectorSearch { + pub vector: Vec, + pub limit: usize, + pub field_name: String, + pub include_row_ids: Option, +} + +impl VectorSearch { + pub fn new(vector: Vec, limit: usize, field_name: String) -> crate::Result { + if vector.is_empty() { + return Err(crate::Error::DataInvalid { + message: "Search vector cannot be empty".to_string(), + source: None, + }); + } + if limit == 0 || limit > i32::MAX as usize { + return Err(crate::Error::DataInvalid { + message: format!("Limit must be between 1 and {}, got: {}", i32::MAX, limit), + source: None, + }); + } + if field_name.is_empty() { + return Err(crate::Error::DataInvalid { + message: "Field name cannot be null or empty".to_string(), + source: None, + }); + } + Ok(Self { + vector, + limit, + field_name, + include_row_ids: None, + }) + } + + pub fn with_include_row_ids(mut self, include_row_ids: roaring::RoaringTreemap) -> Self { + self.include_row_ids = Some(include_row_ids); + self + } +} + +impl std::fmt::Display for VectorSearch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "VectorSearch(field_name={}, limit={})", + self.field_name, self.limit + ) + } +} + +pub struct GlobalIndexIOMeta { + pub file_path: String, + pub file_size: u64, + pub metadata: Vec, +} + +impl GlobalIndexIOMeta { + pub fn new(file_path: String, file_size: u64, metadata: Vec) -> Self { + Self { + file_path, + file_size, + metadata, + } + } +} + pub const KEY_DIMENSION: &str = "index.dimension"; pub const KEY_DISTANCE_METRIC: &str = "distance.metric"; pub const KEY_INDEX_TYPE: &str = "index.type"; @@ -298,6 +365,113 @@ impl LuminaIndexMeta { } } +#[derive(Debug, Clone)] +pub struct SearchResult { + pub row_ids: Vec, + pub scores: Vec, +} + +impl SearchResult { + pub fn new(row_ids: Vec, scores: Vec) -> Self { + assert_eq!(row_ids.len(), scores.len()); + Self { row_ids, scores } + } + + pub fn empty() -> Self { + Self { + row_ids: Vec::new(), + scores: Vec::new(), + } + } + + pub fn from_scored_map(map: HashMap) -> Self { + let mut row_ids = Vec::with_capacity(map.len()); + let mut scores = Vec::with_capacity(map.len()); + for (id, score) in map { + row_ids.push(id); + scores.push(score); + } + Self { row_ids, scores } + } + + pub fn len(&self) -> usize { + self.row_ids.len() + } + + pub fn is_empty(&self) -> bool { + self.row_ids.is_empty() + } + + pub fn offset(&self, offset: i64) -> Self { + if offset == 0 { + return self.clone(); + } + let row_ids = self + .row_ids + .iter() + .map(|&id| { + if offset >= 0 { + id.saturating_add(offset as u64) + } else { + id.saturating_sub(offset.unsigned_abs()) + } + }) + .collect(); + Self { + row_ids, + scores: self.scores.clone(), + } + } + + pub fn or(&self, other: &SearchResult) -> Self { + let mut row_ids = self.row_ids.clone(); + let mut scores = self.scores.clone(); + row_ids.extend_from_slice(&other.row_ids); + scores.extend_from_slice(&other.scores); + Self { row_ids, scores } + } + + pub fn top_k(&self, k: usize) -> Self { + if self.row_ids.len() <= k { + return self.clone(); + } + let mut indices: Vec = (0..self.row_ids.len()).collect(); + indices.sort_by(|&a, &b| { + self.scores[b] + .partial_cmp(&self.scores[a]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + indices.truncate(k); + let row_ids = indices.iter().map(|&i| self.row_ids[i]).collect(); + let scores = indices.iter().map(|&i| self.scores[i]).collect(); + Self { row_ids, scores } + } + + pub fn to_row_ranges(&self) -> Vec { + if self.row_ids.is_empty() { + return Vec::new(); + } + let mut sorted: Vec = self.row_ids.clone(); + sorted.sort_unstable(); + sorted.dedup(); + let mut ranges = Vec::new(); + let mut start = sorted[0] as i64; + let mut end = start; + for &id in &sorted[1..] { + let id = id as i64; + if id == end + 1 { + end = id; + } else { + ranges.push(crate::table::RowRange::new(start, end)); + start = id; + end = id; + } + } + ranges.push(crate::table::RowRange::new(start, end)); + ranges + } +} + #[cfg(test)] mod tests { use super::*; @@ -446,4 +620,51 @@ mod tests { assert_eq!(lumina_opts.get("encoding.pq.m").unwrap(), "64"); assert_eq!(lumina_opts.get("search.parallel_number").unwrap(), "5"); } + + #[test] + fn test_search_result_from_scored_map() { + let mut map = HashMap::new(); + map.insert(1u64, 0.9f32); + map.insert(2, 0.5); + let result = SearchResult::from_scored_map(map); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_search_result_top_k() { + let result = SearchResult::new(vec![1, 2, 3, 4, 5], vec![0.1, 0.9, 0.5, 0.8, 0.3]); + let top = result.top_k(2); + assert_eq!(top.len(), 2); + assert!(top.row_ids.contains(&2)); + assert!(top.row_ids.contains(&4)); + } + + #[test] + fn test_search_result_offset() { + let result = SearchResult::new(vec![0, 1], vec![0.5, 0.6]); + let offset = result.offset(100); + assert_eq!(offset.row_ids, vec![100, 101]); + assert_eq!(offset.scores, vec![0.5, 0.6]); + } + + #[test] + fn test_search_result_or() { + let a = SearchResult::new(vec![1, 2], vec![0.5, 0.6]); + let b = SearchResult::new(vec![3], vec![0.7]); + let merged = a.or(&b); + assert_eq!(merged.len(), 3); + } + + #[test] + fn test_search_result_to_row_ranges() { + let result = SearchResult::new(vec![5, 1, 2, 3, 10], vec![0.1; 5]); + let ranges = result.to_row_ranges(); + assert_eq!(ranges.len(), 3); + assert_eq!(ranges[0].from(), 1); + assert_eq!(ranges[0].to(), 3); + assert_eq!(ranges[1].from(), 5); + assert_eq!(ranges[1].to(), 5); + assert_eq!(ranges[2].from(), 10); + assert_eq!(ranges[2].to(), 10); + } } diff --git a/crates/paimon/src/lumina/reader.rs b/crates/paimon/src/lumina/reader.rs index 542a7a33..9f2ade67 100644 --- a/crates/paimon/src/lumina/reader.rs +++ b/crates/paimon/src/lumina/reader.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::globalindex::{ - DictBasedScoredIndexResult, GlobalIndexIOMeta, ScoredGlobalIndexResult, VectorSearch, -}; use crate::lumina::ffi::LuminaSearcher; -use crate::lumina::{strip_lumina_options, LuminaIndexMeta, LuminaVectorMetric}; +use crate::lumina::{ + strip_lumina_options, GlobalIndexIOMeta, LuminaIndexMeta, LuminaVectorMetric, VectorSearch, +}; use std::collections::BinaryHeap; use std::collections::HashMap; use std::io::{Read, Seek}; @@ -116,15 +115,12 @@ impl LuminaVectorGlobalIndexReader { &mut self, vector_search: &VectorSearch, stream_fn: impl FnOnce(&str) -> crate::Result, - ) -> crate::Result>> { + ) -> crate::Result>> { self.ensure_loaded(stream_fn)?; self.search(vector_search) } - fn search( - &self, - vector_search: &VectorSearch, - ) -> crate::Result>> { + fn search(&self, vector_search: &VectorSearch) -> crate::Result>> { let index_meta = self .index_meta .as_ref() @@ -211,9 +207,7 @@ impl LuminaVectorGlobalIndexReader { return Ok(None); } - Ok(Some(Box::new(DictBasedScoredIndexResult::new( - id_to_scores, - )))) + Ok(Some(id_to_scores)) } fn ensure_loaded( @@ -226,9 +220,9 @@ impl LuminaVectorGlobalIndexReader { let index_meta = LuminaIndexMeta::deserialize(&self.io_meta.metadata)?; - let mut searcher_options = strip_lumina_options(&self.options); - for (k, v) in index_meta.options().iter() { - searcher_options.insert(k.to_string(), v.to_string()); + let mut searcher_options = index_meta.options().clone(); + for (k, v) in strip_lumina_options(&self.options) { + searcher_options.insert(k, v); } let mut searcher = LuminaSearcher::create(&searcher_options)?; @@ -258,7 +252,7 @@ impl Drop for LuminaVectorGlobalIndexReader { #[cfg(test)] mod tests { use super::*; - use crate::globalindex::GlobalIndexIOMeta; + use crate::lumina::GlobalIndexIOMeta; #[test] fn test_convert_distance_to_score() { diff --git a/crates/paimon/src/table/mod.rs b/crates/paimon/src/table/mod.rs index 147239f8..1f58142a 100644 --- a/crates/paimon/src/table/mod.rs +++ b/crates/paimon/src/table/mod.rs @@ -49,6 +49,7 @@ mod table_read; mod table_scan; pub(crate) mod table_write; mod tag_manager; +mod vector_search_builder; mod write_builder; use crate::Result; @@ -71,6 +72,7 @@ pub use table_read::TableRead; pub use table_scan::TableScan; pub use table_write::TableWrite; pub use tag_manager::TagManager; +pub use vector_search_builder::VectorSearchBuilder; pub use write_builder::WriteBuilder; use crate::catalog::Identifier; @@ -149,6 +151,10 @@ impl Table { FullTextSearchBuilder::new(self) } + pub fn new_vector_search_builder(&self) -> VectorSearchBuilder<'_> { + VectorSearchBuilder::new(self) + } + /// Create a write builder for write/commit. /// /// Reference: [pypaimon FileStoreTable.new_write_builder](https://github.com/apache/paimon/blob/master/paimon-python/pypaimon/table/file_store_table.py). From 013478b7ba12097409c6bf0aa6f8903cfa51a96e Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Sun, 26 Apr 2026 11:57:34 +0800 Subject: [PATCH 4/9] add vector search --- .../datafusion/src/vector_search.rs | 253 ++++++++++++++ .../paimon/src/table/vector_search_builder.rs | 328 ++++++++++++++++++ 2 files changed, 581 insertions(+) create mode 100644 crates/integrations/datafusion/src/vector_search.rs create mode 100644 crates/paimon/src/table/vector_search_builder.rs diff --git a/crates/integrations/datafusion/src/vector_search.rs b/crates/integrations/datafusion/src/vector_search.rs new file mode 100644 index 00000000..89575035 --- /dev/null +++ b/crates/integrations/datafusion/src/vector_search.rs @@ -0,0 +1,253 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; +use datafusion::catalog::Session; +use datafusion::catalog::TableFunctionImpl; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result as DFResult; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use paimon::catalog::{Catalog, Identifier}; + +use crate::error::to_datafusion_error; +use crate::runtime::{await_with_runtime, block_on_with_runtime}; +use crate::table::{PaimonScanBuilder, PaimonTableProvider}; + +pub fn register_vector_search( + ctx: &SessionContext, + catalog: Arc, + default_database: &str, +) { + ctx.register_udtf( + "vector_search", + Arc::new(VectorSearchFunction::new(catalog, default_database)), + ); +} + +pub struct VectorSearchFunction { + catalog: Arc, + default_database: String, +} + +impl Debug for VectorSearchFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VectorSearchFunction") + .field("default_database", &self.default_database) + .finish() + } +} + +impl VectorSearchFunction { + pub fn new(catalog: Arc, default_database: &str) -> Self { + Self { + catalog, + default_database: default_database.to_string(), + } + } +} + +impl TableFunctionImpl for VectorSearchFunction { + fn call(&self, args: &[Expr]) -> DFResult> { + if args.len() != 4 { + return Err(datafusion::error::DataFusionError::Plan( + "vector_search requires 4 arguments: (table_name, column_name, query_vector_json, limit)".to_string(), + )); + } + + let table_name = extract_string_literal(&args[0], "table_name")?; + let column_name = extract_string_literal(&args[1], "column_name")?; + let query_vector_json = extract_string_literal(&args[2], "query_vector_json")?; + let limit = extract_int_literal(&args[3], "limit")?; + + if limit <= 0 { + return Err(datafusion::error::DataFusionError::Plan( + "vector_search: limit must be positive".to_string(), + )); + } + + let query_vector: Vec = serde_json::from_str(&query_vector_json).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "vector_search: query_vector_json must be a JSON array of floats, got '{}': {}", + query_vector_json, e + )) + })?; + + if query_vector.is_empty() { + return Err(datafusion::error::DataFusionError::Plan( + "vector_search: query vector cannot be empty".to_string(), + )); + } + + let identifier = parse_table_identifier(&table_name, &self.default_database)?; + + let catalog = Arc::clone(&self.catalog); + let table = block_on_with_runtime( + async move { catalog.get_table(&identifier).await }, + "vector_search: catalog access thread panicked", + ) + .map_err(to_datafusion_error)?; + + let inner = PaimonTableProvider::try_new(table)?; + + Ok(Arc::new(VectorSearchTableProvider { + inner, + column_name, + query_vector, + limit: limit as usize, + })) + } +} + +#[derive(Debug)] +struct VectorSearchTableProvider { + inner: PaimonTableProvider, + column_name: String, + query_vector: Vec, + limit: usize, +} + +#[async_trait] +impl TableProvider for VectorSearchTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> ArrowSchemaRef { + self.inner.schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + limit: Option, + ) -> DFResult> { + let table = self.inner.table(); + + let row_ranges = await_with_runtime(async { + let mut builder = table.new_vector_search_builder(); + builder + .with_vector_column(&self.column_name) + .with_query_vector(self.query_vector.clone()) + .with_limit(self.limit); + builder.execute().await.map_err(to_datafusion_error) + }) + .await?; + + let mut read_builder = table.new_read_builder(); + if let Some(limit) = limit { + read_builder.with_limit(limit); + } + let scan = if row_ranges.is_empty() { + read_builder.new_scan() + } else { + read_builder.new_scan().with_row_ranges(row_ranges) + }; + let plan = await_with_runtime(scan.plan()) + .await + .map_err(to_datafusion_error)?; + + let target = state.config_options().execution.target_partitions; + PaimonScanBuilder { + table, + schema: &self.schema(), + plan: &plan, + projection, + pushed_predicate: None, + limit, + target_partitions: target, + filter_exact: false, + } + .build() + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> DFResult> { + Ok(vec![ + TableProviderFilterPushDown::Unsupported; + filters.len() + ]) + } +} + +fn extract_string_literal(expr: &Expr, name: &str) -> DFResult { + match expr { + Expr::Literal(scalar, _) => { + let s = scalar.try_as_str().flatten().ok_or_else(|| { + datafusion::error::DataFusionError::Plan(format!( + "vector_search: {name} must be a string literal, got: {expr}" + )) + })?; + Ok(s.to_string()) + } + _ => Err(datafusion::error::DataFusionError::Plan(format!( + "vector_search: {name} must be a literal, got: {expr}" + ))), + } +} + +fn extract_int_literal(expr: &Expr, name: &str) -> DFResult { + use datafusion::common::ScalarValue; + match expr { + Expr::Literal(scalar, _) => match scalar { + ScalarValue::Int8(Some(v)) => Ok(*v as i64), + ScalarValue::Int16(Some(v)) => Ok(*v as i64), + ScalarValue::Int32(Some(v)) => Ok(*v as i64), + ScalarValue::Int64(Some(v)) => Ok(*v), + ScalarValue::UInt8(Some(v)) => Ok(*v as i64), + ScalarValue::UInt16(Some(v)) => Ok(*v as i64), + ScalarValue::UInt32(Some(v)) => Ok(*v as i64), + ScalarValue::UInt64(Some(v)) => i64::try_from(*v).map_err(|_| { + datafusion::error::DataFusionError::Plan(format!( + "vector_search: {name} value {v} exceeds i64 range" + )) + }), + _ => Err(datafusion::error::DataFusionError::Plan(format!( + "vector_search: {name} must be an integer literal, got: {expr}" + ))), + }, + _ => Err(datafusion::error::DataFusionError::Plan(format!( + "vector_search: {name} must be a literal, got: {expr}" + ))), + } +} + +fn parse_table_identifier(name: &str, default_database: &str) -> DFResult { + let parts: Vec<&str> = name.split('.').collect(); + match parts.len() { + 1 => Ok(Identifier::new(default_database, parts[0])), + 2 => Ok(Identifier::new(parts[0], parts[1])), + 3 => Ok(Identifier::new(parts[1], parts[2])), + _ => Err(datafusion::error::DataFusionError::Plan(format!( + "vector_search: invalid table name '{name}', expected 'table', 'database.table', or 'catalog.database.table'" + ))), + } +} diff --git a/crates/paimon/src/table/vector_search_builder.rs b/crates/paimon/src/table/vector_search_builder.rs new file mode 100644 index 00000000..ab0d9868 --- /dev/null +++ b/crates/paimon/src/table/vector_search_builder.rs @@ -0,0 +1,328 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::lumina::reader::LuminaVectorGlobalIndexReader; +use crate::lumina::{GlobalIndexIOMeta, SearchResult, VectorSearch, LUMINA_VECTOR_ANN_IDENTIFIER}; +use crate::spec::{DataField, FileKind, IndexManifest}; +use crate::table::snapshot_manager::SnapshotManager; +use crate::table::{RowRange, Table}; +use std::collections::HashMap; +use std::io::Cursor; + +const INDEX_DIR: &str = "index"; + +pub struct VectorSearchBuilder<'a> { + table: &'a Table, + vector_column: Option, + query_vector: Option>, + limit: Option, +} + +impl<'a> VectorSearchBuilder<'a> { + pub(crate) fn new(table: &'a Table) -> Self { + Self { + table, + vector_column: None, + query_vector: None, + limit: None, + } + } + + pub fn with_vector_column(&mut self, name: &str) -> &mut Self { + self.vector_column = Some(name.to_string()); + self + } + + pub fn with_query_vector(&mut self, vector: Vec) -> &mut Self { + self.query_vector = Some(vector); + self + } + + pub fn with_limit(&mut self, limit: usize) -> &mut Self { + self.limit = Some(limit); + self + } + + pub async fn execute(&self) -> crate::Result> { + let vector_column = + self.vector_column + .as_deref() + .ok_or_else(|| crate::Error::ConfigInvalid { + message: "Vector column must be set via with_vector_column()".to_string(), + })?; + let query_vector = + self.query_vector + .as_ref() + .ok_or_else(|| crate::Error::ConfigInvalid { + message: "Query vector must be set via with_query_vector()".to_string(), + })?; + let limit = self.limit.ok_or_else(|| crate::Error::ConfigInvalid { + message: "Limit must be set via with_limit()".to_string(), + })?; + + let vector_search = + VectorSearch::new(query_vector.clone(), limit, vector_column.to_string())?; + + let snapshot_manager = SnapshotManager::new( + self.table.file_io().clone(), + self.table.location().to_string(), + ); + + let snapshot = match snapshot_manager.get_latest_snapshot().await? { + Some(s) => s, + None => return Ok(Vec::new()), + }; + + let index_manifest_name = match snapshot.index_manifest() { + Some(name) => name.to_string(), + None => return Ok(Vec::new()), + }; + + let manifest_path = format!( + "{}/manifest/{}", + self.table.location().trim_end_matches('/'), + index_manifest_name + ); + let index_entries = IndexManifest::read(self.table.file_io(), &manifest_path).await?; + + evaluate_vector_search( + self.table.file_io(), + self.table.location(), + self.table.schema().options(), + &index_entries, + &vector_search, + self.table.schema().fields(), + ) + .await + } +} + +async fn evaluate_vector_search( + file_io: &crate::io::FileIO, + table_path: &str, + table_options: &HashMap, + index_entries: &[crate::spec::IndexManifestEntry], + vector_search: &VectorSearch, + schema_fields: &[DataField], +) -> crate::Result> { + let table_path = table_path.trim_end_matches('/'); + + let field_id = match find_field_id_by_name(schema_fields, &vector_search.field_name) { + Some(id) => id, + None => return Ok(Vec::new()), + }; + + let lumina_entries: Vec<_> = index_entries + .iter() + .filter(|e| { + e.kind == FileKind::Add + && e.index_file.index_type == LUMINA_VECTOR_ANN_IDENTIFIER + && e.index_file + .global_index_meta + .as_ref() + .is_some_and(|m| m.index_field_id == field_id) + }) + .collect(); + + if lumina_entries.is_empty() { + return Ok(Vec::new()); + } + + let futures: Vec<_> = lumina_entries + .into_iter() + .map(|entry| { + let global_meta = entry.index_file.global_index_meta.as_ref().unwrap(); + let path = format!("{table_path}/{INDEX_DIR}/{}", entry.index_file.file_name); + let file_name = entry.index_file.file_name.clone(); + let file_size = entry.index_file.file_size as u64; + let index_meta_bytes = global_meta.index_meta.clone().unwrap_or_default(); + let row_range_start = global_meta.row_range_start; + let vector_search_clone = VectorSearch::new( + vector_search.vector.clone(), + vector_search.limit, + vector_search.field_name.clone(), + ); + let options = table_options.clone(); + let input = file_io.new_input(&path); + async move { + let input = input?; + let bytes = input.read().await.map_err(|e| crate::Error::DataInvalid { + message: format!("Failed to read Lumina index file '{}': {}", file_name, e), + source: None, + })?; + + let io_meta = + GlobalIndexIOMeta::new(file_name.clone(), file_size, index_meta_bytes); + let mut reader = LuminaVectorGlobalIndexReader::new(io_meta, options); + let vs = vector_search_clone?; + + let data = bytes.to_vec(); + let result = reader.visit_vector_search(&vs, |_| Ok(Cursor::new(data)))?; + + match result { + Some(scored_map) => Ok::<_, crate::Error>( + SearchResult::from_scored_map(scored_map).offset(row_range_start), + ), + None => Ok(SearchResult::empty()), + } + } + }) + .collect(); + + let results = futures::future::try_join_all(futures).await?; + let mut merged = SearchResult::empty(); + for r in &results { + merged = merged.or(r); + } + + Ok(merged.top_k(vector_search.limit).to_row_ranges()) +} + +fn find_field_id_by_name(fields: &[DataField], name: &str) -> Option { + fields.iter().find(|f| f.name() == name).map(|f| f.id()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::spec::{DataType, GlobalIndexMeta, IndexFileMeta, IndexManifestEntry, IntType}; + + fn make_field(id: i32, name: &str) -> DataField { + DataField::new(id, name.to_string(), DataType::Int(IntType::default())) + } + + #[test] + fn test_find_field_id_by_name() { + let fields = vec![make_field(1, "id"), make_field(2, "embedding")]; + assert_eq!(find_field_id_by_name(&fields, "embedding"), Some(2)); + assert_eq!(find_field_id_by_name(&fields, "nonexistent"), None); + } + + #[tokio::test] + async fn test_evaluate_no_matching_entries() { + let file_io = crate::io::FileIOBuilder::new("memory").build().unwrap(); + let fields = vec![make_field(1, "id"), make_field(2, "embedding")]; + let vs = VectorSearch::new(vec![1.0, 2.0], 10, "embedding".to_string()).unwrap(); + + let entry = IndexManifestEntry { + kind: FileKind::Add, + partition: vec![], + bucket: 0, + index_file: IndexFileMeta { + index_type: "btree".to_string(), + file_name: "test.idx".to_string(), + file_size: 100, + row_count: 10, + deletion_vectors_ranges: None, + global_index_meta: None, + }, + version: 1, + }; + + let result = evaluate_vector_search( + &file_io, + "memory:///test_table", + &HashMap::new(), + &[entry], + &vs, + &fields, + ) + .await + .unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_evaluate_no_matching_field() { + let file_io = crate::io::FileIOBuilder::new("memory").build().unwrap(); + let fields = vec![make_field(1, "id")]; + let vs = VectorSearch::new(vec![1.0], 10, "embedding".to_string()).unwrap(); + + let entry = IndexManifestEntry { + kind: FileKind::Add, + partition: vec![], + bucket: 0, + index_file: IndexFileMeta { + index_type: LUMINA_VECTOR_ANN_IDENTIFIER.to_string(), + file_name: "test.idx".to_string(), + file_size: 100, + row_count: 10, + deletion_vectors_ranges: None, + global_index_meta: Some(GlobalIndexMeta { + row_range_start: 0, + row_range_end: 9, + index_field_id: 99, + extra_field_ids: None, + index_meta: None, + }), + }, + version: 1, + }; + + let result = evaluate_vector_search( + &file_io, + "memory:///test_table", + &HashMap::new(), + &[entry], + &vs, + &fields, + ) + .await + .unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_evaluate_skips_delete_entries() { + let file_io = crate::io::FileIOBuilder::new("memory").build().unwrap(); + let fields = vec![make_field(2, "embedding")]; + let vs = VectorSearch::new(vec![1.0], 10, "embedding".to_string()).unwrap(); + + let entry = IndexManifestEntry { + kind: FileKind::Delete, + partition: vec![], + bucket: 0, + index_file: IndexFileMeta { + index_type: LUMINA_VECTOR_ANN_IDENTIFIER.to_string(), + file_name: "test.idx".to_string(), + file_size: 100, + row_count: 10, + deletion_vectors_ranges: None, + global_index_meta: Some(GlobalIndexMeta { + row_range_start: 0, + row_range_end: 9, + index_field_id: 2, + extra_field_ids: None, + index_meta: None, + }), + }, + version: 1, + }; + + let result = evaluate_vector_search( + &file_io, + "memory:///test_table", + &HashMap::new(), + &[entry], + &vs, + &fields, + ) + .await + .unwrap(); + assert!(result.is_empty()); + } +} From 35c4d7a97081d3d7b096ff48780c9e4452894274 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Sun, 26 Apr 2026 12:40:39 +0800 Subject: [PATCH 5/9] add lumina case into CI --- .github/workflows/ci.yml | 5 + .../testdata/test_lumina_vector.tar.gz | Bin 0 -> 3549 bytes .../datafusion/tests/read_tables.rs | 90 ++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 crates/integrations/datafusion/testdata/test_lumina_vector.tar.gz diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 93a13a3b..82e66a28 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -110,6 +110,11 @@ jobs: RUST_LOG: DEBUG RUST_BACKTRACE: full + - name: Install lumina native library + run: | + pip install lumina-data + echo "LUMINA_LIB_PATH=$(python3 -c 'import lumina_data; print(lumina_data.__path__[0])')/lib/liblumina_py.so" >> $GITHUB_ENV + - name: DataFusion Integration Test run: cargo test -p paimon-datafusion --all-targets env: diff --git a/crates/integrations/datafusion/testdata/test_lumina_vector.tar.gz b/crates/integrations/datafusion/testdata/test_lumina_vector.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..3909f95cf38ec286371c4129b767f185b30fda25 GIT binary patch literal 3549 zcmV<34I=U%iwFR$k?m;!1MOW4bQIMYp53fTz<|7VB{i+XQGtmVjZh>U7BbOgJAZ{jNF9GZ z#BwV}auy7I{0Wk#)bY;;1|5I74sH#Ey-Wa>>R^kP!VuPn!U#^XFu^heY$nVW7{SaW zN0S_mAzY$l6uYUdD!CRmVG@0sfJPP<71N=IK=-|&&x&jO`QO$M1?bR^G_A}DcuL>uX#FQbhH=89;C*gR1o z+eNiWGJ303n?O)<%!t`Ug$flaRH#s)LWPQagP1HIYn1X7@OkssTaU!bL>B&e3V*Ls zz7~A$ZvVuU>-{Z{!hc>V8^pHu-n=>{e?lrt0y-3aLP@B_V&{wYGr97wSNH{mzYKiX zV-0&^{Ijy~|4`xYQ}{*wKgE#aSp_R#`RH;c6snAJK7kL1DubNBB-e2vUzp`0$<=Kf zlcttuqpAEtq}$UT=0iftxd780?hyJW%G!u`&u-`68s^Q(H1~g_A*Mar78XYA5x|h{ z|Kp$@C2b?tD16~ykQZ8_TtpIudr8KRc{$ijBW9SOXcRUR3G*fxWaBhY14s1~2UcE{z~alJ%k(84PzSf zG}xV<1y)B(yb9iKuvQavXs|BIfy&w9aCmAPTbyxPE-{5V0s&J+zqU*y!gPbvyubyc z1Lx&_FYQ@c0dfI9(-9C%GsQDKOMBu)G}q3vHCR3Nx(o&ryEMC{w%+DW(Lx&h6cI@s z9X2nJ>#!}bIff^8xU4{`t8rnxNs?@uMefj;m-bANR-O3F(Kl%&aA_fyiLlVGHr=$( z`R=XnyxF;K*I&jzdQuQNhMsMGN2|+^Yx^|iKd-t7&D*4f?$zn;==bmxgJKqx!eJWm0q2gRENr%jcQz~- zLZCR!SrCxiA2n*Bso+m5z7OC3H&zQBIF3V8l7)@7Yg3T}a*F@C&5l83WQqR-LMF$5 z9H%ff{^tWJ@&Bc-elFsFz1>+i_}Ts-nkn!lMU0e2+>LIlL!oLOoJ`CFau4qcb($`Z zTt*|t1}UIt3AgpO#o5w_EF|FcQ1ct90||^*etsh+xl`woMkeNCDT~#%$nCNwCGUgvynp${H&E-cS*7gUZI4^2DMpm^5^dJX z`kN)nFK-xgujxnU!!N_fPhS7l@6QO6bf)?5 z9B!XgX#GL`>57xc=`Gjn-geCscaOXJCwDzuy!R`$z6(2xPDZwmDTa)bA1i`(6s%qS z%7HttD*QV%F&T@mbXh;0q9sWFj1|$(bep6S)FrjL(k{3=BF7Xr4E=M$<%v4I;fTX*x&w2du6q*@zI$N{p`fD|9)dyYs2|_Cg0c^JT|fSk@8&))i0bM zw`1Y1kCqmcOsN@v@9~<=yP(xIyDZ1f?`oufGowe?yz|=XCm!EczGMTn>DGlaFPeUF zWTnTlWX|^TB~!k;ZDZHwb4w<3nu5N=Pd|UXdH=sY-oNd-SDu<+ePG9l1Mjt3ouT@` ziQt&m*HSyGZKq$mP;>SBtq*&P$3ac_p*4pNk6lD}Yu0YOdBd(>ZaMPE{`=0}J){SPVr6BMrg{v$6?pZ}@P|1SHckox@ZbNc-6 zPuqjVMIW82Sjpahy7F9KFm?Tpu>T*<`+te=e<=PRQ?LK>0@eSk{$KU~UsV4eJn`Gw zWnJOY-zV-al@?m-L$Wg>-0K3 zqMxVV@b7)wwX;A2x)}2TkxAZYIkrJ}Y-3%K4bo%!+w|*SU3Yd&8E8QVK5v|yzk1|67<59fou-|t<=GYpBDGV;a6vD}vLFT3 zYxRYswoKe+jMPoliC0H@Eu=37)J0r?3vwZ$rFTnjuMRSll$MoNOoz&V0J0D&E)|>8 zml#S6%M5EY(~D=xy2~9gK;>~Ipq^f|kXGEz%y7IwZ!mnI)mK7dW5uP@3U#IC0gWUb zVzNQ|h+(hBa7r`%f@b;#O_^+o0U-k=sI&rDT_9bVoW@wpLLoSZno*iY z!9|1D&-r{b>gD_tI-|6rOy0n4twx@VpXqE@k&JY||Ci?f2>t_@p~Zhn{QoaBrvCqz zykPM0mk$|U>&+S2;y<1;{*=1^L!OXn{|EV}{>tqMP+npbsS{6&djW`-tUL|K7$^mZg7TUy8w^A*x&1e;eaaCPRglh|T zM7b*Ey1$b3R;4c8m^?)-mKSdX_&9Nc(g^R9GzefxMAAz>fD{-pHU~^Zj3yH9f`tg< zTP}8v3o+gRC+{yIL^`+_k#w2Bb%p~S%FdRT$R4S1)Nn#sW#Ktq@{uLN)RyUqd Xp+bcU6)IGy7&8732)b!H0C)fZiHQQ{ literal 0 HcmV?d00001 diff --git a/crates/integrations/datafusion/tests/read_tables.rs b/crates/integrations/datafusion/tests/read_tables.rs index daaee458..f847f087 100644 --- a/crates/integrations/datafusion/tests/read_tables.rs +++ b/crates/integrations/datafusion/tests/read_tables.rs @@ -1046,3 +1046,93 @@ mod fulltext_tests { assert!(ids.contains(&3), "Searching 'search' should match row 3"); } } + +// ======================= Vector Search Tests ======================= + +mod vector_search_tests { + use std::sync::Arc; + + use datafusion::arrow::array::Int32Array; + use datafusion::prelude::SessionContext; + use paimon::{Catalog, CatalogOptions, FileSystemCatalog, Options}; + use paimon_datafusion::{register_vector_search, PaimonCatalogProvider}; + + fn extract_test_warehouse() -> (tempfile::TempDir, String) { + let archive_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("testdata/test_lumina_vector.tar.gz"); + let file = std::fs::File::open(&archive_path) + .unwrap_or_else(|e| panic!("Failed to open {}: {e}", archive_path.display())); + let decoder = flate2::read::GzDecoder::new(file); + let mut archive = tar::Archive::new(decoder); + + let tmp = tempfile::tempdir().expect("Failed to create temp dir"); + let db_dir = tmp.path().join("default.db"); + std::fs::create_dir_all(&db_dir).unwrap(); + archive.unpack(&db_dir).unwrap(); + + let warehouse = format!("file://{}", tmp.path().display()); + (tmp, warehouse) + } + + async fn create_vector_search_context() -> (SessionContext, tempfile::TempDir) { + let (tmp, warehouse) = extract_test_warehouse(); + let mut options = Options::new(); + options.set(CatalogOptions::WAREHOUSE, warehouse); + let catalog = FileSystemCatalog::new(options).expect("Failed to create catalog"); + let catalog: Arc = Arc::new(catalog); + + let ctx = SessionContext::new(); + ctx.register_catalog( + "paimon", + Arc::new(PaimonCatalogProvider::new(Arc::clone(&catalog))), + ); + register_vector_search(&ctx, catalog, "default"); + (ctx, tmp) + } + + fn extract_ids(batches: &[datafusion::arrow::record_batch::RecordBatch]) -> Vec { + let mut ids = Vec::new(); + for batch in batches { + let id_array = batch + .column_by_name("id") + .and_then(|c| c.as_any().downcast_ref::()) + .expect("Expected Int32Array for id"); + for i in 0..batch.num_rows() { + ids.push(id_array.value(i)); + } + } + ids.sort(); + ids + } + + #[tokio::test] + async fn test_vector_search_top3() { + let (ctx, _tmp) = create_vector_search_context().await; + let batches = ctx + .sql("SELECT id FROM vector_search('paimon.default.test_lumina_vector', 'embedding', '[1.0, 0.0, 0.0, 0.0]', 3)") + .await + .expect("SQL should parse") + .collect() + .await + .expect("query should execute"); + + let ids = extract_ids(&batches); + assert_eq!(ids.len(), 3); + assert!(ids.contains(&0), "exact match [1,0,0,0] should be in top 3"); + } + + #[tokio::test] + async fn test_vector_search_top6_returns_all() { + let (ctx, _tmp) = create_vector_search_context().await; + let batches = ctx + .sql("SELECT id FROM vector_search('paimon.default.test_lumina_vector', 'embedding', '[1.0, 0.0, 0.0, 0.0]', 6)") + .await + .expect("SQL should parse") + .collect() + .await + .expect("query should execute"); + + let ids = extract_ids(&batches); + assert_eq!(ids, vec![0, 1, 2, 3, 4, 5]); + } +} From 50f6f494cf2de1d47e66ce48fe4a50e66f1b6fab Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Sun, 26 Apr 2026 14:27:04 +0800 Subject: [PATCH 6/9] fix test case failure --- .../datafusion/src/vector_search.rs | 13 +-- .../datafusion/tests/read_tables.rs | 18 +++++ crates/paimon/src/lumina/ffi.rs | 36 +++++++++ crates/paimon/src/lumina/reader.rs | 80 ++++++++++++++++++- 4 files changed, 139 insertions(+), 8 deletions(-) diff --git a/crates/integrations/datafusion/src/vector_search.rs b/crates/integrations/datafusion/src/vector_search.rs index 89575035..5ef0ea1f 100644 --- a/crates/integrations/datafusion/src/vector_search.rs +++ b/crates/integrations/datafusion/src/vector_search.rs @@ -23,9 +23,11 @@ use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; use datafusion::catalog::Session; use datafusion::catalog::TableFunctionImpl; +use datafusion::common::project_schema; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result as DFResult; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; +use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use paimon::catalog::{Catalog, Identifier}; @@ -160,15 +162,16 @@ impl TableProvider for VectorSearchTableProvider { }) .await?; + if row_ranges.is_empty() { + let schema = project_schema(&self.schema(), projection)?; + return Ok(Arc::new(EmptyExec::new(schema))); + } + let mut read_builder = table.new_read_builder(); if let Some(limit) = limit { read_builder.with_limit(limit); } - let scan = if row_ranges.is_empty() { - read_builder.new_scan() - } else { - read_builder.new_scan().with_row_ranges(row_ranges) - }; + let scan = read_builder.new_scan().with_row_ranges(row_ranges); let plan = await_with_runtime(scan.plan()) .await .map_err(to_datafusion_error)?; diff --git a/crates/integrations/datafusion/tests/read_tables.rs b/crates/integrations/datafusion/tests/read_tables.rs index f847f087..9318a9a0 100644 --- a/crates/integrations/datafusion/tests/read_tables.rs +++ b/crates/integrations/datafusion/tests/read_tables.rs @@ -1135,4 +1135,22 @@ mod vector_search_tests { let ids = extract_ids(&batches); assert_eq!(ids, vec![0, 1, 2, 3, 4, 5]); } + + #[tokio::test] + async fn test_vector_search_without_matching_index_returns_empty() { + let (ctx, _tmp) = create_vector_search_context().await; + let batches = ctx + .sql("SELECT id FROM vector_search('paimon.default.test_lumina_vector', 'missing_embedding', '[1.0]', 10)") + .await + .expect("SQL should parse") + .collect() + .await + .expect("query should execute"); + + let total_rows: usize = batches.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!( + total_rows, 0, + "vector_search without a matching Lumina index should not fall back to a full table scan" + ); + } } diff --git a/crates/paimon/src/lumina/ffi.rs b/crates/paimon/src/lumina/ffi.rs index 4891a131..b699b68d 100644 --- a/crates/paimon/src/lumina/ffi.rs +++ b/crates/paimon/src/lumina/ffi.rs @@ -169,6 +169,42 @@ impl LuminaSearcher { Ok(()) } + pub fn open_file(&mut self, path: &str) -> crate::Result<()> { + if self.stream_ctx_keepalive.is_some() { + return Err(crate::Error::DataInvalid { + message: "A stream is already open; close the searcher before opening a file" + .to_string(), + source: None, + }); + } + + let lib = load_library()?; + let c_path = CString::new(path).map_err(|e| crate::Error::DataInvalid { + message: format!("Invalid path: {}", e), + source: None, + })?; + let mut err_buf = [0u8; ERR_BUF_SIZE]; + + let ret: c_int = unsafe { + let func: Symbol< + unsafe extern "C" fn(*mut c_void, *const c_char, *mut c_char, c_int) -> c_int, + > = lib + .get(b"lumina_searcher_open") + .map_err(|e| crate::Error::DataInvalid { + message: format!("Symbol lumina_searcher_open not found: {}", e), + source: None, + })?; + func( + self.handle, + c_path.as_ptr(), + err_buf.as_mut_ptr() as *mut c_char, + ERR_BUF_SIZE as c_int, + ) + }; + + check_error(ret, &err_buf) + } + pub fn search( &self, query: &[f32], diff --git a/crates/paimon/src/lumina/reader.rs b/crates/paimon/src/lumina/reader.rs index 9f2ade67..09acff88 100644 --- a/crates/paimon/src/lumina/reader.rs +++ b/crates/paimon/src/lumina/reader.rs @@ -21,7 +21,8 @@ use crate::lumina::{ }; use std::collections::BinaryHeap; use std::collections::HashMap; -use std::io::{Read, Seek}; +use std::io::{Read, Seek, SeekFrom}; +use std::path::PathBuf; const MIN_SEARCH_LIST_SIZE: usize = 16; // C ABI returns int64_t -1 for invalid results, which casts to u64::MAX in Rust. @@ -98,6 +99,7 @@ pub struct LuminaVectorGlobalIndexReader { searcher: Option, index_meta: Option, search_options: Option>, + local_index_file: Option, } impl LuminaVectorGlobalIndexReader { @@ -108,6 +110,7 @@ impl LuminaVectorGlobalIndexReader { searcher: None, index_meta: None, search_options: None, + local_index_file: None, } } @@ -227,12 +230,27 @@ impl LuminaVectorGlobalIndexReader { let mut searcher = LuminaSearcher::create(&searcher_options)?; - let stream = stream_fn(&self.io_meta.file_path)?; - searcher.open_stream(stream)?; + let mut stream = stream_fn(&self.io_meta.file_path)?; + let local_index_file = write_temp_index_file(&mut stream)?; + let local_index_path = + local_index_file + .to_str() + .ok_or_else(|| crate::Error::DataInvalid { + message: format!( + "Temporary Lumina index path is not valid UTF-8: {}", + local_index_file.display() + ), + source: None, + })?; + if let Err(err) = searcher.open_file(local_index_path) { + let _ = std::fs::remove_file(&local_index_file); + return Err(err); + } self.search_options = Some(searcher_options); self.index_meta = Some(index_meta); self.searcher = Some(searcher); + self.local_index_file = Some(local_index_file); Ok(()) } @@ -240,9 +258,51 @@ impl LuminaVectorGlobalIndexReader { self.searcher = None; self.index_meta = None; self.search_options = None; + if let Some(path) = self.local_index_file.take() { + let _ = std::fs::remove_file(path); + } } } +fn write_temp_index_file(stream: &mut S) -> crate::Result { + stream + .seek(SeekFrom::Start(0)) + .map_err(|e| crate::Error::UnexpectedError { + message: format!("Failed to seek Lumina index stream to start: {}", e), + source: Some(Box::new(e)), + })?; + + let path = std::env::temp_dir().join(format!( + "paimon-lumina-index-{}.index", + uuid::Uuid::new_v4() + )); + let mut file = std::fs::File::create(&path).map_err(|e| crate::Error::UnexpectedError { + message: format!( + "Failed to create temporary Lumina index file '{}': {}", + path.display(), + e + ), + source: Some(Box::new(e)), + })?; + std::io::copy(stream, &mut file).map_err(|e| crate::Error::UnexpectedError { + message: format!( + "Failed to write temporary Lumina index file '{}': {}", + path.display(), + e + ), + source: Some(Box::new(e)), + })?; + file.sync_all().map_err(|e| crate::Error::UnexpectedError { + message: format!( + "Failed to sync temporary Lumina index file '{}': {}", + path.display(), + e + ), + source: Some(Box::new(e)), + })?; + Ok(path) +} + impl Drop for LuminaVectorGlobalIndexReader { fn drop(&mut self) { self.close(); @@ -253,6 +313,7 @@ impl Drop for LuminaVectorGlobalIndexReader { mod tests { use super::*; use crate::lumina::GlobalIndexIOMeta; + use std::io::Cursor; #[test] fn test_convert_distance_to_score() { @@ -307,4 +368,17 @@ mod tests { let reader = LuminaVectorGlobalIndexReader::new(m, HashMap::new()); assert!(reader.searcher.is_none()); } + + #[test] + fn test_write_temp_index_file_copies_stream() { + let bytes = b"lumina-index-bytes".to_vec(); + let mut stream = Cursor::new(bytes.clone()); + stream.seek(SeekFrom::End(0)).unwrap(); + + let path = write_temp_index_file(&mut stream).unwrap(); + let actual = std::fs::read(&path).unwrap(); + std::fs::remove_file(&path).unwrap(); + + assert_eq!(actual, bytes); + } } From b4b4f86e834401a1d15ca75e6db48a66ddaf9945 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Sun, 26 Apr 2026 15:20:14 +0800 Subject: [PATCH 7/9] fix: preserve vector search filters and refactor table function args --- .../datafusion/src/full_text_search.rs | 73 +++---------------- crates/integrations/datafusion/src/lib.rs | 1 + .../datafusion/src/vector_search.rs | 73 ++++--------------- crates/paimon/src/lumina/ffi.rs | 26 ++++++- crates/paimon/src/lumina/mod.rs | 18 +++++ .../paimon/src/table/vector_search_builder.rs | 11 +-- 6 files changed, 70 insertions(+), 132 deletions(-) diff --git a/crates/integrations/datafusion/src/full_text_search.rs b/crates/integrations/datafusion/src/full_text_search.rs index f1b559b8..20ff38ff 100644 --- a/crates/integrations/datafusion/src/full_text_search.rs +++ b/crates/integrations/datafusion/src/full_text_search.rs @@ -37,11 +37,16 @@ use datafusion::error::Result as DFResult; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; -use paimon::catalog::{Catalog, Identifier}; +use paimon::catalog::Catalog; use crate::error::to_datafusion_error; use crate::runtime::{await_with_runtime, block_on_with_runtime}; use crate::table::{PaimonScanBuilder, PaimonTableProvider}; +use crate::table_function_args::{ + extract_int_literal, extract_string_literal, parse_table_identifier, +}; + +const FUNCTION_NAME: &str = "full_text_search"; /// Register the `full_text_search` table-valued function on a [`SessionContext`]. pub fn register_full_text_search( @@ -88,10 +93,10 @@ impl TableFunctionImpl for FullTextSearchFunction { )); } - let table_name = extract_string_literal(&args[0], "table_name")?; - let column_name = extract_string_literal(&args[1], "column_name")?; - let query_text = extract_string_literal(&args[2], "query_text")?; - let limit = extract_int_literal(&args[3], "limit")?; + let table_name = extract_string_literal(FUNCTION_NAME, &args[0], "table_name")?; + let column_name = extract_string_literal(FUNCTION_NAME, &args[1], "column_name")?; + let query_text = extract_string_literal(FUNCTION_NAME, &args[2], "query_text")?; + let limit = extract_int_literal(FUNCTION_NAME, &args[3], "limit")?; if limit <= 0 { return Err(datafusion::error::DataFusionError::Plan( @@ -99,7 +104,8 @@ impl TableFunctionImpl for FullTextSearchFunction { )); } - let identifier = parse_table_identifier(&table_name, &self.default_database)?; + let identifier = + parse_table_identifier(FUNCTION_NAME, &table_name, &self.default_database)?; let catalog = Arc::clone(&self.catalog); let table = block_on_with_runtime( @@ -201,58 +207,3 @@ impl TableProvider for FullTextSearchTableProvider { ]) } } - -fn extract_string_literal(expr: &Expr, name: &str) -> DFResult { - match expr { - Expr::Literal(scalar, _) => { - let s = scalar.try_as_str().flatten().ok_or_else(|| { - datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} must be a string literal, got: {expr}" - )) - })?; - Ok(s.to_string()) - } - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} must be a literal, got: {expr}" - ))), - } -} - -fn extract_int_literal(expr: &Expr, name: &str) -> DFResult { - use datafusion::common::ScalarValue; - match expr { - Expr::Literal(scalar, _) => match scalar { - ScalarValue::Int8(Some(v)) => Ok(*v as i64), - ScalarValue::Int16(Some(v)) => Ok(*v as i64), - ScalarValue::Int32(Some(v)) => Ok(*v as i64), - ScalarValue::Int64(Some(v)) => Ok(*v), - ScalarValue::UInt8(Some(v)) => Ok(*v as i64), - ScalarValue::UInt16(Some(v)) => Ok(*v as i64), - ScalarValue::UInt32(Some(v)) => Ok(*v as i64), - ScalarValue::UInt64(Some(v)) => i64::try_from(*v).map_err(|_| { - datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} value {v} exceeds i64 range" - )) - }), - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} must be an integer literal, got: {expr}" - ))), - }, - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "full_text_search: {name} must be a literal, got: {expr}" - ))), - } -} - -fn parse_table_identifier(name: &str, default_database: &str) -> DFResult { - let parts: Vec<&str> = name.split('.').collect(); - match parts.len() { - 1 => Ok(Identifier::new(default_database, parts[0])), - 2 => Ok(Identifier::new(parts[0], parts[1])), - // 3-part name: catalog.database.table — ignore catalog prefix - 3 => Ok(Identifier::new(parts[1], parts[2])), - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "full_text_search: invalid table name '{name}', expected 'table', 'database.table', or 'catalog.database.table'" - ))), - } -} diff --git a/crates/integrations/datafusion/src/lib.rs b/crates/integrations/datafusion/src/lib.rs index 4d1688d3..fa094147 100644 --- a/crates/integrations/datafusion/src/lib.rs +++ b/crates/integrations/datafusion/src/lib.rs @@ -48,6 +48,7 @@ pub mod runtime; mod sql_handler; mod system_tables; mod table; +mod table_function_args; mod update; mod vector_search; diff --git a/crates/integrations/datafusion/src/vector_search.rs b/crates/integrations/datafusion/src/vector_search.rs index 5ef0ea1f..34daadc8 100644 --- a/crates/integrations/datafusion/src/vector_search.rs +++ b/crates/integrations/datafusion/src/vector_search.rs @@ -30,11 +30,16 @@ use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; -use paimon::catalog::{Catalog, Identifier}; +use paimon::catalog::Catalog; use crate::error::to_datafusion_error; use crate::runtime::{await_with_runtime, block_on_with_runtime}; use crate::table::{PaimonScanBuilder, PaimonTableProvider}; +use crate::table_function_args::{ + extract_int_literal, extract_string_literal, parse_table_identifier, +}; + +const FUNCTION_NAME: &str = "vector_search"; pub fn register_vector_search( ctx: &SessionContext, @@ -77,10 +82,11 @@ impl TableFunctionImpl for VectorSearchFunction { )); } - let table_name = extract_string_literal(&args[0], "table_name")?; - let column_name = extract_string_literal(&args[1], "column_name")?; - let query_vector_json = extract_string_literal(&args[2], "query_vector_json")?; - let limit = extract_int_literal(&args[3], "limit")?; + let table_name = extract_string_literal(FUNCTION_NAME, &args[0], "table_name")?; + let column_name = extract_string_literal(FUNCTION_NAME, &args[1], "column_name")?; + let query_vector_json = + extract_string_literal(FUNCTION_NAME, &args[2], "query_vector_json")?; + let limit = extract_int_literal(FUNCTION_NAME, &args[3], "limit")?; if limit <= 0 { return Err(datafusion::error::DataFusionError::Plan( @@ -101,7 +107,8 @@ impl TableFunctionImpl for VectorSearchFunction { )); } - let identifier = parse_table_identifier(&table_name, &self.default_database)?; + let identifier = + parse_table_identifier(FUNCTION_NAME, &table_name, &self.default_database)?; let catalog = Arc::clone(&self.catalog); let table = block_on_with_runtime( @@ -200,57 +207,3 @@ impl TableProvider for VectorSearchTableProvider { ]) } } - -fn extract_string_literal(expr: &Expr, name: &str) -> DFResult { - match expr { - Expr::Literal(scalar, _) => { - let s = scalar.try_as_str().flatten().ok_or_else(|| { - datafusion::error::DataFusionError::Plan(format!( - "vector_search: {name} must be a string literal, got: {expr}" - )) - })?; - Ok(s.to_string()) - } - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "vector_search: {name} must be a literal, got: {expr}" - ))), - } -} - -fn extract_int_literal(expr: &Expr, name: &str) -> DFResult { - use datafusion::common::ScalarValue; - match expr { - Expr::Literal(scalar, _) => match scalar { - ScalarValue::Int8(Some(v)) => Ok(*v as i64), - ScalarValue::Int16(Some(v)) => Ok(*v as i64), - ScalarValue::Int32(Some(v)) => Ok(*v as i64), - ScalarValue::Int64(Some(v)) => Ok(*v), - ScalarValue::UInt8(Some(v)) => Ok(*v as i64), - ScalarValue::UInt16(Some(v)) => Ok(*v as i64), - ScalarValue::UInt32(Some(v)) => Ok(*v as i64), - ScalarValue::UInt64(Some(v)) => i64::try_from(*v).map_err(|_| { - datafusion::error::DataFusionError::Plan(format!( - "vector_search: {name} value {v} exceeds i64 range" - )) - }), - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "vector_search: {name} must be an integer literal, got: {expr}" - ))), - }, - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "vector_search: {name} must be a literal, got: {expr}" - ))), - } -} - -fn parse_table_identifier(name: &str, default_database: &str) -> DFResult { - let parts: Vec<&str> = name.split('.').collect(); - match parts.len() { - 1 => Ok(Identifier::new(default_database, parts[0])), - 2 => Ok(Identifier::new(parts[0], parts[1])), - 3 => Ok(Identifier::new(parts[1], parts[2])), - _ => Err(datafusion::error::DataFusionError::Plan(format!( - "vector_search: invalid table name '{name}', expected 'table', 'database.table', or 'catalog.database.table'" - ))), - } -} diff --git a/crates/paimon/src/lumina/ffi.rs b/crates/paimon/src/lumina/ffi.rs index b699b68d..75b89acb 100644 --- a/crates/paimon/src/lumina/ffi.rs +++ b/crates/paimon/src/lumina/ffi.rs @@ -19,16 +19,28 @@ use libloading::{Library, Symbol}; use std::collections::HashMap; use std::ffi::{c_char, c_float, c_int, c_void, CStr, CString}; use std::io::{Read, Seek, SeekFrom}; -use std::sync::OnceLock; +use std::sync::{Mutex, OnceLock}; const ERR_BUF_SIZE: usize = 4096; static LIBRARY: OnceLock = OnceLock::new(); +static LIBRARY_LOAD_LOCK: Mutex<()> = Mutex::new(()); fn load_library() -> crate::Result<&'static Library> { if let Some(lib) = LIBRARY.get() { return Ok(lib); } + + let _guard = LIBRARY_LOAD_LOCK + .lock() + .map_err(|_| crate::Error::UnexpectedError { + message: "Lumina library load lock poisoned".to_string(), + source: None, + })?; + if let Some(lib) = LIBRARY.get() { + return Ok(lib); + } + let lib_path = std::env::var("LUMINA_LIB_PATH").unwrap_or_else(|_| { if cfg!(target_os = "macos") { "liblumina_py.dylib".to_string() @@ -44,8 +56,16 @@ fn load_library() -> crate::Result<&'static Library> { source: None, })? }; - let _ = LIBRARY.set(lib); - Ok(LIBRARY.get().unwrap()) + LIBRARY + .set(lib) + .map_err(|_| crate::Error::UnexpectedError { + message: "Lumina library was initialized unexpectedly while holding load lock" + .to_string(), + source: None, + })?; + Ok(LIBRARY + .get() + .expect("Lumina library should be initialized after successful set")) } fn check_error(ret: c_int, err_buf: &[u8; ERR_BUF_SIZE]) -> crate::Result<()> { diff --git a/crates/paimon/src/lumina/mod.rs b/crates/paimon/src/lumina/mod.rs index a686eee1..e08f1c83 100644 --- a/crates/paimon/src/lumina/mod.rs +++ b/crates/paimon/src/lumina/mod.rs @@ -214,6 +214,7 @@ pub fn strip_lumina_options(paimon_options: &HashMap) -> HashMap result } +#[derive(Clone)] pub struct VectorSearch { pub vector: Vec, pub limit: usize, @@ -621,6 +622,23 @@ mod tests { assert_eq!(lumina_opts.get("search.parallel_number").unwrap(), "5"); } + #[test] + fn test_vector_search_clone_preserves_include_row_ids() { + let mut include_row_ids = roaring::RoaringTreemap::new(); + include_row_ids.insert(1); + include_row_ids.insert(3); + + let vector_search = VectorSearch::new(vec![1.0, 2.0], 10, "embedding".to_string()) + .unwrap() + .with_include_row_ids(include_row_ids.clone()); + + let cloned = vector_search.clone(); + assert_eq!(cloned.vector, vector_search.vector); + assert_eq!(cloned.limit, vector_search.limit); + assert_eq!(cloned.field_name, vector_search.field_name); + assert_eq!(cloned.include_row_ids.as_ref(), Some(&include_row_ids)); + } + #[test] fn test_search_result_from_scored_map() { let mut map = HashMap::new(); diff --git a/crates/paimon/src/table/vector_search_builder.rs b/crates/paimon/src/table/vector_search_builder.rs index ab0d9868..676d304d 100644 --- a/crates/paimon/src/table/vector_search_builder.rs +++ b/crates/paimon/src/table/vector_search_builder.rs @@ -151,11 +151,7 @@ async fn evaluate_vector_search( let file_size = entry.index_file.file_size as u64; let index_meta_bytes = global_meta.index_meta.clone().unwrap_or_default(); let row_range_start = global_meta.row_range_start; - let vector_search_clone = VectorSearch::new( - vector_search.vector.clone(), - vector_search.limit, - vector_search.field_name.clone(), - ); + let vector_search_clone = vector_search.clone(); let options = table_options.clone(); let input = file_io.new_input(&path); async move { @@ -168,10 +164,9 @@ async fn evaluate_vector_search( let io_meta = GlobalIndexIOMeta::new(file_name.clone(), file_size, index_meta_bytes); let mut reader = LuminaVectorGlobalIndexReader::new(io_meta, options); - let vs = vector_search_clone?; - let data = bytes.to_vec(); - let result = reader.visit_vector_search(&vs, |_| Ok(Cursor::new(data)))?; + let result = + reader.visit_vector_search(&vector_search_clone, |_| Ok(Cursor::new(data)))?; match result { Some(scored_map) => Ok::<_, crate::Error>( From 0543771f73edea00475154d5422612a79f0cdedd Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Sun, 26 Apr 2026 15:33:45 +0800 Subject: [PATCH 8/9] add table_function_args.rs --- .../datafusion/src/table_function_args.rs | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 crates/integrations/datafusion/src/table_function_args.rs diff --git a/crates/integrations/datafusion/src/table_function_args.rs b/crates/integrations/datafusion/src/table_function_args.rs new file mode 100644 index 00000000..c0a6c833 --- /dev/null +++ b/crates/integrations/datafusion/src/table_function_args.rs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::ScalarValue; +use datafusion::error::{DataFusionError, Result as DFResult}; +use datafusion::logical_expr::Expr; +use paimon::catalog::Identifier; + +pub(crate) fn extract_string_literal( + function_name: &str, + expr: &Expr, + name: &str, +) -> DFResult { + match expr { + Expr::Literal(scalar, _) => { + let s = scalar.try_as_str().flatten().ok_or_else(|| { + DataFusionError::Plan(format!( + "{function_name}: {name} must be a string literal, got: {expr}" + )) + })?; + Ok(s.to_string()) + } + _ => Err(DataFusionError::Plan(format!( + "{function_name}: {name} must be a literal, got: {expr}" + ))), + } +} + +pub(crate) fn extract_int_literal(function_name: &str, expr: &Expr, name: &str) -> DFResult { + match expr { + Expr::Literal(scalar, _) => match scalar { + ScalarValue::Int8(Some(v)) => Ok(*v as i64), + ScalarValue::Int16(Some(v)) => Ok(*v as i64), + ScalarValue::Int32(Some(v)) => Ok(*v as i64), + ScalarValue::Int64(Some(v)) => Ok(*v), + ScalarValue::UInt8(Some(v)) => Ok(*v as i64), + ScalarValue::UInt16(Some(v)) => Ok(*v as i64), + ScalarValue::UInt32(Some(v)) => Ok(*v as i64), + ScalarValue::UInt64(Some(v)) => i64::try_from(*v).map_err(|_| { + DataFusionError::Plan(format!( + "{function_name}: {name} value {v} exceeds i64 range" + )) + }), + _ => Err(DataFusionError::Plan(format!( + "{function_name}: {name} must be an integer literal, got: {expr}" + ))), + }, + _ => Err(DataFusionError::Plan(format!( + "{function_name}: {name} must be a literal, got: {expr}" + ))), + } +} + +pub(crate) fn parse_table_identifier( + function_name: &str, + name: &str, + default_database: &str, +) -> DFResult { + let parts: Vec<&str> = name.split('.').collect(); + match parts.len() { + 1 => Ok(Identifier::new(default_database, parts[0])), + 2 => Ok(Identifier::new(parts[0], parts[1])), + 3 => Ok(Identifier::new(parts[1], parts[2])), + _ => Err(DataFusionError::Plan(format!( + "{function_name}: invalid table name '{name}', expected 'table', 'database.table', or 'catalog.database.table'" + ))), + } +} From 9b6c3d144d1c602f2df86e6ce2579afd6173f3d2 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Sun, 26 Apr 2026 15:40:27 +0800 Subject: [PATCH 9/9] fix: harden vector search row range conversion --- crates/paimon/src/lumina/mod.rs | 39 +++++++++++++++---- .../src/table/full_text_search_builder.rs | 6 +-- .../paimon/src/table/global_index_scanner.rs | 10 ++--- crates/paimon/src/table/mod.rs | 6 ++- .../paimon/src/table/vector_search_builder.rs | 8 +--- 5 files changed, 43 insertions(+), 26 deletions(-) diff --git a/crates/paimon/src/lumina/mod.rs b/crates/paimon/src/lumina/mod.rs index e08f1c83..04b4f87e 100644 --- a/crates/paimon/src/lumina/mod.rs +++ b/crates/paimon/src/lumina/mod.rs @@ -448,19 +448,32 @@ impl SearchResult { Self { row_ids, scores } } - pub fn to_row_ranges(&self) -> Vec { + pub fn to_row_ranges(&self) -> crate::Result> { if self.row_ids.is_empty() { - return Vec::new(); + return Ok(Vec::new()); } - let mut sorted: Vec = self.row_ids.clone(); + + let mut sorted = self + .row_ids + .iter() + .copied() + .map(|id| { + i64::try_from(id).map_err(|_| crate::Error::DataInvalid { + message: format!( + "Lumina search row id {id} exceeds i64::MAX and cannot be converted to RowRange" + ), + source: None, + }) + }) + .collect::>>()?; + sorted.sort_unstable(); sorted.dedup(); let mut ranges = Vec::new(); - let mut start = sorted[0] as i64; + let mut start = sorted[0]; let mut end = start; for &id in &sorted[1..] { - let id = id as i64; - if id == end + 1 { + if end.checked_add(1) == Some(id) { end = id; } else { ranges.push(crate::table::RowRange::new(start, end)); @@ -469,7 +482,7 @@ impl SearchResult { } } ranges.push(crate::table::RowRange::new(start, end)); - ranges + Ok(ranges) } } @@ -676,7 +689,7 @@ mod tests { #[test] fn test_search_result_to_row_ranges() { let result = SearchResult::new(vec![5, 1, 2, 3, 10], vec![0.1; 5]); - let ranges = result.to_row_ranges(); + let ranges = result.to_row_ranges().unwrap(); assert_eq!(ranges.len(), 3); assert_eq!(ranges[0].from(), 1); assert_eq!(ranges[0].to(), 3); @@ -685,4 +698,14 @@ mod tests { assert_eq!(ranges[2].from(), 10); assert_eq!(ranges[2].to(), 10); } + + #[test] + fn test_search_result_to_row_ranges_rejects_i64_overflow() { + let result = SearchResult::new(vec![i64::MAX as u64 + 1], vec![0.1]); + let err = result.to_row_ranges().unwrap_err(); + assert!( + err.to_string().contains("exceeds i64::MAX"), + "unexpected error: {err}" + ); + } } diff --git a/crates/paimon/src/table/full_text_search_builder.rs b/crates/paimon/src/table/full_text_search_builder.rs index 783bd572..41297b7a 100644 --- a/crates/paimon/src/table/full_text_search_builder.rs +++ b/crates/paimon/src/table/full_text_search_builder.rs @@ -21,7 +21,7 @@ use crate::spec::{DataField, FileKind, IndexManifest}; use crate::table::snapshot_manager::SnapshotManager; -use crate::table::{RowRange, Table}; +use crate::table::{find_field_id_by_name, RowRange, Table}; use crate::tantivy::full_text_search::{FullTextSearch, SearchResult}; use crate::tantivy::reader::TantivyFullTextReader; @@ -201,7 +201,3 @@ async fn evaluate_full_text_search( Ok(merged.top_k(search.limit).to_row_ranges()) } - -fn find_field_id_by_name(fields: &[DataField], name: &str) -> Option { - fields.iter().find(|f| f.name() == name).map(|f| f.id()) -} diff --git a/crates/paimon/src/table/global_index_scanner.rs b/crates/paimon/src/table/global_index_scanner.rs index 06aea9c3..cdbccaca 100644 --- a/crates/paimon/src/table/global_index_scanner.rs +++ b/crates/paimon/src/table/global_index_scanner.rs @@ -369,12 +369,10 @@ impl GlobalIndexScanner { } fn find_field_id_by_name(&self, column: &str) -> Result> { - for field in &self.schema_fields { - if field.name() == column { - return Ok(Some(field.id())); - } - } - Ok(None) + Ok(crate::table::find_field_id_by_name( + &self.schema_fields, + column, + )) } fn entries_for_field(&self, field_id: i32) -> Option<&[GlobalIndexEntry]> { diff --git a/crates/paimon/src/table/mod.rs b/crates/paimon/src/table/mod.rs index 1f58142a..c8ef5ede 100644 --- a/crates/paimon/src/table/mod.rs +++ b/crates/paimon/src/table/mod.rs @@ -77,7 +77,7 @@ pub use write_builder::WriteBuilder; use crate::catalog::Identifier; use crate::io::FileIO; -use crate::spec::TableSchema; +use crate::spec::{DataField, TableSchema}; use std::collections::HashMap; /// Table represents a table in the catalog. @@ -177,3 +177,7 @@ impl Table { /// A stream of arrow [`RecordBatch`]es. pub type ArrowRecordBatchStream = BoxStream<'static, Result>; + +pub(crate) fn find_field_id_by_name(fields: &[DataField], name: &str) -> Option { + fields.iter().find(|f| f.name() == name).map(|f| f.id()) +} diff --git a/crates/paimon/src/table/vector_search_builder.rs b/crates/paimon/src/table/vector_search_builder.rs index 676d304d..1f79e540 100644 --- a/crates/paimon/src/table/vector_search_builder.rs +++ b/crates/paimon/src/table/vector_search_builder.rs @@ -19,7 +19,7 @@ use crate::lumina::reader::LuminaVectorGlobalIndexReader; use crate::lumina::{GlobalIndexIOMeta, SearchResult, VectorSearch, LUMINA_VECTOR_ANN_IDENTIFIER}; use crate::spec::{DataField, FileKind, IndexManifest}; use crate::table::snapshot_manager::SnapshotManager; -use crate::table::{RowRange, Table}; +use crate::table::{find_field_id_by_name, RowRange, Table}; use std::collections::HashMap; use std::io::Cursor; @@ -184,11 +184,7 @@ async fn evaluate_vector_search( merged = merged.or(r); } - Ok(merged.top_k(vector_search.limit).to_row_ranges()) -} - -fn find_field_id_by_name(fields: &[DataField], name: &str) -> Option { - fields.iter().find(|f| f.name() == name).map(|f| f.id()) + merged.top_k(vector_search.limit).to_row_ranges() } #[cfg(test)]