diff --git a/core/conversion/converters/impl/layer_norm.cpp b/core/conversion/converters/impl/layer_norm.cpp index bedc5e1a0f..ab8f337fdb 100644 --- a/core/conversion/converters/impl/layer_norm.cpp +++ b/core/conversion/converters/impl/layer_norm.cpp @@ -117,12 +117,31 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(). } auto power = Weights(ctx, at::ones(expand_size)); - auto scale_nd = ctx->net->addScaleNd( - *div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1); - scale_nd->setName((util::node_info(n) + "_scale_nd").c_str()); - auto scale_nd_out = scale_nd->getOutput(0); - ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out); + auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0); + auto scale_l = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str()); + + auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0); + auto shift_l = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kSUM, + scale_l->getOutput(0), + beta_tensor, + (util::node_info(n) + "_shift").c_str()); + + auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0); + auto power_l = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPOW, + shift_l->getOutput(0), + power_tensor, + (util::node_info(n) + "_power").c_str()); + + power_l->setName((util::node_info(n) + "_scale_nd").c_str()); + auto power_l_out = power_l->getOutput(0); + + ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out); return true; }}); diff --git a/core/conversion/converters/impl/matrix_multiply.cpp b/core/conversion/converters/impl/matrix_multiply.cpp index e5e35a4ddc..7580da73d7 100644 --- a/core/conversion/converters/impl/matrix_multiply.cpp +++ b/core/conversion/converters/impl/matrix_multiply.cpp @@ -1,3 +1,4 @@ +#include "core/conversion/converters/converter_util.h" #include "core/conversion/converters/converters.h" #include "core/util/prelude.h" @@ -16,10 +17,12 @@ auto mm_registrations TRTORCH_UNUSED = LOG_DEBUG("self tensor shape: " << self->getDimensions()); auto other = args[1].ITensorOrFreeze(ctx); - LOG_DEBUG("other tensor shape: " << other->getDimensions()); + // "other" tensor should have same nbDims as self + auto wt_tensor = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false); + LOG_DEBUG("other tensor shape: " << wt_tensor->getDimensions()); auto mm_layer = ctx->net->addMatrixMultiply( - *self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE); + *self, nvinfer1::MatrixOperation::kNONE, *wt_tensor, nvinfer1::MatrixOperation::kNONE); TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n); mm_layer->setName(util::node_info(n).c_str()); auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0)); @@ -73,4 +76,4 @@ auto mm_registrations TRTORCH_UNUSED = } // namespace converters } // namespace conversion } // namespace core -} // namespace trtorch \ No newline at end of file +} // namespace trtorch diff --git a/core/lowering/passes/linear_to_addmm.cpp b/core/lowering/passes/linear_to_addmm.cpp index 67e2b79a65..8dea042bc1 100644 --- a/core/lowering/passes/linear_to_addmm.cpp +++ b/core/lowering/passes/linear_to_addmm.cpp @@ -1,23 +1,55 @@ -#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include +#include "torch/csrc/jit/ir/alias_analysis.h" +#include "torch/csrc/jit/jit_log.h" +#include "torch/csrc/jit/passes/constant_propagation.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/guard_elimination.h" +#include "torch/csrc/jit/passes/peephole.h" +#include "torch/csrc/jit/runtime/graph_executor.h" #include "core/util/prelude.h" +#include "torch/csrc/jit/passes/subgraph_rewrite.h" namespace trtorch { namespace core { namespace lowering { namespace passes { +void replaceLinearWithBiasNonePattern(std::shared_ptr graph) { + // Define the decomposition function for aten::linear for the case where bias (mat2) is None. + static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT( + def linear(self: Tensor, mat1: Tensor, mat2: Tensor): + return torch.matmul(self, mat1.t()) + )SCRIPT"); + + // Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case) + auto block = graph->block(); + for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) { + auto n = *it; + if (n->kind().toQualString() == std::string("aten::linear")) { + auto input_values = n->inputs(); + // input_values[2] is the bias. If none, replace it with the decomposed linear graph. + if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) { + continue; + } else { + torch::jit::WithInsertPoint guard(*it); + std::shared_ptr d_graph = decompose_funcs.get_function("linear").graph(); + torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0); + new_output->setType(it->output()->type()); + it->output()->replaceAllUsesWith(new_output); + it.destroyCurrent(); + } + } + } +} + void LinearToAddMM(std::shared_ptr& graph) { // TensorRT implicitly adds a flatten layer infront of FC layers if necessary std::string flatten_linear_pattern = R"IR( graph(%input, %weight, %bias): %res = aten::linear(%input, %weight, %bias) return (%res))IR"; - std::string flatten_linear_bias_none_pattern = R"IR( - graph(%input, %weight): - %bias: Tensor? = prim::Constant() - %res = aten::linear(%input, %weight, %bias) - return (%res))IR"; std::string fused_linear = R"IR( graph(%input, %weight_t, %bias): @@ -27,20 +59,13 @@ void LinearToAddMM(std::shared_ptr& graph) { %b_f: Tensor = trt::const(%bias) %out: Tensor = aten::add(%b_f, %mm, %1) return (%out))IR"; - std::string fused_linear_bias_none = R"IR( - graph(%input, %weight_t): - %weight = aten::t(%weight_t) - %mm: Tensor = aten::matmul(%input, %weight) - return (%mm))IR"; + + // First find and replace aten::linear nodes with non-tensor bias values. + replaceLinearWithBiasNonePattern(graph); torch::jit::SubgraphRewriter flatten_linear_to_linear; flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear); flatten_linear_to_linear.runOnGraph(graph); - - torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear; - flatten_linear_bias_none_to_linear.RegisterRewritePattern(flatten_linear_bias_none_pattern, fused_linear_bias_none); - flatten_linear_bias_none_to_linear.runOnGraph(graph); - LOG_GRAPH("Post linear to addmm: " << *graph); } } // namespace passes