Skip to content

Commit

Permalink
Add copy_ (pytorch#2432)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#2432

Reviewed By: JacobSzwejbka

Differential Revision: D54909962

fbshipit-source-id: d3cca3e8c6d3406dc502ca62fa58f0bce2a3ae4d
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Mar 14, 2024
1 parent 63a1fde commit 9cb0be6
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
27 changes: 27 additions & 0 deletions kernels/portable/cpu/op_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,33 @@ Tensor& copy_out(
return out;
}

Tensor&
copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) {
(void)ctx;
// Right now we only support blocking data transfer
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, in);

ET_KERNEL_CHECK(
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, in);

ScalarType in_type = in.scalar_type();
ScalarType src_type = src.scalar_type();

ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "copy_", CTYPE, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, src_type, ctx, "copy_", CTYPE_SRC, [&]() {
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
[](const CTYPE val_in, const CTYPE_SRC val_src) {
return convert<CTYPE, CTYPE_SRC>(val_src);
},
in,
src,
in);
});
});

return in;
}

} // namespace native
} // namespace executor
} // namespace torch
5 changes: 5 additions & 0 deletions kernels/portable/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@
- arg_meta: null
kernel_name: torch::executor::copy_out

- op: copy_
kernels:
- arg_meta: null
kernel_name: torch::executor::copy_

- op: cos.out
kernels:
- arg_meta: null
Expand Down
27 changes: 27 additions & 0 deletions kernels/test/op_copy_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ class OpCopyTest : public OperatorTest {
}
};

class OpCopyInplaceTest : public OperatorTest {
protected:
Tensor& op_copy_(Tensor& self, const Tensor& src, bool non_blocking) {
return torch::executor::aten::copy_(context_, self, src, non_blocking);
}
};

// regular test for copy.out
TEST_F(OpCopyTest, AllRealDtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
Expand Down Expand Up @@ -255,3 +262,23 @@ TEST_F(OpCopyTest, DynamicShapeUnbound) {
test_dynamic_shape(
{1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
}

TEST_F(OpCopyInplaceTest, SmokeTest) {
TensorFactory<ScalarType::Int> tf;
Tensor in = tf.zeros({2, 2});
Tensor src = tf.make(/*sizes=*/{2, 2}, /*data=*/{1, 2, 3, 4});
bool non_blocking = false;
op_copy_(in, src, non_blocking);
Tensor expected = tf.make(/*sizes=*/{2, 2}, /*data=*/{1, 2, 3, 4});
EXPECT_TENSOR_EQ(in, expected);
}

TEST_F(OpCopyInplaceTest, BroadCastSrcSupported) {
TensorFactory<ScalarType::Int> tf;
Tensor in = tf.make(/*sizes=*/{2, 2}, /*data=*/{1, 2, 3, 4});
Tensor src = tf.make(/*sizes=*/{1, 2}, /*data=*/{3, 3});
bool non_blocking = false;
op_copy_(in, src, non_blocking);
Tensor expected = tf.make(/*sizes=*/{2, 2}, /*data=*/{3, 3, 3, 3});
EXPECT_TENSOR_EQ(in, expected);
}

0 comments on commit 9cb0be6

Please sign in to comment.