mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert D29952381: [Static Runtime] Ensure that unittests only use out variants or native ops
Test Plan: revert-hammer
Differential Revision:
D29952381 (8737e17af2)
Original commit changeset: e60e70b80ccf
fbshipit-source-id: 59dc2f920b7ceaf94ba8f5f36024e7cc710f6645
This commit is contained in:
parent
491d89da1b
commit
7f1b672b7a
5 changed files with 8 additions and 32 deletions
|
|
@ -343,8 +343,8 @@ TEST(StaticRuntime, IndividualOps_Mul) {
|
|||
std::vector<IValue> scalar_args1{a, 42};
|
||||
std::vector<IValue> scalar_args2{c, 42};
|
||||
|
||||
testStaticRuntime(mul_scalar, scalar_args1, /*args2=*/{}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
|
||||
testStaticRuntime(mul_scalar, scalar_args1, scalar_args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
|
||||
testStaticRuntime(mul_scalar, scalar_args1);
|
||||
testStaticRuntime(mul_scalar, scalar_args1, scalar_args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, IndividualOps_Log) {
|
||||
|
|
@ -455,8 +455,8 @@ TEST(StaticRuntime, IndividualOps_Norm) {
|
|||
auto dtype = at::ScalarType::Float;
|
||||
|
||||
std::vector<IValue> args2{a, 2};
|
||||
testStaticRuntime(norm_2arg, args2, /*args2=*/{}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
|
||||
testStaticRuntime(norm_2arg, args2, {b, 2}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
|
||||
testStaticRuntime(norm_2arg, args2);
|
||||
testStaticRuntime(norm_2arg, args2, {b, 2});
|
||||
|
||||
std::vector<IValue> args3{a, 2, dtype};
|
||||
testStaticRuntime(norm_3arg, args3);
|
||||
|
|
@ -485,8 +485,7 @@ TEST(StaticRuntime, IndividualOps_Reshape) {
|
|||
testStaticRuntime(reshape_script_3, args);
|
||||
testStaticRuntime(reshape_script_4, args);
|
||||
testStaticRuntime(reshape_script_5, args);
|
||||
// tensor.sigmoid_ is delegated to the interpreter.
|
||||
testStaticRuntime(reshape_inplace_script, args, /*args2=*/{}, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
|
||||
testStaticRuntime(reshape_inplace_script, args);
|
||||
testStaticRuntime(reshape_incontiguous_script, args);
|
||||
|
||||
testStaticRuntime(reshape_script_1, args, args1);
|
||||
|
|
@ -494,8 +493,7 @@ TEST(StaticRuntime, IndividualOps_Reshape) {
|
|||
testStaticRuntime(reshape_script_3, args, args1);
|
||||
testStaticRuntime(reshape_script_4, args, args1);
|
||||
testStaticRuntime(reshape_script_5, args, args1);
|
||||
// tensor.sigmoid_ is delegated to the interpreter.
|
||||
testStaticRuntime(reshape_inplace_script, args, args1, /*use_allclose=*/false, /*use_equalnan=*/false, /*expect_fallback=*/true);
|
||||
testStaticRuntime(reshape_inplace_script, args, args1);
|
||||
testStaticRuntime(reshape_incontiguous_script, args, args1);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
#include <torch/csrc/jit/runtime/static/ops.h>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
|
|
@ -121,17 +120,6 @@ void compareTensorLists(
|
|||
}
|
||||
}
|
||||
|
||||
void checkNoInterpreterOp(const StaticModule& smodule) {
|
||||
static const std::unordered_set<std::string> fb_only_ops{"aten::add"};
|
||||
for (const auto& pnode : smodule.nodes()) {
|
||||
auto op_name = pnode.node()->kind().toQualString();
|
||||
if (disableUnsafeMathOp(op_name) || fb_only_ops.count(op_name) > 0) {
|
||||
continue;
|
||||
}
|
||||
EXPECT_TRUE(pnode.has_out_variant() || pnode.has_native_op());
|
||||
}
|
||||
}
|
||||
|
||||
void compareResults(
|
||||
const IValue& expect,
|
||||
const IValue& actual,
|
||||
|
|
@ -191,8 +179,7 @@ void testStaticRuntime(
|
|||
const std::vector<IValue>& args,
|
||||
const std::vector<IValue>& args2,
|
||||
const bool use_allclose,
|
||||
const bool use_equalnan,
|
||||
const bool expect_fallback) {
|
||||
const bool use_equalnan) {
|
||||
auto test_context = makeTestContext(source);
|
||||
|
||||
std::vector<IValue> args_tensors, args_copy;
|
||||
|
|
@ -209,9 +196,6 @@ void testStaticRuntime(
|
|||
for (bool enable_out_variant : {true, false}) {
|
||||
auto smodule = test_context->makeStaticModule(
|
||||
{true, enable_out_variant, enable_out_variant});
|
||||
if (enable_out_variant && !expect_fallback) {
|
||||
checkNoInterpreterOp(smodule);
|
||||
}
|
||||
auto actual = smodule(args, {});
|
||||
if (actual.isTensor()) {
|
||||
EXPECT_GE(smodule.nodes().size(), 2)
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ void testStaticRuntime(
|
|||
const std::vector<c10::IValue>& args,
|
||||
const std::vector<c10::IValue>& args2 = {},
|
||||
const bool use_allclose = false,
|
||||
const bool use_equalnan = false,
|
||||
const bool expect_fallback = false);
|
||||
const bool use_equalnan = false);
|
||||
|
||||
} // namespace test
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -409,10 +409,6 @@ class TORCH_API ProcessedNode {
|
|||
return static_cast<bool>(fn_);
|
||||
}
|
||||
|
||||
bool has_native_op() const {
|
||||
return static_cast<bool>(native_fn_);
|
||||
}
|
||||
|
||||
bool verify_outputs_not_overlapping_with_immutable_inputs() const;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -128,7 +128,6 @@ inline void fastResizeToZero(at::Tensor& t) {
|
|||
|
||||
// check if an op has an out variant registered in Static Runtime
|
||||
bool opIsRegistered(const c10::Symbol& op_name);
|
||||
bool disableUnsafeMathOp(const char* op_name);
|
||||
// check if Static Runtime can run an op natively.
|
||||
// prim ops that are implemented directly in the jit interpreter are implemented
|
||||
// as native ops in Static Runtime
|
||||
|
|
|
|||
Loading…
Reference in a new issue