Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Translator with persistent cpu cache #1645

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions python/cpp/translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,14 +307,18 @@ namespace ctranslate2 {
// If the lock is not acquired immediately it means the model is being used
// in another thread and we can't unload it at this time.
std::unique_lock lock(_mutex, std::try_to_lock);
if (!lock || !_model_is_loaded)
if (!lock)
return;

std::vector<std::shared_ptr<const models::Model>> loaded_models;
if (_model_is_loaded)
loaded_models = _pool->detach_models();

_cached_models = _pool->detach_models();
if (to_cpu)
move_cached_models(Device::CPU, std::vector<int>(_cached_models.size(), 0));
else
if (to_cpu && _cached_models.empty())
_cached_models = clone_models(Device::CPU, std::vector<int>(loaded_models.size(), 0), loaded_models);
minhthuc2502 marked this conversation as resolved.
Show resolved Hide resolved
else if (!to_cpu)
_cached_models.clear();
loaded_models.clear();

// We clear the CUDA allocator cache to further reduce the memory after unloading the model.
if (_device == Device::CUDA)
Expand All @@ -323,19 +327,20 @@ namespace ctranslate2 {
_model_is_loaded = false;
}

void load_model() {
void load_model(const bool keep_cache) {
std::unique_lock lock(_mutex);
if (_model_is_loaded)
return;

if (_cached_models.empty()) {
_cached_models = _model_loader.load();
} else {
move_cached_models(_device, _device_index, _num_replicas_per_device);
}

_pool->set_models(_cached_models);
_cached_models.clear();

std::vector<std::shared_ptr<const models::Model>> loaded_models;
if (_cached_models.empty())
loaded_models = _model_loader.load();
else
loaded_models = clone_models(_device, _device_index, _cached_models, _num_replicas_per_device);

_pool->set_models(loaded_models);
if (!keep_cache)
_cached_models.clear();
_model_is_loaded = true;
}

Expand All @@ -357,13 +362,18 @@ namespace ctranslate2 {
throw std::runtime_error("The model for this translator was unloaded");
}

void move_cached_models(Device device,
const std::vector<int>& device_index,
size_t num_models_per_device = 1) {
for (size_t i = 0; i < _cached_models.size(); ++i) {
auto& model = const_cast<models::Model&>(*_cached_models[i]);
model.set_device(device, device_index[i / num_models_per_device]);

std::vector<std::shared_ptr<const models::Model>> clone_models(Device device,
const std::vector<int>& device_index,
std::vector<std::shared_ptr<const models::Model>> cached_models,
size_t num_models_per_device = 1) {
std::vector<std::shared_ptr<const models::Model>> copied_models;
for (size_t i = 0; i < cached_models.size(); ++i) {
auto& model = const_cast<models::Model&>(*cached_models[i]);
auto copied_model = model.copy_to(device, device_index[i / num_models_per_device]);
copied_models.push_back(copied_model);
}
return copied_models;
}
};

Expand Down Expand Up @@ -684,8 +694,14 @@ namespace ctranslate2 {
)pbdoc")

.def("load_model", &TranslatorWrapper::load_model,
py::arg("keep_cache")=false,
py::call_guard<py::gil_scoped_release>(),
"Loads the model back to the initial device.")
R"pbdoc(
Loads the model back to the initial device.

Arguments:
keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists.
)pbdoc")

.def_property_readonly("model_is_loaded", &TranslatorWrapper::model_is_loaded,
"Whether the model is loaded on the initial device and ready to be used.")
Expand Down
Loading