From 69e49e8be48a4a816f430821772183367308aba2 Mon Sep 17 00:00:00 2001 From: inocsin Date: Mon, 22 Mar 2021 13:58:03 +0800 Subject: [PATCH] feat: update truncate long/double python api Signed-off-by: inocsin --- core/conversion/var/Var.cpp | 8 +++++--- py/trtorch/_compile_spec.py | 6 ++++++ py/trtorch/csrc/register_tensorrt_classes.cpp | 1 + py/trtorch/csrc/tensorrt_classes.cpp | 2 ++ py/trtorch/csrc/tensorrt_classes.h | 2 ++ py/trtorch/csrc/trtorch_py.cpp | 3 ++- 6 files changed, 18 insertions(+), 4 deletions(-) diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index ad31b50d8f..e743b3df5a 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -97,12 +97,14 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) { auto weights = converters::Weights(); if (isIValue()) { auto tensor = ptr_.ivalue->toTensor(); - if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) { + if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) && !ctx->settings.truncate_long_and_double) { + TRTORCH_CHECK(0, "Unable to freeze tensor of type kLong/kDouble into constant layer, try to compile model with truncate_long_and_double ON"); + } else if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) { weights = converters::Weights(ctx, tensor.toType(at::kInt)); - LOG_WARNING("Truncate kLong to kInt for IValue"); + LOG_WARNING("Warning: Truncating weight (constant in the graph) from kLong to kInt to indicate that only constants are affected."); } else if (tensor.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) { weights = converters::Weights(ctx, tensor.toType(at::kFloat)); - LOG_WARNING("Truncate kDouble to kFloat for IValue"); + LOG_WARNING("Warning: Truncating weight (constant in the graph) from kDouble to kFloat to indicate that only constants are affected."); } else { weights = converters::Weights(ctx, tensor); } diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index 4abab01790..3f1e135234 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -176,6 +176,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: if "max_batch_size" in compile_spec: assert type(compile_spec["max_batch_size"]) is int info.max_batch_size = compile_spec["max_batch_size"] + + if "truncate_long_and_double" in compile_spec: + assert type(compile_spec["truncate_long_and_double"]) is bool + info.truncate_long_and_double = compile_spec["truncate_long_and_double"] return info @@ -217,6 +221,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. "num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels "workspace_size": 0, # Maximum size of workspace given to TensorRT "max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set) + "truncate_long_and_double": False, # Truncate long and double into int and float }) } @@ -257,6 +262,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters) backend_spec.set_workspace_size(parsed_spec.workspace_size) backend_spec.set_max_batch_size(parsed_spec.max_batch_size) + backend_spec.set_truncate_long_and_double(parsed_spec.truncate_long_and_double) backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle()) return backend_spec diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index 36ff34931d..cb40c35c30 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -42,6 +42,7 @@ void RegisterTRTCompileSpec() { ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size); + ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, truncate_long_and_double); } struct TRTTSRegistrations { diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index e0b36c4463..520576edbd 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -108,6 +108,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.convert_info.engine_settings.device.gpu_id = device.gpu_id; info.convert_info.engine_settings.device.dla_core = device.dla_core; info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback; + info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double; info.convert_info.engine_settings.capability = toTRTEngineCapability(capability); TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater"); @@ -143,6 +144,7 @@ std::string CompileSpec::stringify() { ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl; ss << " \"Workspace Size\": " << workspace_size << std::endl; ss << " \"Max Batch Size\": " << max_batch_size << std::endl; + ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl; ss << "}"; return ss.str(); } diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index 5371b93f75..c1eed6537a 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -115,6 +115,7 @@ struct CompileSpec : torch::CustomClassHolder { ADD_FIELD_GET_SET(num_min_timing_iters, int64_t); ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t); ADD_FIELD_GET_SET(workspace_size, int64_t); + ADD_FIELD_GET_SET(truncate_long_and_double, bool); ADD_FIELD_GET_SET(max_batch_size, int64_t); ADD_FIELD_GET_SET(device, Device); ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*); @@ -126,6 +127,7 @@ struct CompileSpec : torch::CustomClassHolder { bool refit = false; bool debug = false; bool strict_types = false; + bool truncate_long_and_double = false; Device device; EngineCapability capability = EngineCapability::kDEFAULT; int64_t num_min_timing_iters = 2; diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index e1a5b14eb4..cb3d1d4e39 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -246,7 +246,8 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("num_min_timing_iters", &CompileSpec::num_min_timing_iters) .def_readwrite("num_avg_timing_iters", &CompileSpec::num_avg_timing_iters) .def_readwrite("workspace_size", &CompileSpec::workspace_size) - .def_readwrite("max_batch_size", &CompileSpec::max_batch_size); + .def_readwrite("max_batch_size", &CompileSpec::max_batch_size) + .def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double); py::class_(m, "Device") .def(py::init<>())