2020-11-18 20:17:04 +00:00
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
|
2020-12-18 15:55:56 +00:00
|
|
|
#include <test/cpp/tensorexpr/test_base.h>
|
2020-02-21 21:06:13 +00:00
|
|
|
|
2021-10-19 04:58:26 +00:00
|
|
|
#include <c10/util/irange.h>
|
2020-12-18 15:55:56 +00:00
|
|
|
#include <test/cpp/tensorexpr/padded_buffer.h>
|
|
|
|
|
#include <test/cpp/tensorexpr/test_utils.h>
|
|
|
|
|
#include <torch/csrc/jit/tensorexpr/eval.h>
|
|
|
|
|
#include <torch/csrc/jit/tensorexpr/ir.h>
|
|
|
|
|
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
2021-03-02 04:35:17 +00:00
|
|
|
#include <torch/csrc/jit/tensorexpr/ir_verifier.h>
|
2020-12-18 15:55:56 +00:00
|
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
|
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include <stdexcept>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
namespace torch {
|
|
|
|
|
namespace jit {
|
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
|
|
2020-02-21 21:06:13 +00:00
|
|
|
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, BasicValueTest) {
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle a = IntImm::make(2), b = IntImm::make(3);
|
|
|
|
|
ExprHandle c = Add::make(a, b);
|
2020-02-21 21:06:13 +00:00
|
|
|
SimpleIRExprEval eval(c);
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<int>(), 5);
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, BasicValueTest02) {
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle a(2.0f);
|
|
|
|
|
ExprHandle b(3.0f);
|
|
|
|
|
ExprHandle c(4.0f);
|
|
|
|
|
ExprHandle d(5.0f);
|
|
|
|
|
ExprHandle f = (a + b) - (c + d);
|
2020-02-21 21:06:13 +00:00
|
|
|
SimpleIRExprEval eval(f);
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<float>(), -4.0f);
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, LetTest01) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kFloat);
|
|
|
|
|
ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
|
2020-05-09 23:21:46 +00:00
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle(3.f));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, LetTest02) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kFloat);
|
|
|
|
|
VarHandle y("y", kFloat);
|
|
|
|
|
ExprHandle body =
|
|
|
|
|
ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y);
|
2020-05-09 23:21:46 +00:00
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle(3.f));
|
|
|
|
|
eval.bindVar(y, ExprHandle(6.f));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, LetStmtTest01) {
|
2021-09-14 07:19:57 +00:00
|
|
|
BufHandle a_buf("a", {1}, kFloat);
|
|
|
|
|
BufHandle b_buf("b", {1}, kFloat);
|
2020-03-16 18:38:29 +00:00
|
|
|
|
2020-09-30 03:50:53 +00:00
|
|
|
ExprHandle load_a = a_buf.load(0);
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle var = VarHandle("v", kFloat);
|
2021-08-17 20:39:36 +00:00
|
|
|
StmtPtr let_store = Let::make(var, load_a);
|
|
|
|
|
StmtPtr store_b = b_buf.store({0}, var);
|
|
|
|
|
BlockPtr block = Block::make({let_store, store_b});
|
2020-05-09 23:21:46 +00:00
|
|
|
|
2020-12-22 04:15:34 +00:00
|
|
|
SimpleIREvaluator eval(block, {a_buf, b_buf});
|
2020-03-16 18:38:29 +00:00
|
|
|
|
|
|
|
|
PaddedBuffer<float> a_v(1);
|
|
|
|
|
PaddedBuffer<float> b_v(1);
|
|
|
|
|
PaddedBuffer<float> b_ref(1);
|
|
|
|
|
|
|
|
|
|
a_v(0) = 23;
|
|
|
|
|
b_ref(0) = a_v(0);
|
|
|
|
|
eval(a_v, b_v);
|
|
|
|
|
|
|
|
|
|
ExpectAllNear(b_v, b_ref, 1e-5);
|
|
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, IntTest) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kInt);
|
|
|
|
|
ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4));
|
2020-05-09 23:21:46 +00:00
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle(3));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, FloatTest) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kFloat);
|
2020-05-09 23:21:46 +00:00
|
|
|
ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
|
|
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle(3.f));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, ByteTest) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kByte);
|
|
|
|
|
ExprHandle body = ExprHandle((uint8_t)2) +
|
|
|
|
|
(x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4));
|
2020-05-09 23:21:46 +00:00
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle((uint8_t)3));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, CharTest) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kChar);
|
|
|
|
|
ExprHandle body = ExprHandle((int8_t)2) +
|
|
|
|
|
(x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4));
|
2020-05-09 23:21:46 +00:00
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle((int8_t)3));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, ShortTest) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kShort);
|
|
|
|
|
ExprHandle body = ExprHandle((int16_t)2) +
|
|
|
|
|
(x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4));
|
2020-05-09 23:21:46 +00:00
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle((int16_t)3));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, LongTest) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kLong);
|
|
|
|
|
ExprHandle body = ExprHandle((int64_t)2) +
|
|
|
|
|
(x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4));
|
2020-05-09 23:21:46 +00:00
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle((int64_t)3));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, HalfTest) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kHalf);
|
|
|
|
|
ExprHandle body = ExprHandle((at::Half)2) +
|
|
|
|
|
(x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4));
|
2020-05-09 23:21:46 +00:00
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle((at::Half)3));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, DoubleTest) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kDouble);
|
|
|
|
|
ExprHandle body = ExprHandle((double)2) +
|
|
|
|
|
(x * ExprHandle((double)3) + ExprHandle((double)4));
|
2020-05-09 23:21:46 +00:00
|
|
|
SimpleIRExprEval eval(body);
|
|
|
|
|
eval.bindVar(x, ExprHandle((double)3));
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
2020-05-09 23:21:46 +00:00
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, VectorAdd01) {
|
2020-02-21 21:06:13 +00:00
|
|
|
const int kVectorSize = 8;
|
|
|
|
|
const int kVectorCount = 128;
|
|
|
|
|
const int kTotalSize = kVectorSize * kVectorCount;
|
|
|
|
|
|
2021-09-14 07:19:57 +00:00
|
|
|
BufHandle a_buf("A", {kTotalSize}, kFloat);
|
|
|
|
|
BufHandle b_buf("B", {kTotalSize}, kFloat);
|
|
|
|
|
BufHandle c_buf("C", {kTotalSize}, kFloat);
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
Build the following:
|
2021-10-19 04:58:26 +00:00
|
|
|
for (const auto index : c10::irange(kVectorCount)) {
|
2020-02-21 21:06:13 +00:00
|
|
|
store(c_buf, ramp(index * 8, 1, 8),
|
|
|
|
|
load(a_buf, ramp(index * 8, 1, 8) +
|
|
|
|
|
load(b_buf, ramp(index * 8, 1, 8))))
|
|
|
|
|
}
|
|
|
|
|
*/
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle index = VarHandle("index", kInt);
|
2021-04-13 19:03:30 +00:00
|
|
|
ExprHandle load_a =
|
|
|
|
|
a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
|
|
|
|
|
ExprHandle load_b =
|
|
|
|
|
b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle value = load_a + load_b;
|
2021-08-17 20:39:36 +00:00
|
|
|
StmtPtr store_c =
|
2021-04-13 19:03:30 +00:00
|
|
|
c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value);
|
2021-08-17 20:39:36 +00:00
|
|
|
StmtPtr stmt = For::make(index, 0, kVectorCount, store_c);
|
2020-02-21 21:06:13 +00:00
|
|
|
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
|
|
|
|
|
ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
|
|
|
|
|
ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
|
2020-03-16 18:38:29 +00:00
|
|
|
|
|
|
|
|
PaddedBuffer<float> a_v(kTotalSize);
|
|
|
|
|
PaddedBuffer<float> b_v(kTotalSize);
|
|
|
|
|
PaddedBuffer<float> c_v(kTotalSize);
|
|
|
|
|
PaddedBuffer<float> c_ref(kTotalSize);
|
2021-10-19 04:58:26 +00:00
|
|
|
for (const auto i : c10::irange(kTotalSize)) {
|
2020-03-16 18:38:29 +00:00
|
|
|
a_v(i) = i * i;
|
|
|
|
|
b_v(i) = i * i * 4;
|
|
|
|
|
c_ref(i) = a_v(i) + b_v(i);
|
|
|
|
|
}
|
2020-12-22 04:15:34 +00:00
|
|
|
SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
|
2020-03-16 18:38:29 +00:00
|
|
|
ir_eval(a_v, b_v, c_v);
|
|
|
|
|
ExpectAllNear(c_v, c_ref, 1e-5);
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, CompareSelectEQ) {
|
2020-02-21 21:06:13 +00:00
|
|
|
constexpr int N = 1024;
|
2021-09-14 07:19:57 +00:00
|
|
|
BufHandle a("A", {N}, kInt);
|
|
|
|
|
BufHandle b("B", {N}, kInt);
|
|
|
|
|
BufHandle c("C", {N}, kInt);
|
2020-02-21 21:06:13 +00:00
|
|
|
std::vector<int> a_buffer(N, 1);
|
|
|
|
|
std::vector<int> b_buffer(N, 1);
|
|
|
|
|
std::vector<int> c_buffer(N, 0);
|
|
|
|
|
std::vector<int> c_ref(N, 0);
|
|
|
|
|
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle i("i", kInt);
|
2020-02-21 21:06:13 +00:00
|
|
|
auto memcpy_expr = For::make(
|
|
|
|
|
i,
|
|
|
|
|
0,
|
|
|
|
|
N,
|
2020-09-30 03:50:53 +00:00
|
|
|
c.store(
|
2020-04-02 18:12:51 +00:00
|
|
|
{i},
|
2020-02-21 21:06:13 +00:00
|
|
|
CompareSelect::make(
|
2020-09-30 03:50:53 +00:00
|
|
|
a.load(i), b.load(i), CompareSelectOperation::kEQ)));
|
2020-02-21 21:06:13 +00:00
|
|
|
|
2020-12-22 04:15:34 +00:00
|
|
|
SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
|
2020-02-21 21:06:13 +00:00
|
|
|
ir_eval(a_buffer, b_buffer, c_buffer);
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(a_buffer.size(), N);
|
|
|
|
|
ASSERT_EQ(b_buffer.size(), N);
|
|
|
|
|
ASSERT_EQ(c_buffer.size(), N);
|
|
|
|
|
|
|
|
|
|
assertAllEqual(a_buffer, 1);
|
|
|
|
|
assertAllEqual(b_buffer, 1);
|
|
|
|
|
assertAllEqual(c_buffer, 1);
|
|
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, CompareSelectDtypes) {
|
2020-08-04 19:17:09 +00:00
|
|
|
// LHS and RHS expressions should have the same dtype, but this dtype could
|
|
|
|
|
// differ from the dtype of the return values (but dtypes of true and false
|
|
|
|
|
// return values should be the same).
|
|
|
|
|
// This test constructs a CompareSelect expression where the input dtype is
|
|
|
|
|
// different from the output dtype and verifies that it works correctly:
|
|
|
|
|
// result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2
|
|
|
|
|
constexpr int N = 1024;
|
2021-09-14 07:19:57 +00:00
|
|
|
BufHandle a("A", {N}, kInt);
|
|
|
|
|
BufHandle b("B", {N}, kInt);
|
|
|
|
|
BufHandle c("C", {N}, kFloat);
|
2020-08-04 19:17:09 +00:00
|
|
|
std::vector<int> a_buffer(N, 1);
|
|
|
|
|
std::vector<int> b_buffer(N, 1);
|
|
|
|
|
std::vector<float> c_buffer(N, 0.0f);
|
|
|
|
|
std::vector<float> c_ref(N, 3.14f);
|
|
|
|
|
|
|
|
|
|
VarHandle i("i", kInt);
|
|
|
|
|
// C[i] = (A[i] == B[i]) ? 3.14f : 2.78f
|
|
|
|
|
// A and B are int, C is float.
|
|
|
|
|
auto select_expr = For::make(
|
|
|
|
|
i,
|
|
|
|
|
0,
|
|
|
|
|
N,
|
2020-09-30 03:50:53 +00:00
|
|
|
c.store(
|
2020-08-04 19:17:09 +00:00
|
|
|
{i},
|
|
|
|
|
CompareSelect::make(
|
2020-09-30 03:50:53 +00:00
|
|
|
a.load(i),
|
|
|
|
|
b.load(i),
|
2020-08-04 19:17:09 +00:00
|
|
|
FloatImm::make(3.14f),
|
|
|
|
|
FloatImm::make(2.78f),
|
2020-09-30 03:50:53 +00:00
|
|
|
CompareSelectOperation::kEQ)));
|
2020-08-04 19:17:09 +00:00
|
|
|
|
2020-12-22 04:15:34 +00:00
|
|
|
SimpleIREvaluator ir_eval(select_expr, {a, b, c});
|
2020-08-04 19:17:09 +00:00
|
|
|
ir_eval(a_buffer, b_buffer, c_buffer);
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(a_buffer.size(), N);
|
|
|
|
|
ASSERT_EQ(b_buffer.size(), N);
|
|
|
|
|
ASSERT_EQ(c_buffer.size(), N);
|
|
|
|
|
|
|
|
|
|
assertAllEqual(a_buffer, 1);
|
|
|
|
|
assertAllEqual(b_buffer, 1);
|
|
|
|
|
ExpectAllNear(c_buffer, c_ref, 1e-7);
|
|
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, IntrinsicsDtypes) {
|
2020-08-04 19:17:09 +00:00
|
|
|
constexpr int N = 256;
|
2021-09-14 07:19:57 +00:00
|
|
|
BufHandle a("A", {N}, kDouble);
|
|
|
|
|
BufHandle b("B", {N}, kDouble);
|
2020-08-04 19:17:09 +00:00
|
|
|
std::vector<double> a_buffer(N, -10.0);
|
|
|
|
|
std::vector<double> b_buffer(N, 0.0);
|
|
|
|
|
std::vector<double> b_ref(N, 10.0);
|
|
|
|
|
|
|
|
|
|
VarHandle i("i", kInt);
|
2020-12-18 15:55:56 +00:00
|
|
|
auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i))));
|
2020-08-04 19:17:09 +00:00
|
|
|
|
2020-12-22 04:15:34 +00:00
|
|
|
SimpleIREvaluator ir_eval(abs_expr, {a, b});
|
2020-08-04 19:17:09 +00:00
|
|
|
ir_eval(a_buffer, b_buffer);
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(a_buffer.size(), N);
|
|
|
|
|
ASSERT_EQ(b_buffer.size(), N);
|
|
|
|
|
|
|
|
|
|
assertAllEqual(a_buffer, -10.0);
|
|
|
|
|
ExpectAllNear(b_buffer, b_ref, 1e-7);
|
|
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, Substitute01) {
|
2021-08-17 20:39:36 +00:00
|
|
|
VarPtr x = alloc<Var>("x", kFloat);
|
|
|
|
|
VarPtr y = alloc<Var>("y", kFloat);
|
|
|
|
|
ExprPtr e =
|
|
|
|
|
alloc<Mul>(alloc<Sub>(x, alloc<FloatImm>(1.0f)), alloc<Add>(x, y));
|
|
|
|
|
|
|
|
|
|
VarPtr z = alloc<Var>("z", kFloat);
|
|
|
|
|
ExprPtr e2 = Substitute(e, {{x, alloc<Add>(z, alloc<FloatImm>(5.0f))}});
|
|
|
|
|
ExprPtr e2_ref = alloc<Mul>(
|
|
|
|
|
alloc<Sub>(alloc<Add>(z, alloc<FloatImm>(5.0f)), alloc<FloatImm>(1.0f)),
|
|
|
|
|
alloc<Add>(alloc<Add>(z, alloc<FloatImm>(5.0f)), y));
|
2020-02-21 21:06:13 +00:00
|
|
|
std::ostringstream oss;
|
2020-03-25 18:41:38 +00:00
|
|
|
oss << *e2;
|
2020-02-21 21:06:13 +00:00
|
|
|
std::string e2_str = oss.str();
|
|
|
|
|
|
|
|
|
|
oss.str("");
|
2020-03-25 18:41:38 +00:00
|
|
|
oss << *e2_ref;
|
2020-02-21 21:06:13 +00:00
|
|
|
std::string e2_ref_str = oss.str();
|
|
|
|
|
ASSERT_EQ(e2_str, e2_ref_str);
|
|
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, Math01) {
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle v = sin(ExprHandle(1.0f));
|
|
|
|
|
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
oss << v;
|
|
|
|
|
ASSERT_EQ(oss.str(), "sin(1.f)");
|
|
|
|
|
|
|
|
|
|
SimpleIRExprEval eval(v);
|
|
|
|
|
float v_ref = std::sin(1.0f);
|
|
|
|
|
float res = eval.value<float>();
|
|
|
|
|
ASSERT_NEAR(res, v_ref, 1e-6);
|
|
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, UnaryMath01) {
|
2020-03-16 18:38:29 +00:00
|
|
|
struct TestConfig {
|
|
|
|
|
std::function<ExprHandle(const ExprHandle&)> func;
|
|
|
|
|
std::function<float(float)> ref_func;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<TestConfig> test_configs = {
|
|
|
|
|
{[](const ExprHandle& v) { return sin(v); },
|
|
|
|
|
[](float v) { return std::sin(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return sin(v); },
|
|
|
|
|
[](float v) { return std::sin(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return tan(v); },
|
|
|
|
|
[](float v) { return std::tan(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return asin(v); },
|
|
|
|
|
[](float v) { return std::asin(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return acos(v); },
|
|
|
|
|
[](float v) { return std::acos(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return atan(v); },
|
|
|
|
|
[](float v) { return std::atan(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return sinh(v); },
|
|
|
|
|
[](float v) { return std::sinh(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return cosh(v); },
|
|
|
|
|
[](float v) { return std::cosh(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return tanh(v); },
|
|
|
|
|
[](float v) { return std::tanh(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return exp(v); },
|
|
|
|
|
[](float v) { return std::exp(v); }},
|
2020-12-18 15:55:56 +00:00
|
|
|
{[](const ExprHandle& v) { return tensorexpr::abs(v); },
|
2020-03-16 18:38:29 +00:00
|
|
|
[](float v) { return std::fabs(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return log(v); },
|
|
|
|
|
[](float v) { return std::log(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return log2(v); },
|
|
|
|
|
[](float v) { return std::log2(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return log10(v); },
|
|
|
|
|
[](float v) { return std::log10(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return erf(v); },
|
|
|
|
|
[](float v) { return std::erf(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return sqrt(v); },
|
|
|
|
|
[](float v) { return std::sqrt(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return rsqrt(v); },
|
|
|
|
|
[](float v) { return 1.0f / std::sqrt(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return ceil(v); },
|
|
|
|
|
[](float v) { return std::ceil(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return floor(v); },
|
|
|
|
|
[](float v) { return std::floor(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return round(v); },
|
|
|
|
|
[](float v) { return std::round(v); }},
|
|
|
|
|
{[](const ExprHandle& v) { return trunc(v); },
|
|
|
|
|
[](float v) { return std::trunc(v); }},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (const TestConfig& test_config : test_configs) {
|
|
|
|
|
const float input_v = 0.8765f;
|
|
|
|
|
ExprHandle v = test_config.func(ExprHandle(input_v));
|
|
|
|
|
float v_ref = test_config.ref_func(input_v);
|
|
|
|
|
SimpleIRExprEval eval(v);
|
2020-09-19 14:23:28 +00:00
|
|
|
ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
2020-12-17 02:28:56 +00:00
|
|
|
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
2020-12-17 02:28:56 +00:00
|
|
|
for (float input_v : {std::nan("1"), 0., .5}) {
|
|
|
|
|
ExprHandle v = FloatImm::make(input_v);
|
|
|
|
|
SimpleIRExprEval eval(Intrinsics::make(kIsNan, v));
|
|
|
|
|
ASSERT_NEAR(eval.value<int>(), std::isnan(input_v), 0);
|
|
|
|
|
}
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, BinaryMath01) {
|
2020-03-16 18:38:29 +00:00
|
|
|
struct TestConfig {
|
|
|
|
|
std::function<ExprHandle(const ExprHandle&, const ExprHandle&)> func;
|
|
|
|
|
std::function<float(float, float)> ref_func;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<TestConfig> test_configs = {
|
|
|
|
|
{[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); },
|
|
|
|
|
[](float v1, float v2) { return std::pow(v1, v2); }},
|
|
|
|
|
{[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); },
|
|
|
|
|
[](float v1, float v2) { return std::fmod(v1, v2); }},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (const TestConfig& test_config : test_configs) {
|
|
|
|
|
const float v1 = 0.8765f;
|
|
|
|
|
float v2 = 1.2345f;
|
|
|
|
|
ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2));
|
|
|
|
|
float v_ref = test_config.ref_func(v1, v2);
|
|
|
|
|
SimpleIRExprEval eval(v_expr);
|
2020-09-19 14:23:28 +00:00
|
|
|
ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-05-02 01:42:57 +00:00
|
|
|
TEST(Expr, LogicalOps01) {
|
|
|
|
|
ExprHandle a(23);
|
|
|
|
|
ExprHandle b(11);
|
|
|
|
|
ExprHandle c(0.72f);
|
|
|
|
|
ExprHandle d(0.69f);
|
|
|
|
|
ExprHandle f1 = (a > b) && (c > d);
|
|
|
|
|
ExprHandle f2 = (a > b) && (c < d);
|
|
|
|
|
ExprHandle f3 = (a < b) && (c > d);
|
|
|
|
|
ExprHandle f4 = (a < b) && (c < d);
|
|
|
|
|
ExprHandle f5 = (a < b) || (c > d);
|
|
|
|
|
ExprHandle f6 = (a < b) || (c < d);
|
|
|
|
|
ExprHandle f7 = (a > b) || (c < d);
|
|
|
|
|
ExprHandle f8 = (a > b) || (c > d);
|
|
|
|
|
|
|
|
|
|
SimpleIRExprEval eval1(f1);
|
|
|
|
|
SimpleIRExprEval eval2(f2);
|
|
|
|
|
SimpleIRExprEval eval3(f3);
|
|
|
|
|
SimpleIRExprEval eval4(f4);
|
|
|
|
|
SimpleIRExprEval eval5(f5);
|
|
|
|
|
SimpleIRExprEval eval6(f6);
|
|
|
|
|
SimpleIRExprEval eval7(f7);
|
|
|
|
|
SimpleIRExprEval eval8(f8);
|
|
|
|
|
ASSERT_EQ(eval1.value<int>(), 1);
|
|
|
|
|
ASSERT_EQ(eval2.value<int>(), 0);
|
|
|
|
|
ASSERT_EQ(eval3.value<int>(), 0);
|
|
|
|
|
ASSERT_EQ(eval4.value<int>(), 0);
|
|
|
|
|
ASSERT_EQ(eval5.value<int>(), 1);
|
|
|
|
|
ASSERT_EQ(eval6.value<int>(), 0);
|
|
|
|
|
ASSERT_EQ(eval7.value<int>(), 1);
|
|
|
|
|
ASSERT_EQ(eval8.value<int>(), 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Expr, LogicalOps02) {
|
|
|
|
|
ExprHandle a(23);
|
|
|
|
|
ExprHandle b(11);
|
|
|
|
|
ExprHandle c(0.72f);
|
|
|
|
|
ExprHandle d(0.72f);
|
|
|
|
|
|
|
|
|
|
ExprHandle f1 = (a > b) || (c > d);
|
|
|
|
|
ExprHandle f2 = (a > b) && (c <= d);
|
|
|
|
|
ExprHandle f3 = (a > b) && (c > d);
|
|
|
|
|
ExprHandle ff1 = f1 && f2;
|
|
|
|
|
ExprHandle ff2 = f2 || f3;
|
|
|
|
|
|
|
|
|
|
SimpleIRExprEval eval1(ff1);
|
|
|
|
|
SimpleIRExprEval eval2(ff2);
|
|
|
|
|
ASSERT_EQ(eval1.value<int>(), 1);
|
|
|
|
|
ASSERT_EQ(eval2.value<int>(), 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Expr, LogicalOps03) {
|
|
|
|
|
ExprHandle a(23);
|
|
|
|
|
ExprHandle b(11);
|
|
|
|
|
ExprHandle c(0.72f);
|
|
|
|
|
ExprHandle d(0.69f);
|
|
|
|
|
|
|
|
|
|
// Bool types
|
|
|
|
|
ExprHandle bool_f1 = (a > b) && BoolImm::make(true);
|
|
|
|
|
ExprHandle bool_f2 = (c <= d) || BoolImm::make(true);
|
|
|
|
|
|
|
|
|
|
// Int types
|
|
|
|
|
ExprHandle int_f1 = (a > b) && IntImm::make(1);
|
|
|
|
|
ExprHandle int_f2 = (c <= d) || IntImm::make(1);
|
|
|
|
|
|
|
|
|
|
// Short types
|
|
|
|
|
ExprHandle short_f1 = (a > b) && ShortImm::make(1);
|
|
|
|
|
ExprHandle short_f2 = (c <= d) || ShortImm::make(1);
|
|
|
|
|
|
|
|
|
|
// Long types
|
|
|
|
|
ExprHandle long_f1 = (a > b) && LongImm::make(1);
|
|
|
|
|
ExprHandle long_f2 = (c <= d) || LongImm::make(1);
|
|
|
|
|
|
|
|
|
|
// Char types
|
|
|
|
|
ExprHandle char_f1 = (a > b) && CharImm::make(1);
|
|
|
|
|
ExprHandle char_f2 = (c <= d) || CharImm::make(1);
|
|
|
|
|
|
|
|
|
|
// Byte types
|
|
|
|
|
ExprHandle byte_f1 = (a > b) && ByteImm::make(1);
|
|
|
|
|
ExprHandle byte_f2 = (c <= d) || ByteImm::make(1);
|
|
|
|
|
|
|
|
|
|
SimpleIRExprEval eval1(bool_f1);
|
|
|
|
|
SimpleIRExprEval eval2(bool_f2);
|
|
|
|
|
SimpleIRExprEval eval3(int_f1);
|
|
|
|
|
SimpleIRExprEval eval4(int_f2);
|
|
|
|
|
SimpleIRExprEval eval5(short_f1);
|
|
|
|
|
SimpleIRExprEval eval6(short_f2);
|
|
|
|
|
SimpleIRExprEval eval7(long_f1);
|
|
|
|
|
SimpleIRExprEval eval8(long_f2);
|
|
|
|
|
SimpleIRExprEval eval9(char_f1);
|
|
|
|
|
SimpleIRExprEval eval10(char_f2);
|
|
|
|
|
SimpleIRExprEval eval11(byte_f1);
|
|
|
|
|
SimpleIRExprEval eval12(byte_f2);
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(eval1.value<bool>(), true);
|
|
|
|
|
ASSERT_EQ(eval2.value<bool>(), true);
|
|
|
|
|
ASSERT_EQ(eval3.value<int>(), 1);
|
|
|
|
|
ASSERT_EQ(eval4.value<int>(), 1);
|
|
|
|
|
ASSERT_EQ(eval5.value<int16_t>(), 1);
|
|
|
|
|
ASSERT_EQ(eval6.value<int16_t>(), 1);
|
|
|
|
|
ASSERT_EQ(eval7.value<int64_t>(), 1);
|
|
|
|
|
ASSERT_EQ(eval8.value<int64_t>(), 1);
|
|
|
|
|
ASSERT_EQ(eval9.value<int8_t>(), 1);
|
|
|
|
|
ASSERT_EQ(eval10.value<int8_t>(), 1);
|
|
|
|
|
ASSERT_EQ(eval11.value<uint8_t>(), 1);
|
|
|
|
|
ASSERT_EQ(eval12.value<uint8_t>(), 1);
|
|
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, BitwiseOps) {
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle a(59);
|
|
|
|
|
ExprHandle b(11);
|
|
|
|
|
ExprHandle c(101);
|
|
|
|
|
ExprHandle d(2);
|
|
|
|
|
ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
|
|
|
|
|
|
|
|
|
|
SimpleIRExprEval eval(f);
|
2020-03-27 19:03:42 +00:00
|
|
|
ASSERT_EQ(eval.value<int>(), 11);
|
2020-03-16 18:38:29 +00:00
|
|
|
}
|
|
|
|
|
|
2020-11-18 20:17:04 +00:00
|
|
|
TEST(Expr, DynamicShapeAdd) {
|
2020-02-21 21:06:13 +00:00
|
|
|
auto testWithSize = [](int32_t size) {
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle n("n", kInt);
|
2021-09-14 07:19:57 +00:00
|
|
|
BufHandle a("a", {n}, kFloat);
|
|
|
|
|
BufHandle b("b", {n}, kFloat);
|
|
|
|
|
BufHandle c("c", {n}, kFloat);
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle i("i", kInt);
|
2021-08-17 20:39:36 +00:00
|
|
|
StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
|
2020-02-21 21:06:13 +00:00
|
|
|
std::vector<float> aData(size, 1.0f);
|
|
|
|
|
std::vector<float> bData(size, 2.0f);
|
|
|
|
|
std::vector<float> cData(size, 0.0f);
|
2020-12-22 04:15:34 +00:00
|
|
|
SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size);
|
2020-02-21 21:06:13 +00:00
|
|
|
ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
|
|
|
|
|
};
|
|
|
|
|
testWithSize(1);
|
|
|
|
|
testWithSize(16);
|
|
|
|
|
testWithSize(37);
|
|
|
|
|
}
|
|
|
|
|
|
2021-12-16 17:25:35 +00:00
|
|
|
TEST(Expr, OutOfBounds) {
|
|
|
|
|
ExprHandle N(10);
|
|
|
|
|
ExprHandle start(0);
|
|
|
|
|
ExprHandle stop(15);
|
|
|
|
|
VarHandle i("i", kInt);
|
|
|
|
|
|
|
|
|
|
BufHandle X("X", {N}, kInt);
|
|
|
|
|
|
|
|
|
|
auto body = Store::make(X, {i}, i);
|
|
|
|
|
auto stmt = For::make(i, start, stop, body);
|
|
|
|
|
|
|
|
|
|
PaddedBuffer<int> data(20);
|
|
|
|
|
|
|
|
|
|
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Expr, OutOfBounds2d) {
|
|
|
|
|
std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}};
|
|
|
|
|
for (auto sizes : size_options) {
|
|
|
|
|
ExprHandle N(sizes.first);
|
|
|
|
|
ExprHandle M(sizes.second);
|
|
|
|
|
ExprHandle start(0);
|
|
|
|
|
ExprHandle stopInner(15);
|
|
|
|
|
ExprHandle stopOuter(15);
|
|
|
|
|
VarHandle i("i", kInt);
|
|
|
|
|
VarHandle j("j", kInt);
|
|
|
|
|
|
|
|
|
|
BufHandle X("X", {N, M}, kInt);
|
|
|
|
|
|
|
|
|
|
auto body = Store::make(X, {i, j}, i);
|
|
|
|
|
auto inner = For::make(j, start, stopInner, body);
|
|
|
|
|
auto stmt = For::make(i, start, stopOuter, inner);
|
|
|
|
|
|
|
|
|
|
PaddedBuffer<int> data(400);
|
|
|
|
|
|
|
|
|
|
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Expr, OutOfBounds2dFlattenedIndex) {
|
|
|
|
|
ExprHandle buf_size(149);
|
|
|
|
|
ExprHandle start(0);
|
|
|
|
|
ExprHandle stopInner(15);
|
|
|
|
|
ExprHandle stopOuter(10);
|
|
|
|
|
VarHandle i("i", kInt);
|
|
|
|
|
VarHandle j("j", kInt);
|
|
|
|
|
|
|
|
|
|
BufHandle X("X", {buf_size}, kInt);
|
|
|
|
|
|
|
|
|
|
auto idx = Add::make(Mul::make(i, stopInner), j);
|
|
|
|
|
auto body = Store::make(X, {idx}, i);
|
|
|
|
|
auto inner = For::make(j, start, stopInner, body);
|
|
|
|
|
auto stmt = For::make(i, start, stopOuter, inner);
|
|
|
|
|
|
|
|
|
|
PaddedBuffer<int> data(400);
|
|
|
|
|
|
|
|
|
|
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
|
|
|
|
|
}
|
|
|
|
|
|
2020-03-16 18:38:29 +00:00
|
|
|
void testCond01() {
|
|
|
|
|
const int N = 16;
|
|
|
|
|
PaddedBuffer<float> a_v(N);
|
2021-09-14 07:19:57 +00:00
|
|
|
BufHandle a_buf("a", {N}, kFloat);
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle index = VarHandle("index", kInt);
|
2021-08-17 20:39:36 +00:00
|
|
|
StmtPtr assign_x2 = a_buf.store({index}, cast<float>(index) * 2);
|
|
|
|
|
StmtPtr assign_x3 = a_buf.store({index}, cast<float>(index) * 3);
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
|
2021-08-17 20:39:36 +00:00
|
|
|
StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3);
|
|
|
|
|
StmtPtr for_stmt = For::make(index, 0, N, assign);
|
2020-12-22 04:15:34 +00:00
|
|
|
SimpleIREvaluator(for_stmt, {a_buf})(a_v);
|
2020-03-16 18:38:29 +00:00
|
|
|
|
|
|
|
|
PaddedBuffer<float> a_ref(N);
|
2021-10-19 04:58:26 +00:00
|
|
|
for (const auto i : c10::irange(N)) {
|
2020-03-16 18:38:29 +00:00
|
|
|
if (i % 2 == 0) {
|
|
|
|
|
a_ref(i) = i * 2;
|
|
|
|
|
} else {
|
|
|
|
|
a_ref(i) = i * 3;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ExpectAllNear(a_v, a_ref, 1e-5);
|
|
|
|
|
}
|
|
|
|
|
|
2020-02-21 21:06:13 +00:00
|
|
|
void testIfThenElse01() {
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f));
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
oss << v;
|
2020-03-16 18:38:29 +00:00
|
|
|
ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)");
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
SimpleIRExprEval eval(v);
|
|
|
|
|
ASSERT_EQ(eval.value<float>(), 1.0f);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void testIfThenElse02() {
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f));
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
oss << v;
|
2020-03-16 18:38:29 +00:00
|
|
|
ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
SimpleIRExprEval eval(v);
|
|
|
|
|
ASSERT_EQ(eval.value<float>(), 2.0f);
|
|
|
|
|
}
|
|
|
|
|
|
2020-08-04 19:17:09 +00:00
|
|
|
void testIfThenElse03() {
|
|
|
|
|
ExprHandle v =
|
|
|
|
|
ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f));
|
|
|
|
|
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
oss << v;
|
|
|
|
|
ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
|
|
|
|
|
|
|
|
|
|
SimpleIRExprEval eval(v);
|
|
|
|
|
ASSERT_EQ(eval.value<float>(), 2.0f);
|
|
|
|
|
}
|
|
|
|
|
|
2020-03-17 17:54:43 +00:00
|
|
|
void testStmtClone() {
|
|
|
|
|
const int N = 16;
|
|
|
|
|
|
2021-09-14 07:19:57 +00:00
|
|
|
BufHandle a_buf("a", {N}, kInt);
|
2020-03-17 17:54:43 +00:00
|
|
|
VarHandle index = VarHandle("index", kInt);
|
2021-08-17 20:39:36 +00:00
|
|
|
StmtPtr body = a_buf.store({index}, 5);
|
|
|
|
|
StmtPtr loop = For::make(index, 0, N, body);
|
2020-03-17 17:54:43 +00:00
|
|
|
|
2021-08-17 20:39:36 +00:00
|
|
|
StmtPtr cloned_loop = Stmt::clone(loop);
|
2020-03-17 17:54:43 +00:00
|
|
|
std::vector<int> orig_loop_results(N);
|
|
|
|
|
std::vector<int> cloned_loop_results(N);
|
2020-12-22 04:15:34 +00:00
|
|
|
SimpleIREvaluator(loop, {a_buf})(orig_loop_results);
|
|
|
|
|
SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results);
|
2020-03-17 17:54:43 +00:00
|
|
|
|
|
|
|
|
assertAllEqual(orig_loop_results, 5);
|
|
|
|
|
assertAllEqual(cloned_loop_results, 5);
|
|
|
|
|
|
|
|
|
|
// Let's add another assign to the body in the cloned loop and verify that the
|
|
|
|
|
// original statement hasn't changed while the cloned one has.
|
2021-08-17 20:39:36 +00:00
|
|
|
StmtPtr body_addition = a_buf.store({index}, 33);
|
|
|
|
|
BlockPtr cloned_body = static_to<Block>(static_to<For>(cloned_loop)->body());
|
2020-03-17 17:54:43 +00:00
|
|
|
cloned_body->append_stmt(body_addition);
|
|
|
|
|
|
|
|
|
|
std::vector<int> orig_loop_results_after_mutation(N);
|
|
|
|
|
std::vector<int> cloned_loop_results_after_mutation(N);
|
2020-12-22 04:15:34 +00:00
|
|
|
SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation);
|
|
|
|
|
SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation);
|
2020-03-17 17:54:43 +00:00
|
|
|
|
|
|
|
|
assertAllEqual(orig_loop_results_after_mutation, 5);
|
|
|
|
|
assertAllEqual(cloned_loop_results_after_mutation, 33);
|
|
|
|
|
}
|
|
|
|
|
|
2020-02-21 21:06:13 +00:00
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace torch
|