Skip to content

Commit

Permalink
Allow CodeGen to take Var args (interpreter support only) (pytorch#78)
Browse files Browse the repository at this point in the history
* Test demonstrating dynamic shape

* Allow binding of Vars to args in interpreter

* Pass BufferArgs to LLVMCodeGen

* clang-format-diff
  • Loading branch information
bertmaher authored and Mikhail Zolotukhin committed Feb 18, 2020
1 parent fd2439b commit cc15703
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 66 deletions.
20 changes: 20 additions & 0 deletions test/cpp/tensorexpr/test_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,5 +269,25 @@ void testExprBinaryMath01() {
EXPECT_NEAR(eval.value().as<float>(), v_ref, 1e-6) << "fail: " << v_expr;
}
}

void testExprDynamicShapeAdd() {
auto testWithSize = [](int32_t size) {
Var n("n", kInt32);
Buffer a(Var("a", kHandle), kFloat32, {n});
Buffer b(Var("b", kHandle), kFloat32, {n});
Buffer c(Var("c", kHandle), kFloat32, {n});
Var i("i", kInt32);
Stmt s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
SimpleIREvaluator(s, a, b, c, n)(aData, bData, cData, size);
ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
};
testWithSize(1);
testWithSize(16);
testWithSize(37);
}

} // namespace jit
} // namespace torch
85 changes: 56 additions & 29 deletions test/cpp/tensorexpr/test_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void testLLVMBufferTest() {
std::vector<int32_t> v(5);
std::vector<void*> args({v.data()});
auto rv = IntImm::make(0);
LLVMCodeGen cg(rv, {&a});
LLVMCodeGen cg(rv, {a});
EXPECT_EQ(cg.value<int>(args), 0);
}

Expand All @@ -116,7 +116,7 @@ void testLLVMBlockTest() {
Store::make(a, IntImm::make(0), IntImm::make(4), IntImm::make(1)),
});

