2020-02-21 21:06:13 +00:00
|
|
|
#include <stdexcept>
|
2020-03-16 18:38:29 +00:00
|
|
|
#include "test/cpp/tensorexpr/test_base.h"
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
#include "torch/csrc/jit/tensorexpr/expr.h"
|
|
|
|
|
#include "torch/csrc/jit/tensorexpr/ir.h"
|
|
|
|
|
#include "torch/csrc/jit/tensorexpr/ir_printer.h"
|
|
|
|
|
|
|
|
|
|
#include <sstream>
|
|
|
|
|
namespace torch {
|
|
|
|
|
namespace jit {
|
|
|
|
|
|
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
|
|
|
|
|
|
void testIRPrinterBasicValueTest() {
|
|
|
|
|
KernelScope kernel_scope;
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle a = IntImm::make(2), b = IntImm::make(3);
|
|
|
|
|
ExprHandle c = Add::make(a, b);
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
ss << c;
|
2020-03-26 03:30:45 +00:00
|
|
|
EXPECT_EQ(ss.str(), "2 + 3");
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void testIRPrinterBasicValueTest02() {
|
|
|
|
|
KernelScope kernel_scope;
|
2020-03-16 18:38:29 +00:00
|
|
|
ExprHandle a(2.0f);
|
|
|
|
|
ExprHandle b(3.0f);
|
|
|
|
|
ExprHandle c(4.0f);
|
|
|
|
|
ExprHandle d(5.0f);
|
|
|
|
|
ExprHandle f = (a + b) - (c + d);
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
ss << f;
|
2020-03-26 03:30:45 +00:00
|
|
|
EXPECT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void testIRPrinterLetTest01() {
|
|
|
|
|
KernelScope kernel_scope;
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kFloat);
|
|
|
|
|
ExprHandle value = ExprHandle(3.f);
|
|
|
|
|
ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
|
|
|
|
|
ExprHandle result = Let::make(x, ExprHandle(3.f), body);
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
ss << result;
|
2020-03-26 03:30:45 +00:00
|
|
|
EXPECT_EQ(ss.str(), "let x = 3.f in 2.f + (x * 3.f + 4.f)");
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void testIRPrinterLetTest02() {
|
|
|
|
|
KernelScope kernel_scope;
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kFloat);
|
|
|
|
|
VarHandle y("y", kFloat);
|
|
|
|
|
ExprHandle value = ExprHandle(3.f);
|
|
|
|
|
ExprHandle body =
|
|
|
|
|
ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y);
|
|
|
|
|
ExprHandle e1 = Let::make(x, ExprHandle(3.f), body);
|
|
|
|
|
ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1);
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
ss << e2;
|
2020-03-26 03:30:45 +00:00
|
|
|
EXPECT_EQ(
|
2020-03-16 18:38:29 +00:00
|
|
|
ss.str(), "let y = 6.f in (let x = 3.f in 2.f + (x * 3.f + 4.f * y))");
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void testIRPrinterCastTest() {
|
|
|
|
|
KernelScope kernel_scope;
|
2020-03-16 18:38:29 +00:00
|
|
|
VarHandle x("x", kFloat);
|
|
|
|
|
VarHandle y("y", kFloat);
|
|
|
|
|
ExprHandle value = ExprHandle(3.f);
|
|
|
|
|
ExprHandle body =
|
|
|
|
|
ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y);
|
|
|
|
|
ExprHandle e1 = Let::make(x, Cast::make(kInt, ExprHandle(3.f)), body);
|
|
|
|
|
ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1);
|
2020-02-21 21:06:13 +00:00
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
ss << e2;
|
2020-03-26 03:30:45 +00:00
|
|
|
EXPECT_EQ(
|
2020-02-21 21:06:13 +00:00
|
|
|
ss.str(),
|
2020-03-16 18:38:29 +00:00
|
|
|
"let y = 6.f in (let x = int(3.f) in 2.f + (x * 3.f + 4.f * y))");
|
2020-02-21 21:06:13 +00:00
|
|
|
}
|
|
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace torch
|