Revert D29952381: [Static Runtime] Ensure that unittests only use out variants or native ops

Test Plan: revert-hammer

Differential Revision:
D29952381 (8737e17af2)

Original commit changeset: e60e70b80ccf

fbshipit-source-id: 59dc2f920b7ceaf94ba8f5f36024e7cc710f6645
This commit is contained in:
Rong Rong (AI Infra) 2021-08-04 14:19:56 -07:00 committed by Facebook GitHub Bot
parent 491d89da1b
commit 7f1b672b7a
5 changed files with 8 additions and 32 deletions

View file

@ -343,8 +343,8 @@ TEST(StaticRuntime, IndividualOps_Mul) {
std::vector<IValue> scalar_args1{a, 42};
std::vector<IValue> scalar_args2{c, 42};
testStaticRuntime(mul_scalar, scalar_args1, /*args2=*/{}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
testStaticRuntime(mul_scalar, scalar_args1, scalar_args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
testStaticRuntime(mul_scalar, scalar_args1);
testStaticRuntime(mul_scalar, scalar_args1, scalar_args2);
}
TEST(StaticRuntime, IndividualOps_Log) {
@ -455,8 +455,8 @@ TEST(StaticRuntime, IndividualOps_Norm) {
auto dtype = at::ScalarType::Float;
std::vector<IValue> args2{a, 2};
testStaticRuntime(norm_2arg, args2, /*args2=*/{}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
testStaticRuntime(norm_2arg, args2, {b, 2}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
testStaticRuntime(norm_2arg, args2);
testStaticRuntime(norm_2arg, args2, {b, 2});
std::vector<IValue> args3{a, 2, dtype};
testStaticRuntime(norm_3arg, args3);
@ -485,8 +485,7 @@ TEST(StaticRuntime, IndividualOps_Reshape) {
testStaticRuntime(reshape_script_3, args);
testStaticRuntime(reshape_script_4, args);
testStaticRuntime(reshape_script_5, args);
// tensor.sigmoid_ is delegated to the interpreter.
testStaticRuntime(reshape_inplace_script, args, /*args2=*/{}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
testStaticRuntime(reshape_inplace_script, args);
testStaticRuntime(reshape_incontiguous_script, args);
testStaticRuntime(reshape_script_1, args, args1);
@ -494,8 +493,7 @@ TEST(StaticRuntime, IndividualOps_Reshape) {
testStaticRuntime(reshape_script_3, args, args1);
testStaticRuntime(reshape_script_4, args, args1);
testStaticRuntime(reshape_script_5, args, args1);
// tensor.sigmoid_ is delegated to the interpreter.
testStaticRuntime(reshape_inplace_script, args, args1, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
testStaticRuntime(reshape_inplace_script, args, args1);
testStaticRuntime(reshape_incontiguous_script, args, args1);
}

View file

@ -7,7 +7,6 @@
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/runtime/static/ops.h>
#include <memory>
#include <unordered_map>
@ -121,17 +120,6 @@ void compareTensorLists(
}
}
void checkNoInterpreterOp(const StaticModule& smodule) {
static const std::unordered_set<std::string> fb_only_ops{"aten::add"};
for (const auto& pnode : smodule.nodes()) {
auto op_name = pnode.node()->kind().toQualString();
if (disableUnsafeMathOp(op_name) || fb_only_ops.count(op_name) > 0) {
continue;
}
EXPECT_TRUE(pnode.has_out_variant() || pnode.has_native_op());
}
}
void compareResults(
const IValue& expect,
const IValue& actual,
@ -191,8 +179,7 @@ void testStaticRuntime(
const std::vector<IValue>& args,
const std::vector<IValue>& args2,
const bool use_allclose,
const bool use_equalnan,
const bool expect_fallback) {
const bool use_equalnan) {
auto test_context = makeTestContext(source);
std::vector<IValue> args_tensors, args_copy;
@ -209,9 +196,6 @@ void testStaticRuntime(
for (bool enable_out_variant : {true, false}) {
auto smodule = test_context->makeStaticModule(
{true, enable_out_variant, enable_out_variant});
if (enable_out_variant && !expect_fallback) {
checkNoInterpreterOp(smodule);
}
auto actual = smodule(args, {});
if (actual.isTensor()) {
EXPECT_GE(smodule.nodes().size(), 2)

View file

@ -24,8 +24,7 @@ void testStaticRuntime(
const std::vector<c10::IValue>& args,
const std::vector<c10::IValue>& args2 = {},
const bool use_allclose = false,
const bool use_equalnan = false,
const bool expect_fallback = false);
const bool use_equalnan = false);
} // namespace test
} // namespace jit

View file

@ -409,10 +409,6 @@ class TORCH_API ProcessedNode {
return static_cast<bool>(fn_);
}
bool has_native_op() const {
return static_cast<bool>(native_fn_);
}
bool verify_outputs_not_overlapping_with_immutable_inputs() const;
private:

View file

@ -128,7 +128,6 @@ inline void fastResizeToZero(at::Tensor& t) {
// check if an op has an out variant registered in Static Runtime
bool opIsRegistered(const c10::Symbol& op_name);
bool disableUnsafeMathOp(const char* op_name);
// check if Static Runtime can run an op natively.
// prim ops that are implemented directly in the jit interpreter are implemented
// as native ops in Static Runtime