Skip to content

Commit

Permalink
add addmv_ to torch tensor (ivy-llc#21200)
Browse files Browse the repository at this point in the history
Co-authored-by: Felix Hirwa Nshuti <hirwanshutiflx@gmail.com>
  • Loading branch information
a0m0rajab and fnhirwa authored Aug 6, 2023
1 parent e1339a7 commit 356e6e5
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ def addmm_(self, mat1, mat2, *, beta=1, alpha=1):
def addmv(self, mat, vec, *, beta=1, alpha=1):
return torch_frontend.addmv(self, mat, vec, beta=beta, alpha=alpha)

@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, "torch")
def addmv_(self, mat, vec, *, beta=1, alpha=1):
self.ivy_array = torch_frontend.addmv(
self, mat, vec, beta=beta, alpha=alpha
).ivy_array
return self

@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, "torch")
def addbmm(self, batch1, batch2, *, beta=1, alpha=1):
return torch_frontend.addbmm(self, batch1, batch2, beta=beta, alpha=alpha)
Expand Down
55 changes: 55 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,61 @@ def test_torch_addmv(
)


# addmv_
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="addmv_",
dtype_and_matrices=_get_dtype_input_and_mat_vec(with_input=True),
beta=st.floats(
min_value=-5,
max_value=5,
allow_nan=False,
allow_subnormal=False,
allow_infinity=False,
),
alpha=st.floats(
min_value=-5,
max_value=5,
allow_nan=False,
allow_subnormal=False,
allow_infinity=False,
),
)
def test_torch_addmv_(
dtype_and_matrices,
beta,
alpha,
frontend,
frontend_method_data,
init_flags,
method_flags,
on_device,
backend_fw,
):
input_dtype, x, mat, vec = dtype_and_matrices
helpers.test_frontend_method(
init_input_dtypes=input_dtype,
init_all_as_kwargs_np={
"data": x,
},
method_input_dtypes=input_dtype,
backend_to_test=backend_fw,
method_all_as_kwargs_np={
"mat": mat,
"vec": vec,
"beta": beta,
"alpha": alpha,
},
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
frontend=frontend,
atol_=1e-02,
on_device=on_device,
)


# addbmm
@handle_frontend_method(
class_tree=CLASS_TREE,
Expand Down

0 comments on commit 356e6e5

Please sign in to comment.