diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 7348dfe6c7..ad5668a19a 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -9,17 +9,24 @@ namespace core { namespace conversion { std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) { - os << "Settings requested for TensorRT engine:" \ - << "\n Operating Precision: " << s.op_precision \ - << "\n Make Refittable Engine: " << s.refit \ - << "\n Debuggable Engine: " << s.debug \ - << "\n Strict Type: " << s.strict_type \ - << "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \ - << "\n Min Timing Iterations: " << s.num_min_timing_iters \ - << "\n Avg Timing Iterations: " << s.num_avg_timing_iters \ - << "\n Max Workspace Size: " << s.workspace_size \ - << "\n Device Type: " << s.device \ - << "\n Engine Capability: " << s.capability \ + os << "Settings requested for TensorRT engine:" \ + << "\n Operating Precision: " << s.op_precision \ + << "\n Make Refittable Engine: " << s.refit \ + << "\n Debuggable Engine: " << s.debug \ + << "\n Strict Type: " << s.strict_types \ + << "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \ + << "\n Min Timing Iterations: " << s.num_min_timing_iters \ + << "\n Avg Timing Iterations: " << s.num_avg_timing_iters \ + << "\n Max Workspace Size: " << s.workspace_size; + + if (s.max_batch_size != 0) { + os << "\n Max Batch Size: " << s.max_batch_size; + } else { + os << "\n Max Batch Size: Not set"; + } + + os << "\n Device Type: " << s.device \ + << "\n Engine Capability: " << s.capability \ << "\n Calibrator Created: " << (s.calibrator != nullptr); return os; } @@ -62,7 +69,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) cfg->setFlag(nvinfer1::BuilderFlag::kDEBUG); } - if (settings.strict_type) { + if (settings.strict_types) { cfg->setFlag(nvinfer1::BuilderFlag::kSTRICT_TYPES); } @@ -70,6 +77,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) cfg->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); } + if (settings.max_batch_size != 0) { + builder->setMaxBatchSize(settings.max_batch_size); + } + cfg->setMinTimingIterations(settings.num_min_timing_iters); cfg->setAvgTimingIterations(settings.num_avg_timing_iters); cfg->setMaxWorkspaceSize(settings.workspace_size); diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index d8adef5ff0..b7922a319d 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -20,7 +20,7 @@ struct BuilderSettings { nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT; bool refit = false; bool debug = false; - bool strict_type = false; + bool strict_types = false; bool allow_gpu_fallback = true; nvinfer1::DeviceType device = nvinfer1::DeviceType::kGPU; nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT; @@ -28,6 +28,7 @@ struct BuilderSettings { uint64_t num_min_timing_iters = 2; uint64_t num_avg_timing_iters = 1; uint64_t workspace_size = 0; + uint64_t max_batch_size = 0; BuilderSettings() = default; BuilderSettings(const BuilderSettings& other) = default; diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index 14d36cdbfc..46014186ea 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -175,7 +175,7 @@ struct TRTORCH_API ExtraInfo { /** * Restrict operating type to only set default operation precision (op_precision) */ - bool strict_type = false; + bool strict_types = false; /** * (Only used when targeting DLA (device)) @@ -205,7 +205,12 @@ struct TRTORCH_API ExtraInfo { /** * Maximum size of workspace given to TensorRT */ - uint64_t workspace_size = 1 << 20; + uint64_t workspace_size = 0; + + /** + * Maximum batch size (must be =< 1 to be set, 0 means not set) + */ + uint64_t max_batch_size = 0; /** * Calibration dataloaders for each input for post training quantizatiom diff --git a/cpp/api/src/extra_info.cpp b/cpp/api/src/extra_info.cpp index 23a47e7801..f7fc5709e9 100644 --- a/cpp/api/src/extra_info.cpp +++ b/cpp/api/src/extra_info.cpp @@ -91,8 +91,9 @@ core::ExtraInfo to_internal_extra_info(ExtraInfo external) { internal.convert_info.engine_settings.refit = external.refit; internal.convert_info.engine_settings.debug = external.debug; - internal.convert_info.engine_settings.strict_type = external.strict_type; + internal.convert_info.engine_settings.strict_types = external.strict_types; internal.convert_info.engine_settings.allow_gpu_fallback = external.allow_gpu_fallback; + internal.convert_info.engine_settings.max_batch_size = external.max_batch_size; switch(external.device) { case ExtraInfo::DeviceType::kDLA: