Skip to content

Commit

Permalink
Remove Tensor constructor of Scalar. (#10852)
Browse files Browse the repository at this point in the history
Summary:
This is along the way of removing Tensor as a member of the tagged union in Scalar.  This simplifies ordering dependencies, because currently Scalar and Tensor both depend on each other (so we introduce a TensorBase).  Also, this API isn't particularly useful publicly: we can't autograd through Scalars, so you still need a Tensor overload basically everywhere anyway.

I'm undecided what the final API should be here.  We could keep a Tensor constructor on Scalar, but have it generate a local scalar; this is convenient but given this API used to be non-synchronizing, it may not be the best.

For now, I'm just using _local_scalar, which is clear, although we should get rid of the prefix _ if that's the API we intend to promote.
Pull Request resolved: pytorch/pytorch#10852

Reviewed By: ezyang

Differential Revision: D9496766

Pulled By: gchanan

fbshipit-source-id: 16f39b57536b9707132a5a4d915650c381bb57db
  • Loading branch information
gchanan authored and facebook-github-bot committed Aug 24, 2018
1 parent 2fa4dc6 commit 2bdc0f8
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 40 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/Scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Scalar Scalar::operator-() const {
} else if (isIntegral()) {
return Scalar(-v.i);
} else {
return Scalar(-Tensor(t));
return -Tensor(t)._local_scalar();
}
}

Expand Down
6 changes: 0 additions & 6 deletions aten/src/ATen/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ class AT_API Scalar {
public:
Scalar() : Scalar(int64_t(0)) {}

explicit Scalar(const detail::TensorBase & t)
: tag(Tag::HAS_t), t(t) {
AT_CHECK(t.defined(), "Attempting to create a Scalar from an undefined tensor");
AT_CHECK(t.dim() == 0, "Attempting to create a Scalar from a ", t.dim(), " dim tensor");
}

#define DEFINE_IMPLICIT_CTOR(type,name,member) \
Scalar(type vv) \
: tag(Tag::HAS_##member) { \
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/TensorOperators.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ inline Tensor Tensor::operator[](Tensor index) const {
index.dim() == 0,
"Can only index with tensors that are scalars (zero-dim)");
// The Scalar(Tensor) constructor is explicit, so we need to call it.
return this->operator[](Scalar(index));
return this->operator[](index._local_scalar());
}
inline Tensor Tensor::operator[](int64_t index) const {
return select(0, index);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def handle_zero_dim(env, option):
if broadcasts_arg:
return []
zero_dim_actuals = [arg['name']
if arg['name'] != zero_dim_dispatch else "Scalar({})".format(arg['name'])
if arg['name'] != zero_dim_dispatch else "{}._local_scalar()".format(arg['name'])
for arg in option['formals_list']]
return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]

Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/native/cuda/SummaryOps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ Tensor _bincount_cuda_template(
AT_ERROR("input and weights should have the same length");
}
auto maxScalarGpu = Scalar(self.max());
auto nbins = maxScalarGpu.local().to<int64_t>() + 1L;
auto nbins = self.max().toCLong() + 1L;
nbins = std::max(nbins, minlength);
// alloc output counter on GPU
Tensor output;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/atest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void trace() {
trace += foo_a[i][i];
}

REQUIRE(Scalar(foo.trace()).toFloat() == Approx(trace));
REQUIRE(foo.trace().toCFloat() == Approx(trace));
}

