Skip to content

Commit

Permalink
fix tensor dimensions
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Bielak <jbielak@nvidia.com; git config --global format.signoff true>
  • Loading branch information
Jan Bielak committed Jul 31, 2023
1 parent 65bd1cd commit c8aff5f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/sequential_new/common_back/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ def describe_supplementary_tensors_training(
) -> dict[str, TensorDescriptor]:
return {
"act": TensorDescriptor(self.input_shape, None, self.output_type),
"mu": TensorDescriptor((self.features,), None, DType.FP32),
"rsigma": TensorDescriptor((self.features,), None, DType.FP32),
"mu": TensorDescriptor((self.input_shape[-2],), None, DType.FP32),
"rsigma": TensorDescriptor((self.input_shape[-2],), None, DType.FP32),
}

def describe_supplementary_tensors_inference(
Expand Down

0 comments on commit c8aff5f

Please sign in to comment.