Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton data converter (CMSSW_11_2_0_pre9) #2

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions HeterogeneousCore/SonicTriton/interface/TritonData.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class TritonData {
converterName_ = conf.getParameter<std::string>("converterName");
drankincms marked this conversation as resolved.
Show resolved Hide resolved
}
template <typename DT>
std::unique_ptr<TritonConverterBase<DT>> createConverter() const { return TritonConverterFactory<DT>::get()->create(converterName_); }
void createConverter() const {
drankincms marked this conversation as resolved.
Show resolved Hide resolved
if (!converter_.has_value()) converter_ = std::shared_ptr<TritonConverterBase<DT>>(TritonConverterFactory<DT>::get()->create(converterName_));
drankincms marked this conversation as resolved.
Show resolved Hide resolved
}

//io accessors
template <typename DT>
Expand Down Expand Up @@ -102,7 +104,7 @@ class TritonData {
int64_t byteSize_;
std::any holder_;
std::shared_ptr<Result> result_;
std::any converter_;
mutable std::any converter_;
std::string converterName_;
};

Expand Down
21 changes: 8 additions & 13 deletions HeterogeneousCore/SonicTriton/src/TritonData.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,16 @@ void TritonInputData::toServer(std::shared_ptr<TritonInput<DT>> ptr) {
//shape must be specified for variable dims or if batch size changes
data_->SetShape(fullShape_);

std::unique_ptr<TritonConverterBase<DT>> converter = createConverter<DT>();
createConverter<DT>();
drankincms marked this conversation as resolved.
Show resolved Hide resolved

if (byteSize_ != converter->byteSize())
throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << converter->byteSize()
if (byteSize_ != std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->byteSize())
throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->byteSize()
<< " (should be " << byteSize_ << " for " << dname_ << ")";

int64_t nInput = sizeShape();
for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
const DT* arr = data_in[i0].data();
triton_utils::throwIfError(data_->AppendRaw(converter->convertIn(arr), nInput * byteSize_),
triton_utils::throwIfError(data_->AppendRaw(std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->convertIn(arr), nInput * byteSize_),
name_ + " input(): unable to set data for batch entry " + std::to_string(i0));
}

Expand All @@ -141,7 +141,8 @@ TritonOutput<DT> TritonOutputData::fromServer() const {
throw cms::Exception("TritonDataError") << name_ << " output(): missing result";
}

std::unique_ptr<TritonConverterBase<DT>> converter = createConverter<DT>();
createConverter<DT>();
//std::unique_ptr<TritonConverterBase<DT>> converter = std::any_cast<converter>;

if (byteSize_ != sizeof(DT)) {
throw cms::Exception("TritonDataError") << name_ << " output(): inconsistent byte size " << sizeof(DT)
Expand All @@ -152,14 +153,14 @@ TritonOutput<DT> TritonOutputData::fromServer() const {
TritonOutput<DT> dataOut;
const uint8_t* r0;
size_t contentByteSize;
size_t expectedContentByteSize = nOutput * converter->byteSize() * batchSize_;
size_t expectedContentByteSize = nOutput * std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->byteSize() * batchSize_;
triton_utils::throwIfError(result_->RawData(name_, &r0, &contentByteSize), "output(): unable to get raw");
if (contentByteSize != expectedContentByteSize) {
throw cms::Exception("TritonDataError") << name_ << " output(): unexpected content byte size " << contentByteSize
<< " (expected " << expectedContentByteSize << ")";
}

const DT* r1 = converter->convertOut(r0);
const DT* r1 = std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->convertOut(r0);
dataOut.reserve(batchSize_);
for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
auto offset = i0 * nOutput;
Expand Down Expand Up @@ -188,9 +189,3 @@ template void TritonInputData::toServer(std::shared_ptr<TritonInput<float>> data
template void TritonInputData::toServer(std::shared_ptr<TritonInput<int64_t>> data_in);

template TritonOutput<float> TritonOutputData::fromServer() const;

template std::unique_ptr<TritonConverterBase<float>> TritonInputData::createConverter() const;
template std::unique_ptr<TritonConverterBase<int64_t>> TritonInputData::createConverter() const;

template std::unique_ptr<TritonConverterBase<float>> TritonOutputData::createConverter() const;
template std::unique_ptr<TritonConverterBase<int64_t>> TritonOutputData::createConverter() const;