[JIT] Opinfo tests for nnc fusion - retry (#72486)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72486

Retry #70465.

Test Plan: Imported from OSS

Reviewed By: mikaylagawarecki

Differential Revision: D34061628

Pulled By: davidberard98

fbshipit-source-id: e27ed315bc4ad57cdbfbc9cedffcbb7886004524
(cherry picked from commit 7937808d2ebcc758aad4eac3ae6ffe1056d13fc5)
This commit is contained in:
David Berard 2022-02-09 10:55:34 -08:00 committed by PyTorch MergeBot
parent 7035738b50
commit bbd42c605a
7 changed files with 214 additions and 73 deletions

View file

@ -21,12 +21,16 @@ torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \
enable_profiling_mode_for_profiling_tests, TestCase
enable_profiling_mode_for_profiling_tests, slowTest
from torch.testing._internal.jit_utils import JitTestCase, \
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \
OpDTypes
from torch.testing._internal.common_jit import JitCommonTestCase
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
from textwrap import dedent
from itertools import product, permutations, combinations
@ -78,33 +82,14 @@ def inline_fusion_groups():
class TestTEFuser(JitTestCase):
def setUp(self):
self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu()
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_cpu(True)
# TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle
# torch._C._jit_set_te_must_use_llvm_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
self.tensorexpr_options = TensorExprTestOptions()
# note: `self.dynamic_shapes` instatiated in specialization of class
# defined below
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
self.old_profiling_mode = torch._C._jit_set_profiling_mode(True)
fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)]
self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy)
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
torch._C._debug_set_fusion_group_inlining(False)
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
torch._C._jit_set_texpr_fuser_enabled(True)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
self.int_dtypes = [
torch.int8,
@ -122,18 +107,9 @@ class TestTEFuser(JitTestCase):
self.dtypes = self.int_dtypes + self.fp_dtypes
def tearDown(self):
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
torch._C._jit_set_profiling_mode(self.old_profiling_mode)
self.tensorexpr_options.restore()
torch._C._jit_set_fusion_strategy(self.old_fusion_strategy)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
def assertAllFused(self, graph, except_for=None):
except_for = except_for if except_for is not None else set()
# TODO - upstream
@ -2442,7 +2418,13 @@ def get_name(op):
l.append(op.variant_test_name)
return '.'.join(l)
class TestNNCOpInfo(TestCase):
class TestNNCOpInfo(JitCommonTestCase):
def setUp(self):
self.tensorexpr_options = TensorExprTestOptions()
def tearDown(self):
self.tensorexpr_options.restore()
def te_compile(self, device, dtype, op):
if op.name in skip_ops:
return
@ -2516,6 +2498,27 @@ def f({', '.join(param_names)}):
else:
raise RuntimeError("Expected test to fail. If it now works, move op into works_list")
@slowTest
@onlyCPU
@ops(op_db, dtypes=OpDTypes.supported)
def test_nnc_correctness(self, device, dtype, op):
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
for variant, sample in variant_sample_pairs:
trace = create_traced_fn(self, variant)
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
self.assertEqual(ref, val)
# https://github.com/pytorch/pytorch/issues/35600
# each torch.jit.trace adds state to the _python_cu compilation unit
# since this test traces a lot of functions, out-of-memory can occur
# if the CU is not cleared.
torch.jit._state._python_cu.drop_all_functions()
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)

View file

@ -22,7 +22,7 @@ from torch.testing._internal.common_device_type import \
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
check_alias_annotation
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
import torch.testing._internal.opinfo_helper as opinfo_helper
from torch.testing._internal.composite_compliance import _check_composite_compliance
@ -935,11 +935,6 @@ class TestGradients(TestCase):
sample = first_sample(self, samples)
result = op(sample.input, *sample.args, **sample.kwargs)
# types.LambdaType gave false positives
def is_lambda(lamb):
LAMBDA = lambda: 0 # noqa: E731
return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__
# Tests operators for consistency between JIT and eager, also checks
# correctness of JIT specific alias schemas and intended

View file

@ -8,36 +8,16 @@ import unittest
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
class BaseTestClass(JitTestCase):
def setUp(self):
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
self.old_profiling_mode = torch._C._jit_set_profiling_mode(True)
self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
torch._C._jit_set_texpr_fuser_enabled(True)
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
torch._C._debug_set_fusion_group_inlining(False)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
self.tensorexpr_options = TensorExprTestOptions()
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
def tearDown(self):
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
torch._C._jit_set_profiling_mode(self.old_profiling_mode)
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
self.tensorexpr_options.restore()
def assertLastGraphAllFused(self):
self.assertAllFused(torch.jit.last_executed_optimized_graph())

