Skip to content

Commit

Permalink
feature: add baddbmm_ to torch (ivy-llc#21228)
Browse files Browse the repository at this point in the history
Co-authored-by: paulaehab<eng.paulaehab@gmail.com>
  • Loading branch information
RawaaRajab authored Aug 10, 2023
1 parent 835e436 commit 7ec637b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
6 changes: 6 additions & 0 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,12 @@ def baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
self, batch1=batch1, batch2=batch2, beta=beta, alpha=alpha
)

def baddbmm_(self, batch1, batch2, *, beta=1, alpha=1):
self.ivy_array = torch_frontend.baddbmm(
self, batch1=batch1, batch2=batch2, beta=beta, alpha=alpha
).ivy_array
return self

def bmm(self, mat2):
return torch_frontend.bmm(self, mat2=mat2)

Expand Down
49 changes: 49 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10578,6 +10578,55 @@ def test_torch_baddbmm(
)


@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="baddbmm_",
dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True),
beta=st.floats(
min_value=-5,
max_value=5,
allow_nan=False,
allow_subnormal=False,
allow_infinity=False,
),
alpha=st.floats(
min_value=-5,
max_value=5,
allow_nan=False,
allow_subnormal=False,
allow_infinity=False,
),
)
def test_torch_baddbmm_(
dtype_and_matrices,
beta,
alpha,
frontend,
frontend_method_data,
init_flags,
method_flags,
on_device,
):
input_dtype, x, batch1, batch2 = dtype_and_matrices
helpers.test_frontend_method(
init_input_dtypes=input_dtype,
init_all_as_kwargs_np={"data": x[0]},
method_input_dtypes=input_dtype,
method_all_as_kwargs_np={
"batch1": batch1,
"batch2": batch2,
"beta": beta,
"alpha": alpha,
},
frontend=frontend,
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
on_device=on_device,
)


@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
Expand Down

0 comments on commit 7ec637b

Please sign in to comment.