diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 345bc022194..e141134a632 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -343,8 +343,8 @@ TEST(StaticRuntime, IndividualOps_Mul) { std::vector scalar_args1{a, 42}; std::vector scalar_args2{c, 42}; - testStaticRuntime(mul_scalar, scalar_args1, /*args2=*/{}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true); - testStaticRuntime(mul_scalar, scalar_args1, scalar_args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true); + testStaticRuntime(mul_scalar, scalar_args1); + testStaticRuntime(mul_scalar, scalar_args1, scalar_args2); } TEST(StaticRuntime, IndividualOps_Log) { @@ -455,8 +455,8 @@ TEST(StaticRuntime, IndividualOps_Norm) { auto dtype = at::ScalarType::Float; std::vector args2{a, 2}; - testStaticRuntime(norm_2arg, args2, /*args2=*/{}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true); - testStaticRuntime(norm_2arg, args2, {b, 2}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true); + testStaticRuntime(norm_2arg, args2); + testStaticRuntime(norm_2arg, args2, {b, 2}); std::vector args3{a, 2, dtype}; testStaticRuntime(norm_3arg, args3); @@ -485,8 +485,7 @@ TEST(StaticRuntime, IndividualOps_Reshape) { testStaticRuntime(reshape_script_3, args); testStaticRuntime(reshape_script_4, args); testStaticRuntime(reshape_script_5, args); - // tensor.sigmoid_ is delegated to the interpreter. - testStaticRuntime(reshape_inplace_script, args, /*args2=*/{}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true); + testStaticRuntime(reshape_inplace_script, args); testStaticRuntime(reshape_incontiguous_script, args); testStaticRuntime(reshape_script_1, args, args1); @@ -494,8 +493,7 @@ TEST(StaticRuntime, IndividualOps_Reshape) { testStaticRuntime(reshape_script_3, args, args1); testStaticRuntime(reshape_script_4, args, args1); testStaticRuntime(reshape_script_5, args, args1); - // tensor.sigmoid_ is delegated to the interpreter. - testStaticRuntime(reshape_inplace_script, args, args1, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true); + testStaticRuntime(reshape_inplace_script, args, args1); testStaticRuntime(reshape_incontiguous_script, args, args1); } diff --git a/benchmarks/static_runtime/test_utils.cc b/benchmarks/static_runtime/test_utils.cc index 574698ded22..68f2af271da 100644 --- a/benchmarks/static_runtime/test_utils.cc +++ b/benchmarks/static_runtime/test_utils.cc @@ -7,7 +7,6 @@ #include #include #include -#include #include #include @@ -121,17 +120,6 @@ void compareTensorLists( } } -void checkNoInterpreterOp(const StaticModule& smodule) { - static const std::unordered_set fb_only_ops{"aten::add"}; - for (const auto& pnode : smodule.nodes()) { - auto op_name = pnode.node()->kind().toQualString(); - if (disableUnsafeMathOp(op_name) || fb_only_ops.count(op_name) > 0) { - continue; - } - EXPECT_TRUE(pnode.has_out_variant() || pnode.has_native_op()); - } -} - void compareResults( const IValue& expect, const IValue& actual, @@ -191,8 +179,7 @@ void testStaticRuntime( const std::vector& args, const std::vector& args2, const bool use_allclose, - const bool use_equalnan, - const bool expect_fallback) { + const bool use_equalnan) { auto test_context = makeTestContext(source); std::vector args_tensors, args_copy; @@ -209,9 +196,6 @@ void testStaticRuntime( for (bool enable_out_variant : {true, false}) { auto smodule = test_context->makeStaticModule( {true, enable_out_variant, enable_out_variant}); - if (enable_out_variant && !expect_fallback) { - checkNoInterpreterOp(smodule); - } auto actual = smodule(args, {}); if (actual.isTensor()) { EXPECT_GE(smodule.nodes().size(), 2) diff --git a/benchmarks/static_runtime/test_utils.h b/benchmarks/static_runtime/test_utils.h index e3808ac298f..8c616f255f7 100644 --- a/benchmarks/static_runtime/test_utils.h +++ b/benchmarks/static_runtime/test_utils.h @@ -24,8 +24,7 @@ void testStaticRuntime( const std::vector& args, const std::vector& args2 = {}, const bool use_allclose = false, - const bool use_equalnan = false, - const bool expect_fallback = false); + const bool use_equalnan = false); } // namespace test } // namespace jit diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index e73be076de5..bf28dfc70b6 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -409,10 +409,6 @@ class TORCH_API ProcessedNode { return static_cast(fn_); } - bool has_native_op() const { - return static_cast(native_fn_); - } - bool verify_outputs_not_overlapping_with_immutable_inputs() const; private: diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h index 349e3c9d044..ff5d69e1cb8 100644 --- a/torch/csrc/jit/runtime/static/ops.h +++ b/torch/csrc/jit/runtime/static/ops.h @@ -128,7 +128,6 @@ inline void fastResizeToZero(at::Tensor& t) { // check if an op has an out variant registered in Static Runtime bool opIsRegistered(const c10::Symbol& op_name); -bool disableUnsafeMathOp(const char* op_name); // check if Static Runtime can run an op natively. // prim ops that are implemented directly in the jit interpreter are implemented // as native ops in Static Runtime