#include #include "test/cpp/jit/test_utils.h" namespace torch { namespace jit { class TypeCheckTest : public ::testing::Test { protected: TypeCheckTest() : interp(makeInterp()) {} InterpreterState interp; private: static InterpreterState makeInterp() { auto graph = std::make_shared(); std::unordered_map vmap; parseIR( R"IR( graph(%a.1 : Tensor, %b.1 : Tensor): %t0 : Float(2:2, 2:1, device=cpu, requires_grad=1), %t1 : Float(3:3, 3:1), %type_matched : bool = prim::TypeCheck(%a.1, %b.1) return (%t0, %t1, %type_matched) )IR", &*graph, vmap); Code function(graph, ""); return InterpreterState(function); } }; TEST_F(TypeCheckTest, MatchingType) { // TypeCheck yields to true! Shape, grad and device matches. auto a = at::zeros({2, 2}, at::kFloat); auto b = at::ones({3, 3}, at::kFloat); a.set_requires_grad(true); a = a.to(at::kCPU); std::vector stack({a, b}); interp.run(stack); ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a)); ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b)); ASSERT_TRUE(stack[2].toBool()); } TEST_F(TypeCheckTest, SizeMismatch) { auto a = at::zeros({2, 2}, at::kFloat); auto b = at::ones({2, 2}, at::kFloat); // Size mismatch a.set_requires_grad(true); a = a.to(at::kCPU); std::vector stack({a, b}); interp.run(stack); ASSERT_FALSE(stack[2].toBool()); } TEST_F(TypeCheckTest, GradientMismatch) { auto a = at::zeros({2, 2}, at::kFloat); auto b = at::ones({3, 3}, at::kFloat); a = a.to(at::kCPU); a.set_requires_grad(false); // Gradient mismatch std::vector stack({a, b}); interp.run(stack); ASSERT_FALSE(stack[2].toBool()); } TEST_F(TypeCheckTest, ScalarTypeMismatch) { auto a = at::zeros({2, 2}, at::kFloat); auto b = at::ones({3, 3}, at::kFloat); a = a.to(at::kCPU); a.set_requires_grad(true); a = a.to(at::kInt); // Scalar type mismatch std::vector stack({a, b}); interp.run(stack); ASSERT_FALSE(stack[2].toBool()); } TEST_F(TypeCheckTest, DeviceMismatch_CUDA) { auto a = at::zeros({2, 2}, at::kFloat); auto b = at::ones({3, 3}, at::kFloat); a.set_requires_grad(true); a = a.to(at::kCUDA); // Device mismatch std::vector stack({a, b}); interp.run(stack); ASSERT_FALSE(stack[2].toBool()); } // TODO: These tests weren't doing anything. // TEST(TypeCheckErrorTest, EmptyCheckRaises) { // // Test empty Typecheck raises an internal assertion // auto graph = std::make_shared(); // std::unordered_map vmap; // EXPECT_ANY_THROW(parseIR( // R"IR( // graph(%a.1 : Tensor, // %b.1 : Tensor): // %type_matched : bool = prim::TypeCheck() // return (%type_matched) // )IR", // &*graph, // vmap)); // } // TODO: These tests weren't doing anything. // TEST(TypeCheckErrorTest, WrongInputOutputCountRaises) { // // Test for assertion if num_inputs + 1 != num_outputs // auto graph = std::make_shared(); // std::unordered_map vmap; // EXPECT_ANY_THROW(parseIR( // R"IR( // graph(%a.1 : Tensor, // %b.1 : Tensor): // %type_matched : bool = prim::TypeCheck(%a.1) // return (%type_matched) // )IR", // &*graph, // vmap)); // } TEST(InterpreterTest, Basic_CUDA) { constexpr int batch_size = 4; constexpr int input_size = 256; constexpr int seq_len = 32; int hidden_size = 2 * input_size; auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA); auto hx = at::randn({batch_size, hidden_size}, at::kCUDA); auto cx = at::randn({batch_size, hidden_size}, at::kCUDA); auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA)); auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA)); auto lstm_g = build_lstm(); Code lstm_function(lstm_g, ""); InterpreterState lstm_interp(lstm_function); auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh}); std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); ASSERT_TRUE(exactlyEqual(outputs[0], hx)); ASSERT_TRUE(exactlyEqual(outputs[1], cx)); } } // namespace jit } // namespace torch