Skip to content

Commit

Permalink
Expose exception text to C, Python and Rust API's (#39)
Browse files Browse the repository at this point in the history
Previously when something went wrong in the Rust or C API - all the end user would see is a `CUVS_ERROR` return code with no extra indication of what went wrong.

This change exposes the exception text to both the C and Rust api's, and provides a convenience method to automatically catch c++ exceptions, and convert the exception into an error code with the text set appropiately.

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #39
  • Loading branch information
benfred authored Mar 5, 2024
1 parent 8e6979f commit b94c1b5
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 100 deletions.
10 changes: 10 additions & 0 deletions cpp/include/cuvs/core/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ cuvsError_t cuvsResourcesDestroy(cuvsResources_t res);
*/
cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream);

/** @brief Returns a string describing the last seen error on this thread, or
* NULL if the last function succeeded.
*/
const char* cuvsGetLastErrorText();

/**
* @brief Sets a string describing an error seen on the thread. Passing NULL
* clears any previously seen error message.
*/
void cuvsSetLastErrorText(const char* error);
#ifdef __cplusplus
}
#endif
Expand Down
45 changes: 45 additions & 0 deletions cpp/include/cuvs/core/exceptions.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "c_api.h"

#include <exception>

namespace cuvs::core {

/**
* @brief Translates C++ exceptions into cuvs C-API error codes
*/
template <typename Fn>
cuvsError_t translate_exceptions(Fn func)
{
cuvsError_t status;
try {
func();
status = CUVS_SUCCESS;
cuvsSetLastErrorText(NULL);
} catch (const std::exception& e) {
cuvsSetLastErrorText(e.what());
status = CUVS_ERROR;
} catch (...) {
cuvsSetLastErrorText("unknown exception");
status = CUVS_ERROR;
}
return status;
}
} // namespace cuvs::core
38 changes: 17 additions & 21 deletions cpp/src/core/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,42 @@

#include <cstdint>
#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
#include <memory>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <thread>

extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res)
{
cuvsError_t status;
try {
return cuvs::core::translate_exceptions([=] {
auto res_ptr = new raft::resources{};
*res = reinterpret_cast<uintptr_t>(res_ptr);
status = CUVS_SUCCESS;
} catch (...) {
status = CUVS_ERROR;
}
return status;
});
}

extern "C" cuvsError_t cuvsResourcesDestroy(cuvsResources_t res)
{
cuvsError_t status;
try {
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
delete res_ptr;
status = CUVS_SUCCESS;
} catch (...) {
status = CUVS_ERROR;
}
return status;
});
}

extern "C" cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream)
{
cuvsError_t status;
try {
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
raft::resource::set_cuda_stream(*res_ptr, static_cast<rmm::cuda_stream_view>(stream));
status = CUVS_SUCCESS;
} catch (...) {
status = CUVS_ERROR;
}
return status;
});
}

thread_local std::string last_error_text = "";

extern "C" const char* cuvsGetLastErrorText()
{
return last_error_text.empty() ? NULL : last_error_text.c_str();
}

extern "C" void cuvsSetLastErrorText(const char* error) { last_error_text = error ? error : ""; }
62 changes: 14 additions & 48 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <raft/core/resources.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
#include <cuvs/core/interop.hpp>
#include <cuvs/neighbors/cagra.h>
#include <cuvs/neighbors/cagra.hpp>
Expand Down Expand Up @@ -96,17 +97,12 @@ void _search(cuvsResources_t res,

extern "C" cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index)
{
try {
*index = new cuvsCagraIndex{};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { *index = new cuvsCagraIndex{}; });
}

extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr)
{
try {
return cuvs::core::translate_exceptions([=] {
auto index = *index_c_ptr;

if (index.dtype.code == kDLFloat) {
Expand All @@ -123,18 +119,15 @@ extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr)
delete index_ptr;
}
delete index_c_ptr;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
cuvsCagraIndexParams_t params,
DLManagedTensor* dataset_tensor,
cuvsCagraIndex_t index)
{
try {
return cuvs::core::translate_exceptions([=] {
auto dataset = dataset_tensor->dl_tensor;

if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
Expand All @@ -151,13 +144,7 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
dataset.dtype.code,
dataset.dtype.bits);
}
return CUVS_SUCCESS;
} catch (const std::exception& ex) {
std::cerr << "Error occurred: " << ex.what() << std::endl;
return CUVS_ERROR;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
Expand All @@ -167,7 +154,7 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
{
try {
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
auto neighbors = neighbors_tensor->dl_tensor;
auto distances = distances_tensor->dl_tensor;
Expand Down Expand Up @@ -198,57 +185,36 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
queries.dtype.code,
queries.dtype.bits);
}
return CUVS_SUCCESS;
} catch (const std::exception& ex) {
std::cerr << "Error occurred: " << ex.what() << std::endl;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraIndexParamsCreate(cuvsCagraIndexParams_t* params)
{
try {
return cuvs::core::translate_exceptions([=] {
*params = new cuvsCagraIndexParams{.intermediate_graph_degree = 128,
.graph_degree = 64,
.build_algo = IVF_PQ,
.nn_descent_niter = 20};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraIndexParamsDestroy(cuvsCagraIndexParams_t params)
{
try {
delete params;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsCagraSearchParamsCreate(cuvsCagraSearchParams_t* params)
{
try {
return cuvs::core::translate_exceptions([=] {
*params = new cuvsCagraSearchParams{.itopk_size = 64,
.search_width = 1,
.hashmap_max_fill_rate = 0.5,
.num_random_samplings = 1,
.rand_xor_mask = 0x128394};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraSearchParamsDestroy(cuvsCagraSearchParams_t params)
{
try {
delete params;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { delete params; });
}
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# =============================================================================

# Set the list of Cython files to build
set(cython_sources cydlpack.pyx)
set(cython_sources cydlpack.pyx exceptions.pyx)
set(linked_libraries cuvs::cuvs cuvs_c)

# Build all of the Cython targets
Expand Down
1 change: 1 addition & 0 deletions python/cuvs/cuvs/common/c_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ cdef extern from "cuvs/core/c_api.h":
cuvsError_t cuvsResourcesCreate(cuvsResources_t* res)
cuvsError_t cuvsResourcesDestroy(cuvsResources_t res)
cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream)
const char * cuvsGetLastErrorText()
37 changes: 37 additions & 0 deletions python/cuvs/cuvs/common/exceptions.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: language_level=3

from cuvs.common.c_api cimport cuvsError_t, cuvsGetLastErrorText


class CuvsException(Exception):
pass


def get_last_error_text():
""" returns the last error description from the cuvs c-api """
cdef const char* c_err = cuvsGetLastErrorText()
if c_err is NULL:
return
cdef bytes err = c_err
return err.decode("utf8")


def check_cuvs(status: cuvsError_t):
""" Converts a status code into an exception """
if status == cuvsError_t.CUVS_ERROR:
raise CuvsException(get_last_error_text())
Loading

0 comments on commit b94c1b5

Please sign in to comment.