mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[JIT] Opinfo tests for nnc fusion - retry (#72486)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72486 Retry #70465. Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D34061628 Pulled By: davidberard98 fbshipit-source-id: e27ed315bc4ad57cdbfbc9cedffcbb7886004524 (cherry picked from commit 7937808d2ebcc758aad4eac3ae6ffe1056d13fc5)
This commit is contained in:
parent
7035738b50
commit
bbd42c605a
7 changed files with 214 additions and 73 deletions
|
|
@ -21,12 +21,16 @@ torch._C._jit_set_profiling_executor(True)
|
|||
torch._C._jit_set_profiling_mode(True)
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \
|
||||
enable_profiling_mode_for_profiling_tests, TestCase
|
||||
enable_profiling_mode_for_profiling_tests, slowTest
|
||||
from torch.testing._internal.jit_utils import JitTestCase, \
|
||||
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining
|
||||
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
|
||||
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions
|
||||
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
|
||||
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \
|
||||
OpDTypes
|
||||
from torch.testing._internal.common_jit import JitCommonTestCase
|
||||
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
|
||||
|
||||
from textwrap import dedent
|
||||
from itertools import product, permutations, combinations
|
||||
|
|
@ -78,33 +82,14 @@ def inline_fusion_groups():
|
|||
|
||||
class TestTEFuser(JitTestCase):
|
||||
def setUp(self):
|
||||
self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
|
||||
self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu()
|
||||
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
|
||||
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
# TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle
|
||||
# torch._C._jit_set_te_must_use_llvm_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
self.tensorexpr_options = TensorExprTestOptions()
|
||||
|
||||
# note: `self.dynamic_shapes` instatiated in specialization of class
|
||||
# defined below
|
||||
|
||||
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
||||
self.old_profiling_mode = torch._C._jit_set_profiling_mode(True)
|
||||
|
||||
fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)]
|
||||
self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy)
|
||||
|
||||
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
|
||||
torch._C._debug_set_fusion_group_inlining(False)
|
||||
|
||||
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
|
||||
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(False)
|
||||
|
||||
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
||||
self.int_dtypes = [
|
||||
torch.int8,
|
||||
|
|
@ -122,18 +107,9 @@ class TestTEFuser(JitTestCase):
|
|||
self.dtypes = self.int_dtypes + self.fp_dtypes
|
||||
|
||||
def tearDown(self):
|
||||
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
|
||||
torch._C._jit_set_profiling_mode(self.old_profiling_mode)
|
||||
self.tensorexpr_options.restore()
|
||||
torch._C._jit_set_fusion_strategy(self.old_fusion_strategy)
|
||||
|
||||
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
|
||||
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state)
|
||||
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
|
||||
|
||||
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
|
||||
|
||||
def assertAllFused(self, graph, except_for=None):
|
||||
except_for = except_for if except_for is not None else set()
|
||||
# TODO - upstream
|
||||
|
|
@ -2442,7 +2418,13 @@ def get_name(op):
|
|||
l.append(op.variant_test_name)
|
||||
return '.'.join(l)
|
||||
|
||||
class TestNNCOpInfo(TestCase):
|
||||
class TestNNCOpInfo(JitCommonTestCase):
|
||||
def setUp(self):
|
||||
self.tensorexpr_options = TensorExprTestOptions()
|
||||
|
||||
def tearDown(self):
|
||||
self.tensorexpr_options.restore()
|
||||
|
||||
def te_compile(self, device, dtype, op):
|
||||
if op.name in skip_ops:
|
||||
return
|
||||
|
|
@ -2516,6 +2498,27 @@ def f({', '.join(param_names)}):
|
|||
else:
|
||||
raise RuntimeError("Expected test to fail. If it now works, move op into works_list")
|
||||
|
||||
@slowTest
|
||||
@onlyCPU
|
||||
@ops(op_db, dtypes=OpDTypes.supported)
|
||||
def test_nnc_correctness(self, device, dtype, op):
|
||||
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
|
||||
|
||||
for variant, sample in variant_sample_pairs:
|
||||
trace = create_traced_fn(self, variant)
|
||||
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
||||
|
||||
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
||||
|
||||
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
||||
|
||||
self.assertEqual(ref, val)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/35600
|
||||
# each torch.jit.trace adds state to the _python_cu compilation unit
|
||||
# since this test traces a lot of functions, out-of-memory can occur
|
||||
# if the CU is not cleared.
|
||||
torch.jit._state._python_cu.drop_all_functions()
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from torch.testing._internal.common_device_type import \
|
|||
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
|
||||
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
|
||||
check_alias_annotation
|
||||
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
|
||||
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
|
||||
import torch.testing._internal.opinfo_helper as opinfo_helper
|
||||
from torch.testing._internal.composite_compliance import _check_composite_compliance
|
||||
|
||||
|
|
@ -935,11 +935,6 @@ class TestGradients(TestCase):
|
|||
sample = first_sample(self, samples)
|
||||
result = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
# types.LambdaType gave false positives
|
||||
def is_lambda(lamb):
|
||||
LAMBDA = lambda: 0 # noqa: E731
|
||||
return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__
|
||||
|
||||
|
||||
# Tests operators for consistency between JIT and eager, also checks
|
||||
# correctness of JIT specific alias schemas and intended
|
||||
|
|
|
|||
|
|
@ -8,36 +8,16 @@ import unittest
|
|||
|
||||
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
|
||||
|
||||
|
||||
class BaseTestClass(JitTestCase):
|
||||
def setUp(self):
|
||||
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
||||
self.old_profiling_mode = torch._C._jit_set_profiling_mode(True)
|
||||
|
||||
self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
|
||||
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
|
||||
torch._C._debug_set_fusion_group_inlining(False)
|
||||
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(False)
|
||||
|
||||
self.tensorexpr_options = TensorExprTestOptions()
|
||||
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
||||
|
||||
def tearDown(self):
|
||||
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
|
||||
torch._C._jit_set_profiling_mode(self.old_profiling_mode)
|
||||
|
||||
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
|
||||
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
|
||||
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
|
||||
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
|
||||
self.tensorexpr_options.restore()
|
||||
|
||||
def assertLastGraphAllFused(self):
|
||||
self.assertAllFused(torch.jit.last_executed_optimized_graph())
|
||||
|
|
|
|||
|
|
@ -1395,7 +1395,12 @@ void initJitScriptBindings(PyObject* module) {
|
|||
.def(
|
||||
"get_class",
|
||||
[](const std::shared_ptr<CompilationUnit>& self,
|
||||
const std::string& name) { return self->get_class(name); });
|
||||
const std::string& name) { return self->get_class(name); })
|
||||
.def(
|
||||
"drop_all_functions",
|
||||
[](const std::shared_ptr<CompilationUnit>& self) {
|
||||
self->drop_all_functions();
|
||||
});
|
||||
|
||||
py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
|
||||
.def(
|
||||
|
|
|
|||
|
|
@ -454,10 +454,11 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
|||
value_ = iter->second;
|
||||
}
|
||||
|
||||
// disable ubsan because sometimes this performs out-of-bound casts
|
||||
// e.g. it will cast negative floats to unsigned char
|
||||
template <typename SrcType, typename DstType>
|
||||
std::vector<DstType> castValues(
|
||||
const Dtype& src_dtype,
|
||||
const InterpValue& v) {
|
||||
std::vector<DstType> castValues(const Dtype& src_dtype, const InterpValue& v)
|
||||
__ubsan_ignore_undefined__ {
|
||||
const std::vector<SrcType>& src_values = v.as_vec<SrcType>();
|
||||
std::vector<DstType> dst_values(src_values.size());
|
||||
for (int i = 0; i < src_dtype.lanes(); ++i) {
|
||||
|
|
|
|||
|
|
@ -6302,6 +6302,8 @@ def skips_mvlgamma(skip_redundant=False):
|
|||
skips = (
|
||||
# outside domain values are hard error for mvlgamma op.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_float_domains'),
|
||||
# TODO: float16 fails due to tensor not satisfying > (p-1)/2
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo'),
|
||||
)
|
||||
if skip_redundant:
|
||||
# Redundant tests
|
||||
|
|
@ -8375,7 +8377,12 @@ op_db: List[OpInfo] = [
|
|||
supports_fwgrad_bwgrad=True,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
|
||||
sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1)),
|
||||
sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1),
|
||||
skips=(
|
||||
# https://github.com/pytorch/pytorch/issues/71784
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||
device_type='cpu', dtypes=(torch.float16,)),
|
||||
)),
|
||||
OpInfo('addmv',
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
|
||||
|
|
@ -8637,6 +8644,7 @@ op_db: List[OpInfo] = [
|
|||
sample_inputs_func=sample_inputs_allclose,
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
|
||||
),
|
||||
supports_out=False),
|
||||
OpInfo('broadcast_to',
|
||||
|
|
@ -8897,6 +8905,8 @@ op_db: List[OpInfo] = [
|
|||
skips=(
|
||||
# RuntimeError: Tensor must have a last dimension with stride 1
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"),
|
||||
# RuntimeError: "eq_cpu" not implemented for 'ComplexHalf'
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)),
|
||||
)),
|
||||
BinaryUfuncInfo('complex',
|
||||
dtypes=floating_types(),
|
||||
|
|
@ -9661,6 +9671,7 @@ op_db: List[OpInfo] = [
|
|||
# RuntimeError:
|
||||
# Arguments for call are not valid.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
|
||||
),
|
||||
supports_inplace_autograd=False,
|
||||
sample_inputs_func=sample_inputs_gradient),
|
||||
|
|
@ -9898,6 +9909,9 @@ op_db: List[OpInfo] = [
|
|||
# Fails on XLA.
|
||||
# AssertionError: False is not true : Tensors failed to compare as equal!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
|
||||
# https://github.com/pytorch/pytorch/issues/71774
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||
device_type='cpu', dtypes=(torch.long,)),
|
||||
)),
|
||||
OpInfo('linalg.norm',
|
||||
op=torch.linalg.norm,
|
||||
|
|
@ -10218,6 +10232,9 @@ op_db: List[OpInfo] = [
|
|||
# AssertionError: False is not true : Tensors failed to compare as equal!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestOpInfo',
|
||||
device_type='xla', dtypes=(torch.long,)),
|
||||
# https://github.com/pytorch/pytorch/issues/71774
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||
device_type='cpu', dtypes=(torch.long,)),
|
||||
)),
|
||||
OpInfo('max',
|
||||
variant_test_name='reduction_with_dim',
|
||||
|
|
@ -11235,6 +11252,8 @@ op_db: List[OpInfo] = [
|
|||
skips=(
|
||||
# Pre-existing condition; Needs to be fixed
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_composite_compliance'),
|
||||
# RuntimeError: "max_pool1d_impl" not implemented for 'BFloat16'
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_max_pool),
|
||||
OpInfo('nn.functional.max_pool2d',
|
||||
|
|
@ -11601,8 +11620,10 @@ op_db: List[OpInfo] = [
|
|||
sample_inputs_func=sample_inputs_batch_norm,
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip("We don't want to differentiate wrt running mean / std"),
|
||||
"TestCommon", "test_floating_inputs_are_differentiable"),)
|
||||
),
|
||||
"TestCommon", "test_floating_inputs_are_differentiable"),
|
||||
# see https://github.com/pytorch/pytorch/issues/71286
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
|
||||
)),
|
||||
# This variant tests batch_norm with cuDNN disabled only on CUDA devices
|
||||
OpInfo('nn.functional.batch_norm',
|
||||
variant_test_name='without_cudnn',
|
||||
|
|
@ -12219,7 +12240,11 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.expectedFailure, 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
|
||||
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}),
|
||||
'TestCommon', 'test_noncontiguous_samples',
|
||||
device_type='cuda', active_if=TEST_WITH_ROCM)),
|
||||
device_type='cuda', active_if=TEST_WITH_ROCM),
|
||||
# https://github.com/pytorch/pytorch/issues/71774
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||
device_type='cpu', dtypes=(torch.long,)),
|
||||
),
|
||||
skips=(
|
||||
# RuntimeError:
|
||||
# object has no attribute __rmatmul__:
|
||||
|
|
@ -13279,6 +13304,7 @@ op_db: List[OpInfo] = [
|
|||
skips=(
|
||||
# RuntimeError: attribute lookup is not defined on builtin
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
|
||||
)),
|
||||
OpInfo('bfloat16',
|
||||
op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
|
||||
|
|
@ -13291,6 +13317,7 @@ op_db: List[OpInfo] = [
|
|||
skips=(
|
||||
# RuntimeError: attribute lookup is not defined on builtin
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
|
||||
)),
|
||||
OpInfo('bool',
|
||||
op=lambda x, *args, **kwargs: x.bool(*args, **kwargs),
|
||||
|
|
@ -13507,6 +13534,8 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
|
||||
# Empty tensor data is garbage so it's hard to make comparisons with it.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
|
||||
# Empty tensor data is garbage so it's hard to make comparisons with it.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
|
||||
# Can't find schemas for this operator for some reason
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
|
||||
)),
|
||||
|
|
@ -13613,6 +13642,8 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
|
||||
# Empty tensor data is garbage so it's hard to make comparisons with it.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
|
||||
# Empty tensor data is garbage so it's hard to make comparisons with it.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
|
||||
# Can't find schemas for this operator for some reason
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
|
||||
),
|
||||
|
|
@ -13816,7 +13847,9 @@ op_db: List[OpInfo] = [
|
|||
# RuntimeError: Arguments for call not valid.
|
||||
# Expected a value of type 'List[Tensor]' for argument
|
||||
# 'tensors' but instead found type 'Tensor (inferred)'.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),)),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
|
||||
# see https://github.com/pytorch/pytorch/issues/71286
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),)),
|
||||
OpInfo('vstack',
|
||||
aliases=('row_stack',),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ import functools
|
|||
# Testing utils
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, \
|
||||
freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS
|
||||
freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS, \
|
||||
is_iterable_of_tensors
|
||||
from torch.testing._internal.common_jit import JitCommonTestCase
|
||||
from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
|
||||
|
||||
|
|
@ -762,3 +763,126 @@ def _get_py3_code(code, fn_name):
|
|||
loader.exec_module(module)
|
||||
fn = getattr(module, fn_name)
|
||||
return fn
|
||||
|
||||
class TensorExprTestOptions():
|
||||
def __init__(self):
|
||||
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
||||
self.old_profiling_mode = torch._C._jit_set_profiling_mode(True)
|
||||
|
||||
self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
|
||||
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
|
||||
torch._C._debug_set_fusion_group_inlining(False)
|
||||
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(False)
|
||||
|
||||
def restore(self):
|
||||
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
|
||||
torch._C._jit_set_profiling_mode(self.old_profiling_mode)
|
||||
|
||||
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
|
||||
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
|
||||
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
|
||||
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
|
||||
|
||||
def clone_inputs(args):
|
||||
inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
inputs.append(arg.detach().clone())
|
||||
elif is_iterable_of_tensors(arg):
|
||||
inputs.append([t.detach().clone() for t in arg])
|
||||
else:
|
||||
inputs.append(arg)
|
||||
|
||||
return inputs
|
||||
|
||||
def get_traced_sample_variant_pairs(device, dtype, op):
|
||||
# tuples of (variant, sample)
|
||||
outputs: List[Tuple[Any, Any]] = []
|
||||
|
||||
samples = op.sample_inputs(device, dtype)
|
||||
|
||||
# Acquires variants to test
|
||||
func = op.get_op()
|
||||
method = op.get_method()
|
||||
variants = {
|
||||
# TODO: inplace tests currently fail, fix and add inplace variant
|
||||
'function': func, 'method': method,
|
||||
}
|
||||
|
||||
# TODO: find better way to standardize on op registration itself..
|
||||
has_fake_function = op.name in ["resize_", 'resize_as_']
|
||||
|
||||
if has_fake_function:
|
||||
variants = {'method': getattr(torch.Tensor, op.name)}
|
||||
|
||||
# In eager mode, these ops can take (Tensor, bool) args; but in
|
||||
# JIT they can only take (Tensor, Scalar), and bool is not a
|
||||
# scalar in the JIT type system. So to test these in JIT, the bool
|
||||
# is converted to an int for the test.
|
||||
ops_with_unsupported_bool_args = [
|
||||
{
|
||||
"name": "div_floor_rounding",
|
||||
"arg_idx": [0],
|
||||
},
|
||||
{
|
||||
"name": "div_no_rounding_mode",
|
||||
"arg_idx": [0],
|
||||
},
|
||||
{
|
||||
"name": "div_trunc_rounding",
|
||||
"arg_idx": [0],
|
||||
},
|
||||
{
|
||||
"name": "index_fill",
|
||||
"arg_idx": [2],
|
||||
},
|
||||
{
|
||||
"name": "full_like",
|
||||
"arg_idx": [0],
|
||||
},
|
||||
{
|
||||
"name": "mul",
|
||||
"arg_idx": [0],
|
||||
},
|
||||
{
|
||||
"name": "new_full",
|
||||
"arg_idx": [1],
|
||||
},
|
||||
]
|
||||
|
||||
# doesn't support tracing
|
||||
if has_fake_function:
|
||||
return outputs
|
||||
|
||||
for sample in samples:
|
||||
for func_type, variant in variants.items():
|
||||
if variant is None:
|
||||
continue
|
||||
|
||||
if is_lambda(variant):
|
||||
continue
|
||||
|
||||
matching_ops = filter(lambda x: op.formatted_name == x["name"], ops_with_unsupported_bool_args)
|
||||
for op_data in matching_ops:
|
||||
for idx in op_data["arg_idx"]:
|
||||
args = list(sample.args)
|
||||
if len(sample.args) > idx and isinstance(sample.args[idx], bool):
|
||||
args[idx] = int(args[idx])
|
||||
sample.args = tuple(args)
|
||||
|
||||
outputs.append((variant, sample))
|
||||
|
||||
return outputs
|
||||
|
||||
# types.LambdaType gave false positives
|
||||
def is_lambda(lamb):
|
||||
LAMBDA = lambda: 0 # noqa: E731
|
||||
return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__
|
||||
|
|
|
|||
Loading…
Reference in a new issue