diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 8a1d96c57c5..697d1017490 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -1257,5 +1257,36 @@ TEST_F(Kernel, ConstantTensorsNonContiguous) { ASSERT_TRUE(at::allclose(o, ref)); } +TEST_F(Kernel, RunFast) { +#ifdef TORCH_ENABLE_LLVM + // TODO: Implement call_raw in IREval and remove the ifdef + KernelScope kernel_scope; + + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[1, 5], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + + k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + for (size_t i = 0; i < 5 * 3; i++) { + CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 4b2bcca0e95..3f50b4df8fa 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -3185,3 +3185,21 @@ void TensorExprKernel::runKernel(Stack& stack) { push_one(stack, std::move(o)); } } + +void TensorExprKernel::runFast( + const std::vector& inputs, + const std::vector& outputs) { + KernelScope kernelScope(&kernelArena_); + + std::vector args(inputs); + args.reserve(inputs.size() + outputs.size() + constants_.size()); + args.insert(args.end(), outputs.begin(), outputs.end()); + + // TODO: we can consider preallocating and pre-filling the args vector. + for (auto c : constants_) { + args.push_back(c.ptr); + } + + // Call the kernel. + codegen_->call_raw(args); +} diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index adad3f6a177..46a074d01c2 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -70,6 +70,9 @@ class TORCH_API TensorExprKernel { explicit TensorExprKernel(const std::shared_ptr& subgraph); void run(Stack& stack); + void runFast( + const std::vector& inputs, + const std::vector& outputs); void fallback(Stack& stack) { InterpreterState(code_).run(stack);