Skip to content

Commit

Permalink
Merge pull request #1201 from pytorch/squashed_collections
Browse files Browse the repository at this point in the history
feat: support for grouped inputs
  • Loading branch information
narendasan authored Aug 9, 2022
2 parents b62df15 + 223dfd1 commit 48a7f28
Show file tree
Hide file tree
Showing 44 changed files with 1,694 additions and 326 deletions.
94 changes: 61 additions & 33 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ GraphAndMapping ConstructFallbackGraph(
// update the input ranges for each segments
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);

// TODO mapping Inputs Ivalue to flatten one here
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
Expand Down Expand Up @@ -306,57 +307,80 @@ void MapInputsAndDetermineDTypes(
CompileSpec& cfg,
std::shared_ptr<torch::jit::Graph>& g,
ir::StaticParams& static_params,
ir::TypeMap& first_use_type_map) {
// Associate input specs with inputs
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));

for (auto& in : g->inputs()) {
if (static_params.find(in) == static_params.end()) {
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
auto est_type_opt = first_use_type_map.find(in)->second;
if (est_type_opt && !spec.dtype_is_user_defined) {
ir::CollectionTypeMap& first_use_type_map) {
cfg.convert_info.collection_input_spec_map =
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));

auto collection_inputs = ir::get_collection_inputs(g, static_params);
LOG_DEBUG(
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
<< g->inputs().size() << ", CollectionInputSpecMap size is" << collection_inputs.size());

for (auto in : collection_inputs) {
std::vector<ir::Input>& spec = cfg.convert_info.collection_input_spec_map.find(in)->second;
std::vector<c10::optional<at::ScalarType>> est_type_opt;

auto est_it = first_use_type_map.find(in);
if (est_it != first_use_type_map.end()) {
est_type_opt = first_use_type_map.find(in)->second;
}
// traverse elements in est_type_out and spec
for (size_t i = 0; i < est_type_opt.size(); i++) {
if (est_type_opt[i] && !spec[i].dtype_is_user_defined) {
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
// type
LOG_INFO(
"Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
<< in->debugName() << " has type " << est_type_opt.value()
<< ". If this is incorrect explicitly set dtype for input and file a bug");
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
<< in->debugName() << " has type " << est_type_opt[i].value());
spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value());
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
// If we cannot calculate the type and the user did not define the type, then default to FP32
LOG_WARNING(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec.dtype = nvinfer1::DataType::kFLOAT;
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
if (!est_type_opt) {
LOG_INFO("Cannot infer input tensor dtype in graph. Using user provided input dtype settings");
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
spec[i].dtype = nvinfer1::DataType::kFLOAT;
} else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) {
if (!est_type_opt[i]) {
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ". The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};

} else {
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
est_type_opt[i].value()) {
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.inputs.find(in)->second.dtype;
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt.value() << std::endl;
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
ss << est_type_opt[i].value() << std::endl;
ss << "The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
}
// Overwrite type map with user settings
// We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
}
} else {
// The user defined the type so no changes are necessary
}
}
}
// }
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
Expand All @@ -370,7 +394,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

Expand All @@ -395,23 +419,26 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto outputIsCollection = conversion::OutputIsCollection(g->block());
if (cfg.partition_info.enabled &&
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
LOG_INFO("Skipping partitioning since model is fully supported");
}