View file

@ -1395,7 +1395,12 @@ void initJitScriptBindings(PyObject* module) {
.def(
"get_class",
[](const std::shared_ptr<CompilationUnit>& self,
const std::string& name) { return self->get_class(name); });
const std::string& name) { return self->get_class(name); })
.def(
"drop_all_functions",
[](const std::shared_ptr<CompilationUnit>& self) {
self->drop_all_functions();
});
py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
.def(

View file

@ -454,10 +454,11 @@ class SimpleIREvaluatorImpl : public IRVisitor {
value_ = iter->second;
}
// disable ubsan because sometimes this performs out-of-bound casts
// e.g. it will cast negative floats to unsigned char
template <typename SrcType, typename DstType>
std::vector<DstType> castValues(
const Dtype& src_dtype,
const InterpValue& v) {
std::vector<DstType> castValues(const Dtype& src_dtype, const InterpValue& v)
__ubsan_ignore_undefined__ {
const std::vector<SrcType>& src_values = v.as_vec<SrcType>();
std::vector<DstType> dst_values(src_values.size());
for (int i = 0; i < src_dtype.lanes(); ++i) {

View file

@ -6302,6 +6302,8 @@ def skips_mvlgamma(skip_redundant=False):
skips = (
# outside domain values are hard error for mvlgamma op.
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_float_domains'),
# TODO: float16 fails due to tensor not satisfying > (p-1)/2
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo'),
)
if skip_redundant:
# Redundant tests
@ -8375,7 +8377,12 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1)),
sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1),
skips=(
# https://github.com/pytorch/pytorch/issues/71784
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
device_type='cpu', dtypes=(torch.float16,)),
)),
OpInfo('addmv',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
@ -8637,6 +8644,7 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_allclose,
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
),
supports_out=False),
OpInfo('broadcast_to',
@ -8897,6 +8905,8 @@ op_db: List[OpInfo] = [
skips=(
# RuntimeError: Tensor must have a last dimension with stride 1
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"),
# RuntimeError: "eq_cpu" not implemented for 'ComplexHalf'
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)),
)),
BinaryUfuncInfo('complex',
dtypes=floating_types(),
@ -9661,6 +9671,7 @@ op_db: List[OpInfo] = [
# RuntimeError:
# Arguments for call are not valid.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
),
supports_inplace_autograd=False,
sample_inputs_func=sample_inputs_gradient),
@ -9898,6 +9909,9 @@ op_db: List[OpInfo] = [
# Fails on XLA.
# AssertionError: False is not true : Tensors failed to compare as equal!
DecorateInfo(unittest.expectedFailure, 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
# https://github.com/pytorch/pytorch/issues/71774
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
device_type='cpu', dtypes=(torch.long,)),
)),
OpInfo('linalg.norm',
op=torch.linalg.norm,
@ -10218,6 +10232,9 @@ op_db: List[OpInfo] = [
# AssertionError: False is not true : Tensors failed to compare as equal!
DecorateInfo(unittest.expectedFailure, 'TestOpInfo',
device_type='xla', dtypes=(torch.long,)),
# https://github.com/pytorch/pytorch/issues/71774
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
device_type='cpu', dtypes=(torch.long,)),
)),
OpInfo('max',
variant_test_name='reduction_with_dim',
@ -11235,6 +11252,8 @@ op_db: List[OpInfo] = [
skips=(
# Pre-existing condition; Needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_composite_compliance'),
# RuntimeError: "max_pool1d_impl" not implemented for 'BFloat16'
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)),
),
sample_inputs_func=sample_inputs_max_pool),
OpInfo('nn.functional.max_pool2d',
@ -11601,8 +11620,10 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_batch_norm,
skips=(
DecorateInfo(unittest.skip("We don't want to differentiate wrt running mean / std"),
"TestCommon", "test_floating_inputs_are_differentiable"),)
),
"TestCommon", "test_floating_inputs_are_differentiable"),
# see https://github.com/pytorch/pytorch/issues/71286
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
)),
# This variant tests batch_norm with cuDNN disabled only on CUDA devices
OpInfo('nn.functional.batch_norm',
variant_test_name='without_cudnn',
@ -12219,7 +12240,11 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}),
'TestCommon', 'test_noncontiguous_samples',
device_type='cuda', active_if=TEST_WITH_ROCM)),
device_type='cuda', active_if=TEST_WITH_ROCM),
# https://github.com/pytorch/pytorch/issues/71774
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
device_type='cpu', dtypes=(torch.long,)),
),
skips=(
# RuntimeError:
# object has no attribute __rmatmul__:
@ -13279,6 +13304,7 @@ op_db: List[OpInfo] = [
skips=(
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
)),
OpInfo('bfloat16',
op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
@ -13291,6 +13317,7 @@ op_db: List[OpInfo] = [
skips=(
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
)),
OpInfo('bool',
op=lambda x, *args, **kwargs: x.bool(*args, **kwargs),
@ -13507,6 +13534,8 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
# Empty tensor data is garbage so it's hard to make comparisons with it.
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
# Empty tensor data is garbage so it's hard to make comparisons with it.
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
# Can't find schemas for this operator for some reason
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
)),
@ -13613,6 +13642,8 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
# Empty tensor data is garbage so it's hard to make comparisons with it.
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
# Empty tensor data is garbage so it's hard to make comparisons with it.
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
# Can't find schemas for this operator for some reason
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
),
@ -13816,7 +13847,9 @@ op_db: List[OpInfo] = [
# RuntimeError: Arguments for call not valid.
# Expected a value of type 'List[Tensor]' for argument
# 'tensors' but instead found type 'Tensor (inferred)'.
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),)),
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
# see https://github.com/pytorch/pytorch/issues/71286
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),)),
OpInfo('vstack',
aliases=('row_stack',),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),

