-
Notifications
You must be signed in to change notification settings - Fork 356
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(//py): Initial compiliant implementation of the to_backend api for
PyTorch Users can now use a direct PyTorch integration by just importing the trtorch package. The only difference between torch._C._jit_to_tensorrt and trtorch.compile is that you need to use the trtorch.TensorRTCompileSpec constructor to build a wrapper around your spec dictionary Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
- Loading branch information
1 parent
b24c0d8
commit 59113cf
Showing
15 changed files
with
573 additions
and
133 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#include "tensorrt_classes.h" | ||
|
||
namespace trtorch { | ||
namespace backend { | ||
namespace { | ||
void RegisterTRTCompileSpec() { | ||
#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \ | ||
(registry).def("set_"#field_name, &class_name::set_##field_name); \ | ||
(registry).def("get_"#field_name, &class_name::get_##field_name); | ||
|
||
static auto TRTORCH_UNUSED TRTInputRangeTSRegistrtion = torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange") | ||
.def(torch::init<>()); | ||
|
||
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistrtion, trtorch::pyapi::InputRange, min); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistrtion, trtorch::pyapi::InputRange, opt); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistrtion, trtorch::pyapi::InputRange, max); | ||
|
||
static auto TRTORCH_UNUSED TRTCompileSpecTSRegistrtion = torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec") | ||
.def(torch::init<>()) | ||
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) | ||
.def("__str__", &trtorch::pyapi::CompileSpec::stringify); | ||
|
||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, op_precision); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, refit); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, debug); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, strict_types); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, allow_gpu_fallback); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, device); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, capability); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, num_min_timing_iters); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, num_avg_timing_iters); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, workspace_size); | ||
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistrtion, trtorch::pyapi::CompileSpec, max_batch_size); | ||
} | ||
|
||
struct TRTTSRegistrations { | ||
TRTTSRegistrations() { | ||
RegisterTRTCompileSpec(); | ||
} | ||
}; | ||
|
||
static TRTTSRegistrations register_trt_classes = TRTTSRegistrations(); | ||
} | ||
} // namespace backend | ||
} // namespace trtorch | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#include "torch/csrc/jit/passes/lower_graph.h" | ||
|
||
#include "tensorrt_backend.h" | ||
#include "tensorrt_classes.h" | ||
|
||
#include "core/compiler.h" | ||
#include "core/lowering/lowering.h" | ||
#include "core/runtime/runtime.h" | ||
|
||
namespace trtorch { | ||
namespace backend { | ||
|
||
c10::IValue TensorRTBackend::preprocess(c10::IValue mod, c10::impl::GenericDict method_compile_spec) { | ||
auto mod_ = mod.toModule(); | ||
LOG_DEBUG("Placing module in eval mode if not already"); | ||
mod_.eval(); | ||
mod_ = core::lowering::LowerModule(mod_); | ||
|
||
auto spec = | ||
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec); | ||
|
||
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) { | ||
TRTORCH_CHECK(core::CheckMethodOperatorSupport(mod.toModule(), it->key()), | ||
"Method " << it->key() << "cannot be compiled by TRTorch"); | ||
} | ||
|
||
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) { | ||
const auto& method_name = it->key(); | ||
auto method = mod_.get_method(method_name); | ||
auto graph = method.graph(); | ||
core::lowering::LowerGraph(graph); | ||
} | ||
|
||
return mod_._ivalue(); | ||
} | ||
|
||
c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10::impl::GenericDict method_compile_spec) { | ||
auto mod = processed_mod.toModule(); | ||
auto spec = | ||
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec); | ||
|
||
auto handles = c10::impl::GenericDict(c10::StringType::get(), c10::getCustomClassType<c10::intrusive_ptr<core::runtime::TRTEngine>>()); | ||
|
||
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) { | ||
const auto& method_name = it->key(); | ||
auto method = mod.get_method(method_name); | ||
auto g = method.graph(); | ||
|
||
auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass<trtorch::pyapi::CompileSpec>(); | ||
LOG_DEBUG(raw_spec->stringify()); | ||
auto cfg = raw_spec->toInternalCompileSpec(); | ||
auto convert_cfg = std::move(cfg.convert_info); | ||
auto graph_and_ivalues = torch::jit::LowerGraph(*g, mod._ivalue()); | ||
|
||
g = graph_and_ivalues.first; | ||
auto params = graph_and_ivalues.second; | ||
auto named_params = core::conversion::get_named_params(g->inputs(), params); | ||
|
||
auto serialized_engine = core::conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params); | ||
auto engine_handle = c10::make_intrusive<core::runtime::TRTEngine>(it->key(), serialized_engine); | ||
handles.insert(method.name(), at::IValue(engine_handle)); | ||
} | ||
|
||
return c10::impl::toGenericDict(handles); | ||
} | ||
|
||
|
||
c10::impl::GenericList TensorRTBackend::execute(c10::IValue handle, c10::impl::GenericList inputs) { | ||
TRTORCH_ASSERT(inputs.size() > 0, "Trying to execute on empty list of arguments"); | ||
auto engine = handle.toCustomClass<core::runtime::TRTEngine>(); | ||
std::vector<at::Tensor> in_vec; | ||
for (size_t i = 0, e = inputs.size(); i < e; ++i) { | ||
c10::IValue val = inputs[i]; | ||
TRTORCH_CHECK(val.isTensor(), "TensorRT currently only accepts Tensors as inputs"); | ||
in_vec.push_back(val.toTensor()); | ||
} | ||
auto outputs = core::runtime::execute_engine(in_vec, engine); | ||
return c10::impl::toList(c10::List<at::Tensor>(outputs)); | ||
} | ||
|
||
namespace { | ||
static auto reg = torch::jit::backend<TensorRTBackend>("tensorrt"); | ||
} | ||
|
||
} // namespace backend | ||
} // namespace trtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#pragma once | ||
#include "torch/csrc/jit/api/module.h" | ||
#include "torch/csrc/jit/backends/backend.h" | ||
|
||
namespace trtorch { | ||
namespace backend { | ||
|
||
class TensorRTBackend: public torch::jit::PyTorchBackendInterface { | ||
public: | ||
explicit TensorRTBackend() {} | ||
virtual ~TensorRTBackend() = default; | ||
|
||
c10::IValue preprocess(c10::IValue mod, c10::impl::GenericDict method_compile_spec) override; | ||
c10::impl::GenericDict compile(c10::IValue processed_mod, c10::impl::GenericDict method_compile_spec) override; | ||
c10::impl::GenericList execute(c10::IValue handle, c10::impl::GenericList inputs) override; | ||
}; | ||
|
||
} // namespace backend | ||
} // namespace trtorch |
Oops, something went wrong.