-
Notifications
You must be signed in to change notification settings - Fork 355
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(//core/conversion/var): created ITensorOrFreeze() method, to rep…
…lace functionality of Var::ITensor() Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): updates to some comments on the PR Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressed PR comment Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressed PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressed PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressing PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressing PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): Addressed PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): bug in test_serialization, need to fix Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> Delete converters.h.orig delete .orig file Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Update activation.cpp addressing PR comments Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Delete converters.h.orig delete .orig file Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Update activation.cpp addressing PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com>
- Loading branch information
1 parent
362c932
commit 2ccf8d0
Showing
17 changed files
with
167 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#pragma once | ||
|
||
#include "core/util/prelude.h" | ||
#include "core/conversion/conversionctx/ConversionCtx.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace conversion { | ||
namespace converters { | ||
|
||
struct Weights { | ||
nvinfer1::Weights data; | ||
nvinfer1::Dims kernel_shape; | ||
nvinfer1::Dims shape; | ||
int64_t num_input_maps; | ||
int64_t num_output_maps; | ||
|
||
Weights(); | ||
Weights(ConversionCtx* ctx, at::Tensor t); | ||
Weights(ConversionCtx* ctx, float val); | ||
Weights(ConversionCtx* ctx, int32_t val); | ||
friend std::ostream& operator<<(std::ostream& os, const Weights& w); | ||
}; | ||
|
||
inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) { | ||
auto t_weights = Weights(ctx, t); | ||
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); | ||
TRTORCH_CHECK(const_layer, "Unable to freeze tensor"); | ||
|
||
auto out = const_layer->getOutput(0); | ||
|
||
std::ostringstream tensor_id; | ||
tensor_id << reinterpret_cast<int*>(out); | ||
|
||
LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer"); | ||
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str()); | ||
|
||
return out; | ||
} | ||
|
||
|
||
} // namespace converters | ||
} // namespace conversion | ||
} // namespace core | ||
} // namespace trtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.