if (cfg.partition_info.enabled &&
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
outputIsCollection)) {
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
auto graph_and_mapping =
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
auto collection_input_ivalues_map =
partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
auto graph_and_mapping = ConstructFallbackGraph(
new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
new_g = graph_and_mapping.first;
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
Expand All @@ -429,6 +456,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
TORCHTRT_CHECK(
conversion::VerifyConverterSupportForBlock(g->block()),
"Not all operations in graph are supported by the compiler");
// TODO find the right
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
}
Expand Down
6 changes: 4 additions & 2 deletions core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
#include "core/partitioning/partitioning.h"
#include "core/runtime/runtime.h"
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/ir/ir.h"

namespace torch_tensorrt {
namespace core {

struct CompileSpec {
CompileSpec(std::vector<ir::Input> inputs) : inputs(inputs) {}
std::vector<ir::Input> inputs;
CompileSpec(std::vector<ir::Input> inputs) : graph_inputs(inputs) {}
CompileSpec(torch::jit::IValue& input_signature) : graph_inputs(input_signature) {}
ir::GraphInputs graph_inputs;
conversion::ConversionInfo convert_info;
lowering::LowerInfo lower_info;
partitioning::PartitionInfo partition_info;
Expand Down
31 changes: 24 additions & 7 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
<< "please report this error to https://www.github.com/NVIDIA/Torch-TensorRT/issues");
}

void AddInputs(
ConversionCtx* ctx,
c10::ArrayRef<const torch::jit::Value*> inputs,
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs) {
void AddInputs(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> inputs, ConversionInfo& conversion_info) {
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs = conversion_info.inputs;
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>> collection_input_spec =
conversion_info.collection_input_spec_map;

std::vector<const torch::jit::Value*> input_tensors;
for (auto in : inputs) {
// Disregarding inputs that are not tensors
Expand Down Expand Up @@ -166,9 +167,15 @@ void AddInputs(
for (auto input : input_tensors) {
const torch::jit::Value* in = input;
TORCHTRT_CHECK(
input_specs.find(in) != input_specs.end(),
input_specs.find(in) != input_specs.end() || collection_input_spec.find(in) != collection_input_spec.end(),
"Cannot find an input spec associated with input: " << in->debugName());
ir::Input& spec = input_specs.find(in)->second;
ir::Input spec;
if (input_specs.find(in) != input_specs.end()) {
spec = input_specs.find(in)->second;
} else {
spec = collection_input_spec.find(in)->second[0]; // assume input is tensor
}
// ir::Input& spec = input_specs.find(in)->second;

std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
LOG_INFO(
Expand Down Expand Up @@ -408,7 +415,7 @@ void ConvertBlockToNetDef(

auto inputs = b->inputs();
AddParamsToCtxValueMap(ctx, static_params);
AddInputs(ctx, inputs, build_info.inputs);
AddInputs(ctx, inputs, build_info);

auto nodes = b->nodes();

Expand Down Expand Up @@ -549,6 +556,16 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
return convertable_ops;
}

bool OutputIsCollection(const torch::jit::Block* b) {
for (auto out : b->outputs()) {
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
out->type()->kind() == torch::jit::TypeKind::ListType) {
return true;
}
}
return false;
}

bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) {
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
if (unsupported_ops.size() != 0) {
Expand Down
3 changes: 3 additions & 0 deletions core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace conversion {

struct ConversionInfo {
ir::InputSpecMap inputs;
ir::CollectionInputSpecMap collection_input_spec_map;
BuilderSettings engine_settings;
};

Expand All @@ -25,6 +26,8 @@ std::string ConvertBlockToEngine(

bool OpSupported(const torch::jit::Node* n);

bool OutputIsCollection(const torch::jit::Block* b);

bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);

c10::optional<torch::jit::IValue> EvaluateNode(
Expand Down
7 changes: 7 additions & 0 deletions core/conversion/converters/converter_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ nvinfer1::ILayer* add_elementwise(
nvinfer1::ITensor* self,
nvinfer1::ITensor* other,
const std::string& name) {
if (self->getType() == nvinfer1::DataType::kFLOAT && other->getType() == nvinfer1::DataType::kINT32) {
LOG_DEBUG("Type mismatch, casting other to " << self->getType());
other = castITensor(ctx, other, self->getType());
} else if (self->getType() == nvinfer1::DataType::kINT32 && other->getType() == nvinfer1::DataType::kFLOAT) {
LOG_DEBUG("Type mismatch, casting self to " << other->getType());
self = castITensor(ctx, self, other->getType());
}
// ensure self to have larger number of dimension
bool swapSelfOther = false;
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
// Should implement self * other
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);

auto mul =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);
Expand All @@ -426,6 +427,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());

auto mul =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);
Expand Down
8 changes: 0 additions & 8 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@ namespace conversion {
namespace evaluators {
namespace {

int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) {
// Handle negative indexing
idx = list_size + idx;
}
return idx;
}

DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
eq,
"aten::eq",
Expand Down
9 changes: 9 additions & 0 deletions core/conversion/evaluators/eval_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ namespace core {
namespace conversion {
namespace evaluators {

int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) {
// Handle negative indexing
idx = list_size + idx;
}
return idx;
}


// TODO: Switch back to PyTorch canonical implimentation
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast<c10::FunctionType>()) {
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/evaluators/eval_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ at::Tensor createTensorFromList(
const torch::jit::IValue& dtype,
const torch::jit::IValue& device);

int64_t normalizeIndex(int64_t idx, int64_t list_size);

at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU);

} // namespace evaluators
Expand Down
Loading

0 comments on commit 48a7f28

Please sign in to comment.