diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4b86642..4f2aeb6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,19 +35,19 @@ jobs: wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb sudo apt-get install -y ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb sudo apt-get update - sudo apt-get install -y libarrow-dev + sudo apt-get install -y libarrow-dev libboost-coroutine-dev libboost-context-dev - name: Install dependencies (macOS) if: matrix.os == 'macos-latest' run: | - brew install apache-arrow protobuf + brew install apache-arrow protobuf boost - uses: msys2/setup-msys2@v2 if: matrix.os == 'windows-latest' with: msystem: ucrt64 path-type: inherit - install: mingw-w64-ucrt-x86_64-arrow mingw-w64-ucrt-x86_64-protobuf + install: mingw-w64-ucrt-x86_64-arrow mingw-w64-ucrt-x86_64-protobuf mingw-w64-ucrt-x86_64-boost - name: Configure CMake if: matrix.os != 'windows-latest' @@ -163,7 +163,7 @@ jobs: cmake .. -G Ninja \ -DCMAKE_TOOLCHAIN_FILE=$GITHUB_WORKSPACE/toolchain.cmake \ -DCMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/arrow-install \ - -DBUILD_EXAMPLES=ON -DBUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Release + -DBUILD_EXAMPLES=ON -DCMAKE_BUILD_TYPE=Release - name: Build run: ninja -C build diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index c8ffed4..a1ca685 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -29,7 +29,7 @@ jobs: wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb sudo apt-get install -y ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb sudo apt-get update - sudo apt-get install -y libarrow-dev + sudo apt-get install -y libarrow-dev libboost-coroutine-dev libboost-context-dev - name: Configure CMake run: | diff --git a/CMakeLists.txt b/CMakeLists.txt index c889702..cf94e00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -351,6 +351,57 @@ if(BUILD_TESTS) ) list(APPEND TEST_TARGETS lancedb_index_tests) + # Add test executable for async table tests (runs without valgrind) + add_executable(lancedb_table_async_tests + tests/test_main.cpp + tests/test_common.cpp + tests/test_table_async.cpp + ) + target_link_libraries(lancedb_table_async_tests + PRIVATE + lancedb + Catch2::Catch2 + Threads::Threads + ${ARROW_LIBRARIES} + ) + target_include_directories(lancedb_table_async_tests + PRIVATE ${ARROW_INCLUDE_DIRS} + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tests + ) + target_compile_options(lancedb_table_async_tests PRIVATE ${ARROW_CFLAGS_OTHER}) + set_target_properties(lancedb_table_async_tests PROPERTIES + BUILD_RPATH ${RUST_TARGET_DIR} + ) + list(APPEND TEST_TARGETS lancedb_table_async_tests) + + # Add test executable for coroutine table tests (runs with valgrind) + find_package(Boost REQUIRED COMPONENTS coroutine context) + add_executable(lancedb_table_coro_tests + tests/test_main.cpp + tests/test_common.cpp + tests/test_table_coro.cpp + ) + target_link_libraries(lancedb_table_coro_tests + PRIVATE + lancedb + Catch2::Catch2 + Threads::Threads + ${ARROW_LIBRARIES} + Boost::coroutine + Boost::context + ) + target_include_directories(lancedb_table_coro_tests + PRIVATE ${ARROW_INCLUDE_DIRS} + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tests + ) + target_compile_options(lancedb_table_coro_tests PRIVATE ${ARROW_CFLAGS_OTHER}) + set_target_properties(lancedb_table_coro_tests PROPERTIES + BUILD_RPATH ${RUST_TARGET_DIR} + ) + list(APPEND TEST_TARGETS lancedb_table_coro_tests) + # Add test executable for vector tests (runs without valgrind) add_executable(lancedb_vector_index_tests tests/test_main.cpp @@ -503,6 +554,32 @@ if(BUILD_TESTS) --suppressions=${CMAKE_CURRENT_SOURCE_DIR}/valgrind.supp $ ) + # Run async table tests with valgrind + add_test(NAME lancedb_table_async_tests + COMMAND ${TEST_ENV_PREFIX} ${VALGRIND_EXECUTABLE} + --tool=memcheck + --leak-check=full + --show-leak-kinds=definite + --errors-for-leak-kinds=definite + --track-origins=yes + --error-exitcode=1 + --log-file=${CMAKE_BINARY_DIR}/valgrind_table_async.txt + --suppressions=${CMAKE_CURRENT_SOURCE_DIR}/valgrind.supp + $ + ) + # Run coroutine table tests with valgrind + add_test(NAME lancedb_table_coro_tests + COMMAND ${TEST_ENV_PREFIX} ${VALGRIND_EXECUTABLE} + --tool=memcheck + --leak-check=full + --show-leak-kinds=definite + --errors-for-leak-kinds=definite + --track-origins=yes + --error-exitcode=1 + --log-file=${CMAKE_BINARY_DIR}/valgrind_table_coro.txt + --suppressions=${CMAKE_CURRENT_SOURCE_DIR}/valgrind.supp + $ + ) else() message(WARNING "Valgrind not found, running tests without memory checking") add_test(NAME lancedb_connection_tests COMMAND ${TEST_ENV_PREFIX} $) @@ -510,6 +587,8 @@ if(BUILD_TESTS) add_test(NAME lancedb_table_meta_tests COMMAND ${TEST_ENV_PREFIX} $) add_test(NAME lancedb_index_tests COMMAND ${TEST_ENV_PREFIX} $) add_test(NAME lancedb_query_tests COMMAND ${TEST_ENV_PREFIX} $) + add_test(NAME lancedb_table_async_tests COMMAND ${TEST_ENV_PREFIX} $) + add_test(NAME lancedb_table_coro_tests COMMAND ${TEST_ENV_PREFIX} $) endif() # Run vector index tests WITHOUT valgrind (too slow under valgrind) diff --git a/examples/full.cpp b/examples/full.cpp index dcf4db1..bc14ae1 100644 --- a/examples/full.cpp +++ b/examples/full.cpp @@ -101,7 +101,7 @@ int main() { std::cerr << "failed to create connection builder" << std::endl; return 1; } - LanceDBConnection* db = lancedb_connect_builder_execute(builder); + LanceDBConnection* db = lancedb_connect_builder_execute(builder, nullptr); if (!db) { std::cerr << "failed to connect to database" << std::endl; return 1; @@ -135,7 +135,7 @@ int main() { LanceDBTable* table = nullptr; if (const LanceDBError result = lancedb_table_create(db, table_name.c_str(), reinterpret_cast(&c_schema), - nullptr, &table, nullptr); result != LANCEDB_SUCCESS) { + nullptr, &table, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error creating table: " << table_name << ", error: " << lancedb_error_to_message(result) << std::endl; lancedb_connection_free(db); if (c_schema.release) { @@ -149,7 +149,7 @@ int main() { // try to create a table that already exists if (const LanceDBError result = lancedb_table_create(db, "my_table", reinterpret_cast(&c_schema), - nullptr, &table, &error_message); result != LANCEDB_SUCCESS) { + nullptr, &table, nullptr, &error_message); result != LANCEDB_SUCCESS) { std::cout << "failed to create table that already exists (expected), error: '" << lancedb_error_to_message(result) << "' message: " << std::endl << error_message << std::endl; @@ -162,7 +162,7 @@ int main() { // try to create a table with invalid name if (const LanceDBError result = lancedb_table_create(db, "invalid table name", reinterpret_cast(&c_schema), - nullptr, &table, &error_message); result != LANCEDB_SUCCESS) { + nullptr, &table, nullptr, &error_message); result != LANCEDB_SUCCESS) { std::cout << "failed to create table with invalid name (expected), error: '" << lancedb_error_to_message(result) << "' message: " << std::endl << error_message << std::endl; @@ -179,7 +179,7 @@ int main() { // try to create a table with invalid input (null schema) if (const LanceDBError result = lancedb_table_create(db, "invalid_table", nullptr, - nullptr, &table, &error_message); result != LANCEDB_SUCCESS) { + nullptr, &table, nullptr, &error_message); result != LANCEDB_SUCCESS) { std::cout << "failed to create table with null schema (expected), error: '" << lancedb_error_to_message(result) << "' message: " << std::endl << error_message << std::endl; @@ -190,7 +190,7 @@ int main() { } // open the table to work with it - LanceDBTable* tbl = lancedb_connection_open_table(db, table_name.c_str()); + LanceDBTable* tbl = lancedb_connection_open_table(db, table_name.c_str(), nullptr); if (!tbl) { std::cerr << "failed to open table: " << table_name << std::endl; lancedb_connection_free(db); @@ -204,7 +204,7 @@ int main() { .force_update_statistics = 0 // don't force update statistics }; if (const LanceDBError result = lancedb_table_create_scalar_index( - tbl, key_columns, 1, LANCEDB_INDEX_BTREE, &scalar_config, nullptr); result != LANCEDB_SUCCESS) { + tbl, key_columns, 1, LANCEDB_INDEX_BTREE, &scalar_config, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "failed to create scalar index on 'key' column, error: '" << lancedb_error_to_message(result) << "'" << std::endl; } else { @@ -214,7 +214,7 @@ int main() { // try to create the same index again without replace flag scalar_config.replace = 0; if (const LanceDBError result = lancedb_table_create_scalar_index( - tbl, key_columns, 1, LANCEDB_INDEX_BTREE, &scalar_config, &error_message); result != LANCEDB_SUCCESS) { + tbl, key_columns, 1, LANCEDB_INDEX_BTREE, &scalar_config, nullptr, &error_message); result != LANCEDB_SUCCESS) { std::cout << "failed to create scalar index on 'key' column (expected), error: '" << lancedb_error_to_message(result) << "' message: " << std::endl << error_message << std::endl; @@ -266,7 +266,7 @@ int main() { reinterpret_cast(&c_schema), &batch_reader, nullptr); error == LANCEDB_SUCCESS) { // add data to table - if (const LanceDBError result = lancedb_table_add(tbl, batch_reader, nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_table_add(tbl, batch_reader, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "failed to write record batch to table, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "wrote " << num_rows << " rows to table" << std::endl; @@ -295,7 +295,7 @@ int main() { .replace = 1 // replace existing index }; if (const LanceDBError result = lancedb_table_create_vector_index( - tbl, data_columns, 1, LANCEDB_INDEX_IVF_FLAT, &vector_config, nullptr); result != LANCEDB_SUCCESS) { + tbl, data_columns, 1, LANCEDB_INDEX_IVF_FLAT, &vector_config, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "failed to create vector index on 'data' column, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "created vector index on 'data' column" << std::endl; @@ -309,7 +309,7 @@ int main() { }; for (size_t i = 0; i < 3; i++) { if (const LanceDBError result = lancedb_table_create_scalar_index( - tbl, &tag_columns[i], 1, LANCEDB_INDEX_BITMAP, &bitmap_config, nullptr); result != LANCEDB_SUCCESS) { + tbl, &tag_columns[i], 1, LANCEDB_INDEX_BITMAP, &bitmap_config, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "failed to create bitmap index on '" << tag_columns[i] << "' column, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "created bitmap index on '" << tag_columns[i] << "' column" << std::endl; @@ -336,7 +336,7 @@ int main() { "data", reinterpret_cast(&c_arrays_ptr), reinterpret_cast(&c_schema_ptr), - &count_out, nullptr); result != LANCEDB_SUCCESS) { + &count_out, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error querying nearest to vector, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "query returned " << count_out << " results" << std::endl; @@ -375,7 +375,7 @@ int main() { } else { std::cout << "set query distance type to: L2" << std::endl; // execute the query - if (LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); query_result) { + if (LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); query_result) { std::cout << "executed query" << std::endl; // get the result as arrow arrays struct ArrowArray** c_arrays_ptr; @@ -386,6 +386,7 @@ int main() { reinterpret_cast(&c_arrays_ptr), reinterpret_cast(&c_schema_ptr), &count_out, + nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error converting query result to arrow, error: " << lancedb_error_to_message(result) << std::endl; lancedb_query_result_free(query_result); @@ -410,19 +411,19 @@ int main() { // list all tables in the database and loop through them char** table_names; size_t name_count; - if (const LanceDBError result = lancedb_connection_table_names(db, &table_names, &name_count, nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_connection_table_names(db, &table_names, &name_count, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error listing table names, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << name_count << " tables found" << std::endl; for (size_t i = 0; i < name_count; i++) { - if (LanceDBTable* tbl = lancedb_connection_open_table(db, table_names[i]); tbl) { + if (LanceDBTable* tbl = lancedb_connection_open_table(db, table_names[i], nullptr); tbl) { // get the schema of the table struct ArrowSchema* c_schema_ptr; if (const LanceDBError result = lancedb_table_arrow_schema( tbl, reinterpret_cast(&c_schema_ptr), - nullptr); result == LANCEDB_SUCCESS) { + nullptr, nullptr); result == LANCEDB_SUCCESS) { if (auto schema = arrow::ImportSchema(c_schema_ptr); schema.ok()) { std::cout << "table: " << table_names[i] << ", schema:" << std::endl; std::cout << (*schema)->ToString() << std::endl; @@ -437,14 +438,14 @@ int main() { // list all indices of the table char** indices; size_t indices_count; - if (const LanceDBError result = lancedb_table_list_indices(tbl, &indices, &indices_count, nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_table_list_indices(tbl, &indices, &indices_count, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "failed to list indices, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "found " << indices_count << " indices:" << std::endl; for (size_t i = 0; i < indices_count; i++) { std::cout << " - " << indices[i] << std::endl; // delete the index - if (const LanceDBError result = lancedb_table_drop_index(tbl, indices[i], nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_table_drop_index(tbl, indices[i], nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << " error dropping index: " << indices[i] << ", error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << " dropped index: " << indices[i] << std::endl; @@ -454,27 +455,27 @@ int main() { } // optimize the table after index deletion - if (const LanceDBError result = lancedb_table_optimize(tbl, LANCEDB_OPTIMIZE_ALL, nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_table_optimize(tbl, LANCEDB_OPTIMIZE_ALL, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error optimizing table: " << table_names[i] << ", error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "optimized table: " << table_names[i] << std::endl; } // number of rows in the table - auto row_count = lancedb_table_count_rows(tbl); + auto row_count = lancedb_table_count_rows(tbl, nullptr); std::cout << "table: " << table_names[i] << " has: " << row_count << " rows" << std::endl; // delete some rows const auto delete_predicates = {"key = \"key_10\"", "key = \"key_20\"", "key = \"key_30\"", "key = \"kaboom\""}; for (const auto& predicate : delete_predicates) { - if (const LanceDBError result = lancedb_table_delete(tbl, predicate, nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_table_delete(tbl, predicate, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error deleting row with predicate: " << predicate << ", error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "deleted row with predicate: " << predicate << std::endl; } } // check number of rows in the table after deletion - row_count = lancedb_table_count_rows(tbl); + row_count = lancedb_table_count_rows(tbl, nullptr); std::cout << "after deletion table: " << table_names[i] << " has: " << row_count << " rows" << std::endl; // perform table upsert with 3 new rows and 3 updated rows @@ -535,6 +536,7 @@ int main() { on_columns.data(), 1, &config, + nullptr, &error_message); result != LANCEDB_SUCCESS) { std::cerr << "failed to upsert record batch to table, error: " << lancedb_error_to_message(result) << ", message: " << error_message << std::endl; lancedb_free_string(error_message); @@ -554,11 +556,11 @@ int main() { } // check number of rows in the table after upsert - row_count = lancedb_table_count_rows(tbl); + row_count = lancedb_table_count_rows(tbl, nullptr); std::cout << "after upsert table: " << table_names[i] << " has: " << row_count << " rows" << std::endl; // drop the table - if (LanceDBError result = lancedb_connection_drop_table(db, table_names[i], nullptr, nullptr); result != LANCEDB_SUCCESS) { + if (LanceDBError result = lancedb_connection_drop_table(db, table_names[i], nullptr, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error dropping table: " << table_names[i] << ", error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "dropped table: " << table_names[i] << std::endl; @@ -571,7 +573,7 @@ int main() { } lancedb_free_table_names(table_names, name_count); - if (const LanceDBError result = lancedb_connection_drop_all_tables(db, nullptr, nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_connection_drop_all_tables(db, nullptr, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error dropping all tables, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "dropped all tables" << std::endl; diff --git a/examples/s3.cpp b/examples/s3.cpp index 336e495..1167846 100644 --- a/examples/s3.cpp +++ b/examples/s3.cpp @@ -35,7 +35,7 @@ LanceDBTable* create_empty_table(LanceDBConnection* db) { char* error_message = nullptr; if (const LanceDBError result = lancedb_table_create(db, table_name.c_str(), reinterpret_cast(&c_schema), - nullptr, &tbl, &error_message); result != LANCEDB_SUCCESS) { + nullptr, &tbl, nullptr, &error_message); result != LANCEDB_SUCCESS) { std::cerr << "error creating table: " << table_name << ", error: " << error_message << std::endl; lancedb_connection_free(db); lancedb_free_string(error_message); @@ -84,7 +84,7 @@ int main(int argc, char** argv) { builder = lancedb_connect_builder_storage_option(builder, "allow_http", "true"); builder = lancedb_connect_builder_storage_option(builder, "aws_s3_addressing_style", "path"); - LanceDBConnection* db = lancedb_connect_builder_execute(builder); + LanceDBConnection* db = lancedb_connect_builder_execute(builder, nullptr); if (!db) { std::cerr << "failed to connect to database" << std::endl; return 1; @@ -101,7 +101,7 @@ int main(int argc, char** argv) { char** table_names; size_t name_count; char* error_message = nullptr; - if (const LanceDBError result = lancedb_connection_table_names(db, &table_names, &name_count, &error_message); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_connection_table_names(db, &table_names, &name_count, nullptr, &error_message); result != LANCEDB_SUCCESS) { std::cerr << "error listing table names, error: " << error_message << std::endl; lancedb_free_string(error_message); } else { @@ -112,7 +112,7 @@ int main(int argc, char** argv) { lancedb_free_table_names(table_names, name_count); } - if (const LanceDBError result = lancedb_connection_drop_table(db, "empty_table", nullptr, &error_message); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_connection_drop_table(db, "empty_table", nullptr, nullptr, &error_message); result != LANCEDB_SUCCESS) { std::cerr << "error dropping table, error: " << error_message << std::endl; lancedb_free_string(error_message); } else { diff --git a/examples/simple.cpp b/examples/simple.cpp index 8168898..4f5c3ac 100644 --- a/examples/simple.cpp +++ b/examples/simple.cpp @@ -85,7 +85,7 @@ LanceDBTable* create_table(LanceDBConnection* db) { LanceDBTable* tbl; LanceDBError result = lancedb_table_create(db, table_name.c_str(), reinterpret_cast(&c_schema), - reader, &tbl, nullptr); + reader, &tbl, nullptr, nullptr); if (c_schema.release) { c_schema.release(&c_schema); @@ -116,7 +116,7 @@ LanceDBTable* create_table(LanceDBConnection* db) { return nullptr; } - if (const LanceDBError result = lancedb_table_add(tbl, reader, nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_table_add(tbl, reader, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "failed to write record batch to table, error: " << lancedb_error_to_message(result) << std::endl; } std::cout << "wrote rows to table" << std::endl; @@ -138,7 +138,7 @@ LanceDBTable* create_empty_table(LanceDBConnection* db) { LanceDBTable* tbl = nullptr; if (const LanceDBError result = lancedb_table_create(db, table_name.c_str(), reinterpret_cast(&c_schema), - nullptr, &tbl, nullptr); result != LANCEDB_SUCCESS) { + nullptr, &tbl, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error creating table: " << table_name << ", error: " << lancedb_error_to_message(result) << std::endl; lancedb_connection_free(db); } else { @@ -163,7 +163,7 @@ void create_index(LanceDBTable* tbl) { .replace = 1 // replace existing index }; if (const LanceDBError result = lancedb_table_create_vector_index( - tbl, data_columns, 1, LANCEDB_INDEX_IVF_FLAT, &vector_config, nullptr); result != LANCEDB_SUCCESS) { + tbl, data_columns, 1, LANCEDB_INDEX_IVF_FLAT, &vector_config, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "failed to create vector index on 'item' column, error: " << lancedb_error_to_message(result) << std::endl; return; } @@ -187,7 +187,7 @@ SearchResult search(LanceDBTable* tbl) { "item", reinterpret_cast(&c_arrays), reinterpret_cast(&c_schema), - &count_out, nullptr); result != LANCEDB_SUCCESS) { + &count_out, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error querying nearest to vector, error: " << lancedb_error_to_message(result) << std::endl; return {nullptr, nullptr, 0}; } @@ -229,7 +229,7 @@ int main() { std::cerr << "failed to create connection builder" << std::endl; return 1; } - LanceDBConnection* db = lancedb_connect_builder_execute(builder); + LanceDBConnection* db = lancedb_connect_builder_execute(builder, nullptr); if (!db) { std::cerr << "failed to connect to database" << std::endl; return 1; @@ -238,7 +238,7 @@ int main() { // list table names char** table_names; size_t name_count; - if (const LanceDBError result = lancedb_connection_table_names(db, &table_names, &name_count, nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_connection_table_names(db, &table_names, &name_count, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error listing table names, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << name_count << " tables found" << std::endl; @@ -254,7 +254,7 @@ int main() { print_query_result(c_arrays, c_schema); lancedb_free_arrow_arrays(reinterpret_cast(c_arrays), count_out); lancedb_free_arrow_schema(reinterpret_cast(c_schema)); - if (const LanceDBError result = lancedb_table_delete(tbl, "id > 24", nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_table_delete(tbl, "id > 24", nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error deleting rows from table, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "deleted rows where id > 24" << std::endl; @@ -264,7 +264,7 @@ int main() { auto empty_table = create_empty_table(db); lancedb_table_free(empty_table); - if (const LanceDBError result = lancedb_connection_drop_table(db, "my_table", nullptr, nullptr); result != LANCEDB_SUCCESS) { + if (const LanceDBError result = lancedb_connection_drop_table(db, "my_table", nullptr, nullptr, nullptr); result != LANCEDB_SUCCESS) { std::cerr << "error dropping table, error: " << lancedb_error_to_message(result) << std::endl; } else { std::cout << "dropped table my_table" << std::endl; diff --git a/include/lancedb.h b/include/lancedb.h index 11fb5a7..1960039 100644 --- a/include/lancedb.h +++ b/include/lancedb.h @@ -54,6 +54,14 @@ typedef struct LanceDBQueryResult LanceDBQueryResult; */ typedef struct LanceDBSession LanceDBSession; +/** + * Opaque handle to a Runtime + * + * When passed as NULL to API functions, the global runtime is used. + * Create with lancedb_runtime_new(), free with lancedb_runtime_free(). + */ +typedef struct LanceDBRuntime LanceDBRuntime; + /** * Opaque handle to Arrow RecordBatchReader */ @@ -279,7 +287,7 @@ LanceDBConnectBuilder* lancedb_connect(const char* uri); * On success, the builder is consumed by this function and must not be used after calling. * The returned connection must be freed with lancedb_connection_free(). */ -LanceDBConnection* lancedb_connect_builder_execute(LanceDBConnectBuilder* builder); +LanceDBConnection* lancedb_connect_builder_execute(LanceDBConnectBuilder* builder, const LanceDBRuntime* runtime); /** @@ -345,6 +353,7 @@ LanceDBError lancedb_connection_table_names( const LanceDBConnection* connection, char*** names_out, size_t* count_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -417,6 +426,7 @@ LanceDBError lancedb_table_names_builder_execute( LanceDBTableNamesBuilder* builder, char*** names_out, size_t* count_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -440,7 +450,8 @@ void lancedb_table_names_builder_free(LanceDBTableNamesBuilder* builder); */ LanceDBTable* lancedb_connection_open_table( const LanceDBConnection* connection, - const char* table_name + const char* table_name, + const LanceDBRuntime* runtime ); /** @@ -459,6 +470,7 @@ LanceDBError lancedb_connection_drop_table( const LanceDBConnection* connection, const char* table_name, const char* _namespace, + const LanceDBRuntime* runtime, char** error_message ); @@ -483,6 +495,7 @@ LanceDBError lancedb_connection_rename_table( const char* new_name, const char* cur_namespace, const char* new_namespace, + const LanceDBRuntime* runtime, char** error_message ); @@ -500,6 +513,7 @@ LanceDBError lancedb_connection_rename_table( LanceDBError lancedb_connection_drop_all_tables( const LanceDBConnection* connection, const char* _namespace, + const LanceDBRuntime* runtime, char** error_message ); @@ -517,6 +531,7 @@ LanceDBError lancedb_connection_drop_all_tables( LanceDBError lancedb_connection_create_namespace( const LanceDBConnection* connection, const char* namespace_name, + const LanceDBRuntime* runtime, char** error_message ); @@ -534,6 +549,7 @@ LanceDBError lancedb_connection_create_namespace( LanceDBError lancedb_connection_drop_namespace( const LanceDBConnection* connection, const char* namespace_name, + const LanceDBRuntime* runtime, char** error_message ); @@ -556,6 +572,7 @@ LanceDBError lancedb_connection_list_namespaces( const char* namespace_parent, char*** namespaces_out, size_t* count_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -602,6 +619,7 @@ LanceDBSession* lancedb_session_new(const LanceDBSessionOptions* options); LanceDBError lancedb_session_index_cache_stats( const LanceDBSession* session, LanceDBSessionCacheStats* out_stats, + const LanceDBRuntime* runtime, char** error_message ); @@ -619,6 +637,7 @@ LanceDBError lancedb_session_index_cache_stats( LanceDBError lancedb_session_metadata_cache_stats( const LanceDBSession* session, LanceDBSessionCacheStats* out_stats, + const LanceDBRuntime* runtime, char** error_message ); @@ -631,6 +650,25 @@ LanceDBError lancedb_session_metadata_cache_stats( */ void lancedb_session_free(LanceDBSession* session); +/** + * Create a new Runtime + * + * @return Non-null pointer to LanceDBRuntime on success, NULL on failure + * + * The returned runtime must be freed with lancedb_runtime_free(). + * When passed to API functions, this runtime is used instead of the global one. + */ +LanceDBRuntime* lancedb_runtime_new(void); + +/** + * Free a Runtime + * + * @param runtime - pointer to LanceDBRuntime returned from lancedb_runtime_new() + * + * After calling this function, the runtime pointer must not be used. + */ +void lancedb_runtime_free(LanceDBRuntime* runtime); + /** * Free a Table * @@ -663,6 +701,7 @@ LanceDBError lancedb_table_create( const FFI_ArrowSchema* schema_ptr, LanceDBRecordBatchReader* reader, LanceDBTable** table_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -681,6 +720,7 @@ LanceDBError lancedb_table_create( LanceDBError lancedb_table_arrow_schema( const LanceDBTable* table, FFI_ArrowSchema** schema_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -690,7 +730,7 @@ LanceDBError lancedb_table_arrow_schema( * @param table - pointer to LanceDBTable * @return Table version number on success, 0 on failure */ -unsigned long long lancedb_table_version(const LanceDBTable* table); +unsigned long long lancedb_table_version(const LanceDBTable* table, const LanceDBRuntime* runtime); /** * Count rows in table @@ -698,7 +738,7 @@ unsigned long long lancedb_table_version(const LanceDBTable* table); * @param table - pointer to LanceDBTable * @return Number of rows in table on success, 0 on failure (or empty table) */ -unsigned long long lancedb_table_count_rows(const LanceDBTable* table); +unsigned long long lancedb_table_count_rows(const LanceDBTable* table, const LanceDBRuntime* runtime); /** * Add data to table using Arrow RecordBatchReader @@ -716,6 +756,7 @@ unsigned long long lancedb_table_count_rows(const LanceDBTable* table); LanceDBError lancedb_table_add( const LanceDBTable* table, LanceDBRecordBatchReader* reader, + const LanceDBRuntime* runtime, char** error_message ); @@ -743,6 +784,7 @@ LanceDBError lancedb_table_merge_insert( const char* const* on_columns, size_t num_columns, const LanceDBMergeInsertConfig* config, + const LanceDBRuntime* runtime, char** error_message ); @@ -760,6 +802,7 @@ LanceDBError lancedb_table_merge_insert( LanceDBError lancedb_table_delete( const LanceDBTable* table, const char* predicate, + const LanceDBRuntime* runtime, char** error_message ); @@ -1080,7 +1123,7 @@ LanceDBError lancedb_vector_query_ef( * @return Pointer to LanceDBQueryResult on success, NULL on failure * Caller must free with lancedb_query_result_free() */ -LanceDBQueryResult* lancedb_query_execute(LanceDBQuery* query); +LanceDBQueryResult* lancedb_query_execute(LanceDBQuery* query, const LanceDBRuntime* runtime); /** * Execute vector query and return streaming result @@ -1089,7 +1132,7 @@ LanceDBQueryResult* lancedb_query_execute(LanceDBQuery* query); * @return Pointer to LanceDBQueryResult on success, NULL on failure * Caller must free with lancedb_query_result_free() */ -LanceDBQueryResult* lancedb_vector_query_execute(LanceDBVectorQuery* query); +LanceDBQueryResult* lancedb_vector_query_execute(LanceDBVectorQuery* query, const LanceDBRuntime* runtime); /** * Convert query result to Arrow RecordBatch arrays @@ -1110,6 +1153,7 @@ LanceDBError lancedb_query_result_to_arrow( struct FFI_ArrowArray*** result_arrays, struct FFI_ArrowSchema** result_schema, size_t* count_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -1161,6 +1205,7 @@ LanceDBError lancedb_table_nearest_to( struct FFI_ArrowArray*** result_arrays, struct FFI_ArrowSchema** result_schema, size_t* count_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -1221,6 +1266,7 @@ LanceDBError lancedb_table_create_vector_index( size_t num_columns, LanceDBIndexType index_type, const LanceDBVectorIndexConfig* config, + const LanceDBRuntime* runtime, char** error_message ); @@ -1244,6 +1290,7 @@ LanceDBError lancedb_table_create_scalar_index( size_t num_columns, LanceDBIndexType index_type, const LanceDBScalarIndexConfig* config, + const LanceDBRuntime* runtime, char** error_message ); @@ -1265,6 +1312,7 @@ LanceDBError lancedb_table_create_fts_index( const char* const* columns, size_t num_columns, const LanceDBFtsIndexConfig* config, + const LanceDBRuntime* runtime, char** error_message ); @@ -1285,6 +1333,7 @@ LanceDBError lancedb_table_list_indices( const LanceDBTable* table, char*** indices_out, size_t* count_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -1302,6 +1351,7 @@ LanceDBError lancedb_table_list_indices( LanceDBError lancedb_table_drop_index( const LanceDBTable* table, const char* index_name, + const LanceDBRuntime* runtime, char** error_message ); @@ -1319,6 +1369,7 @@ LanceDBError lancedb_table_drop_index( LanceDBError lancedb_table_optimize( const LanceDBTable* table, LanceDBOptimizeType optimize_type, + const LanceDBRuntime* runtime, char** error_message ); @@ -1356,6 +1407,7 @@ LanceDBError lancedb_table_index_stats( const LanceDBTable* table, const char* index_name, LanceDBIndexStats* stats_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -1393,6 +1445,7 @@ LanceDBError lancedb_table_list_versions( LanceDBVersion** versions_out, LanceDBVersionMetadata** metadata_out, size_t* count_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -1425,6 +1478,7 @@ LanceDBError lancedb_table_get_metadata( char*** keys_out, char*** values_out, size_t* count_out, + const LanceDBRuntime* runtime, char** error_message ); @@ -1447,6 +1501,7 @@ LanceDBError lancedb_table_set_metadata( const char* const* keys, const char* const* values, size_t count, + const LanceDBRuntime* runtime, char** error_message ); @@ -1467,6 +1522,7 @@ LanceDBError lancedb_table_delete_metadata( const LanceDBTable* table, const char* const* keys, size_t count, + const LanceDBRuntime* runtime, char** error_message ); diff --git a/src/connection.rs b/src/connection.rs index a91aee1..0761def 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -71,13 +71,55 @@ pub struct LanceDBSessionCacheStats { const DEFAULT_INDEX_CACHE_SIZE_BYTES: usize = 6 * 1024 * 1024 * 1024; const DEFAULT_METADATA_CACHE_SIZE_BYTES: usize = 1024 * 1024 * 1024; -/// Runtime to handle async operations +/// Opaque handle to a tokio Runtime +#[repr(C)] +pub struct LanceDBRuntime { + inner: tokio::runtime::Runtime, +} + +/// Global Runtime to handle async operations static RUNTIME: OnceLock = OnceLock::new(); pub(crate) fn get_runtime() -> &'static tokio::runtime::Runtime { RUNTIME.get_or_init(|| tokio::runtime::Runtime::new().expect("Failed to create tokio runtime")) } +/// Resolve runtime: use provided one if non-null, otherwise fall back to global +pub(crate) unsafe fn resolve_runtime<'a>( + runtime: *const LanceDBRuntime, +) -> &'a tokio::runtime::Runtime { + if runtime.is_null() { + get_runtime() + } else { + &(*runtime).inner + } +} + +/// Create a new tokio Runtime +/// +/// # Returns +/// - Non-null pointer to LanceDBRuntime on success +/// - Null pointer on failure +#[no_mangle] +pub extern "C" fn lancedb_runtime_new() -> *mut LanceDBRuntime { + match tokio::runtime::Runtime::new() { + Ok(rt) => Box::into_raw(Box::new(LanceDBRuntime { inner: rt })), + Err(_) => ptr::null_mut(), + } +} + +/// Free a Runtime +/// +/// # Safety +/// - `runtime` must be a valid pointer returned from `lancedb_runtime_new` +/// - `runtime` must not be used after calling this function +#[no_mangle] +pub unsafe extern "C" fn lancedb_runtime_free(runtime: *mut LanceDBRuntime) { + if !runtime.is_null() { + let _ = Box::from_raw(runtime); + } +} + /// Create a ConnectBuilder for the given URI /// /// # Safety @@ -117,6 +159,7 @@ pub unsafe extern "C" fn lancedb_connect(uri: *const c_char) -> *mut LanceDBConn #[no_mangle] pub unsafe extern "C" fn lancedb_connect_builder_execute( builder: *mut LanceDBConnectBuilder, + runtime: *const LanceDBRuntime, ) -> *mut LanceDBConnection { if builder.is_null() { return ptr::null_mut(); @@ -125,7 +168,7 @@ pub unsafe extern "C" fn lancedb_connect_builder_execute( let builder_box = Box::from_raw(builder); let connect_builder = *builder_box.inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(connect_builder.execute()) { Ok(connection) => { let boxed_connection = Box::new(LanceDBConnection { @@ -309,6 +352,7 @@ pub unsafe extern "C" fn lancedb_session_new( pub unsafe extern "C" fn lancedb_session_index_cache_stats( session: *const LanceDBSession, out_stats: *mut LanceDBSessionCacheStats, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if session.is_null() || out_stats.is_null() { @@ -317,7 +361,7 @@ pub unsafe extern "C" fn lancedb_session_index_cache_stats( } let session_ref = &*session; - let stats = get_runtime().block_on(session_ref.inner.index_cache_stats()); + let stats = resolve_runtime(runtime).block_on(session_ref.inner.index_cache_stats()); *out_stats = LanceDBSessionCacheStats { hits: stats.hits, misses: stats.misses, @@ -340,6 +384,7 @@ pub unsafe extern "C" fn lancedb_session_index_cache_stats( pub unsafe extern "C" fn lancedb_session_metadata_cache_stats( session: *const LanceDBSession, out_stats: *mut LanceDBSessionCacheStats, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if session.is_null() || out_stats.is_null() { @@ -348,7 +393,7 @@ pub unsafe extern "C" fn lancedb_session_metadata_cache_stats( } let session_ref = &*session; - let stats = get_runtime().block_on(session_ref.inner.metadata_cache_stats()); + let stats = resolve_runtime(runtime).block_on(session_ref.inner.metadata_cache_stats()); *out_stats = LanceDBSessionCacheStats { hits: stats.hits, misses: stats.misses, @@ -389,6 +434,7 @@ pub unsafe extern "C" fn lancedb_table_create( schema_ptr: *const arrow_schema::ffi::FFI_ArrowSchema, reader: *mut LanceDBRecordBatchReader, table_out: *mut *mut LanceDBTable, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if connection.is_null() || table_name.is_null() || schema_ptr.is_null() || table_out.is_null() { @@ -402,7 +448,7 @@ pub unsafe extern "C" fn lancedb_table_create( }; let conn = &(*connection).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(async { // Import schema from Arrow C ABI @@ -457,6 +503,7 @@ pub unsafe extern "C" fn lancedb_connection_table_names( connection: *const LanceDBConnection, names_out: *mut *mut *mut c_char, count_out: *mut usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if connection.is_null() || names_out.is_null() || count_out.is_null() { @@ -465,7 +512,7 @@ pub unsafe extern "C" fn lancedb_connection_table_names( } let conn = &(*connection).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(conn.table_names().execute()) { Ok(names) => { @@ -639,6 +686,7 @@ pub unsafe extern "C" fn lancedb_table_names_builder_execute( builder: *mut LanceDBTableNamesBuilder, names_out: *mut *mut *mut c_char, count_out: *mut usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if builder.is_null() { @@ -655,7 +703,7 @@ pub unsafe extern "C" fn lancedb_table_names_builder_execute( let table_names_builder = *builder_box.inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(table_names_builder.execute()) { Ok(names) => { @@ -743,6 +791,7 @@ pub unsafe extern "C" fn lancedb_free_namespace_list(namespaces: *mut *mut c_cha pub unsafe extern "C" fn lancedb_connection_open_table( connection: *const LanceDBConnection, table_name: *const c_char, + runtime: *const LanceDBRuntime, ) -> *mut LanceDBTable { if connection.is_null() || table_name.is_null() { return ptr::null_mut(); @@ -753,7 +802,7 @@ pub unsafe extern "C" fn lancedb_connection_open_table( }; let conn = &(*connection).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(conn.open_table(table_name_str).execute()) { Ok(table) => { @@ -779,6 +828,7 @@ pub unsafe extern "C" fn lancedb_connection_drop_table( connection: *const LanceDBConnection, table_name: *const c_char, namespace: *const c_char, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if connection.is_null() || table_name.is_null() { @@ -792,7 +842,7 @@ pub unsafe extern "C" fn lancedb_connection_drop_table( }; let conn = &(*connection).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let result = if namespace.is_null() { runtime.block_on(conn.drop_table(table_name_str, &[])) @@ -827,6 +877,7 @@ pub unsafe extern "C" fn lancedb_connection_rename_table( new_name: *const c_char, cur_namespace: *const c_char, new_namespace: *const c_char, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if connection.is_null() || old_name.is_null() || new_name.is_null() { @@ -845,7 +896,7 @@ pub unsafe extern "C" fn lancedb_connection_rename_table( }; let conn = &(*connection).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let cur_namespace_vec = if cur_namespace.is_null() { Vec::new() @@ -891,6 +942,7 @@ pub unsafe extern "C" fn lancedb_connection_rename_table( pub unsafe extern "C" fn lancedb_connection_drop_all_tables( connection: *const LanceDBConnection, namespace: *const c_char, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if connection.is_null() { @@ -899,7 +951,7 @@ pub unsafe extern "C" fn lancedb_connection_drop_all_tables( } let conn = &(*connection).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let result = if namespace.is_null() { runtime.block_on(conn.drop_all_tables(&[])) @@ -930,6 +982,7 @@ pub unsafe extern "C" fn lancedb_connection_drop_all_tables( pub unsafe extern "C" fn lancedb_connection_create_namespace( connection: *const LanceDBConnection, namespace_name: *const c_char, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if connection.is_null() || namespace_name.is_null() { @@ -943,7 +996,7 @@ pub unsafe extern "C" fn lancedb_connection_create_namespace( }; let conn = &(*connection).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let request = CreateNamespaceRequest { namespace: vec![namespace_str.to_string()], @@ -967,6 +1020,7 @@ pub unsafe extern "C" fn lancedb_connection_create_namespace( pub unsafe extern "C" fn lancedb_connection_drop_namespace( connection: *const LanceDBConnection, namespace_name: *const c_char, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if connection.is_null() || namespace_name.is_null() { @@ -980,7 +1034,7 @@ pub unsafe extern "C" fn lancedb_connection_drop_namespace( }; let conn = &(*connection).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let request = DropNamespaceRequest { namespace: vec![namespace_str.to_string()], @@ -1009,6 +1063,7 @@ pub unsafe extern "C" fn lancedb_connection_list_namespaces( namespace_parent: *const c_char, namespaces_out: *mut *mut *mut c_char, count_out: *mut usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if connection.is_null() || namespaces_out.is_null() || count_out.is_null() { @@ -1027,7 +1082,7 @@ pub unsafe extern "C" fn lancedb_connection_list_namespaces( }; let conn = &(*connection).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let request = ListNamespacesRequest { namespace: parent_namespace, diff --git a/src/index.rs b/src/index.rs index 529c077..1625973 100644 --- a/src/index.rs +++ b/src/index.rs @@ -17,7 +17,7 @@ use lancedb::index::vector::{ }; use lancedb::index::Index; -use crate::connection::{get_runtime, LanceDBTable}; +use crate::connection::{resolve_runtime, LanceDBRuntime, LanceDBTable}; use crate::error::{ handle_error, set_invalid_argument_message, set_unknown_error_message, LanceDBError, }; @@ -139,6 +139,7 @@ pub unsafe extern "C" fn lancedb_table_create_vector_index( num_columns: usize, index_type: LanceDBIndexType, config: *const LanceDBVectorIndexConfig, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || columns.is_null() || num_columns == 0 { @@ -166,7 +167,7 @@ pub unsafe extern "C" fn lancedb_table_create_vector_index( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); // Use default config if none provided let cfg = if config.is_null() { @@ -287,6 +288,7 @@ pub unsafe extern "C" fn lancedb_table_create_scalar_index( num_columns: usize, index_type: LanceDBIndexType, config: *const LanceDBScalarIndexConfig, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || columns.is_null() || num_columns == 0 { @@ -314,7 +316,7 @@ pub unsafe extern "C" fn lancedb_table_create_scalar_index( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); // Use default config if none provided let cfg = if config.is_null() { @@ -367,6 +369,7 @@ pub unsafe extern "C" fn lancedb_table_create_fts_index( columns: *const *const c_char, num_columns: usize, config: *const LanceDBFtsIndexConfig, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || columns.is_null() || num_columns == 0 { @@ -394,7 +397,7 @@ pub unsafe extern "C" fn lancedb_table_create_fts_index( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); // Use default config if none provided let cfg = if config.is_null() { @@ -462,6 +465,7 @@ pub unsafe extern "C" fn lancedb_table_list_indices( table: *const LanceDBTable, indices_out: *mut *mut *mut c_char, count_out: *mut usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || indices_out.is_null() || count_out.is_null() { @@ -470,7 +474,7 @@ pub unsafe extern "C" fn lancedb_table_list_indices( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(tbl.list_indices()) { Ok(indices) => { @@ -530,6 +534,7 @@ pub unsafe extern "C" fn lancedb_table_list_indices( pub unsafe extern "C" fn lancedb_table_drop_index( table: *const LanceDBTable, index_name: *const c_char, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || index_name.is_null() { @@ -543,7 +548,7 @@ pub unsafe extern "C" fn lancedb_table_drop_index( }; let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(tbl.drop_index(index_name_str)) { Ok(_) => LanceDBError::Success, @@ -563,6 +568,7 @@ pub unsafe extern "C" fn lancedb_table_drop_index( pub unsafe extern "C" fn lancedb_table_optimize( table: *const LanceDBTable, optimize_type: LanceDBOptimizeType, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() { @@ -571,7 +577,7 @@ pub unsafe extern "C" fn lancedb_table_optimize( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); use lancedb::table::{OptimizeAction, OptimizeOptions}; @@ -639,6 +645,7 @@ pub unsafe extern "C" fn lancedb_table_index_stats( table: *const LanceDBTable, index_name: *const c_char, stats_out: *mut LanceDBIndexStats, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || index_name.is_null() || stats_out.is_null() { @@ -652,7 +659,7 @@ pub unsafe extern "C" fn lancedb_table_index_stats( }; let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(tbl.index_stats(name_str)) { Ok(Some(stats)) => { diff --git a/src/query.rs b/src/query.rs index e6a934a..5d85281 100644 --- a/src/query.rs +++ b/src/query.rs @@ -18,7 +18,7 @@ use futures::TryStreamExt; use lancedb::query::{ExecutableQuery, QueryBase, Select}; use lancedb::{DistanceType, Table}; -use crate::connection::{get_runtime, LanceDBTable}; +use crate::connection::{resolve_runtime, LanceDBRuntime, LanceDBTable}; use crate::error::{set_invalid_argument_message, set_unknown_error_message, LanceDBError}; use crate::types::LanceDBDistanceType; @@ -471,13 +471,14 @@ pub unsafe extern "C" fn lancedb_vector_query_ef( #[no_mangle] pub unsafe extern "C" fn lancedb_query_execute( query: *mut LanceDBQuery, + runtime: *const LanceDBRuntime, ) -> *mut LanceDBQueryResult { if query.is_null() { return ptr::null_mut(); } let query_box = Box::from_raw(query); - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(async { let mut rust_query = query_box.table.query(); @@ -515,13 +516,14 @@ pub unsafe extern "C" fn lancedb_query_execute( #[no_mangle] pub unsafe extern "C" fn lancedb_vector_query_execute( query: *mut LanceDBVectorQuery, + runtime: *const LanceDBRuntime, ) -> *mut LanceDBQueryResult { if query.is_null() { return ptr::null_mut(); } let query_box = Box::from_raw(query); - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(async { let mut rust_query = match query_box @@ -587,6 +589,7 @@ pub unsafe extern "C" fn lancedb_query_result_to_arrow( batches_out: *mut *mut *mut arrow_array::ffi::FFI_ArrowArray, schema_out: *mut *mut arrow_schema::ffi::FFI_ArrowSchema, count_out: *mut usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if result.is_null() || batches_out.is_null() || schema_out.is_null() || count_out.is_null() { @@ -595,7 +598,7 @@ pub unsafe extern "C" fn lancedb_query_result_to_arrow( } let result_box = Box::from_raw(result); - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(async { let batches: Vec = result_box.inner.try_collect().await?; diff --git a/src/table.rs b/src/table.rs index d615eb0..b10bda6 100644 --- a/src/table.rs +++ b/src/table.rs @@ -16,7 +16,7 @@ use futures::TryStreamExt; use lance::dataset::transaction::UpdateMapEntry; use lancedb::query::{ExecutableQuery, QueryBase}; -use crate::connection::{get_runtime, LanceDBTable}; +use crate::connection::{resolve_runtime, LanceDBRuntime, LanceDBTable}; use crate::error::{ handle_error, set_invalid_argument_message, set_not_supported_message, set_unknown_error_message, LanceDBError, @@ -37,6 +37,7 @@ use crate::types::{LanceDBMergeInsertConfig, LanceDBRecordBatchReader}; pub unsafe extern "C" fn lancedb_table_arrow_schema( table: *const LanceDBTable, schema_out: *mut *mut arrow_schema::ffi::FFI_ArrowSchema, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || schema_out.is_null() { @@ -45,7 +46,7 @@ pub unsafe extern "C" fn lancedb_table_arrow_schema( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(tbl.schema()) { Ok(schema) => { @@ -72,6 +73,7 @@ pub unsafe extern "C" fn lancedb_table_arrow_schema( pub unsafe extern "C" fn lancedb_table_add( table: *const LanceDBTable, reader: *mut LanceDBRecordBatchReader, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || reader.is_null() { @@ -80,7 +82,7 @@ pub unsafe extern "C" fn lancedb_table_add( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); // Take ownership of the reader let reader_box = Box::from_raw(reader); @@ -109,6 +111,7 @@ pub unsafe extern "C" fn lancedb_table_merge_insert( on_columns: *const *const c_char, num_columns: usize, config: *const LanceDBMergeInsertConfig, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || data.is_null() || on_columns.is_null() || num_columns == 0 { @@ -133,7 +136,7 @@ pub unsafe extern "C" fn lancedb_table_merge_insert( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); // Take ownership of the data reader let data_box = Box::from_raw(data); @@ -178,6 +181,7 @@ pub unsafe extern "C" fn lancedb_table_cleanup_old_versions( table: *const LanceDBTable, older_than_days: u16, delete_unverified: i32, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() { @@ -186,7 +190,7 @@ pub unsafe extern "C" fn lancedb_table_cleanup_old_versions( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let older_than = chrono::Duration::days(older_than_days as i64); let delete_unverified_bool = delete_unverified != 0; @@ -207,6 +211,7 @@ pub unsafe extern "C" fn lancedb_table_cleanup_old_versions( #[no_mangle] pub unsafe extern "C" fn lancedb_table_compact_files( table: *const LanceDBTable, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() { @@ -215,7 +220,7 @@ pub unsafe extern "C" fn lancedb_table_compact_files( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); // Note: Compaction API is simplified in this implementation let _ = (tbl, runtime); @@ -235,6 +240,7 @@ pub unsafe extern "C" fn lancedb_table_compact_files( pub unsafe extern "C" fn lancedb_table_restore_version( table: *const LanceDBTable, version: u64, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() { @@ -243,7 +249,7 @@ pub unsafe extern "C" fn lancedb_table_restore_version( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); // Note: restore API is simplified in this implementation let _ = (version, tbl, runtime); @@ -372,13 +378,16 @@ pub unsafe extern "C" fn lancedb_free_arrow_schema( /// # Returns /// - Table version number on success, 0 on failure #[no_mangle] -pub unsafe extern "C" fn lancedb_table_version(table: *const LanceDBTable) -> u64 { +pub unsafe extern "C" fn lancedb_table_version( + table: *const LanceDBTable, + runtime: *const LanceDBRuntime, +) -> u64 { if table.is_null() { return 0; } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); runtime.block_on(tbl.version()).unwrap_or(0) } @@ -391,13 +400,16 @@ pub unsafe extern "C" fn lancedb_table_version(table: *const LanceDBTable) -> u6 /// # Returns /// - Number of rows in table on success, 0 on failure (or empty table) #[no_mangle] -pub unsafe extern "C" fn lancedb_table_count_rows(table: *const LanceDBTable) -> u64 { +pub unsafe extern "C" fn lancedb_table_count_rows( + table: *const LanceDBTable, + runtime: *const LanceDBRuntime, +) -> u64 { if table.is_null() { return 0; } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(tbl.count_rows(None)) { Ok(count) => count as u64, @@ -418,6 +430,7 @@ pub unsafe extern "C" fn lancedb_table_count_rows(table: *const LanceDBTable) -> pub unsafe extern "C" fn lancedb_table_delete( table: *const LanceDBTable, predicate: *const c_char, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || predicate.is_null() { @@ -431,7 +444,7 @@ pub unsafe extern "C" fn lancedb_table_delete( }; let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); match runtime.block_on(tbl.delete(predicate_str)) { Ok(_) => LanceDBError::Success, @@ -463,6 +476,7 @@ pub unsafe extern "C" fn lancedb_table_nearest_to( result_arrays: *mut *mut *mut arrow_array::ffi::FFI_ArrowArray, result_schema: *mut *mut arrow_schema::ffi::FFI_ArrowSchema, count_out: *mut usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() @@ -478,7 +492,7 @@ pub unsafe extern "C" fn lancedb_table_nearest_to( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let vec_slice = std::slice::from_raw_parts(vector, dimension); let vec_data: Vec = vec_slice.to_vec(); @@ -585,6 +599,7 @@ pub unsafe extern "C" fn lancedb_table_list_versions( versions_out: *mut *mut LanceDBVersion, metadata_out: *mut *mut LanceDBVersionMetadata, count_out: *mut usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || versions_out.is_null() || count_out.is_null() { @@ -593,7 +608,7 @@ pub unsafe extern "C" fn lancedb_table_list_versions( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let include_metadata = !metadata_out.is_null(); match runtime.block_on(tbl.list_versions()) { @@ -768,6 +783,7 @@ pub unsafe extern "C" fn lancedb_table_get_metadata( keys_out: *mut *mut *mut c_char, values_out: *mut *mut *mut c_char, count_out: *mut usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() @@ -782,7 +798,7 @@ pub unsafe extern "C" fn lancedb_table_get_metadata( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let ds = match tbl.dataset() { Some(ds) => ds, @@ -876,6 +892,7 @@ pub unsafe extern "C" fn lancedb_table_set_metadata( keys: *const *const c_char, values: *const *const c_char, count: usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || keys.is_null() || values.is_null() || count == 0 { @@ -909,7 +926,7 @@ pub unsafe extern "C" fn lancedb_table_set_metadata( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let ds = match tbl.dataset() { Some(ds) => ds, @@ -944,6 +961,7 @@ pub unsafe extern "C" fn lancedb_table_delete_metadata( table: *const LanceDBTable, keys: *const *const c_char, count: usize, + runtime: *const LanceDBRuntime, error_message: *mut *mut c_char, ) -> LanceDBError { if table.is_null() || keys.is_null() || count == 0 { @@ -972,7 +990,7 @@ pub unsafe extern "C" fn lancedb_table_delete_metadata( } let tbl = &(*table).inner; - let runtime = get_runtime(); + let runtime = resolve_runtime(runtime); let ds = match tbl.dataset() { Some(ds) => ds, diff --git a/tests/test_common.cpp b/tests/test_common.cpp index 73c296e..64231bc 100644 --- a/tests/test_common.cpp +++ b/tests/test_common.cpp @@ -34,6 +34,7 @@ void LanceDBFixture::create_empty_table(const std::string& table_name) { reinterpret_cast(&c_schema), nullptr, &table, + nullptr, &error_message ); @@ -72,6 +73,7 @@ LanceDBTable* LanceDBFixture::create_table_with_data(const std::string& table_na reinterpret_cast(&c_schema), reader, &table, + nullptr, &error_message ); diff --git a/tests/test_common.h b/tests/test_common.h index 5848890..98578f6 100644 --- a/tests/test_common.h +++ b/tests/test_common.h @@ -64,7 +64,7 @@ class LanceDBFixture : public BaseFixture { LanceDBFixture() { LanceDBConnectBuilder* builder = lancedb_connect(uri.c_str()); REQUIRE(builder != nullptr); - db = lancedb_connect_builder_execute(builder); + db = lancedb_connect_builder_execute(builder, nullptr); REQUIRE(db != nullptr); } @@ -88,7 +88,7 @@ class LanceDBSessionFixture : public LanceDBFixture { session = lancedb_session_new(&session_options); REQUIRE(session != nullptr); builder = lancedb_connect_builder_session(builder, session); - db = lancedb_connect_builder_execute(builder); + db = lancedb_connect_builder_execute(builder, nullptr); REQUIRE(db != nullptr); } diff --git a/tests/test_connection.cpp b/tests/test_connection.cpp index 30eee54..d16a5a6 100644 --- a/tests/test_connection.cpp +++ b/tests/test_connection.cpp @@ -26,7 +26,7 @@ TEST_CASE_METHOD(BaseFixture, "LanceDB Connection Builder", "[connection]") { REQUIRE(builder != nullptr); builder = lancedb_connect_builder_session(builder, nullptr); REQUIRE(builder != nullptr); - LanceDBConnection* db = lancedb_connect_builder_execute(builder); + LanceDBConnection* db = lancedb_connect_builder_execute(builder, nullptr); REQUIRE(db != nullptr); lancedb_connection_free(db); } @@ -70,7 +70,7 @@ TEST_CASE_METHOD(BaseFixture, "LanceDB Connection Builder", "[connection]") { REQUIRE(builder != nullptr); builder = lancedb_connect_builder_session(builder, session); REQUIRE(builder != nullptr); - LanceDBConnection* db = lancedb_connect_builder_execute(builder); + LanceDBConnection* db = lancedb_connect_builder_execute(builder, nullptr); REQUIRE(db != nullptr); lancedb_connection_free(db); lancedb_session_free(session); @@ -80,7 +80,7 @@ TEST_CASE_METHOD(BaseFixture, "LanceDB Connection Builder", "[connection]") { REQUIRE(builder != nullptr); builder = lancedb_connect_builder_session(builder, nullptr); REQUIRE(builder != nullptr); - LanceDBConnection* db = lancedb_connect_builder_execute(builder); + LanceDBConnection* db = lancedb_connect_builder_execute(builder, nullptr); REQUIRE(db != nullptr); lancedb_connection_free(db); } @@ -114,10 +114,10 @@ TEST_CASE_METHOD(BaseFixture, "LanceDB Session", "[connection]") { LanceDBSessionCacheStats index_stats{}; LanceDBSessionCacheStats metadata_stats{}; char* error_message = nullptr; - auto index_result = lancedb_session_index_cache_stats(session, &index_stats, &error_message); + auto index_result = lancedb_session_index_cache_stats(session, &index_stats, nullptr, &error_message); REQUIRE(index_result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - auto metadata_result = lancedb_session_metadata_cache_stats(session, &metadata_stats, &error_message); + auto metadata_result = lancedb_session_metadata_cache_stats(session, &metadata_stats, nullptr, &error_message); REQUIRE(metadata_result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); lancedb_session_free(session); @@ -127,12 +127,12 @@ TEST_CASE_METHOD(BaseFixture, "LanceDB Session", "[connection]") { REQUIRE(session != nullptr); LanceDBSessionCacheStats stats{}; char* error_message = nullptr; - auto result = lancedb_session_index_cache_stats(nullptr, &stats, &error_message); + auto result = lancedb_session_index_cache_stats(nullptr, &stats, nullptr, &error_message); REQUIRE(result == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); error_message = nullptr; - result = lancedb_session_metadata_cache_stats(session, nullptr, &error_message); + result = lancedb_session_metadata_cache_stats(session, nullptr, nullptr, &error_message); REQUIRE(result == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); @@ -145,7 +145,7 @@ TEST_CASE_METHOD(BaseFixture, "LanceDB Session", "[connection]") { blocker.close(); LanceDBConnectBuilder* builder = lancedb_connect(uri.c_str()); - LanceDBConnection* db = lancedb_connect_builder_execute(builder); + LanceDBConnection* db = lancedb_connect_builder_execute(builder, nullptr); REQUIRE(db == nullptr); } } @@ -159,7 +159,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Tables", "[connection]") { char** names_out = nullptr; size_t count_out = 0; char* error_message = nullptr; - auto result = lancedb_connection_table_names(db, &names_out, &count_out, &error_message); + auto result = lancedb_connection_table_names(db, &names_out, &count_out, nullptr, &error_message); REQUIRE(error_message == nullptr); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(count_out == num_tables); @@ -175,7 +175,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Tables", "[connection]") { } SECTION("Open Tables") { for (size_t i = 0; i < count_out; ++i) { - auto tbl = lancedb_connection_open_table(db, names_out[i]); + auto tbl = lancedb_connection_open_table(db, names_out[i], nullptr); REQUIRE(tbl != nullptr); lancedb_table_free(tbl); } @@ -183,10 +183,10 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Tables", "[connection]") { SECTION("Drop Tables") { for (size_t i = 0; i < count_out; ++i) { char* error_message = nullptr; - auto result = lancedb_connection_drop_table(db, names_out[i], _namespace, &error_message); + auto result = lancedb_connection_drop_table(db, names_out[i], _namespace, nullptr, &error_message); REQUIRE(error_message == nullptr); REQUIRE(result == LANCEDB_SUCCESS); - auto tbl = lancedb_connection_open_table(db, names_out[i]); + auto tbl = lancedb_connection_open_table(db, names_out[i], nullptr); REQUIRE(tbl == nullptr); } } @@ -199,24 +199,25 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Tables", "[connection]") { new_name.c_str(), _namespace, _namespace, + nullptr, &error_message); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); REQUIRE(result == LANCEDB_NOT_SUPPORTED); - auto tbl = lancedb_connection_open_table(db, new_name.c_str()); + auto tbl = lancedb_connection_open_table(db, new_name.c_str(), nullptr); REQUIRE(tbl == nullptr); - tbl = lancedb_connection_open_table(db, names_out[i]); + tbl = lancedb_connection_open_table(db, names_out[i], nullptr); REQUIRE(tbl != nullptr); lancedb_table_free(tbl); } } SECTION("Drop All Tables") { char* error_message = nullptr; - auto result = lancedb_connection_drop_all_tables(db, _namespace, &error_message); + auto result = lancedb_connection_drop_all_tables(db, _namespace, nullptr, &error_message); REQUIRE(error_message == nullptr); REQUIRE(result == LANCEDB_SUCCESS); for (size_t i = 0; i < count_out; ++i) { - auto tbl = lancedb_connection_open_table(db, names_out[i]); + auto tbl = lancedb_connection_open_table(db, names_out[i], nullptr); REQUIRE(tbl == nullptr); } } @@ -237,7 +238,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Names Builder", "[connection]") char** names_out = nullptr; size_t count_out = 0; char* error_message = nullptr; - auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, &error_message); + auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, nullptr, &error_message); REQUIRE(error_message == nullptr); REQUIRE(result == LANCEDB_SUCCESS); @@ -258,7 +259,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Names Builder", "[connection]") char** names_out = nullptr; size_t count_out = 0; char* error_message = nullptr; - auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, &error_message); + auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, nullptr, &error_message); REQUIRE(error_message == nullptr); REQUIRE(result == LANCEDB_SUCCESS); @@ -278,7 +279,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Names Builder", "[connection]") char** names_out = nullptr; size_t count_out = 0; char* error_message = nullptr; - auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, &error_message); + auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, nullptr, &error_message); REQUIRE(error_message == nullptr); REQUIRE(result == LANCEDB_SUCCESS); @@ -303,7 +304,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Names Builder", "[connection]") char** names_out = nullptr; size_t count_out = 0; char* error_message = nullptr; - auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, &error_message); + auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, nullptr, &error_message); REQUIRE(error_message == nullptr); REQUIRE(result == LANCEDB_SUCCESS); @@ -327,7 +328,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Names Builder", "[connection]") char** names_out = nullptr; size_t count_out = 0; char* error_message = nullptr; - auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, &error_message); + auto result = lancedb_table_names_builder_execute(builder, &names_out, &count_out, nullptr, &error_message); REQUIRE(error_message == nullptr); REQUIRE(result == LANCEDB_SUCCESS); @@ -351,7 +352,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Names Builder", "[connection]") char** names_out = nullptr; size_t count_out = 0; char* error_message = nullptr; - auto result = lancedb_table_names_builder_execute(nullptr, &names_out, &count_out, &error_message); + auto result = lancedb_table_names_builder_execute(nullptr, &names_out, &count_out, nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); @@ -401,7 +402,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Names Builder", "[connection]") char** all_names = nullptr; size_t all_count = 0; char* error_message = nullptr; - auto result = lancedb_connection_table_names(db, &all_names, &all_count, &error_message); + auto result = lancedb_connection_table_names(db, &all_names, &all_count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); REQUIRE(all_count == num_tables); @@ -427,7 +428,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Names Builder", "[connection]") char** page_names = nullptr; size_t page_count = 0; char* page_error = nullptr; - result = lancedb_table_names_builder_execute(builder, &page_names, &page_count, &page_error); + result = lancedb_table_names_builder_execute(builder, &page_names, &page_count, nullptr, &page_error); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(page_error == nullptr); @@ -458,7 +459,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Names Builder", "[connection]") TEST_CASE_METHOD(LanceDBFixture, "LanceDB Namespaces", "[connection]") { char* error_message = nullptr; const char* _namespace = "myspace"; - auto result = lancedb_connection_create_namespace(db, _namespace, &error_message); + auto result = lancedb_connection_create_namespace(db, _namespace, nullptr, &error_message); REQUIRE(error_message != nullptr); REQUIRE(result == LANCEDB_NOT_SUPPORTED); lancedb_free_string(error_message); @@ -471,6 +472,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Namespaces", "[connection]") { _namespace, &names_out, &count_out, + nullptr, &error_message); REQUIRE(error_message != nullptr); REQUIRE(result == LANCEDB_NOT_SUPPORTED); @@ -483,6 +485,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Namespaces", "[connection]") { char* error_message = nullptr; auto result = lancedb_connection_drop_namespace(db, _namespace, + nullptr, &error_message); REQUIRE(error_message != nullptr); REQUIRE(result == LANCEDB_NOT_SUPPORTED); diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 999f1ce..9af0759 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -22,7 +22,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_scalar_index( - table, columns, 1, LANCEDB_INDEX_BTREE, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_BTREE, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -30,7 +30,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { // List indices (should have one index) char** indices = nullptr; size_t count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -49,13 +49,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { auto reader = create_reader_from_batch(batch); REQUIRE(reader != nullptr); - result = lancedb_table_add(table, reader, &error_message); + result = lancedb_table_add(table, reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Verify total row count - REQUIRE(lancedb_table_count_rows(table) == 150); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 150); lancedb_table_free(table); } @@ -63,7 +63,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { SECTION("Create BTREE index on empty table then add data") { // Create empty table create_empty_table(table_name); - LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str()); + LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str(), nullptr); REQUIRE(table != nullptr); // Create BTREE index on the "key" column @@ -75,7 +75,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_scalar_index( - table, columns, 1, LANCEDB_INDEX_BTREE, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_BTREE, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -83,7 +83,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { // List indices (should have one index) char** indices = nullptr; size_t count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -102,13 +102,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { auto reader = create_reader_from_batch(batch); REQUIRE(reader != nullptr); - result = lancedb_table_add(table, reader, &error_message); + result = lancedb_table_add(table, reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Verify row count - REQUIRE(lancedb_table_count_rows(table) == 100); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 100); lancedb_table_free(table); } @@ -127,7 +127,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_scalar_index( - table, columns, 1, LANCEDB_INDEX_BTREE, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_BTREE, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -135,7 +135,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { // Replace the index config.replace = 1; result = lancedb_table_create_scalar_index( - table, columns, 1, LANCEDB_INDEX_BTREE, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_BTREE, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -143,7 +143,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index", "[index]") { // List indices (should still have one index after replacement) char** indices = nullptr; size_t count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -177,13 +177,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Index Stats", "[index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_scalar_index( - table, columns, 1, LANCEDB_INDEX_BTREE, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_BTREE, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); // Get the index name char** indices = nullptr; size_t count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(count == 1); std::string index_name = indices[0]; @@ -191,7 +191,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Index Stats", "[index]") { // Get index stats LanceDBIndexStats stats = {}; - result = lancedb_table_index_stats(table, index_name.c_str(), &stats, &error_message); + result = lancedb_table_index_stats(table, index_name.c_str(), &stats, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -219,13 +219,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Index Stats", "[index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_scalar_index( - table, columns, 1, LANCEDB_INDEX_BTREE, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_BTREE, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); // Get the index name char** indices = nullptr; size_t count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(count == 1); std::string index_name = indices[0]; @@ -235,12 +235,12 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Index Stats", "[index]") { auto batch = create_test_record_batch(50, 100); auto reader = create_reader_from_batch(batch); REQUIRE(reader != nullptr); - result = lancedb_table_add(table, reader, &error_message); + result = lancedb_table_add(table, reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); // Get index stats - should show unindexed rows LanceDBIndexStats stats = {}; - result = lancedb_table_index_stats(table, index_name.c_str(), &stats, &error_message); + result = lancedb_table_index_stats(table, index_name.c_str(), &stats, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -260,7 +260,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Index Stats", "[index]") { LanceDBIndexStats stats = {}; char* error_message = nullptr; LanceDBError result = lancedb_table_index_stats( - table, "no_such_index", &stats, &error_message); + table, "no_such_index", &stats, nullptr, &error_message); REQUIRE(result == LANCEDB_INDEX_NOT_FOUND); @@ -275,7 +275,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Index Stats", "[index]") { LanceDBIndexStats stats = {}; char* error_message = nullptr; LanceDBError result = lancedb_table_index_stats( - nullptr, "some_index", &stats, &error_message); + nullptr, "some_index", &stats, nullptr, &error_message); REQUIRE(result == LANCEDB_INVALID_ARGUMENT); @@ -297,7 +297,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index List and Drop", "[index]" char** indices = nullptr; size_t count = 0; char* error_message = nullptr; - LanceDBError result = lancedb_table_list_indices(table, &indices, &count, &error_message); + LanceDBError result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -321,7 +321,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index List and Drop", "[index]" char* error_message = nullptr; LanceDBError result = lancedb_table_create_scalar_index( - table, columns, 1, LANCEDB_INDEX_BTREE, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_BTREE, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -329,7 +329,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index List and Drop", "[index]" // List indices to get the index name char** indices = nullptr; size_t count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -344,7 +344,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index List and Drop", "[index]" lancedb_free_index_list(indices, count); // Drop the index - result = lancedb_table_drop_index(table, index_name.c_str(), &error_message); + result = lancedb_table_drop_index(table, index_name.c_str(), nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -352,7 +352,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index List and Drop", "[index]" // List indices again (should be empty) indices = nullptr; count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -368,7 +368,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Scalar Index List and Drop", "[index]" // Try to drop an index that doesn't exist char* error_message = nullptr; - LanceDBError result = lancedb_table_drop_index(table, "non_existent_index", &error_message); + LanceDBError result = lancedb_table_drop_index(table, "non_existent_index", nullptr, &error_message); // Should fail REQUIRE(result != LANCEDB_SUCCESS); diff --git a/tests/test_query.cpp b/tests/test_query.cpp index 8072e88..e858cf7 100644 --- a/tests/test_query.cpp +++ b/tests/test_query.cpp @@ -12,7 +12,7 @@ void verify_query_result(LanceDBQueryResult* query_result, size_t expected_rows) size_t count = 0; char* error_message = nullptr; const auto result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -45,7 +45,7 @@ void create_key_index(LanceDBTable* table) { char* error_message = nullptr; LanceDBError result = lancedb_table_create_scalar_index( - table, index_columns, 1, LANCEDB_INDEX_BTREE, &config, &error_message); + table, index_columns, 1, LANCEDB_INDEX_BTREE, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -73,7 +73,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Query - all entries", "[query]") { REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_query_execute(query); + LanceDBQueryResult* query_result = lancedb_query_execute(query, nullptr); REQUIRE(query_result != nullptr); verify_query_result(query_result, total_rows); } @@ -109,7 +109,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Query - all entries", "[query]") { REQUIRE(error_message == nullptr); // Execute query (consumes the query object) - LanceDBQueryResult* query_result = lancedb_query_execute(query); + LanceDBQueryResult* query_result = lancedb_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Verify this page has the expected number of rows @@ -151,7 +151,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Query - Where Filter", "[query]") { REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_query_execute(query); + LanceDBQueryResult* query_result = lancedb_query_execute(query, nullptr); REQUIRE(query_result != nullptr); verify_query_result(query_result, 1); @@ -176,7 +176,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Query - Where Filter", "[query]") { REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_query_execute(query); + LanceDBQueryResult* query_result = lancedb_query_execute(query, nullptr); REQUIRE(query_result != nullptr); verify_query_result(query_result, 5); @@ -211,7 +211,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Query - Where Filter no Index", "[quer REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_query_execute(query); + LanceDBQueryResult* query_result = lancedb_query_execute(query, nullptr); REQUIRE(query_result != nullptr); verify_query_result(query_result, 1); @@ -236,7 +236,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Query - Where Filter no Index", "[quer REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_query_execute(query); + LanceDBQueryResult* query_result = lancedb_query_execute(query, nullptr); REQUIRE(query_result != nullptr); verify_query_result(query_result, 5); @@ -261,7 +261,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Query - Where Filter no Index", "[quer REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_query_execute(query); + LanceDBQueryResult* query_result = lancedb_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Convert to Arrow @@ -270,7 +270,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Query - Where Filter no Index", "[quer size_t count = 0; error_message = nullptr; result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -306,7 +306,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Query - Filter on non-existent column" REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // error should be caught at execution time - LanceDBQueryResult* query_result = lancedb_query_execute(query); + LanceDBQueryResult* query_result = lancedb_query_execute(query, nullptr); REQUIRE(query_result == nullptr); error_message = nullptr; @@ -319,7 +319,7 @@ TEST_CASE_METHOD(LanceDBSessionFixture, "LanceDB Query - repeated queries popula LanceDBSessionCacheStats final_index_stats{}; char* error_message = nullptr; - LanceDBError result = lancedb_session_index_cache_stats(session, &initial_index_stats, &error_message); + LanceDBError result = lancedb_session_index_cache_stats(session, &initial_index_stats, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -349,12 +349,12 @@ TEST_CASE_METHOD(LanceDBSessionFixture, "LanceDB Query - repeated queries popula REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - LanceDBQueryResult* query_result = lancedb_query_execute(query); + LanceDBQueryResult* query_result = lancedb_query_execute(query, nullptr); REQUIRE(query_result != nullptr); lancedb_query_result_free(query_result); } - result = lancedb_session_index_cache_stats(session, &final_index_stats, &error_message); + result = lancedb_session_index_cache_stats(session, &final_index_stats, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); diff --git a/tests/test_table.cpp b/tests/test_table.cpp index 4a0b7f3..9ffe7f5 100644 --- a/tests/test_table.cpp +++ b/tests/test_table.cpp @@ -14,7 +14,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Creation", "[table]") { SECTION("Create table with data") { constexpr auto row_num = 10; LanceDBTable* table = create_table_with_data("table_with_data", row_num, 0); - REQUIRE(lancedb_table_count_rows(table) == row_num); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num); lancedb_table_free(table); } @@ -22,13 +22,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Creation", "[table]") { const std::string table_name = "table_reopen_test"; constexpr auto row_num = 15; LanceDBTable* table = create_table_with_data(table_name, row_num, 0); - REQUIRE(lancedb_table_count_rows(table) == row_num); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num); lancedb_table_free(table); // Reopen the table - LanceDBTable* reopened_table = lancedb_connection_open_table(db, table_name.c_str()); + LanceDBTable* reopened_table = lancedb_connection_open_table(db, table_name.c_str(), nullptr); REQUIRE(reopened_table != nullptr); - REQUIRE(lancedb_table_count_rows(reopened_table) == row_num); + REQUIRE(lancedb_table_count_rows(reopened_table, nullptr) == row_num); lancedb_table_free(reopened_table); } @@ -57,6 +57,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Creation", "[table]") { reinterpret_cast(&c_schema), reader, &table2, + nullptr, &error_message ); @@ -81,15 +82,15 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Add", "[table]") { create_empty_table(table_name); // Open the table - LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str()); + LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str(), nullptr); REQUIRE(table != nullptr); SECTION("Add data to empty table") { // Verify table is initially empty - REQUIRE(lancedb_table_count_rows(table) == 0); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 0); // Initial version should be 1 (empty table) - auto version = lancedb_table_version(table); + auto version = lancedb_table_version(table, nullptr); REQUIRE(version == 1); constexpr auto row_num = 10; @@ -100,22 +101,22 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Add", "[table]") { REQUIRE(reader != nullptr); char* error_message = nullptr; - LanceDBError result = lancedb_table_add(table, reader, &error_message); + LanceDBError result = lancedb_table_add(table, reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Verify row count - REQUIRE(lancedb_table_count_rows(table) == row_num); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num); // Version should increment to 2 - version = lancedb_table_version(table); + version = lancedb_table_version(table, nullptr); REQUIRE(version == 2); } SECTION("Add multiple batches of data") { // Initial version should be 1 (empty table) - auto version = lancedb_table_version(table); + auto version = lancedb_table_version(table, nullptr); REQUIRE(version == 1); // Add first batch @@ -125,14 +126,14 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Add", "[table]") { REQUIRE(reader1 != nullptr); char* error_message = nullptr; - LanceDBError result = lancedb_table_add(table, reader1, &error_message); + LanceDBError result = lancedb_table_add(table, reader1, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == row_num1); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num1); // Version should increment to 2 - version = lancedb_table_version(table); + version = lancedb_table_version(table, nullptr); REQUIRE(version == 2); // Add second batch @@ -141,14 +142,14 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Add", "[table]") { auto reader2 = create_reader_from_batch(batch2); REQUIRE(reader2 != nullptr); - result = lancedb_table_add(table, reader2, &error_message); + result = lancedb_table_add(table, reader2, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == row_num1+row_num2); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num1+row_num2); // Version should increment to 3 - version = lancedb_table_version(table); + version = lancedb_table_version(table, nullptr); REQUIRE(version == 3); } @@ -160,11 +161,11 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Add", "[table]") { REQUIRE(reader1 != nullptr); char* error_message = nullptr; - LanceDBError result = lancedb_table_add(table, reader1, &error_message); + LanceDBError result = lancedb_table_add(table, reader1, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == row_num); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num); // Add data with overlapping keys (5-14) // Keys 5-9 already exist in the table @@ -174,7 +175,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Add", "[table]") { auto reader2 = create_reader_from_batch(batch2); REQUIRE(reader2 != nullptr); - result = lancedb_table_add(table, reader2, &error_message); + result = lancedb_table_add(table, reader2, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -182,16 +183,16 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Add", "[table]") { // table_add adds all rows // So we should have 10 (original) + 10 (new batch) = 20 rows // Even though keys 5-9 exist in both batches - REQUIRE(lancedb_table_count_rows(table) == 20); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 20); // Version should increment - auto version = lancedb_table_version(table); + auto version = lancedb_table_version(table, nullptr); REQUIRE(version == 3); } SECTION("Add data with null reader should fail") { char* error_message = nullptr; - LanceDBError result = lancedb_table_add(table, nullptr, &error_message); + LanceDBError result = lancedb_table_add(table, nullptr, nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); @@ -206,7 +207,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Add", "[table]") { REQUIRE(reader != nullptr); char* error_message = nullptr; - LanceDBError result = lancedb_table_add(nullptr, reader, &error_message); + LanceDBError result = lancedb_table_add(nullptr, reader, nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); @@ -227,7 +228,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { create_empty_table(table_name); // Open the table - LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str()); + LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str(), nullptr); REQUIRE(table != nullptr); // Add initial data @@ -237,14 +238,14 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { REQUIRE(initial_reader != nullptr); char* error_message = nullptr; - LanceDBError result = lancedb_table_add(table, initial_reader, &error_message); + LanceDBError result = lancedb_table_add(table, initial_reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == row_num); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num); // Initial version after add should be 2 (1 for empty table creation, 2 after add) - auto version = lancedb_table_version(table); + auto version = lancedb_table_version(table, nullptr); REQUIRE(version == 2); SECTION("Merge insert with update and insert") { @@ -291,16 +292,16 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { char* error_message = nullptr; LanceDBError result = lancedb_table_merge_insert( - table, merge_reader, on_columns, 1, &config, &error_message); + table, merge_reader, on_columns, 1, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Should have 10 (original) - 5 (overlapping) + 10 (total in merge) = 15 rows - REQUIRE(lancedb_table_count_rows(table) == 15); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 15); // Version should increment to 3 (was 2 before merge insert) - auto version = lancedb_table_version(table); + auto version = lancedb_table_version(table, nullptr); REQUIRE(version == 3); } @@ -337,16 +338,16 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { char* error_message = nullptr; LanceDBError result = lancedb_table_merge_insert( - table, merge_reader, on_columns, 1, &config, &error_message); + table, merge_reader, on_columns, 1, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Should still have 10 rows (only updates, no inserts) - REQUIRE(lancedb_table_count_rows(table) == 10); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 10); // Version should increment to 3 (was 2 before merge insert) - auto version = lancedb_table_version(table); + auto version = lancedb_table_version(table, nullptr); REQUIRE(version == 3); } @@ -383,16 +384,16 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { char* error_message = nullptr; LanceDBError result = lancedb_table_merge_insert( - table, merge_reader, on_columns, 1, &config, &error_message); + table, merge_reader, on_columns, 1, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Should have 10 + 5 = 15 rows (only inserts, no updates) - REQUIRE(lancedb_table_count_rows(table) == 15); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 15); // Version should increment to 3 (was 2 before merge insert) - auto version = lancedb_table_version(table); + auto version = lancedb_table_version(table, nullptr); REQUIRE(version == 3); } @@ -405,22 +406,22 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { char* error_message = nullptr; LanceDBError result = lancedb_table_merge_insert( - table, merge_reader, on_columns, 1, nullptr, &error_message); + table, merge_reader, on_columns, 1, nullptr, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Default behavior should handle the merge - REQUIRE(lancedb_table_count_rows(table) >= 10); + REQUIRE(lancedb_table_count_rows(table, nullptr) >= 10); // Version should increment to 3 (was 2 before merge insert) - auto version = lancedb_table_version(table); + auto version = lancedb_table_version(table, nullptr); REQUIRE(version == 3); } SECTION("Merge insert with no actual changes") { // Get current version - auto version = lancedb_table_version(table); + auto version = lancedb_table_version(table, nullptr); REQUIRE(version == 2); // Create data with same keys and same values as existing data @@ -456,16 +457,16 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { char* error_message = nullptr; LanceDBError result = lancedb_table_merge_insert( - table, merge_reader, on_columns, 1, &config, &error_message); + table, merge_reader, on_columns, 1, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Row count should remain 10 (no new rows) - REQUIRE(lancedb_table_count_rows(table) == 10); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 10); // Check if version changed even though data is identical - version = lancedb_table_version(table); + version = lancedb_table_version(table, nullptr); // Version increments even if data doesn't actually change REQUIRE(version == 3); } @@ -479,7 +480,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { char* error_message = nullptr; LanceDBError result = lancedb_table_merge_insert( - table, nullptr, on_columns, 1, &config, &error_message); + table, nullptr, on_columns, 1, &config, nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); @@ -501,7 +502,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { char* error_message = nullptr; LanceDBError result = lancedb_table_merge_insert( - nullptr, merge_reader, on_columns, 1, &config, &error_message); + nullptr, merge_reader, on_columns, 1, &config, nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); @@ -525,7 +526,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Merge Insert", "[table]") { char* error_message = nullptr; LanceDBError result = lancedb_table_merge_insert( - table, merge_reader, nullptr, 1, &config, &error_message); + table, merge_reader, nullptr, 1, &config, nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); @@ -544,7 +545,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Delete", "[table]") { const std::string table_name = "test_delete_table"; create_empty_table(table_name); - LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str()); + LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str(), nullptr); REQUIRE(table != nullptr); // Add initial data (keys key_0 through key_9) @@ -554,61 +555,61 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Delete", "[table]") { REQUIRE(initial_reader != nullptr); char* error_message = nullptr; - LanceDBError result = lancedb_table_add(table, initial_reader, &error_message); + LanceDBError result = lancedb_table_add(table, initial_reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == row_num); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num); SECTION("Delete row matching predicate") { char* error_message = nullptr; - LanceDBError result = lancedb_table_delete(table, "key = 'key_0'", &error_message); + LanceDBError result = lancedb_table_delete(table, "key = 'key_0'", nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == row_num - 1); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num - 1); } SECTION("Delete multiple rows matching predicate") { char* error_message = nullptr; LanceDBError result = lancedb_table_delete( - table, "key IN ('key_0', 'key_1', 'key_2')", &error_message); + table, "key IN ('key_0', 'key_1', 'key_2')", nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == row_num - 3); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num - 3); result = lancedb_table_delete( - table, "key = 'key_10' OR key = 'key_11' OR key = 'key_12')", &error_message); + table, "key = 'key_10' OR key = 'key_11' OR key = 'key_12')", nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == row_num - 6); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num - 6); } SECTION("Delete with predicate matching no rows") { char* error_message = nullptr; LanceDBError result = lancedb_table_delete( - table, "key = 'nonexistent'", &error_message); + table, "key = 'nonexistent'", nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == row_num); + REQUIRE(lancedb_table_count_rows(table, nullptr) == row_num); } SECTION("Delete all rows") { char* error_message = nullptr; LanceDBError result = lancedb_table_delete( - table, "key IS NOT NULL", &error_message); + table, "key IS NOT NULL", nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - REQUIRE(lancedb_table_count_rows(table) == 0); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 0); } SECTION("Delete with unknown column should fail") { char* error_message = nullptr; LanceDBError result = lancedb_table_delete( - table, "unknown = 'key_0'", &error_message); + table, "unknown = 'key_0'", nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); REQUIRE(result != LANCEDB_SUCCESS); @@ -619,7 +620,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Delete", "[table]") { SECTION("Delete with empty predicate should fail") { char* error_message = nullptr; LanceDBError result = lancedb_table_delete( - table, "", &error_message); + table, "", nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); REQUIRE(result != LANCEDB_SUCCESS); @@ -630,7 +631,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Delete", "[table]") { SECTION("Delete with null table should fail") { char* error_message = nullptr; LanceDBError result = lancedb_table_delete( - nullptr, "key = 'key_0'", &error_message); + nullptr, "key = 'key_0'", nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); REQUIRE(error_message != nullptr); @@ -640,7 +641,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Delete", "[table]") { SECTION("Delete with null predicate should fail") { char* error_message = nullptr; LanceDBError result = lancedb_table_delete( - table, nullptr, &error_message); + table, nullptr, nullptr, &error_message); REQUIRE(result != LANCEDB_SUCCESS); REQUIRE(error_message != nullptr); @@ -782,34 +783,34 @@ TEST_CASE_METHOD(LanceDBSessionFixture, "LanceDB Table CRUD with same session ac REQUIRE(table_b != nullptr); // Read after create - REQUIRE(lancedb_table_count_rows(table_a) == table_a_rows); - REQUIRE(lancedb_table_count_rows(table_b) == table_b_rows); + REQUIRE(lancedb_table_count_rows(table_a, nullptr) == table_a_rows); + REQUIRE(lancedb_table_count_rows(table_b, nullptr) == table_b_rows); // Reopen and read again lancedb_table_free(table_a); lancedb_table_free(table_b); - table_a = lancedb_connection_open_table(db, table_a_name.c_str()); - table_b = lancedb_connection_open_table(db, table_b_name.c_str()); + table_a = lancedb_connection_open_table(db, table_a_name.c_str(), nullptr); + table_b = lancedb_connection_open_table(db, table_b_name.c_str(), nullptr); REQUIRE(table_a != nullptr); REQUIRE(table_b != nullptr); - REQUIRE(lancedb_table_count_rows(table_a) == table_a_rows); - REQUIRE(lancedb_table_count_rows(table_b) == table_b_rows); + REQUIRE(lancedb_table_count_rows(table_a, nullptr) == table_a_rows); + REQUIRE(lancedb_table_count_rows(table_b, nullptr) == table_b_rows); // Delete lancedb_table_free(table_a); lancedb_table_free(table_b); char* error_message = nullptr; - LanceDBError result = lancedb_connection_drop_table(db, table_a_name.c_str(), _namespace, &error_message); + LanceDBError result = lancedb_connection_drop_table(db, table_a_name.c_str(), _namespace, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - result = lancedb_connection_drop_table(db, table_b_name.c_str(), _namespace, &error_message); + result = lancedb_connection_drop_table(db, table_b_name.c_str(), _namespace, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Verify delete - table_a = lancedb_connection_open_table(db, table_a_name.c_str()); - table_b = lancedb_connection_open_table(db, table_b_name.c_str()); + table_a = lancedb_connection_open_table(db, table_a_name.c_str(), nullptr); + table_b = lancedb_connection_open_table(db, table_b_name.c_str(), nullptr); REQUIRE(table_a == nullptr); REQUIRE(table_b == nullptr); } diff --git a/tests/test_table_async.cpp b/tests/test_table_async.cpp new file mode 100644 index 0000000..d4b03c3 --- /dev/null +++ b/tests/test_table_async.cpp @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright The LanceDB Authors + */ + +#include +#include +#include +#include +#include "test_common.h" + +using AsyncResult = std::tuple; + +TEST_CASE_METHOD(LanceDBFixture, "LanceDB Concurrent Table Add", "[table][async]") { + const std::string table_name = "test_concurrent_add"; + constexpr auto initial_rows = 10; + LanceDBTable* table = create_table_with_data(table_name, initial_rows, 0); + REQUIRE(table != nullptr); + REQUIRE(lancedb_table_count_rows(table, nullptr) == initial_rows); + + SECTION("Multiple batched adds via std::async") { + constexpr int num_batches = 8; + constexpr int rows_per_batch = 10; + + std::vector> futures; + futures.reserve(num_batches); + + for (int i = 0; i < num_batches; i++) { + int start_index = initial_rows + i * rows_per_batch; + futures.push_back(std::async(std::launch::async, [&table, start_index]() -> AsyncResult { + char* error_message = nullptr; + auto batch = create_test_record_batch(rows_per_batch, start_index); + auto reader = create_reader_from_batch(batch); + auto error = lancedb_table_add(table, reader, nullptr, &error_message); + return {error, error_message}; + })); + } + + for (auto& f : futures) { + auto [error, error_message] = f.get(); + REQUIRE(error == LANCEDB_SUCCESS); + REQUIRE(error_message == nullptr); + } + + auto total_rows = lancedb_table_count_rows(table, nullptr); + REQUIRE(total_rows == initial_rows + num_batches * rows_per_batch); + } + + SECTION("Multiple batched merge inserts with overlapping keys via std::async") { + constexpr int num_batches = 4; + constexpr int rows_per_batch = 10; + // Each batch overlaps the previous by half + constexpr int stride = rows_per_batch / 2; + + std::vector> futures; + futures.reserve(num_batches); + + for (int i = 0; i < num_batches; i++) { + int start_index = initial_rows + i * stride; + futures.push_back(std::async(std::launch::async, [&table, start_index]() -> AsyncResult { + char* error_message = nullptr; + auto batch = create_test_record_batch(rows_per_batch, start_index); + auto reader = create_reader_from_batch(batch); + + const char* on_columns[] = {"key"}; + LanceDBMergeInsertConfig config = { + .when_matched_update_all = 1, + .when_not_matched_insert_all = 1 + }; + + auto error = lancedb_table_merge_insert( + table, reader, on_columns, 1, &config, nullptr, &error_message); + return {error, error_message}; + })); + } + + for (auto& f : futures) { + auto [error, error_message] = f.get(); + REQUIRE(error == LANCEDB_SUCCESS); + REQUIRE(error_message == nullptr); + } + + // Concurrent merge inserts don't see each other's in-flight changes, + // so overlapping keys between batches are not deduplicated + auto total_rows = lancedb_table_count_rows(table, nullptr); + REQUIRE(total_rows == initial_rows + num_batches * rows_per_batch); + } + + lancedb_table_free(table); +} diff --git a/tests/test_table_coro.cpp b/tests/test_table_coro.cpp new file mode 100644 index 0000000..fdae6e4 --- /dev/null +++ b/tests/test_table_coro.cpp @@ -0,0 +1,183 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright The LanceDB Authors + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#if BOOST_VERSION >= 108400 +#include +#include +#else +#include +#endif +#include "test_common.h" + +// the default coroutine stack is not enough when block_on runs inside a coroutine +constexpr size_t CORO_STACK_SIZE = 8 * 1024 * 1024; // 8 MB + +// Helper to spawn a coroutine with a custom stack size across Boost versions +template +void spawn_with_stack(Executor& ex, F&& f) { +#if BOOST_VERSION >= 108400 + boost::asio::spawn(ex, boost::asio::allocator_arg_t{}, + boost::context::fixedsize_stack(CORO_STACK_SIZE), + std::forward(f), boost::asio::detached); +#else + boost::asio::spawn(ex, std::forward(f), + boost::coroutines::attributes(CORO_STACK_SIZE)); +#endif +} + +using AsyncResult = std::tuple; + +TEST_CASE_METHOD(BaseFixture, "LanceDB Coroutine Table Add", "[table][coro]") { + // program crash when asio coroutines are used with the global tokio runtime + LanceDBRuntime* runtime = nullptr; // lancedb_runtime_new(); + //REQUIRE(runtime != nullptr); + + // Create connection with explicit runtime + LanceDBConnectBuilder* builder = lancedb_connect(uri.c_str()); + REQUIRE(builder != nullptr); + LanceDBConnection* db = lancedb_connect_builder_execute(builder, runtime); + REQUIRE(db != nullptr); + + // Create table with data using explicit runtime + const std::string table_name = "test_coro_add"; + constexpr auto initial_rows = 10; + + auto schema = create_test_schema(); + auto batch = create_test_record_batch(initial_rows, 0); + auto reader = create_reader_from_batch(batch); + REQUIRE(reader != nullptr); + + struct ArrowSchema c_schema; + REQUIRE(arrow::ExportSchema(*schema, &c_schema).ok()); + + LanceDBTable* table = nullptr; + char* error_message = nullptr; + LanceDBError result = lancedb_table_create( + db, + table_name.c_str(), + reinterpret_cast(&c_schema), + reader, + &table, + runtime, + &error_message + ); + REQUIRE(result == LANCEDB_SUCCESS); + REQUIRE(error_message == nullptr); + REQUIRE(table != nullptr); + + if (c_schema.release) { + c_schema.release(&c_schema); + } + + REQUIRE(lancedb_table_count_rows(table, runtime) == initial_rows); + + constexpr int num_threads = 4; + + SECTION("Multiple batched adds via boost::asio::spawn") { + constexpr int num_batches = 8; + constexpr int rows_per_batch = 10; + + boost::asio::io_context io; + auto work_guard = boost::asio::make_work_guard(io); + std::vector results(num_batches); + + // Start threads first so they are waiting for work + std::vector threads; + for (int t = 0; t < num_threads; t++) { + threads.emplace_back([&io]() { io.run(); }); + } + + // Spawn coroutines — threads will pick them up concurrently + for (int i = 0; i < num_batches; i++) { + int start_index = initial_rows + i * rows_per_batch; + spawn_with_stack(io, [&table, runtime, &results, i, start_index](boost::asio::yield_context) { + char* error_message = nullptr; + auto batch = create_test_record_batch(rows_per_batch, start_index); + auto reader = create_reader_from_batch(batch); + auto error = lancedb_table_add(table, reader, runtime, &error_message); + results[i] = {error, error_message}; + }); + } + + // Release the guard so threads can finish once all coroutines complete + work_guard.reset(); + for (auto& t : threads) { + t.join(); + } + + for (const auto& [error, error_message] : results) { + REQUIRE(error == LANCEDB_SUCCESS); + REQUIRE(error_message == nullptr); + } + + auto total_rows = lancedb_table_count_rows(table, runtime); + REQUIRE(total_rows == initial_rows + num_batches * rows_per_batch); + } + + SECTION("Multiple batched merge inserts with overlapping keys via boost::asio::spawn") { + constexpr int num_batches = 4; + constexpr int rows_per_batch = 10; + // Each batch overlaps the previous by half + constexpr int stride = rows_per_batch / 2; + + boost::asio::io_context io; + auto work_guard = boost::asio::make_work_guard(io); + std::vector results(num_batches); + + // Start threads first so they are waiting for work + std::vector threads; + for (int t = 0; t < num_threads; t++) { + threads.emplace_back([&io]() { io.run(); }); + } + + // Spawn coroutines — threads will pick them up concurrently + for (int i = 0; i < num_batches; i++) { + int start_index = initial_rows + i * stride; + spawn_with_stack(io, [&table, runtime, &results, i, start_index](boost::asio::yield_context) { + char* error_message = nullptr; + auto batch = create_test_record_batch(rows_per_batch, start_index); + auto reader = create_reader_from_batch(batch); + + const char* on_columns[] = {"key"}; + LanceDBMergeInsertConfig config = { + .when_matched_update_all = 1, + .when_not_matched_insert_all = 1 + }; + + auto error = lancedb_table_merge_insert( + table, reader, on_columns, 1, &config, runtime, &error_message); + results[i] = {error, error_message}; + }); + } + + // Release the guard so threads can finish once all coroutines complete + work_guard.reset(); + for (auto& t : threads) { + t.join(); + } + + for (const auto& [error, error_message] : results) { + REQUIRE(error == LANCEDB_SUCCESS); + REQUIRE(error_message == nullptr); + } + + // Concurrent merge inserts don't see each other's in-flight changes, + // so overlapping keys between batches are not deduplicated + auto total_rows = lancedb_table_count_rows(table, runtime); + REQUIRE(total_rows == initial_rows + num_batches * rows_per_batch); + } + + lancedb_table_free(table); + lancedb_connection_free(db); + lancedb_runtime_free(runtime); +} diff --git a/tests/test_table_meta.cpp b/tests/test_table_meta.cpp index 43b1ca5..d990d1c 100644 --- a/tests/test_table_meta.cpp +++ b/tests/test_table_meta.cpp @@ -18,7 +18,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table List Versions", "[table]") { size_t count = 0; char* error_message = nullptr; - LanceDBError result = lancedb_table_list_versions(table, &versions, nullptr, &count, &error_message); + LanceDBError result = lancedb_table_list_versions(table, &versions, nullptr, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -41,7 +41,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table List Versions", "[table]") { SECTION("List versions after adding data") { const std::string table_name = "versions_add_test"; create_empty_table(table_name); - LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str()); + LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str(), nullptr); REQUIRE(table != nullptr); // Add data 3 times to create versions 2, 3, 4 @@ -51,7 +51,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table List Versions", "[table]") { auto reader = create_reader_from_batch(batch); REQUIRE(reader != nullptr); - LanceDBError add_result = lancedb_table_add(table, reader, &error_message); + LanceDBError add_result = lancedb_table_add(table, reader, nullptr, &error_message); REQUIRE(add_result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); } @@ -59,7 +59,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table List Versions", "[table]") { LanceDBVersion* versions = nullptr; size_t count = 0; - LanceDBError result = lancedb_table_list_versions(table, &versions, nullptr, &count, &error_message); + LanceDBError result = lancedb_table_list_versions(table, &versions, nullptr, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -81,7 +81,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table List Versions", "[table]") { size_t count = 0; char* error_message = nullptr; - LanceDBError result = lancedb_table_list_versions(table, &versions, &metadata, &count, &error_message); + LanceDBError result = lancedb_table_list_versions(table, &versions, &metadata, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -102,7 +102,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table List Versions", "[table]") { size_t count = 0; char* error_message = nullptr; - LanceDBError result = lancedb_table_list_versions(nullptr, &versions, nullptr, &count, &error_message); + LanceDBError result = lancedb_table_list_versions(nullptr, &versions, nullptr, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); @@ -117,7 +117,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table List Versions", "[table]") { char* error_message = nullptr; size_t count = 0; - LanceDBError result = lancedb_table_list_versions(table, nullptr, nullptr, &count, &error_message); + LanceDBError result = lancedb_table_list_versions(table, nullptr, nullptr, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); @@ -143,7 +143,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Schema", "[table]") { FFI_ArrowSchema* schema_out = nullptr; char* error_message = nullptr; - LanceDBError result = lancedb_table_arrow_schema(table, &schema_out, &error_message); + LanceDBError result = lancedb_table_arrow_schema(table, &schema_out, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -169,7 +169,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Schema", "[table]") { FFI_ArrowSchema* schema_out = nullptr; char* error_message = nullptr; - LanceDBError result = lancedb_table_arrow_schema(nullptr, &schema_out, &error_message); + LanceDBError result = lancedb_table_arrow_schema(nullptr, &schema_out, nullptr, &error_message); REQUIRE(result == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); @@ -183,7 +183,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Schema", "[table]") { char* error_message = nullptr; - LanceDBError result = lancedb_table_arrow_schema(table, nullptr, &error_message); + LanceDBError result = lancedb_table_arrow_schema(table, nullptr, nullptr, &error_message); REQUIRE(result == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); @@ -210,7 +210,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { size_t count = 0; char* error_message = nullptr; - LanceDBError result = lancedb_table_get_metadata(table, nullptr, 0, &keys, &values, &count, &error_message); + LanceDBError result = lancedb_table_get_metadata(table, nullptr, 0, &keys, &values, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -232,7 +232,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { }); char* error_message = nullptr; - LanceDBError result = lancedb_table_set_metadata(table, set_keys.get(), set_values.get(), expected.size(), &error_message); + LanceDBError result = lancedb_table_set_metadata(table, set_keys.get(), set_values.get(), expected.size(), nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -241,7 +241,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { char** values = nullptr; size_t count = 0; - result = lancedb_table_get_metadata(table, nullptr, 0, &keys, &values, &count, &error_message); + result = lancedb_table_get_metadata(table, nullptr, 0, &keys, &values, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); REQUIRE(count == expected.size()); @@ -276,7 +276,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { REQUIRE(expected.size() == filter_length-1); char* error_message = nullptr; - LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, set_length, &error_message); + LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, set_length, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -285,7 +285,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { char** values = nullptr; size_t count = 0; - result = lancedb_table_get_metadata(table, filter_keys, filter_length, &keys, &values, &count, &error_message); + result = lancedb_table_get_metadata(table, filter_keys, filter_length, &keys, &values, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); REQUIRE(count == expected.size()); @@ -316,7 +316,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { } char* error_message = nullptr; - LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, set_length, &error_message); + LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, set_length, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -324,7 +324,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { char** values = nullptr; size_t count = 0; - result = lancedb_table_get_metadata(table, filter_keys, filter_length, &keys, &values, &count, &error_message); + result = lancedb_table_get_metadata(table, filter_keys, filter_length, &keys, &values, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); REQUIRE(count == expected.size()); @@ -343,7 +343,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { const char* set_values[] = {"valid_value"}; char* error_message = nullptr; - LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, 1, &error_message); + LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, 1, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -354,7 +354,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { char** values = nullptr; size_t count = 0; - result = lancedb_table_get_metadata(table, filter_keys, 1, &keys, &values, &count, &error_message); + result = lancedb_table_get_metadata(table, filter_keys, 1, &keys, &values, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); @@ -365,13 +365,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { const char* set_values[] = {"red"}; char* error_message = nullptr; - LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, 1, &error_message); + LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, 1, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Update the value const char* update_values[] = {"blue"}; - result = lancedb_table_set_metadata(table, set_keys, update_values, 1, &error_message); + result = lancedb_table_set_metadata(table, set_keys, update_values, 1, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -380,7 +380,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { char** values = nullptr; size_t count = 0; - result = lancedb_table_get_metadata(table, nullptr, 0, &keys, &values, &count, &error_message); + result = lancedb_table_get_metadata(table, nullptr, 0, &keys, &values, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); REQUIRE(count == 1); @@ -413,11 +413,11 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { char* error_message = nullptr; - LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, set_length, &error_message); + LanceDBError result = lancedb_table_set_metadata(table, set_keys, set_values, set_length, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - result = lancedb_table_delete_metadata(table, del_keys, del_length, &error_message); + result = lancedb_table_delete_metadata(table, del_keys, del_length, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -426,7 +426,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { char** values = nullptr; size_t count = 0; - result = lancedb_table_get_metadata(table, nullptr, 0, &keys, &values, &count, &error_message); + result = lancedb_table_get_metadata(table, nullptr, 0, &keys, &values, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); REQUIRE(count == expected.size()); @@ -445,7 +445,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { const char* del_keys[] = {"nonexistent"}; char* error_message = nullptr; - LanceDBError result = lancedb_table_delete_metadata(table, del_keys, 1, &error_message); + LanceDBError result = lancedb_table_delete_metadata(table, del_keys, 1, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); } @@ -456,25 +456,25 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { // Null table const char* keys[] = {"k"}; const char* values[] = {"v"}; - REQUIRE(lancedb_table_set_metadata(nullptr, keys, values, 1, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_set_metadata(nullptr, keys, values, 1, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); error_message = nullptr; // Null keys - REQUIRE(lancedb_table_set_metadata(table, nullptr, values, 1, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_set_metadata(table, nullptr, values, 1, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); error_message = nullptr; // Null values - REQUIRE(lancedb_table_set_metadata(table, keys, nullptr, 1, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_set_metadata(table, keys, nullptr, 1, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); error_message = nullptr; // Zero count - REQUIRE(lancedb_table_set_metadata(table, keys, values, 0, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_set_metadata(table, keys, values, 0, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); } @@ -485,25 +485,25 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { char** values = nullptr; size_t count = 0; - REQUIRE(lancedb_table_get_metadata(nullptr, nullptr, 0, &keys, &values, &count, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_get_metadata(nullptr, nullptr, 0, &keys, &values, &count, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); error_message = nullptr; - REQUIRE(lancedb_table_get_metadata(table, nullptr, 0, nullptr, &values, &count, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_get_metadata(table, nullptr, 0, nullptr, &values, &count, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); error_message = nullptr; // Null filter_keys with non-zero filter_count - REQUIRE(lancedb_table_get_metadata(table, nullptr, 1, &keys, &values, &count, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_get_metadata(table, nullptr, 1, &keys, &values, &count, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); error_message = nullptr; // Non-null filter_keys with zero filter_count const char* filter[] = {"k"}; - REQUIRE(lancedb_table_get_metadata(table, filter, 0, &keys, &values, &count, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_get_metadata(table, filter, 0, &keys, &values, &count, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); } @@ -511,12 +511,12 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Table Metadata", "[table]") { SECTION("Delete metadata with null arguments should fail") { char* error_message = nullptr; - REQUIRE(lancedb_table_delete_metadata(nullptr, nullptr, 1, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_delete_metadata(nullptr, nullptr, 1, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); error_message = nullptr; - REQUIRE(lancedb_table_delete_metadata(table, nullptr, 1, &error_message) == LANCEDB_INVALID_ARGUMENT); + REQUIRE(lancedb_table_delete_metadata(table, nullptr, 1, nullptr, &error_message) == LANCEDB_INVALID_ARGUMENT); REQUIRE(error_message != nullptr); lancedb_free_string(error_message); } diff --git a/tests/test_vector_index.cpp b/tests/test_vector_index.cpp index bbaea30..bcf4050 100644 --- a/tests/test_vector_index.cpp +++ b/tests/test_vector_index.cpp @@ -27,7 +27,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -35,7 +35,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { // List indices (should have one index) char** indices = nullptr; size_t count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -54,13 +54,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { auto reader = create_reader_from_batch(batch); REQUIRE(reader != nullptr); - result = lancedb_table_add(table, reader, &error_message); + result = lancedb_table_add(table, reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Verify total row count - REQUIRE(lancedb_table_count_rows(table) == 306); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 306); lancedb_table_free(table); } @@ -84,7 +84,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, columns, 1, LANCEDB_INDEX_IVF_PQ, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_IVF_PQ, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -94,13 +94,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { auto reader = create_reader_from_batch(batch); REQUIRE(reader != nullptr); - result = lancedb_table_add(table, reader, &error_message); + result = lancedb_table_add(table, reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Verify total row count - REQUIRE(lancedb_table_count_rows(table) == 306); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 306); lancedb_table_free(table); } @@ -124,7 +124,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, columns, 1, LANCEDB_INDEX_IVF_HNSW_PQ, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_IVF_HNSW_PQ, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -134,13 +134,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { auto reader = create_reader_from_batch(batch); REQUIRE(reader != nullptr); - result = lancedb_table_add(table, reader, &error_message); + result = lancedb_table_add(table, reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Verify total row count - REQUIRE(lancedb_table_count_rows(table) == 306); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 306); lancedb_table_free(table); } @@ -164,7 +164,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, columns, 1, LANCEDB_INDEX_IVF_HNSW_SQ, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_IVF_HNSW_SQ, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -174,13 +174,13 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { auto reader = create_reader_from_batch(batch); REQUIRE(reader != nullptr); - result = lancedb_table_add(table, reader, &error_message); + result = lancedb_table_add(table, reader, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // Verify total row count - REQUIRE(lancedb_table_count_rows(table) == 306); + REQUIRE(lancedb_table_count_rows(table, nullptr) == 306); lancedb_table_free(table); } @@ -188,7 +188,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { SECTION("Create IVF_FLAT index on empty table should fail") { // Create empty table create_empty_table(table_name); - LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str()); + LanceDBTable* table = lancedb_connection_open_table(db, table_name.c_str(), nullptr); REQUIRE(table != nullptr); // Try to create IVF_FLAT index on empty table (should fail - needs training data) @@ -205,7 +205,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, nullptr, &error_message); // Vector index creation on empty table should fail REQUIRE(result != LANCEDB_SUCCESS); @@ -237,7 +237,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -245,7 +245,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index", "[vector_index]") { // Replace with IVF_PQ index config.replace = 1; result = lancedb_table_create_vector_index( - table, columns, 1, LANCEDB_INDEX_IVF_PQ, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_IVF_PQ, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -276,7 +276,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index List and Drop", "[vector_ char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, &error_message); + table, columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -284,7 +284,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index List and Drop", "[vector_ // List indices to get the index name char** indices = nullptr; size_t count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -299,7 +299,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index List and Drop", "[vector_ lancedb_free_index_list(indices, count); // Drop the index - result = lancedb_table_drop_index(table, index_name.c_str(), &error_message); + result = lancedb_table_drop_index(table, index_name.c_str(), nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -307,7 +307,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index List and Drop", "[vector_ // List indices again (should be empty) indices = nullptr; count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -330,7 +330,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index List and Drop", "[vector_ char* error_message = nullptr; LanceDBError result = lancedb_table_create_scalar_index( - table, scalar_columns, 1, LANCEDB_INDEX_BTREE, &scalar_config, &error_message); + table, scalar_columns, 1, LANCEDB_INDEX_BTREE, &scalar_config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -348,7 +348,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index List and Drop", "[vector_ }; result = lancedb_table_create_vector_index( - table, vector_columns, 1, LANCEDB_INDEX_IVF_FLAT, &vector_config, &error_message); + table, vector_columns, 1, LANCEDB_INDEX_IVF_FLAT, &vector_config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -356,7 +356,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index List and Drop", "[vector_ // List indices (should have two indices) char** indices = nullptr; size_t count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -376,7 +376,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index List and Drop", "[vector_ lancedb_free_index_list(indices, count); // Drop the first index - result = lancedb_table_drop_index(table, first_index_name.c_str(), &error_message); + result = lancedb_table_drop_index(table, first_index_name.c_str(), nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -384,7 +384,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Index List and Drop", "[vector_ // List indices again (should have one index remaining) indices = nullptr; count = 0; - result = lancedb_table_list_indices(table, &indices, &count, &error_message); + result = lancedb_table_list_indices(table, &indices, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); diff --git a/tests/test_vector_query.cpp b/tests/test_vector_query.cpp index 01a6c26..a91af21 100644 --- a/tests/test_vector_query.cpp +++ b/tests/test_vector_query.cpp @@ -60,6 +60,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - nearest_to without inde &result_arrays, &result_schema, &count, + nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); @@ -100,6 +101,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - nearest_to without inde &result_arrays, &result_schema, &count, + nullptr, &error_message); // Should succeed because the API finds the "data" vector column @@ -171,7 +173,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - paged query with limit REQUIRE(error_message == nullptr); // Execute query (consumes the query object) - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Convert to Arrow @@ -180,7 +182,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - paged query with limit size_t count = 0; error_message = nullptr; result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -239,7 +241,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - nearest_to with IVF_FLA char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, vector_columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, &error_message); + table, vector_columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -263,6 +265,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - nearest_to with IVF_FLA &result_arrays, &result_schema, &count, + nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); @@ -303,6 +306,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - nearest_to with IVF_FLA &result_arrays, &result_schema, &count, + nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); @@ -348,7 +352,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, vector_columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, &error_message); + table, vector_columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -376,7 +380,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Convert to Arrow and verify results @@ -385,7 +389,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter size_t count = 0; error_message = nullptr; result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -427,7 +431,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Convert to Arrow and verify results @@ -436,7 +440,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter size_t count = 0; error_message = nullptr; result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -477,7 +481,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Convert to Arrow and verify results @@ -486,7 +490,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter size_t count = 0; error_message = nullptr; result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -527,7 +531,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Convert to Arrow and verify results @@ -536,7 +540,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter size_t count = 0; error_message = nullptr; result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -589,7 +593,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Convert to Arrow and verify results @@ -598,7 +602,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - configuration parameter size_t count = 0; error_message = nullptr; result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -641,7 +645,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - HNSW parameters", "[vec char* error_message = nullptr; LanceDBError result = lancedb_table_create_vector_index( - table, vector_columns, 1, LANCEDB_INDEX_IVF_HNSW_SQ, &config, &error_message); + table, vector_columns, 1, LANCEDB_INDEX_IVF_HNSW_SQ, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -669,7 +673,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - HNSW parameters", "[vec REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Convert to Arrow and verify results @@ -678,7 +682,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - HNSW parameters", "[vec size_t count = 0; error_message = nullptr; result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -725,7 +729,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - HNSW parameters", "[vec REQUIRE(error_message == nullptr); // Execute query - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result != nullptr); // Convert to Arrow and verify results @@ -734,7 +738,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - HNSW parameters", "[vec size_t count = 0; error_message = nullptr; result = lancedb_query_result_to_arrow( - query_result, &result_arrays, &result_schema, &count, &error_message); + query_result, &result_arrays, &result_schema, &count, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -780,6 +784,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - error cases", "[vector_ &result_arrays, &result_schema, &count, + nullptr, &error_message); // Should fail @@ -809,6 +814,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - error cases", "[vector_ &result_arrays, &result_schema, &count, + nullptr, &error_message); // Should fail @@ -852,7 +858,7 @@ TEST_CASE_METHOD(LanceDBFixture, "LanceDB Vector Query - Filter on non-existent REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); // error should be caught at execution time - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result == nullptr); lancedb_table_free(table); @@ -863,7 +869,7 @@ TEST_CASE_METHOD(LanceDBSessionFixture, "LanceDB Vector Query - repeated queries LanceDBSessionCacheStats final_index_stats{}; char* error_message = nullptr; - LanceDBError result = lancedb_session_index_cache_stats(session, &initial_index_stats, &error_message); + LanceDBError result = lancedb_session_index_cache_stats(session, &initial_index_stats, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -884,7 +890,7 @@ TEST_CASE_METHOD(LanceDBSessionFixture, "LanceDB Vector Query - repeated queries }; result = lancedb_table_create_vector_index( - table, vector_columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, &error_message); + table, vector_columns, 1, LANCEDB_INDEX_IVF_FLAT, &config, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); @@ -904,12 +910,12 @@ TEST_CASE_METHOD(LanceDBSessionFixture, "LanceDB Vector Query - repeated queries REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr); - LanceDBQueryResult* query_result = lancedb_vector_query_execute(query); + LanceDBQueryResult* query_result = lancedb_vector_query_execute(query, nullptr); REQUIRE(query_result != nullptr); lancedb_query_result_free(query_result); } - result = lancedb_session_index_cache_stats(session, &final_index_stats, &error_message); + result = lancedb_session_index_cache_stats(session, &final_index_stats, nullptr, &error_message); REQUIRE(result == LANCEDB_SUCCESS); REQUIRE(error_message == nullptr);