TEST_CASE( "atest", "[]" ) {
Expand Down
11 changes: 5 additions & 6 deletions aten/src/ATen/test/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ static void test(Type & type) {
auto z = b.sort(1);
auto z_sorted = std::get<0>(z);

REQUIRE(Scalar(z_sorted[0][0]).toFloat() < Scalar(z_sorted[0][1]).toFloat());
REQUIRE(z_sorted[0][0].toCFloat() < z_sorted[0][1].toCFloat());
}

if(type.backend() != Backend::CUDA)
SECTION( "randperm" ) {
Tensor b = randperm(15, type);
Tensor rv, ri;
std::tie(rv, ri) = sort(b, 0);
REQUIRE(Scalar(rv[0]).toFloat() <= Scalar(rv[1]).toFloat());
REQUIRE(rv[0].toCFloat() <= rv[1].toCFloat());
}

SECTION( "context" ) {
Expand Down Expand Up @@ -154,7 +154,7 @@ static void test(Type & type) {

SECTION( "abs(value)" ) {
Tensor r = at::abs(type.scalarTensor(-3));
REQUIRE(Scalar(r).toInt() == 3);
REQUIRE(r.toCInt() == 3);
}

//TODO(zach): operator overloads
Expand Down Expand Up @@ -184,7 +184,6 @@ static void test(Type & type) {
SECTION( "zero-dim" ) {
Tensor a = type.scalarTensor(4); //rand(type, {1});

REQUIRE_NOTHROW(Scalar(a));
Tensor b = rand({3,4}, type);
REQUIRE((a + a).dim() == 0);
REQUIRE((1 + a).dim() == 0);
Expand All @@ -196,7 +195,7 @@ static void test(Type & type) {
auto f = rand({3,4}, type);
f[2] = zeros({4}, type);
f[1][0] = -1;
REQUIRE(Scalar(f[2][0]).toDouble() == 0);
REQUIRE(f[2][0].toCDouble() == 0);
}

SECTION( "tensor from TH" ) {
Expand Down Expand Up @@ -256,7 +255,7 @@ static void test(Type & type) {
REQUIRE_THROWS_WITH(
tensor[ones({}) * 3.14].equal(one),
StartsWith(
"Can only index tensors with integral scalars (got CPUFloatType)"));
"Can only index tensors with integral scalars (got CPUDoubleType)"));
REQUIRE_THROWS_WITH(
tensor[Tensor()].equal(one),
StartsWith("Can only index with tensors that are defined"));
Expand Down
24 changes: 2 additions & 22 deletions aten/src/ATen/test/scalar_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,6 @@ struct Foo<Half> {
static void apply(Tensor a, Tensor b) {}
};

void test_ctors() {
// create scalars backed by tensors
auto s1 = Scalar(CPU(kFloat).scalarTensor(1));
auto s2 = Scalar(CPU(kFloat).scalarTensor(2));
Scalar{s1};
Scalar{std::move(s2)};
REQUIRE(s2.isBackedByTensor());
REQUIRE(!s2.toTensor().defined());
s2 = s1;
REQUIRE(s2.isBackedByTensor());
REQUIRE(s2.toFloat() == 1.0);
Scalar s3;
s3 = std::move(s2);
REQUIRE(s2.isBackedByTensor());
REQUIRE(!s2.toTensor().defined());
REQUIRE(s3.isBackedByTensor());
REQUIRE(s3.toFloat() == 1.0);
}

void test_overflow() {
auto s1 = Scalar(M_PI);
REQUIRE(s1.toFloat() == static_cast<float>(M_PI));
Expand Down Expand Up @@ -109,9 +90,8 @@ TEST_CASE( "scalar test", "[]" ) {
Tensor next_h = i2h.add(h2h);
next_h = next_h.tanh();

REQUIRE_THROWS(Scalar{Tensor{}});
REQUIRE_THROWS(Tensor{}._local_scalar());

test_ctors();
test_overflow();

if(at::hasCUDA()) {
Expand All @@ -123,7 +103,7 @@ TEST_CASE( "scalar test", "[]" ) {
// check Scalar.toTensor on Scalars backed by different data types
REQUIRE(bar.toTensor().type().scalarType() == kDouble);
REQUIRE(what.toTensor().type().scalarType() == kLong);
REQUIRE(Scalar(ones({})).toTensor().type().scalarType() == kFloat);
REQUIRE(ones({})._local_scalar().toTensor().type().scalarType() == kDouble);

if (x.type().scalarType() != ScalarType::Half) {
AT_DISPATCH_ALL_TYPES(x.type(), "foo", [&] {
Expand Down

0 comments on commit 2bdc0f8

Please sign in to comment.