diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e7b8e7aba35..776ea8d0b8a 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -557,6 +557,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp ) add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB}) @@ -574,6 +575,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp ) if (USE_NCCL) list(APPEND Caffe2_HIP_SRCS diff --git a/test/cpp/tensorexpr/gtest.cpp b/test/cpp/tensorexpr/gtest.cpp index 7c660eeeb78..62e299818b5 100644 --- a/test/cpp/tensorexpr/gtest.cpp +++ b/test/cpp/tensorexpr/gtest.cpp @@ -12,5 +12,14 @@ namespace jit { TH_FORALL_TESTS(TENSOREXPR_GTEST) #undef TENSOREXPR_GTEST +#ifdef USE_CUDA +#define TENSOREXPR_GTEST_CUDA(name) \ + TEST(TensorExprTest, name##_CUDA) { \ + test##name(); \ + } +TH_FORALL_TESTS_CUDA(TENSOREXPR_GTEST_CUDA) +#undef TENSOREXPR_GTEST_CUDA +#endif + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp new file mode 100644 index 00000000000..7854d024647 --- /dev/null +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -0,0 +1,333 @@ +#ifdef USE_CUDA + +#include +#include +#include "test/cpp/tensorexpr/test_base.h" + +#include + +#include "test/cpp/tensorexpr/padded_buffer.h" +#include "torch/csrc/jit/tensorexpr/buffer.h" +#include "torch/csrc/jit/tensorexpr/cuda_codegen.h" +#include "torch/csrc/jit/tensorexpr/schedule.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; +using namespace torch::jit::tensorexpr::schedule; + +template +void testCudaTestVectorAdd01_impl() { + KernelScope kernel_scope; + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Dtype dtype = ToDtype(); + Buffer a_buf("a", dtype, {num_iter, block_count, block_size}); + Buffer b_buf("b", dtype, {num_iter, block_count, block_size}); + Tensor* c = Compute( + "c", + { + {num_iter, "n"}, + {block_count, "b_id"}, + {block_size, "t_id"}, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return a_buf(n, b_id, t_id) + b_buf(n, b_id, t_id); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + l.SetGPUBlockIndex(loops[1], 0); + l.SetGPUThreadIndex(loops[2], 0); + Stmt* stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); + const int N = block_count * block_size * num_iter; + PaddedBuffer a_v(N); + PaddedBuffer b_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (int i = 0; i < N; i++) { + a_v(i) = ctype(i); + b_v(i) = ctype(i * 3 + 7); + c_ref(i) = a_v(i) + b_v(i); + } + + // TODO: move gpu support into PaddedBuffer + ctype* a_dev = nullptr; + cudaMalloc(&a_dev, N * sizeof(ctype)); + ctype* b_dev = nullptr; + cudaMalloc(&b_dev, N * sizeof(ctype)); + ctype* c_dev = nullptr; + cudaMalloc(&c_dev, N * sizeof(ctype)); + cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice); + cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice); + cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cuda_cg(c_dev, a_dev, b_dev); + + cudaDeviceSynchronize(); + cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + ExpectAllNear(c_v, c_ref, 1e-5); + + cudaFree(a_dev); + cudaFree(b_dev); + cudaFree(c_dev); +} + +void testCudaTestVectorAdd01() { + // floating types. + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + + // integer types. + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); +} + +static void testCudaTestVectorAdd02_impl(int N, int block_size) { + KernelScope kernel_scope; + Buffer a_buf("a", kFloat, {N}); + Buffer b_buf("b", kFloat, {N}); + Tensor* c = Compute( + "c", + { + {N, "N"}, + }, + [&](const VarHandle& n) { return a_buf(n) + b_buf(n); }); + LoopNest l({c}); + Stmt* n_outer; + Stmt* n_inner; + std::vector loops = l.getLoopStmtsFor(c); + l.SplitWithMask(loops[0], block_size, &n_outer, &n_inner); + l.SetGPUBlockIndex(n_outer, 0); + l.SetGPUThreadIndex(n_inner, 0); + Stmt* stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); + PaddedBuffer a_v(N); + PaddedBuffer b_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (int i = 0; i < N; i++) { + a_v(i) = i; + b_v(i) = i * 3 + 7; + c_ref(i) = a_v(i) + b_v(i); + } + + // TODO: move gpu support into PaddedBuffer + float* a_dev = nullptr; + cudaMalloc(&a_dev, N * sizeof(float)); + float* b_dev = nullptr; + cudaMalloc(&b_dev, N * sizeof(float)); + float* c_dev = nullptr; + cudaMalloc(&c_dev, N * sizeof(float)); + cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cuda_cg(c_dev, a_dev, b_dev); + + cudaDeviceSynchronize(); + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + ExpectAllNear(c_v, c_ref, 1e-5); + + cudaFree(a_dev); + cudaFree(b_dev); + cudaFree(c_dev); +} + +void testCudaTestVectorAdd02() { + testCudaTestVectorAdd02_impl(1024, 128); + testCudaTestVectorAdd02_impl(1030, 128); +} + +void testCudaDynamicShape2D() { + KernelScope kernel_scope; + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {m, n}); + Buffer b(VarHandle("b", kHandle), kFloat, {m, n}); + Tensor* c = Compute( + "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { + return a(i, j) + b(i, j); + }); + LoopNest l({c}); + Stmt* s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, c, m, n}); + + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + float* aDev = nullptr; + float* bDev = nullptr; + float* cDev = nullptr; + cudaMalloc(&aDev, aData.size() * sizeof(aData[0])); + cudaMalloc(&bDev, bData.size() * sizeof(bData[0])); + cudaMalloc(&cDev, cData.size() * sizeof(cData[0])); + cudaMemcpy( + aDev, + aData.data(), + aData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice); + cudaMemcpy( + bDev, + bData.data(), + bData.size() * sizeof(bData[0]), + cudaMemcpyHostToDevice); + cudaMemcpy( + cDev, + cData.data(), + cData.size() * sizeof(cData[0]), + cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cg.call({aDev, bDev, cDev, M, N}); + cudaDeviceSynchronize(); + + cudaMemcpy( + cData.data(), + cDev, + cData.size() * sizeof(cData[0]), + cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + + cudaFree(aDev); + cudaFree(bDev); + cudaFree(cDev); + }; + testWithSize(32, 32); + testWithSize(1, 16); + testWithSize(27, 13); +} + +void testCudaTestRand01() { + KernelScope kernel_scope; + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Tensor* c = Compute( + "c", + { + {num_iter, "n"}, + {block_count, "b_id"}, + {block_size, "t_id"}, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return Intrinsics::make(IntrinsicsOp::kRand, kFloat); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + l.SetGPUBlockIndex(loops[1], 0); + l.SetGPUThreadIndex(loops[2], 0); + Stmt* stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c); + const int N = block_count * block_size * num_iter; + PaddedBuffer c_v(N); + + // TODO: move gpu support into PaddedBuffer + float* c_dev = nullptr; + cudaMalloc(&c_dev, N * sizeof(float)); + cudaDeviceSynchronize(); + + cuda_cg(c_dev); + + cudaDeviceSynchronize(); + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + float sum1 = 0; + float sum2 = 0; + float sum3 = 0; + for (int i = 0; i < N; i++) { + float v = c_v.data()[i]; + sum1 += v; + sum2 += v * v; + sum3 += v * v * v; + EXPECT_TRUE(v >= 0 && v < 1) << "invalid value: " << i << ", " << v; + } + sum1 /= N; + sum2 /= N; + sum3 /= N; + float sum1_mean = 1.f / 2; + float sum2_mean = 1.f / 3; + float sum3_mean = 1.f / 4; + + EXPECT_NEAR(sum1, sum1_mean, 2e-2); + EXPECT_NEAR(sum2, sum2_mean, 2e-2); + EXPECT_NEAR(sum3, sum3_mean, 2e-2); + cudaFree(c_dev); +} + +void testCudaDynamicShapeSplit() { + KernelScope ks; + constexpr int N = 4096; + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {n}); + Tensor* b = + Compute("b", {{n, "n"}}, [&](const VarHandle& i) { return a(i) * 2.0f; }); + LoopNest l({b}); + Stmt* outer; + Stmt* inner; + std::vector loops = l.getLoopStmtsFor(b); + l.SplitWithMask(loops[0], 1024, &outer, &inner); + l.SetGPUBlockIndex(outer, 0); + l.SetGPUThreadIndex(inner, 0); + Stmt* s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, n}); + + std::vector aData(N, 1.0f); + std::vector bData(N, 1.0f); + float* aDev = nullptr; + float* bDev = nullptr; + cudaMalloc(&aDev, aData.size() * sizeof(aData[0])); + cudaMalloc(&bDev, bData.size() * sizeof(bData[0])); + cudaMemcpy( + aDev, + aData.data(), + aData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice); + cudaMemcpy( + bDev, + bData.data(), + bData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cg.call({aDev, bDev, N}); + cudaDeviceSynchronize(); + + cudaMemcpy( + bData.data(), + bDev, + bData.size() * sizeof(aData[0]), + cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + ExpectAllNear(bData, std::vector(N, 2.0f), 1e-7); + + cudaFree(aDev); + cudaFree(bDev); +} + +} // namespace jit +} // namespace torch + +#endif diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index 17aff26cebe..19642d6d971 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -87,6 +87,11 @@ namespace jit { _(ATenltInt) #define TH_FORALL_TESTS_CUDA(_) \ + _(CudaTestVectorAdd01) \ + _(CudaTestVectorAdd02) \ + _(CudaDynamicShape2D) \ + _(CudaTestRand01) \ + _(CudaDynamicShapeSplit) #define DECLARE_TENSOREXPR_TEST(name) void test##name(); TH_FORALL_TESTS(DECLARE_TENSOREXPR_TEST) diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 381d68a9b86..361340dacb3 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -34,6 +34,16 @@ class ExecutionCounter(object): return value - self.start_value +class CudaCodeGenCreated(ExecutionCounter): + def __init__(self): + super(CudaCodeGenCreated, self).__init__("cuda_codegen_created") + + +class CudaCodeGenExecuted(ExecutionCounter): + def __init__(self): + super(CudaCodeGenExecuted, self).__init__("cuda_codegen_executed") + + class SimpleIREvalExecuted(ExecutionCounter): def __init__(self): super(SimpleIREvalExecuted, self).__init__("simple_ir_eval_executed") @@ -80,7 +90,7 @@ class TestTensorExprFuser(BaseTestClass): c = torch.addcmul(torch.add(x, y), z, w) return c - device_options = ["cpu"] + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu'] for dev in device_options: rand_a = torch.rand(1024, dtype=torch.float, device=dev) rand_b = torch.rand(1024, dtype=torch.float, device=dev) @@ -102,6 +112,79 @@ class TestTensorExprFuser(BaseTestClass): np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6) + def test_three_arg_cuda(self): + if not torch.cuda.is_available(): + return + cuda_cg_executed = CudaCodeGenExecuted() + cuda_cg_created = CudaCodeGenCreated() + + def test(x, y, z): + aaa = torch.add(x, y) + bbb = torch.add(aaa, z) + return bbb + + M = 32 + N = 32 + traced = torch.jit.trace( + test, + ( + torch.rand(M, N, device="cuda"), + torch.rand(M, N, device="cuda"), + torch.rand(M, N, device="cuda"), + ), + ) + + a = torch.rand(M, N, device="cuda") + b = torch.rand(M, N, device="cuda") + c = torch.rand(M, N, device="cuda") + x = traced(a, b, c) + npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() + np.testing.assert_allclose(npr, x.cpu().numpy()) + assert cuda_cg_executed.elapsed_value() >= 1 + assert cuda_cg_created.elapsed_value() >= 1 + + + def test_broadcast_cuda(self): + if not torch.cuda.is_available(): + return + + def test_body(M, N, L, K): + if not torch.cuda.is_available(): + return + cuda_cg_executed = CudaCodeGenExecuted() + cuda_cg_created = CudaCodeGenCreated() + + def test(x, y, z): + v1 = torch.add(x, y) + v2 = torch.add(v1, z) + return v2 + + a_shape = [M, N] + b_shape = [L, M, 1] + c_shape = [K, L, 1, 1] + traced = torch.jit.trace( + test, + ( + torch.rand(*a_shape, device="cuda"), + torch.rand(*b_shape, device="cuda"), + torch.rand(*c_shape, device="cuda"), + ), + ) + + a = torch.rand(*a_shape, device="cuda") + b = torch.rand(*b_shape, device="cuda") + c = torch.rand(*c_shape, device="cuda") + x = traced(a, b, c) + npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() + np.testing.assert_allclose(npr, x.cpu().numpy()) + assert cuda_cg_executed.elapsed_value() >= 1 + assert cuda_cg_created.elapsed_value() >= 1 + + test_configs = [[36, 17, 63, 33], [32, 32, 32, 32]] + for test_config in test_configs: + test_body(*test_config) + + def test_all_combos(self): def easy(x, y, z): a = torch.add(x, y) @@ -426,7 +509,7 @@ class TestTensorExprFuser(BaseTestClass): c = torch.lt(x, y) return c - device_options = ["cpu"] + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] for dev in device_options: traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev))) a = torch.ones(1024, dtype=torch.int32, device=dev) @@ -451,7 +534,7 @@ class TestTensorExprFuser(BaseTestClass): def test(x): return torch.clamp(x + 3.0, 0.0, 6.0) - device_options = ["cpu"] + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] for dev in device_options: traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) @@ -463,7 +546,7 @@ class TestTensorExprFuser(BaseTestClass): def test(x): return torch.clamp(F.relu(x), 0, 0.5) - device_options = ["cpu"] + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] for dev in device_options: traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) a = 20.0 * torch.rand(1024, device=dev) - 10.0 @@ -598,7 +681,7 @@ class TestTensorExprFuser(BaseTestClass): # test_tanh_backward, test_type_as, } - device_options = ["cpu"] + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu'] for torch_fn in fns: for dev in device_options: rand_a = torch.rand(1024, device=dev) @@ -776,7 +859,7 @@ class TestTensorExprFuser(BaseTestClass): test_neg, test_relu, } - device_options = ["cpu"] + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu'] for torch_fn in fns: for dev in device_options: @@ -797,6 +880,26 @@ class TestTensorExprFuser(BaseTestClass): np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) + def test_rand_like(self): + devices = ["cuda"] if torch.cuda.is_available() else [] + N = 1 << 16 + + def run_rand_like(x, y): + return torch.rand_like(torch.add(x, y)) + + for device in devices: + x = torch.rand(N, device=device) + traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False) + x_v = traced(x, x) + x_np = x.cpu().numpy() + x1_mean = np.mean(x_np) + x2_mean = np.mean(x_np ** 2) + x3_mean = np.mean(x_np ** 3) + np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2) + np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2) + np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2) + + def test_nans(self): def test_max(x, y): return torch.max(2 * x, 2 * y) @@ -898,6 +1001,10 @@ class TestTensorExprFuser(BaseTestClass): def test_cat_cpu(self): self._test_cat('cpu') + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + def test_cat_cuda(self): + self._test_cat('cuda') + def test_scalar(self): @torch.jit.script def test_float(x, y, z, a, b): @@ -1001,8 +1108,66 @@ class TestTensorExprFuser(BaseTestClass): assert interp.elapsed_value() == 1 + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + @unittest.skip("dynamic shapes are not quite there yet") + def test_dynamic_shape(self): + with num_profiled_runs(2): + @torch.jit.script + def test(x, y, z): + return x * y * z + cuda = CudaCodeGenCreated() + x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)] + ref = test(x, y, z) + _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)]) + res = test(x, y, z) + np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy()) + assert cuda.elapsed_value() == 1 + + # A wild broadcast appears. + x = torch.rand(4, 8).cuda() + y = torch.rand(1, 8).cuda() + z = torch.rand(4, 1).cuda() + res = test(x, y, z) + xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] + np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) + assert cuda.elapsed_value() == 1 + + # Mismatched shapes shouldn't reach codegen. + x = torch.rand(4, 8).cuda() + y = torch.rand(4, 8).cuda() + z = torch.rand(5, 8).cuda() + try: + res = test(x, y, z) + except RuntimeError as e: + assert "The size of tensor a (4) must match" in e.args[0] + assert cuda.elapsed_value() == 1 + + # Changing a static dimension fails guards. + # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)] + # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] + # res = test(x, y, z) + # print(test.graph_for(x, y, z)) + # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) + # assert cuda.elapsed_value() == 1 + + @unittest.skip("guarding on static shapes is not working") + def test_guard_fails(self): + @torch.jit.script + def test(x, y, z): + return x * y * z + cuda = CudaCodeGenExecuted() + r1 = test(*[torch.rand(4).cuda() for _ in range(3)]) + assert cuda.elapsed_value() == 0 + r2 = test(*[torch.rand(4).cuda() for _ in range(3)]) + assert cuda.elapsed_value() == 1 + r3 = test(*[torch.rand(4).cuda() for _ in range(3)]) + assert cuda.elapsed_value() == 2 + r4 = test(*[torch.rand(7).cuda() for _ in range(3)]) + print(test.graph_for(*[torch.rand(7).cuda() for _ in range(3)])) + assert cuda.elapsed_value() == 2 + def test_bitwise_ops(self): - devices = ["cpu"] + devices = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] def run_and(x, y): return x & (x & y) diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 11ecc78c7fa..d548a35e642 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -219,6 +219,7 @@ libtorch_cuda_sources = [ "torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp", "torch/csrc/autograd/profiler_cuda.cpp", "torch/csrc/autograd/functions/comm.cpp", + "torch/csrc/jit/tensorexpr/cuda_codegen.cpp", ] torch_cpp_srcs = [ diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 02a652dfad0..722c2c7d7a5 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -60,6 +60,7 @@ #include #include #include +#include #include #include @@ -414,6 +415,42 @@ void initJITBindings(PyObject* module) { ExecutionTriggerList::GetInstance().FindByName(trigger_name); return trigger->value(); }) + .def( + "_jit_get_te_cuda_pointwise_loop_levels", + []() -> int { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseLoopLevels(); + }) + .def( + "_jit_set_te_cuda_pointwise_loop_levels", + [](int level) { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseLoopLevels() = level; + }) + .def( + "_jit_get_te_cuda_pointwise_block_count", + []() -> int { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseBlockCount(); + }) + .def( + "_jit_set_te_cuda_pointwise_block_count", + [](int block_count) { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseBlockCount() = block_count; + }) + .def( + "_jit_get_te_cuda_pointwise_block_size", + []() -> int { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseBlockSize(); + }) + .def( + "_jit_set_te_cuda_pointwise_block_size", + [](int block_size) { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseBlockSize() = block_size; + }) .def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled) .def( "_jit_fuser_get_fused_kernel_code", diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp new file mode 100644 index 00000000000..c3f13325d3b --- /dev/null +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -0,0 +1,695 @@ +#include "torch/csrc/jit/tensorexpr/cuda_codegen.h" +#include "torch/csrc/jit/tensorexpr/cuda_half_support.h" + +#include "ATen/CUDAGenerator.h" +#include "c10/cuda/CUDAFunctions.h" +#include "torch/csrc/jit/tensorexpr/cuda_random.h" +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/execution_counter.h" + +#define DEBUG_PRINT 0 + +namespace torch { +namespace jit { +namespace tensorexpr { + +DEFINE_TRIGGER(cuda_codegen_created); +DEFINE_TRIGGER(cuda_codegen_executed); + +// A RAII wrapper to manage a variable and name pair in the look-up table. +// TODO: move this to a more shared place. +class ScopedVarName { + public: + ScopedVarName(VarNameMap* mapping, const Var* var, const std::string& name) + : mapping_(mapping), var_(var) { + auto iter = mapping->find(var); + if (iter != mapping->end()) { + throw std::runtime_error("Duplicate var entry: " + var->name_hint()); + } + mapping->insert(std::make_pair(var, name)); + } + + ScopedVarName( + UniqueNameManager* manager, + const Var* var, + const std::string& name) + : ScopedVarName(&manager->unique_name_mapping_, var, name) {} + + ScopedVarName(const ScopedVarName&) = delete; + ScopedVarName& operator=(const ScopedVarName&) = delete; + + ~ScopedVarName() noexcept(false) { + mapping_->erase(var_); + } + + private: + VarNameMap* mapping_ = nullptr; + const Var* var_ = nullptr; +}; + +static int as_int(const Expr* expr) { + auto v = dynamic_cast(expr); + TORCH_CHECK(v, "Expression is not an integer constant"); + return v->value(); +} + +static bool is_zero(const Expr* expr) { + return as_int(expr) == 0; +} + +static const at::cuda::NVRTC& nvrtc() { + return at::globalContext().getNVRTC(); +} + +static void getMajorMinor( + const cudaDeviceProp* const prop, + int& major, + int& minor) { + using CudaVersion = std::pair; + CudaVersion nvrtc_version; + AT_CUDA_NVRTC_CHECK( + nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second)); + + AT_ASSERT(nvrtc_version.first >= 6); + + CudaVersion dev_version = CudaVersion(prop->major, prop->minor); + CudaVersion max_dev_version(dev_version); + if (nvrtc_version.first <= 7) { // 7 supports 2-5.x + max_dev_version = CudaVersion(5, 0); + } else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x + max_dev_version = CudaVersion(6, 0); + } else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2 + max_dev_version = CudaVersion(7, 2); + } else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5 + max_dev_version = CudaVersion(7, 5); + } + if (dev_version > max_dev_version) { + dev_version = max_dev_version; + } + major = dev_version.first; + minor = dev_version.second; +} + +void CudaPrinter::visit(const For* v) { + const LoopOptions& loop_options = v->loop_options(); + if (loop_options.is_gpu_block_index()) { + ScopedVarName var_name( + name_manager(), v->var(), loop_options.gpu_block_index_str()); + v->body()->accept(this); + int gpu_block_index = loop_options.gpu_block_index(); + if (gpu_block_extents_.size() <= gpu_block_index) { + gpu_block_extents_.resize(gpu_block_index + 1); + } + if (!is_zero(v->start())) { + throw std::runtime_error( + "start must be zero for gpu_block_index: " + + std::to_string(ExprHandle(v->start()))); + } + gpu_block_extents_[gpu_block_index] = v->stop(); + } else if (loop_options.is_gpu_thread_index()) { + ScopedVarName var_name( + name_manager(), v->var(), loop_options.gpu_thread_index_str()); + v->body()->accept(this); + int gpu_thread_index = loop_options.gpu_thread_index(); + if (gpu_thread_extents_.size() <= gpu_thread_index) { + gpu_thread_extents_.resize(gpu_thread_index + 1); + } + if (!is_zero(v->start())) { + throw std::runtime_error( + "start must be zero for gpu_block_index: " + + std::to_string(ExprHandle(v->start()))); + } + gpu_thread_extents_[gpu_thread_index] = v->stop(); + } else { + IRPrinter::visit(v); + } +} + +void CudaPrinter::visit(const Intrinsics* v) { + if (v->op_type() == IntrinsicsOp::kRand) { + os() << "Uint32ToFloat(" << *rand_func_ << "())"; + return; + } + + std::string func_name = v->func_name(); + + // get type of resulting expression. + ScalarType returnType = v->param(0)->dtype().scalar_type(); + for (int i = 1; i < v->nparams(); ++i) { + returnType = promoteTypes(returnType, v->param(i)->dtype().scalar_type()); + } + + if (returnType == ScalarType::Half || returnType == ScalarType::Float) { + func_name = func_name + "f"; + } + + os() << func_name << "("; + for (int i = 0; i < v->nparams(); i++) { + if (i > 0) { + os() << ", "; + } + os() << *v->param(i); + } + os() << ")"; +} + +void CudaPrinter::visit(const Load* v) { + // TODO: find a better metric in using ldg or not. Support different dtypes. + if (v->dtype().scalar_type() == ScalarType::Half) { + os() << "__half2float(" << *v->base_handle() << "[" << *v->index() << "])"; + } else { + os() << "__ldg(" << *v->base_handle() << " + " << *v->index() << ")"; + } +} + +void CudaPrinter::visit(const Store* v) { + os() << *v->base_handle() << "[" << *v->index() << "] = "; + if (v->value()->dtype().scalar_type() == ScalarType::Half) { + os() << "__float2half(" << *v->value() << ");"; + } else { + os() << *v->value() << ";"; + } +} + +void CudaPrinter::visit(const Max* v) { + auto dtype = v->dtype().scalar_type(); + switch (dtype) { + case ScalarType::Half: + // doing Half math in float. + case ScalarType::Float: + os() << "fmaxf"; + break; + case ScalarType::Double: + os() << "fmax"; + break; + default: + os() << "max"; + break; + } + os() << "("; + v->lhs()->accept(this); + os() << ","; + v->rhs()->accept(this); + os() << ")"; +} + +void CudaPrinter::visit(const Min* v) { + auto dtype = v->dtype().scalar_type(); + switch (dtype) { + case ScalarType::Half: + // doing Half math in float. + case ScalarType::Float: + os() << "fminf"; + break; + case ScalarType::Double: + os() << "fmin"; + break; + default: + os() << "min"; + break; + } + os() << "("; + v->lhs()->accept(this); + os() << ","; + v->rhs()->accept(this); + os() << ")"; +} + +std::string cudaDtypeCppString(const Dtype& dtype) { + switch (dtype.scalar_type()) { + case ScalarType::Half: + return "half"; + case ScalarType::Char: + return "char"; + case ScalarType::Byte: + return "unsigned char"; + case ScalarType::Short: + return "short"; + case ScalarType::Long: + return "long"; + default:; /* nothing */ + } + return dtype.ToCppString(); +} + +void CudaPrinter::visit(const LetStmt* v) { + const Var* var = v->var(); + if (var->dtype().scalar_type() == ScalarType::Half) { + // we do math in floats so use that. + os() << "float"; + } else { + os() << cudaDtypeCppString(var->dtype()); + } + os() << " " << *var << " = " << *v->value() << "; " << std::endl; + v->body()->accept(this); +} + +void CudaPrinter::visit(const IfThenElse* v) { + os() << "(("; + v->condition()->accept(this); + os() << ") ? "; + v->true_value()->accept(this); + os() << " : "; + v->false_value()->accept(this); + os() << ")"; +} + +class PrioritizeLoad : public IRMutator { + public: + const Expr* mutate(const Load* v) override { + // Look at the declaration of this variable for more details. + if (nested_if_then_else_ > 0) { + return IRMutator::mutate(v); + } + MemLoadList& load_list = load_stack_.back(); + const Var* load_new_var = new Var("v", v->dtype()); + const Expr* new_value = IRMutator::mutate(v); + load_list.push_back(std::make_pair(load_new_var, new_value)); + return load_new_var; + } + + // TODO: merge this with the IRMutator::mutate version. + Stmt* mutate(const For* v) override { + const Var* var = v->var(); + const Expr* start = v->start(); + const Expr* stop = v->stop(); + Stmt* body = v->body(); + LoopOptions loop_options = v->loop_options(); + const Var* var_new = dynamic_cast(var->accept_mutator(this)); + const Expr* start_new = start->accept_mutator(this); + const Expr* stop_new = stop->accept_mutator(this); + PushList(); + Stmt* body_new = body->accept_mutator(this); + if (!body_new) { + return nullptr; + } + Stmt* body_with_loads = AddMemLoadsFromList(body_new); + PopList(); + if (var == var_new && start == start_new && stop == stop_new && + body == body_with_loads) { + return (Stmt*)v; + } + return new For(var_new, start_new, stop_new, body_with_loads, loop_options); + } + + Stmt* mutate(const LetStmt* v) override { + const Var* var = v->var(); + const Expr* value = v->value(); + Stmt* body = v->body(); + const Var* var_new = dynamic_cast(var->accept_mutator(this)); + if (var_new == nullptr) { + throw std::runtime_error("LetStmt var must be variable"); + } + const Expr* value_new = value->accept_mutator(this); + PushList(); + Stmt* body_new = body->accept_mutator(this); + Stmt* body_with_loads = AddMemLoadsFromList(body_new); + PopList(); + if (var == var_new && value == value_new && body == body_with_loads) { + return (Stmt*)v; + } + return new LetStmt(var_new, value_new, body_with_loads); + } + + Stmt* mutate(const Cond* v) override { + const Expr* cond_old = v->condition(); + Stmt* true_old = v->true_stmt(); + Stmt* false_old = v->false_stmt(); + + const Expr* cond_new = cond_old->accept_mutator(this); + PushList(); + Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old; + Stmt* true_with_loads = AddMemLoadsFromList(true_new); + PopList(); + PushList(); + Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old; + Stmt* false_with_loads = AddMemLoadsFromList(false_new); + PopList(); + + if (cond_old == cond_new && true_old == true_with_loads && + false_old == false_with_loads) { + return (Stmt*)v; + } + return new Cond(cond_new, true_with_loads, false_with_loads); + } + + const Expr* mutate(const IfThenElse* v) override { + nested_if_then_else_++; + const Expr* new_v = IRMutator::mutate(v); + nested_if_then_else_--; + return new_v; + } + + Stmt* Process(Stmt* stmt) { + this->PushList(); + Stmt* stmt_v = stmt; + Stmt* stmt_new = stmt_v->accept_mutator(this); + Stmt* stmt_with_loads = AddMemLoadsFromList(stmt_new); + this->PopList(); + return stmt_with_loads; + } + + private: + using MemLoadEntry = std::pair; + using MemLoadList = std::vector; + using MemoryLoadStack = std::vector; + + void PushList() { + load_stack_.push_back(MemLoadList()); + } + + void PopList() { + load_stack_.pop_back(); + } + + Stmt* AddMemLoadsFromList(Stmt* stmt) { + MemLoadList& load_list = load_stack_.back(); + Stmt* stmt_v = stmt; + for (auto iter = load_list.rbegin(); iter != load_list.rend(); iter++) { + const MemLoadEntry& entry = *iter; + const Var* var_ptr = entry.first; + stmt_v = new LetStmt(var_ptr, entry.second, stmt_v); + } + return stmt_v; + } + + MemoryLoadStack load_stack_; + // TODO: For now, we are not moving the loads with the IfThenElse. + // Eventually, we should switch to a more generic structure like: + // int v2 = IfThenElse(cond, true_v, false_v) + 2 -> + // + // int v; + // if (cond) { + // v = true_v; + // } else { + // v = false_v; + // } + // int v2 = v + 2; + int nested_if_then_else_ = 0; +}; + +class HasRand : public IRVisitor { + public: + HasRand(Stmt* stmt) : stmt_(stmt) { + stmt_->accept(this); + } + + bool has_rand() const { + return has_rand_; + } + + private: + void visit(const Intrinsics* v) override { + if (v->op_type() == IntrinsicsOp::kRand) { + has_rand_ = true; + } else { + IRVisitor::visit(v); + } + } + Stmt* stmt_; + bool has_rand_ = false; +}; + +std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) { + // We are using a global counter here to make sure difference instances within + // CudaCodeGen have different names. + static int64_t counter = 0; + ++counter; + int64_t value = counter; + return func_prefix + "_" + std::to_string(value); +} + +void CudaCodeGen::Initialize() { + // TODO: handle multiple kernels. + // TODO: handle dynamic dimension. + // TODO: call nvrtc. + HasRand has_rand_func(stmt()); + has_random_ = has_rand_func.has_rand(); + printer_ = std::make_unique(&oss_, has_random_); + + os() << "#define NAN __int_as_float(0x7fffffff)\n" + "#define POS_INFINITY __int_as_float(0x7f800000)\n" + "#define NEG_INFINITY __int_as_float(0xff800000)\n"; + if (has_random_) { + os() << philox_random_string << std::endl; + } + + // Check whether the statement uses the Half type, if so add the + // half_support_literal. + CudaHalfChecker halfChecker; + stmt()->accept(&halfChecker); + if (halfChecker.hasHalf()) { + os() << fuser::cuda::half_support_literal << std::endl; + } + + std::string func_name = GetUniqueFuncName("func"); + os() << "extern \"C\" __global__" << std::endl << "void " << func_name << "("; + const std::vector buffer_args = this->buffer_args(); + for (size_t i = 0; i < buffer_args.size(); i++) { + if (i > 0) { + os() << ", "; + } + const BufferArg& buffer_arg = buffer_args[i]; + const Var* var = buffer_arg.var(); + Dtype dtype = buffer_arg.dtype(); + + os() << cudaDtypeCppString(dtype) << (buffer_arg.isVar() ? " " : "* ") + << name_manager()->get_unique_name(var); + } + const Var* rand_seed; + const Var* rand_offset; + if (has_random_) { + // TODO: switch to kUint64 when it is available. + rand_seed = new Var("rand_seed", kInt); + rand_offset = new Var("rand_offset", kInt); + std::string uint64_str = "unsigned long long"; + os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " " + << *rand_offset; + } + os() << ") {"; + os() << std::endl; + + if (has_random_) { + const Var* idx = new Var("idx", kInt); + os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;" + << std::endl; + const Var* rand_func = printer_->rand_func(); + os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", " + << *rand_offset << ");" << std::endl; + os() << std::endl; + } + + Stmt* stmt_v = stmt(); + PrioritizeLoad prioritize_load; + stmt_v = prioritize_load.Process(stmt_v); + stmt_v->accept(printer_.get()); + os() << std::endl; + os() << "}"; + + // Check that all block extents had been set. + const std::vector& gpu_block_extents = + printer_->gpu_block_extents(); + const std::vector& gpu_thread_extents = + printer_->gpu_thread_extents(); + for (size_t i = 0; i < gpu_block_extents.size(); i++) { + if (!gpu_block_extents[i]) { + throw std::runtime_error("Missing gpu_block_index: " + std::to_string(i)); + } + } + +#if DEBUG_PRINT + std::cout << "stmt: " << std::endl; + std::cout << oss_.str() << std::endl; + std::cout << "block("; + for (size_t i = 0; i < gpu_block_extents.size(); i++) { + if (i > 0) { + std::cout << ", "; + } + std::cout << *gpu_block_extents[i]; + } + std::cout << "), thread("; + for (size_t i = 0; i < gpu_thread_extents.size(); i++) { + if (i > 0) { + std::cout << ", "; + } + std::cout << *gpu_thread_extents[i]; + } + std::cout << ")" << std::endl; + ; +#endif + + CompileToNVRTC(oss_.str(), func_name); + USE_TRIGGER(cuda_codegen_created); +} + +void CudaCodeGen::call(const std::vector& args) { + CHECK_EQ(args.size(), buffer_args().size()); + + // TODO: move as much of this into the constructors. + const std::vector& gpu_block_extents = + printer_->gpu_block_extents(); + const std::vector& gpu_thread_extents = + printer_->gpu_thread_extents(); + CHECK(gpu_block_extents.size() <= 3); + CHECK(gpu_thread_extents.size() <= 3); + std::vector gpu_block_extents_v(3, 1); + std::vector gpu_thread_extents_v(3, 1); + // evaluate all the block/thread extents into values + // TODO: eventually, codegen these calculations and make them part of the + // module. + for (size_t i = 0; i < gpu_block_extents.size(); i++) { + ExprEval eval( + ExprHandle(gpu_block_extents[i]), buffer_args()); + gpu_block_extents_v[i] = eval.value(args); + } + for (size_t i = 0; i < gpu_thread_extents.size(); i++) { + ExprEval eval( + ExprHandle(gpu_thread_extents[i]), buffer_args()); + gpu_thread_extents_v[i] = eval.value(args); + } + + // Skip launching the kernel if there are no elements to process. + for (int extent : gpu_block_extents_v) { + if (extent == 0) { + return; + } + } + + // Bind the buffer addresses into arguments + auto const& buffer_args = this->buffer_args(); + int ptr_count = buffer_args.size(); + if (has_random_) { + ptr_count += 2; + } + std::vector args_data(buffer_args.size()); + std::vector ptr_to_args(ptr_count); + uint64_t rand_seed = uint64_t(-1); + uint64_t rand_offset = uint64_t(-1); + for (size_t i = 0; i < buffer_args.size(); i++) { + auto const& bufferArg = buffer_args[i]; + if (bufferArg.isVar()) { + auto stype = bufferArg.dtype().scalar_type(); + switch (stype) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + ptr_to_args[i] = args[i].Name##Ptr(); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Unhandled dtype in argument"; + } + } else { + args_data[i] = args[i].data(); + ptr_to_args[i] = &args_data[i]; + } + } + + if (has_random_) { + auto gen = at::cuda::detail::getDefaultCUDAGenerator(); + // TODO: total hack. Switch to numel when it is available. + int64_t total_elements_per_thread = (1LL << 28); + { + std::lock_guard lock(gen->mutex_); + auto philox_engine_inputs = + gen->philox_engine_inputs(total_elements_per_thread); + rand_seed = philox_engine_inputs.first; + rand_offset = philox_engine_inputs.second; + } + ptr_to_args[buffer_args.size()] = &rand_seed; + ptr_to_args[buffer_args.size() + 1] = &rand_offset; + } + + // Launch the kernels + auto stream = at::cuda::getCurrentCUDAStream(); + AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( + function_, + gpu_block_extents_v[0], + gpu_block_extents_v[1], + gpu_block_extents_v[2], + gpu_thread_extents_v[0], + gpu_thread_extents_v[1], + gpu_thread_extents_v[2], + 0, + stream, + ptr_to_args.data(), + nullptr)); + USE_TRIGGER(cuda_codegen_executed); +} + +void CudaCodeGen::CompileToNVRTC( + const std::string& code, + const std::string& func_name) { + // Initializes driver's API context (if necessary) + CUdevice device = 0; + CUcontext pctx = 0; + AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); + if (!pctx) { + std::unique_lock cudaFreeMutexLock( + *(c10::cuda::CUDACachingAllocator::getFreeMutex())); + cudaFree(0); + } + + // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work + // properly in some scenarios + const auto prior_device = at::cuda::current_device(); + at::cuda::set_device(device); + + // Acquires device and NVRTC properties (for compile arch and occupancy + // calculations) + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + int major, minor; + getMajorMinor(prop, major, minor); + +#if DEBUG_PRINT + std::cout << "major: " << major << ", " + << "minor: " << minor << std::endl; +#endif + + // Creates the NVRTC program + nvrtcProgram program; + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram( + &program, code.c_str(), nullptr, 0, nullptr, nullptr)); + +#ifdef __HIP_PLATFORM_HCC__ + std::vector args = {}; +#else + const std::string compute = "--gpu-architecture=compute_" + + std::to_string(major) + std::to_string(minor); + const std::vector args = { + "--std=c++14", compute.c_str(), "-default-device"}; +#endif + + const auto result = + nvrtc().nvrtcCompileProgram(program, args.size(), args.data()); + if (result != NVRTC_SUCCESS) { + size_t logsize; + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize)); + std::vector log(logsize); + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data())); + std::stringstream cu; + cu << log.data() << std::endl; + cu << "nvrtc compilation failed: " << std::endl; + cu << code << std::endl; + throw std::runtime_error(cu.str()); + } + ResourceGuard holdProgram( + [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); }); + AT_CUDA_NVRTC_CHECK(result); + size_t ptx_size; + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size)); + std::vector ptx; + ptx.resize(ptx_size); + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx.data())); + + CUmodule module; + AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data())); + AT_CUDA_DRIVER_CHECK( + nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str())); +} + +RegisterCodeGen cuda_codegen_reg("cuda_codegen"); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.h b/torch/csrc/jit/tensorexpr/cuda_codegen.h new file mode 100644 index 00000000000..7afa9a07a0b --- /dev/null +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.h @@ -0,0 +1,123 @@ +#pragma once + +#include +#include + +#include "ATen/ATen.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/nvrtc_stub/ATenNVRTC.h" +#include "c10/cuda/CUDACachingAllocator.h" +#include "c10/cuda/CUDAGuard.h" +#include "torch/csrc/jit/resource_guard.h" +#include "torch/csrc/jit/tensorexpr/codegen.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" +#include "torch/csrc/jit/tensorexpr/ir_visitor.h" +#include "torch/csrc/jit/tensorexpr/unique_name_manager.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +// A class that overrides the underlying IRPrinter to produce Cuda C. +class CudaPrinter : public IRPrinter { + public: + explicit CudaPrinter(std::ostream* os, bool has_random) : IRPrinter(*os) { + if (has_random) { + rand_func_ = new Var("rand", kHandle); + } + } + + void visit(const Cast* v) override { + auto dtype = v->dtype(); + if (dtype == kHalf) { + os() << "half"; + } else { + os() << dtype; + } + os() << "("; + v->src_value()->accept(this); + os() << ")"; + } + + void visit(const Intrinsics* v) override; + void visit(const For* v) override; + + void visit(const Load* v) override; + void visit(const Store* v) override; + void visit(const Max* v) override; + void visit(const Min* v) override; + void visit(const LetStmt* v) override; + void visit(const IfThenElse* v) override; + + const std::vector& gpu_block_extents() const { + return gpu_block_extents_; + } + + const std::vector& gpu_thread_extents() const { + return gpu_thread_extents_; + } + + const Var* rand_func() const { + return rand_func_; + } + + using IRPrinter::name_manager; + using IRPrinter::visit; + + private: + std::vector gpu_block_extents_; + std::vector gpu_thread_extents_; + const Var* rand_func_; +}; + +// Construct Cuda C from the buffer and tensor input, and invoke the kernel +// when real arguments are provided. +class TORCH_CUDA_API CudaCodeGen : public CodeGen { + public: + template + CudaCodeGen(Stmt* stmt, Ts... ts) : CodeGen(stmt, std::forward(ts)...) { + Initialize(); + } + + CudaCodeGen(Stmt* stmt, const std::vector& buffer_args) + : CodeGen(stmt, buffer_args) { + Initialize(); + } + + ~CudaCodeGen() override {} + + void call(const std::vector& args) override; + + template + void operator()(const Ts&... ts) { + call(std::vector({CallArg(ts)...})); + } + + private: + void Initialize(); + + void CompileToNVRTC(const std::string& code, const std::string& func_name); + + UniqueNameManager* name_manager() { + if (!printer_) { + throw std::runtime_error("Null IRPrinter is not expected"); + } + return printer_->name_manager(); + } + + std::ostream& os() { + return printer_->os(); + } + + std::ostringstream oss_; + std::unique_ptr printer_; + CUfunction function_; + bool has_random_ = false; + + std::string GetUniqueFuncName(const std::string& func_prefix); +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/cuda_half_support.h b/torch/csrc/jit/tensorexpr/cuda_half_support.h new file mode 100644 index 00000000000..91d4a5f68cc --- /dev/null +++ b/torch/csrc/jit/tensorexpr/cuda_half_support.h @@ -0,0 +1,30 @@ +#pragma once + +#include "torch/csrc/jit/codegen/fuser/cuda/resource_strings.h" +#include "torch/csrc/jit/tensorexpr/cuda_codegen.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +// Walk the Statment looking for Half size loads/stores. +class CudaHalfChecker : public IRVisitor { + public: + bool hasHalf() { + return hasHalf_; + } + + void visit(const Load* v) override { + hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; + } + void visit(const Store* v) override { + hasHalf_ |= v->value()->dtype().scalar_type() == ScalarType::Half; + } + + private: + bool hasHalf_{false}; +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/cuda_random.h b/torch/csrc/jit/tensorexpr/cuda_random.h new file mode 100644 index 00000000000..987ac5211d9 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/cuda_random.h @@ -0,0 +1,104 @@ +#pragma once + +namespace torch { +namespace jit { +namespace tensorexpr { + +constexpr auto philox_random_string = R"( + +class Philox { +public: + __device__ inline Philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + key.x = (unsigned int)seed; + key.y = (unsigned int)(seed >> 32); + counter = make_uint4(0, 0, 0, 0); + counter.z = (unsigned int)(subsequence); + counter.w = (unsigned int)(subsequence >> 32); + STATE = 0; + incr_n(offset / 4); + } + + __device__ inline unsigned long operator()() { + if(STATE == 0) { + uint4 counter_ = counter; + uint2 key_ = key; + for(int i = 0; i < 9; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); key_.y += (kPhilox10B); + } + output = single_round(counter_, key_); + incr(); + } + unsigned long ret; + switch(STATE) { + case 0: ret = output.x; break; + case 1: ret = output.y; break; + case 2: ret = output.z; break; + case 3: ret = output.w; break; + } + STATE = (STATE + 1) % 4; + return ret; + } + +private: + uint4 counter; + uint4 output; + uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + __device__ inline void incr() { + if (++counter.x) + return; + if (++counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, + unsigned int *result_high) { + *result_high = __umulhi(a, b); + return a*b; + } + + __device__ inline uint4 single_round(uint4 ctr, uint2 key) { + unsigned int hi0; + unsigned int hi1; + unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + + uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + return ret; + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +// Inverse of 2^32. +#define M_RAN_INVM32 2.3283064e-10f +__device__ __inline__ float Uint32ToFloat(unsigned int x) { + return x * M_RAN_INVM32; +} + +)"; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index f4c5adc6337..f26ba0b5602 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -5,6 +5,30 @@ using namespace torch::jit; using namespace torch::jit::tensorexpr; +namespace torch { +namespace jit { +namespace tensorexpr { + +static int te_cuda_pointwise_loop_levels = -1; +static int te_cuda_pointwise_block_count = -1; +static int te_cuda_pointwise_block_size = -1; + +int& GetTECudaPointwiseLoopLevels() { + return te_cuda_pointwise_loop_levels; +} + +int& GetTECudaPointwiseBlockCount() { + return te_cuda_pointwise_block_count; +} + +int& GetTECudaPointwiseBlockSize() { + return te_cuda_pointwise_block_size; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + static at::ScalarType tensorType(Tensor* t) { return static_cast(t->body()->dtype().scalar_type()); } @@ -883,12 +907,96 @@ Tensor* TensorExprKernel::ComputeValue(const torch::jit::Value* v) { void TensorExprKernel::LowerToBackend(BackendType backend_type) { std::vector tensor_outputs(tensor_outputs_); + if (backend_type == BackendType::kCudaCodeGen) { + for (size_t tensor_idx = 0; tensor_idx < tensor_outputs_.size(); + tensor_idx++) { + Tensor* tensor = tensor_outputs_[tensor_idx]; + ExprHandle total_count = ExprHandle(tensor->dim(0)); + for (int i = 1; i < tensor->ndim(); i++) { + const IntImm* total_count_i = total_count.AsNode(); + const IntImm* tensor_dim_i = + dynamic_cast(tensor->dim(i)); + if (total_count_i && tensor_dim_i) { + // TODO: switch to real constant folding when it is available. + total_count = + ExprHandle(total_count_i->value() * tensor_dim_i->value()); + } else { + total_count = total_count * ExprHandle(tensor->dim(i)); + } + } + // Flatten the index for GPU kernels. + // TODO: move this to fusing axis when it is ready. + Tensor* new_out = Compute( + tensor->func_var()->name_hint() + "_flat", + {total_count}, + [tensor](const VarHandle& index) -> ExprHandle { + std::vector dims; + ExprHandle value = index; + for (int i = tensor->ndim() - 1; i >= 0; i--) { + ExprHandle idx = value; + if (i > 0) { + idx = Mod::make(value, ExprHandle(tensor->dim(i))); + } + dims.push_back(idx); + value = value / ExprHandle(tensor->dim(i)); + } + std::reverse(dims.begin(), dims.end()); + return tensor->call(dims); + }); + tensor_outputs[tensor_idx] = new_out; + } + } + torch::jit::tensorexpr::schedule::LoopNest l(tensor_outputs); // Compute non-output tensors_ inline for (auto& p : tensors_) { l.ComputeInline(l.getLoopBodyFor(p.second)); } + if (backend_type == kCudaCodeGen) { + for (size_t i = 0; i < tensor_outputs_.size(); i++) { + l.ComputeInline(l.getLoopBodyFor(tensor_outputs_[i])); + + Tensor* tensor = tensor_outputs[i]; + const Var* index = tensor->arg(0); + int loop_levels = GetTECudaPointwiseLoopLevels(); + const int kDefaultLoopLevels = 2; + loop_levels = (loop_levels > 0) ? loop_levels : kDefaultLoopLevels; + int block_count = GetTECudaPointwiseBlockCount(); + int block_size = GetTECudaPointwiseBlockSize(); + + if (loop_levels == 2) { + Stmt* outer; + Stmt* inner; + const int kDefaultBlockSize = 512; + if (block_size < 0) { + block_size = kDefaultBlockSize; + } + std::vector loops = l.getLoopStmtsFor(tensor); + l.SplitWithMask(loops[0], block_size, &outer, &inner); + l.SetGPUBlockIndex(outer, 0); + l.SetGPUThreadIndex(inner, 0); + } else if (loop_levels == 3) { + Stmt* outer; + Stmt* inner; + Stmt* inner_1; + Stmt* inner_2; + // TODO: change the number of microprocessors + const int kDefaultBlockCount = 1280; + const int kDefaultBlockSize = 256; + block_count = (block_count > 0) ? block_count : kDefaultBlockCount; + block_size = (block_size > 0) ? block_size : kDefaultBlockSize; + std::vector loops = l.getLoopStmtsFor(tensor); + l.SplitWithMask(loops[0], block_count * block_size, &outer, &inner); + l.SplitWithMask(inner, block_size, &inner_1, &inner_2); + l.SetGPUBlockIndex(inner_1, 0); + l.SetGPUThreadIndex(inner_2, 0); + } else { + throw std::runtime_error( + "Invalid loop-level: " + std::to_string(loop_levels)); + } + } + } l.ApplyInlines(); Stmt* stmt = l.root_stmt(); @@ -911,6 +1019,9 @@ void TensorExprKernel::LowerToBackend(BackendType backend_type) { // Generate code. std::string codegen_name; switch (backend_type_) { + case kCudaCodeGen: + codegen_name = "cuda_codegen"; + break; case kSimpleIREval: codegen_name = "simple_ir_eval"; break; @@ -933,7 +1044,9 @@ void TensorExprKernel::PickAndCheckBackendType( throw std::runtime_error("No tensor inputs"); }(); BackendType backend_type = BackendType::kUninitialized; - if (device.type() == at::kCPU) { + if (device.type() == at::kCUDA) { + backend_type = kCudaCodeGen; + } else if (device.type() == at::kCPU) { backend_type = kSimpleIREval; } else { throw std::runtime_error("Invalid device type"); @@ -956,6 +1069,7 @@ void TensorExprKernel::CodeGenRun( const std::vector& run_args) { switch (backend_type_) { case kSimpleIREval: + case kCudaCodeGen: codegen_->call(run_args); break; default: diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index bbaf212a4a5..f3dcf6587bb 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -52,6 +52,7 @@ class TensorExprKernel { enum BackendType { kUninitialized, kSimpleIREval, + kCudaCodeGen, }; ExprHandle constant(const torch::jit::Value* v); @@ -205,6 +206,10 @@ class TensorExprKernel { at::Device device_ = at::kCPU; }; +TORCH_API int& GetTECudaPointwiseLoopLevels(); +TORCH_API int& GetTECudaPointwiseBlockCount(); +TORCH_API int& GetTECudaPointwiseBlockSize(); + } // namespace tensorexpr } // namespace jit } // namespace torch