Skip to content

Commit

Permalink
fix(//tests/cpp): Fix the BERT C++ test
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Aug 15, 2024
1 parent 842d7df commit ee21db0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 0 additions & 2 deletions core/conversion/converters/impl/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
gamma = tensor_to_const(ctx, gamma_torch_tensor);
} else {
gamma = args[2].ITensorOrFreeze(ctx);
// gamma = broadcast(ctx, n, gamma, input_shape_vec.size(), "gamma");
gamma = add_expand(ctx, gamma, input_shape);
}

Expand All @@ -43,7 +42,6 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
beta = tensor_to_const(ctx, beta_torch_tensor);
} else {
beta = args[3].ITensorOrFreeze(ctx);
// beta = broadcast(ctx, n, beta, input_shape_vec.size(), "beta");
beta = add_expand(ctx, beta, input_shape);
}

Expand Down
6 changes: 5 additions & 1 deletion tests/cpp/test_compiled_modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
std::vector<torch::jit::IValue> trt_inputs_ivalues;
std::vector<torch_tensorrt::Input> shapes;
for (uint64_t i = 0; i < input_shapes.size(); i++) {
auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]);
auto in = at::randn(input_shapes[i], {at::kCUDA}).to(input_types[i]);
if (input_types[i] == at::kInt || input_types[i] == at::kLong) {
auto in = at::randint(0, 2, input_shapes[i], {at::kCUDA}).to(input_types[i]);
}

jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
auto in_spec = torch_tensorrt::Input(input_shapes[i]);
Expand Down
4 changes: 2 additions & 2 deletions tests/py/ts/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_efficientnet_b0(self):
)

def test_bert_base_uncased(self):
self.model = cm.BertModule().cuda()
self.model = cm.BertModule()
self.input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda")

compile_spec = {
Expand All @@ -116,7 +116,7 @@ def test_bert_base_uncased(self):
"enabled_precisions": {torch.float},
"truncate_long_and_double": True,
}
with torchtrt.logging.errors():
with torchtrt.logging.debug():
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)

model_outputs = self.model(self.input, self.input)
Expand Down

0 comments on commit ee21db0

Please sign in to comment.