LLVMCodeGen cg(block, {&a});
LLVMCodeGen cg(block, {a});
EXPECT_EQ(cg.value<int>(args), 0);
EXPECT_EQ(v[0], 4);
EXPECT_EQ(v[1], 4);
Expand All @@ -133,7 +133,7 @@ void testLLVMLoadStoreTest() {
IntImm::make(0),
Load::make(a, IntImm::make(0), IntImm::make(1)),
IntImm::make(1));
LLVMCodeGen cg(store, {&a, &b});
LLVMCodeGen cg(store, {a, b});
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
EXPECT_EQ(cg.value<int>(args), 0);
EXPECT_EQ(a_buffer[0], 42);
Expand All @@ -151,7 +151,7 @@ void testLLVMVecLoadStoreTest() {
Ramp::make(0, 1, 4),
Load::make(a, Ramp::make(0, 1, 4), Broadcast::make(IntImm::make(1), 4)),
Broadcast::make(IntImm::make(1), 4));
LLVMCodeGen cg(store, {&a, &b});
LLVMCodeGen cg(store, {a, b});
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
EXPECT_EQ(cg.value<int>(args), 0);
EXPECT_EQ(a_buffer[0], 1);
Expand All @@ -176,7 +176,7 @@ void testLLVMMemcpyTest() {
auto expr =
For::make(i, 0, N, Store::make(b, i, Load::make(a, i, mask), mask));

LLVMCodeGen cg(expr, {&a, &b});
LLVMCodeGen cg(expr, {a, b});

std::vector<void*> args({a_buffer.data(), b_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand All @@ -194,10 +194,9 @@ void testLLVMBzeroTest() {

auto mask = IntImm::make(1);
Var i("i", kInt32);
auto expr =
For::make(i, 0, N, Store::make(b, i, IntImm::make(0), mask));
auto expr = For::make(i, 0, N, Store::make(b, i, IntImm::make(0), mask));

LLVMCodeGen cg(expr, {&b});
LLVMCodeGen cg(expr, {b});

std::vector<void*> args({b_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -227,7 +226,7 @@ void testLLVMElemwiseAdd() {
Add::make(Load::make(a, i, mask), Load::make(b, i, mask)),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -257,7 +256,7 @@ void testLLVMElemwiseAddFloat() {
N,
Store::make(c, i, Load::make(a, i, mask) + Load::make(b, i, mask), mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand All @@ -282,10 +281,14 @@ void testLLVMElemwiseLog10Float() {
auto expr = For::make(
i,
0,
N/4,
Store::make(b, Ramp::make(i * 4, 1, 4), log10(Load::make(a, Ramp::make(i * 4, 1, 4), mask)), mask));
N / 4,
Store::make(
b,
Ramp::make(i * 4, 1, 4),
log10(Load::make(a, Ramp::make(i * 4, 1, 4), mask)),
mask));

LLVMCodeGen cg(expr, {&a, &b});
LLVMCodeGen cg(expr, {a, b});

std::vector<void*> args({a_buffer.data(), b_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -317,7 +320,7 @@ void testLLVMElemwiseMaxInt() {
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -351,7 +354,7 @@ void testLLVMElemwiseMinInt() {
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -385,7 +388,7 @@ void testLLVMElemwiseMaxNumFloat() {
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -419,7 +422,7 @@ void testLLVMElemwiseMaxNumNaNFloat() {
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -452,7 +455,7 @@ void testLLVMElemwiseMinNumFloat() {
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -486,7 +489,7 @@ void testLLVMElemwiseMinNumNaNFloat() {
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -520,7 +523,7 @@ void testLLVMElemwiseMaximumFloat() {
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -554,7 +557,7 @@ void testLLVMElemwiseMaximumNaNFloat() {
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -589,7 +592,7 @@ void testLLVMElemwiseMinimumFloat() {
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -623,7 +626,7 @@ void testLLVMElemwiseMinimumNaNFloat() {
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -668,7 +671,7 @@ void testLLVMCompareSelectIntEQ() {
CompareSelectOperation::kEQ),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand Down Expand Up @@ -707,7 +710,7 @@ void testLLVMCompareSelectFloatEQ() {
CompareSelectOperation::kEQ),
mask));

LLVMCodeGen cg(expr, {&a, &b, &c});
LLVMCodeGen cg(expr, {a, b, c});

std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
Expand All @@ -726,7 +729,7 @@ void testLLVMStoreFloat() {
std::vector<float> result_buffer = {0.0f};
auto expr = Store::make(
result, IntImm::make(0), FloatImm::make(3.14f), IntImm::make(1));
LLVMCodeGen cg(expr, {&result});
LLVMCodeGen cg(expr, {result});
std::vector<void*> args({result_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
EXPECT_EQ(result_buffer[0], 3.14f);
Expand All @@ -739,7 +742,7 @@ void testLLVMSimpleMath01() {
Schedule sch = Schedule::make({tensor});
Stmt stmt = sch.Lower();
Buffer f_buf(tensor.function().func_var(), kFloat32, {N});
LLVMCodeGen cg(stmt, {&f_buf});
LLVMCodeGen cg(stmt, {f_buf});

PaddedBuffer<float> f_v(N, "f_v");
std::vector<void*> args({f_v.data()});
Expand All @@ -764,7 +767,7 @@ void testLLVMComputeMul() {
Schedule sch = Schedule::make({c});
Stmt s = sch.Lower();

LLVMCodeGen cg(s, {&a, &b, &c_buf});
LLVMCodeGen cg(s, {a, b, c_buf});

std::vector<float> a_vec(N, 21.0f);
std::vector<float> b_vec(N, 2.0f);
Expand All @@ -789,7 +792,7 @@ void testLLVMBroadcastAdd() {
Schedule sch = Schedule::make({c});
Stmt s = sch.Lower();

LLVMCodeGen cg(s, {&a, &b, &c_buf});
LLVMCodeGen cg(s, {a, b, c_buf});

std::vector<float> av(M * N);
std::iota(av.begin(), av.end(), 0);
Expand All @@ -805,6 +808,30 @@ void testLLVMBroadcastAdd() {
}
}
}

void testLLVMDynamicShapeAdd() {
#if 0
auto testWithSize = [](int32_t size) {
Var n("n", kInt32);
Buffer a(Var("a", kHandle), kFloat32, {n});
Buffer b(Var("b", kHandle), kFloat32, {n});
Buffer c(Var("c", kHandle), kFloat32, {n});
Var i("i", kInt32);
Stmt s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
LLVMCodeGen cg(s, {a, b, c, n});
std::vector<void*> args({aData.data(), bData.data(), cData.data(), size));
cg.value<float>(args);
ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
};
testWithSize(1);
testWithSize(16);
testWithSize(37);
#endif
}

} // namespace jit
} // namespace torch

Expand Down
2 changes: 2 additions & 0 deletions test/cpp/tensorexpr/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace jit {
_(ExprMath01) \
_(ExprUnaryMath01) \
_(ExprBinaryMath01) \
_(ExprDynamicShapeAdd) \
_(IRPrinterBasicValueTest) \
_(IRPrinterBasicValueTest02) \
_(IRPrinterLetTest01) \
Expand Down Expand Up @@ -69,6 +70,7 @@ namespace jit {
_(LLVMSimpleMath01) \
_(LLVMComputeMul) \
_(LLVMBroadcastAdd) \
_(LLVMDynamicShapeAdd) \
_(CudaTestVectorAdd01) \
_(ATen_cast_Float) \
_(ATennegInt) \
Expand Down
13 changes: 4 additions & 9 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,17 +686,12 @@ struct TensorExprKernel {
}
}
Stmt stmt = sch.Lower();

#ifdef ENABLE_LLVM
// Set up formal params (inputs, then outputs) for kernel.
std::vector<Buffer*> params;
for (auto& b : buffer_args) {
params.push_back(&b);
}
Buffer outbuf(
tensor_output->function().func_var(),
tensor_output->dtype(),
tensor_output->dims());
params.push_back(&outbuf);
std::vector<CodeGen::BufferArg> params(
buffer_args.begin(), buffer_args.end());
params.push_back(*tensor_output);

// Generate code.
codegen = std::make_unique<LLVMCodeGen>(stmt, params);
Expand Down
25 changes: 24 additions & 1 deletion torch/csrc/jit/tensorexpr/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class CodeGen::BufferArg {
dtype_(tensor.function().body().dtype()) {}
BufferArg(const Function& func)
: var_(func.func_var()), dtype_(func.body().dtype()) {}
BufferArg(const Var& var) : var_(var), dtype_(var.dtype()), isVar_(true) {}

const Var& var() const {
return var_;
}
Expand All @@ -76,9 +78,14 @@ class CodeGen::BufferArg {
return dtype_;
}

bool isVar() const {
return isVar_;
}

private:
Var var_;
Dtype dtype_;
bool isVar_{false};
};

class CodeGen::CallArg {
Expand All @@ -91,12 +98,28 @@ class CodeGen::CallArg {

CallArg(void* ptr) : ptr_(ptr) {}

CallArg(int32_t i) : ival_(i) {}

CallArg(float f) : fval_(f) {}

void* data() const {
return ptr_;
}

int32_t intData() const {
return ival_;
}

float floatData() const {
return fval_;
}

private:
void* ptr_ = nullptr;
union {
void* ptr_;
float fval_;
int32_t ival_;
};
};

} // namespace tensorexpr
Expand Down
Loading

0 comments on commit cc15703

Please sign in to comment.