diff --git a/core/conversion/converters/impl/lstm_cell.cpp b/core/conversion/converters/impl/lstm_cell.cpp index 6daca547de..edb8388dd0 100755 --- a/core/conversion/converters/impl/lstm_cell.cpp +++ b/core/conversion/converters/impl/lstm_cell.cpp @@ -99,16 +99,57 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() TRTORCH_CHECK(add3, "Unable to create ElementWise layer from node: " << *n); auto add3_out = add3->getOutput(0); - - - - - auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE); - TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n); - mm_layer->setName(util::node_info(n).c_str()); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0)); + // chunk Tensor into 4 parts and apply activation functions + auto dims = util::toVec(add3_out->getDimensions()); + auto batch = dims[0]; + auto hidden = dims[1]/4; + + auto size = util::toDims(std::vector({batch, hidden})); + auto stride = util::toDims(std::vector({1, 1})); + + auto slice1 = ctx->net->addSlice(*add3_out, util::toDims(std::vector({0, 0})), size, stride); + TRTORCH_CHECK(slice1, "Unable to create Slice layer from node: " << *n); + auto activ1 = ctx->net->addActivation(*slice1->getOutput(0), nvinfer1::ActivationType::kSIGMOID); + TRTORCH_CHECK(activ1, "Unable to create sigmoid activation layer from node: " << *n); + auto ingate = activ1->getOutput(0); + + auto slice2 = ctx->net->addSlice(*add3_out, util::toDims(std::vector({0, hidden})), size, stride); + TRTORCH_CHECK(slice2, "Unable to create Slice layer from node: " << *n); + auto activ2 = ctx->net->addActivation(*slice2->getOutput(0), nvinfer1::ActivationType::kSIGMOID); + TRTORCH_CHECK(activ2, "Unable to create sigmoid activation layer from node: " << *n); + auto forgetgate = activ2->getOutput(0); + + auto slice3 = ctx->net->addSlice(*add3_out, util::toDims(std::vector({0, 2*hidden})), size, stride); + TRTORCH_CHECK(slice3, "Unable to create Slice layer from node: " << *n); + auto activ3 = ctx->net->addActivation(*slice3->getOutput(0), nvinfer1::ActivationType::kTANH); + TRTORCH_CHECK(activ3, "Unable to create tanh activation layer from node: " << *n); + auto cellgate = activ3->getOutput(0); + + auto slice4 = ctx->net->addSlice(*add3_out, util::toDims(std::vector({0, 3*hidden})), size, stride); + TRTORCH_CHECK(slice4, "Unable to create Slice layer from node: " << *n); + auto activ4 = ctx->net->addActivation(*slice4->getOutput(0), nvinfer1::ActivationType::kSIGMOID); + TRTORCH_CHECK(activ4, "Unable to create sigmoid activation layer from node: " << *n); + auto outgate = activ4->getOutput(0); + + // compute cy + auto forget_cx = ctx->net->addElementWise(*forgetgate, *state[1], nvinfer1::ElementWiseOperation::kPROD); + TRTORCH_CHECK(forget_cx, "Unable to create ElementWise layer from node: " << *n); + auto in_cell = ctx->net->addElementWise(*ingate, *cellgate, nvinfer1::ElementWiseOperation::kPROD); + TRTORCH_CHECK(in_cell, "Unable to create ElementWise layer from node: " << *n); + auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); + TRTORCH_CHECK(cy, "Unable to create ElementWise layer from node: " << *n); + auto cy_out = ctx->AssociateValueAndTensor(n->outputs()[1], cy->getOutput(0)); + + // compute hy + auto cy_tanh = ctx->net->addActivation(*cy_out, nvinfer1::ActivationType::kTANH); + TRTORCH_CHECK(cy_tanh, "Unable to create tanh activation layer from node: " << *n); + auto hy = ctx->net->addElementWise(*outgate, *cy_tanh->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); + TRTORCH_CHECK(hy, "Unable to create ElementWise layer from node: " << *n); + auto hy_out = ctx->AssociateValueAndTensor(n->outputs()[0], hy->getOutput(0)); + + LOG_DEBUG("Output tensor [hy] shape: " << hy_out->getDimensions()); + LOG_DEBUG("Output tensor [cy] shape: " << cy_out->getDimensions()); - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; } });