From 993a35a8cbbd2d2f79c87b48e1b3df7ef8ceb419 Mon Sep 17 00:00:00 2001 From: Hao Lu Date: Thu, 13 May 2021 17:46:00 -0700 Subject: [PATCH] [Static Runtime] Support clamp.Tensor (#58191) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58191 There are two clamp overloads: clamp.Scalar and clamp.Tensor. SR needs to support both or has checks in place to avoid runtime errors. Supporting both is not too hard so here we are. Reviewed By: edvgha Differential Revision: D28371949 fbshipit-source-id: 0ec6b8a0b8c6277e50d8e51e4e7a45aa62211e22 --- benchmarks/static_runtime/test_scripts.h | 12 ++++++++++++ benchmarks/static_runtime/test_static_runtime.cc | 9 +++++++++ torch/csrc/jit/runtime/static/ops.cpp | 16 ++++++++++++---- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 9dfe5ec95ad..72eb3c09a5f 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -337,3 +337,15 @@ const std::string repeat = R"JIT( def forward(self, a: Tensor, repeats: List[int]): return torch.repeat(a, repeats) )JIT"; + +const auto clamp_script_1 = R"JIT( + def forward(self, inp: Tensor, min: int, max: int): + a = torch.clamp(inp, min, max) + return (a) +)JIT"; + +const auto clamp_script_2 = R"JIT( + def forward(self, inp: Tensor, min: Tensor, max: Tensor): + a = torch.clamp(inp, min, max) + return (a) +)JIT"; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 5e93d58838c..48e64839d81 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -141,6 +141,15 @@ TEST(StaticRuntime, Clone) { testStaticRuntime(clone_script_1, args_1); } +TEST(StaticRuntime, Clamp) { + auto a = at::randn({2, 3}); + auto max_t = at::full_like(a, 1); + auto min_t = at::full_like(a, -1); + + testStaticRuntime(clamp_script_1, {a, -1, 1}); + testStaticRuntime(clamp_script_2, {a, min_t, max_t}); +} + TEST(StaticRuntime, Logit) { auto a = at::ones({2, 3}); double b = 1e-6; diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 88c50d6da61..c6fa7d35089 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -307,7 +307,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator { }; }); -// TODO: support +// clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor // clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator { @@ -316,14 +316,22 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator { } return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); - const auto in1_s = p_node->Input(1).toOptional(); - const auto in2_s = p_node->Input(2).toOptional(); + if (p_node->Output(0).isNone()) { p_node->Output(0) = create_empty_from(in0_t); } auto& out_t = p_node->Output(0).toTensor(); fastResizeToZero(out_t); - at::native::clamp_out(in0_t, in1_s, in2_s, out_t); + + if (p_node->Input(1).isTensor()) { + auto in1_t = p_node->Input(1).toOptional(); + auto in2_t = p_node->Input(2).toOptional(); + at::native::clamp_out(in0_t, in1_t, in2_t, out_t); + } else { + auto in1_s = p_node->Input(1).toOptional(); + auto in2_s = p_node->Input(2).toOptional(); + at::native::clamp_out(in0_t, in1_s, in2_s, out_t); + } }; });