Skip to content

Commit

Permalink
V3.4.1 Release (#34)
Browse files Browse the repository at this point in the history
* support multi-hot cat input

* Update CI.DockerFile to latest v3.5-integration
  • Loading branch information
yingcanw authored Mar 1, 2022
1 parent f8f69cc commit fc9ddc4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
27 changes: 17 additions & 10 deletions src/hugectr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,8 @@ ModelState::Create_EmbeddingCache()
"Please confirm that device ", gpu_shape[i],
" is added to 'deployed_device_list' in the ps configuration file");

if (embedding_cache_map.find(gpu_shape[i]) == embedding_cache_map.end()) {
if (embedding_cache_map.find(gpu_shape[i]) == embedding_cache_map.end() &&
support_gpu_cache_) {
HCTR_TRITON_LOG(
INFO, "******Creating Embedding Cache for model ", name_,
" in device ", gpu_shape[i]);
Expand Down Expand Up @@ -1466,6 +1467,7 @@ class ModelInstanceState {
const std::string& Name() const { return name_; }
TRITONSERVER_InstanceGroupKind Kind() const { return kind_; }
int32_t DeviceId() const { return device_id_; }
size_t EmbeddingTableCount() { return num_embedding_tables; }

// Get the state of the model that corresponds to this instance.
ModelState* StateForModel() const { return model_state_; }
Expand Down Expand Up @@ -1510,6 +1512,7 @@ class ModelInstanceState {
const std::string name_;
const TRITONSERVER_InstanceGroupKind kind_;
const int32_t device_id_;
size_t num_embedding_tables;

// HugeCTR Model buffer for input and output
// There buffers will be shared for all the requests
Expand Down Expand Up @@ -1598,7 +1601,10 @@ ModelInstanceState::ModelInstanceState(
HCTR_TRITON_LOG(INFO, "Categorical Row Index buffer allocation: ");
row_ptr_buf = HugeCTRBuffer<int>::create();
std::vector<size_t> row_ptrs_dims = {static_cast<size_t>(
model_state_->BatchSize() * model_state_->SlotNum() + 1)};
model_state_->BatchSize() * model_state_->SlotNum() +
model_state_->GetEmbeddingCache(device_id_)
->get_cache_config()
.num_emb_table_)};
row_ptr_buf->reserve(row_ptrs_dims);
row_ptr_buf->allocate();

Expand Down Expand Up @@ -1626,6 +1632,7 @@ ModelInstanceState::LoadHugeCTRModel()
INFO, "The model origin json configuration file path is: ",
model_state_->HugeCTRJsonConfig());
embedding_cache = model_state_->GetEmbeddingCache(device_id_);
num_embedding_tables = embedding_cache->get_cache_config().num_emb_table_;
hugectrmodel_ = HugeCTR::HugeCTRModel::load_model(
type, model_state_->HugeCTRJsonConfig(), instance_params_,
embedding_cache);
Expand Down Expand Up @@ -2227,11 +2234,8 @@ TRITONBACKEND_ModelInstanceExecute(
TRITONBACKEND_Output* output;

numofdes = des_byte_size / sizeof(float);
if (instance_state->StateForModel()->SupportLongEmbeddingKey()) {
numofcat = cat_byte_size / sizeof(long long);
} else {
numofcat = cat_byte_size / sizeof(unsigned int);
}
numofcat = row_byte_size / sizeof(int);


if (instance_state->StateForModel()->DeseNum() != 0 &&
numofdes % instance_state->StateForModel()->DeseNum() != 0) {
Expand All @@ -2243,7 +2247,9 @@ TRITONBACKEND_ModelInstanceExecute(
"configuration. The input sample size to be an integer "
"multiple of the configuration."));
}
if (numofcat % instance_state->StateForModel()->CatNum() != 0) {
if ((numofcat - instance_state->EmbeddingTableCount()) %
instance_state->StateForModel()->SlotNum() !=
0) {
GUARDED_RESPOND_IF_ERROR(
responses, r,
TRITONSERVER_ErrorNew(
Expand All @@ -2257,8 +2263,9 @@ TRITONBACKEND_ModelInstanceExecute(
floor(numofdes / instance_state->StateForModel()->DeseNum());
}

num_of_sample_cat =
floor(numofcat / instance_state->StateForModel()->CatNum());
num_of_sample_cat = floor(
(numofcat - instance_state->EmbeddingTableCount()) /
instance_state->StateForModel()->SlotNum());

if (instance_state->StateForModel()->DeseNum() != 0 &&
num_of_sample_des != num_of_sample_cat) {
Expand Down
2 changes: 1 addition & 1 deletion test/CI.DockerFile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM gitlab-master.nvidia.com:5005/dl/hugectr/hugectr:devel_inference
# for testing
ARG HUGECTR_BRANCH=v3.4-integration
ARG HUGECTR_BRANCH=v3.5-integration
ARG INFERENCE_BRANCH=main
ARG TRITON_BRANCH=r21.09
ARG INFERENCE_MODE=ON
Expand Down

0 comments on commit fc9ddc4

Please sign in to comment.