View file

@ -16,7 +16,8 @@ import functools
# Testing utils
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_WINDOWS, \
freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS
freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS, \
is_iterable_of_tensors
from torch.testing._internal.common_jit import JitCommonTestCase
from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
@ -762,3 +763,126 @@ def _get_py3_code(code, fn_name):
loader.exec_module(module)
fn = getattr(module, fn_name)
return fn
class TensorExprTestOptions():
def __init__(self):
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
self.old_profiling_mode = torch._C._jit_set_profiling_mode(True)
self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
torch._C._jit_set_texpr_fuser_enabled(True)
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
torch._C._debug_set_fusion_group_inlining(False)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
def restore(self):
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
torch._C._jit_set_profiling_mode(self.old_profiling_mode)
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
def clone_inputs(args):
inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
for arg in args:
if isinstance(arg, torch.Tensor):
inputs.append(arg.detach().clone())
elif is_iterable_of_tensors(arg):
inputs.append([t.detach().clone() for t in arg])
else:
inputs.append(arg)
return inputs
def get_traced_sample_variant_pairs(device, dtype, op):
# tuples of (variant, sample)
outputs: List[Tuple[Any, Any]] = []
samples = op.sample_inputs(device, dtype)
# Acquires variants to test
func = op.get_op()
method = op.get_method()
variants = {
# TODO: inplace tests currently fail, fix and add inplace variant
'function': func, 'method': method,
}
# TODO: find better way to standardize on op registration itself..
has_fake_function = op.name in ["resize_", 'resize_as_']
if has_fake_function:
variants = {'method': getattr(torch.Tensor, op.name)}
# In eager mode, these ops can take (Tensor, bool) args; but in
# JIT they can only take (Tensor, Scalar), and bool is not a
# scalar in the JIT type system. So to test these in JIT, the bool
# is converted to an int for the test.
ops_with_unsupported_bool_args = [
{
"name": "div_floor_rounding",
"arg_idx": [0],
},
{
"name": "div_no_rounding_mode",
"arg_idx": [0],
},
{
"name": "div_trunc_rounding",
"arg_idx": [0],
},
{
"name": "index_fill",
"arg_idx": [2],
},
{
"name": "full_like",
"arg_idx": [0],
},
{
"name": "mul",
"arg_idx": [0],
},
{
"name": "new_full",
"arg_idx": [1],
},
]
# doesn't support tracing
if has_fake_function:
return outputs
for sample in samples:
for func_type, variant in variants.items():
if variant is None:
continue
if is_lambda(variant):
continue
matching_ops = filter(lambda x: op.formatted_name == x["name"], ops_with_unsupported_bool_args)
for op_data in matching_ops:
for idx in op_data["arg_idx"]:
args = list(sample.args)
if len(sample.args) > idx and isinstance(sample.args[idx], bool):
args[idx] = int(args[idx])
sample.args = tuple(args)
outputs.append((variant, sample))
return outputs
# types.LambdaType gave false positives
def is_lambda(lamb):
LAMBDA = lambda: 0 # noqa: E731
return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__