Skip to content

Commit

Permalink
refactor: Apply linting
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 7, 2021
1 parent b3589c5 commit 24c3a22
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 21 deletions.
3 changes: 1 addition & 2 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
LOG_INFO(*g << "(LoweringGraph)\n");

// segment the graph and convert segmented TensorRT block
auto segmented_blocks =
partitioning::Partition(g, convert_cfg.input_ranges, cfg.partition_info);
auto segmented_blocks = partitioning::Partition(g, convert_cfg.input_ranges, cfg.partition_info);
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
return mod;
}
Expand Down
5 changes: 4 additions & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
<< "please report this error to https://www.github.com/NVIDIA/TRTorch/issues");
}

void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<ir::InputRange>& input_dims) {
void AddInputs(
ConversionCtx* ctx,
at::ArrayRef<const torch::jit::Value*> inputs,
std::vector<ir::InputRange>& input_dims) {
std::vector<const torch::jit::Value*> input_tensors;
for (auto in : inputs) {
// Disregarding inputs that are not tensors
Expand Down
10 changes: 6 additions & 4 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
if (isIValue()) {
LOG_DEBUG(ctx->logger, "Found IValue containing object of type " << *(ptr_.ivalue->type()));
}

TRTORCH_CHECK(
isITensor() || (isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isCustomClass())),
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());
Expand All @@ -100,8 +100,10 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
if (ptr_.ivalue->isTensor()) {
auto weights = converters::Weights();
auto tensor = ptr_.ivalue->toTensor();
if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) && !ctx->settings.truncate_long_and_double) {
TRTORCH_THROW_ERROR("Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled");
if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) &&
!ctx->settings.truncate_long_and_double) {
TRTORCH_THROW_ERROR(
"Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled");
} else if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, tensor.toType(at::kInt));
LOG_WARNING("Truncating weight (constant in the graph) from Int64 to Int32");
Expand All @@ -111,7 +113,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
} else {
weights = converters::Weights(ctx, tensor);
}

auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");
out = const_layer->getOutput(0);
Expand Down
8 changes: 4 additions & 4 deletions core/partitioning/PartitionInfo.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#pragma once

#include <cstdint>
#include <vector>
#include <string>
#include <vector>

namespace trtorch {
namespace core {
namespace partitioning {

struct PartitionInfo {
bool enabled = false;
uint64_t min_block_size = 1;
std::vector<std::string> forced_fallback_operators;
bool enabled = false;
uint64_t min_block_size = 1;
std::vector<std::string> forced_fallback_operators;
};

std::ostream& operator<<(std::ostream& os, const PartitionInfo& s);
Expand Down
2 changes: 0 additions & 2 deletions core/partitioning/SegmentedBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ torch::jit::Node* SegmentedBlock::cloneNode(torch::jit::Node* node) {
return new_node;
}



} // namespace partitioning
} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion core/partitioning/SegmentedBlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,6 @@ struct SegmentedBlock {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
};

} // namespace ir
} // namespace partitioning
} // namespace core
} // namespace trtorch
6 changes: 1 addition & 5 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr
return;
}

std::vector<SegmentedBlock> segment_graph(
std::shared_ptr<torch::jit::Graph> g,
const PartitionInfo& partition_info) {
std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, const PartitionInfo& partition_info) {
auto min_block_size = partition_info.min_block_size;
std::unordered_set<std::string> forced_fallback_operators(
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());
Expand Down Expand Up @@ -199,12 +197,10 @@ std::vector<SegmentedBlock> segment_graph(
return std::move(segmented_blocks);
}


std::vector<SegmentedBlock> Partition(
std::shared_ptr<torch::jit::Graph> g,
std::vector<ir::InputRange>& input_ranges,
const PartitionInfo& partition_info) {

LOG_DEBUG(partition_info);
// segment lowering global graph into blocks
std::vector<SegmentedBlock> segmented_blocks = segment_graph(g, partition_info);
Expand Down
3 changes: 1 addition & 2 deletions core/partitioning/shape_analysis.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "core/partitioning/SegmentedBlock.h"
#include "core/ir/ir.h"

#include "core/partitioning/SegmentedBlock.h"

namespace trtorch {
namespace core {
Expand Down

0 comments on commit 24c3a22

Please sign in to comment.