mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Move test ops gradients and test ops jit to separate files
Fixes #72368 As per reference issue, the test_ops in single file takes around 3:30-4:00Hrs to execute on asan jobs: Reference : pytorch_test_times.json ``` { "commit": "39535fec6c3ff5bf7c2d322d096c59571c3295ed", "JOB_BASE_NAME": "linux-xenial-py3.7-clang7-asan", "job_times": { "test_ops": 14928.355000000636, <- This test group is over 4hrs alone ``` ---- Hence separating test_ops into following parts: 1. TestGradients 2. TestJit 3. TestCommon and TestMathBits Pull Request resolved: https://github.com/pytorch/pytorch/pull/74297 Approved by: https://github.com/malfet
This commit is contained in:
parent
577bf04872
commit
ebca80ed08
5 changed files with 519 additions and 474 deletions
|
|
@ -259,6 +259,8 @@ CORE_TEST_LIST = [
|
|||
"test_modules",
|
||||
"test_nn",
|
||||
"test_ops",
|
||||
"test_ops_gradients",
|
||||
"test_ops_jit",
|
||||
"test_torch"
|
||||
]
|
||||
|
||||
|
|
|
|||
475
test/test_ops.py
475
test/test_ops.py
|
|
@ -1,28 +1,25 @@
|
|||
# Owner(s): ["high priority"]
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import partial, wraps
|
||||
from functools import partial
|
||||
import warnings
|
||||
import unittest
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing import FileCheck, make_tensor
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_dtype import floating_and_complex_types_and, get_all_dtypes
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper,
|
||||
gradcheck, gradgradcheck, IS_IN_CI, suppress_warnings, noncontiguous_like,
|
||||
IS_IN_CI, suppress_warnings, noncontiguous_like,
|
||||
TEST_WITH_ASAN, IS_WINDOWS, IS_FBCODE, first_sample)
|
||||
from torch.testing._internal.common_methods_invocations import \
|
||||
(op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(deviceCountAtLeast, instantiate_device_type_tests, ops, onlyCPU,
|
||||
onlyCUDA, onlyNativeDeviceTypes, OpDTypes, skipMeta)
|
||||
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
|
||||
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
|
||||
check_alias_annotation
|
||||
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
|
||||
|
||||
|
||||
import torch.testing._internal.opinfo_helper as opinfo_helper
|
||||
from torch.testing._internal.composite_compliance import _check_composite_compliance
|
||||
|
||||
|
|
@ -722,466 +719,6 @@ class TestCommon(TestCase):
|
|||
for arg in sample.kwargs.values():
|
||||
check_tensor_floating_is_differentiable(arg)
|
||||
|
||||
|
||||
# gradcheck requires double precision
|
||||
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=[torch.double, torch.cdouble])
|
||||
|
||||
|
||||
class TestGradients(TestCase):
|
||||
exact_dtype = True
|
||||
|
||||
# Copies inputs to inplace operations to avoid inplace modifications
|
||||
# to leaves requiring gradient
|
||||
def _get_safe_inplace(self, inplace_variant):
|
||||
@wraps(inplace_variant)
|
||||
def _fn(t, *args, **kwargs):
|
||||
return inplace_variant(t.clone(), *args, **kwargs)
|
||||
|
||||
return _fn
|
||||
|
||||
def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
|
||||
check_batched_grad=None, check_batched_forward_grad=False):
|
||||
assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
|
||||
# NB: check_backward_ad does not affect gradgradcheck (always True)
|
||||
if variant is None:
|
||||
self.skipTest("Skipped! Variant not implemented.")
|
||||
if not op.supports_dtype(dtype, torch.device(device).type):
|
||||
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
|
||||
|
||||
def is_inplace(variant):
|
||||
if hasattr(variant, "__wrapped__"):
|
||||
return variant.__wrapped__ is op.get_inplace()
|
||||
return variant is op.get_inplace()
|
||||
|
||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs)
|
||||
|
||||
for sample in samples:
|
||||
if sample.broadcasts_input and is_inplace(variant):
|
||||
continue
|
||||
|
||||
# Note on TensorList inputs
|
||||
#
|
||||
# gradcheck does not support TensorList inputs so here we pass TensorList
|
||||
# inputs of size n as n single Tensor inputs to gradcheck and wrap the op
|
||||
# in a function that puts the n Tensor inputs back into a TensorList
|
||||
def fn(*inputs):
|
||||
# Put tensors back into TensorList since we splat them when passing to gradcheck
|
||||
if is_iterable_of_tensors(sample.input):
|
||||
n = len(sample.input)
|
||||
inputs = (inputs[:n], *inputs[n:])
|
||||
output = op.gradcheck_wrapper(variant, *inputs, **sample.kwargs)
|
||||
if sample.output_process_fn_grad is not None:
|
||||
return sample.output_process_fn_grad(output)
|
||||
return output
|
||||
|
||||
# Splat TensorList inputs into single Tensor inputs
|
||||
gradcheck_args = (sample.input,) if isinstance(sample.input, torch.Tensor) else tuple(sample.input)
|
||||
gradcheck_args += sample.args
|
||||
|
||||
if check == 'gradcheck':
|
||||
if check_batched_grad is None:
|
||||
check_batched_grad = op.check_batched_grad
|
||||
self.assertTrue(gradcheck(fn, gradcheck_args,
|
||||
check_batched_grad=check_batched_grad,
|
||||
check_grad_dtypes=True,
|
||||
nondet_tol=op.gradcheck_nondet_tol,
|
||||
fast_mode=op.gradcheck_fast_mode,
|
||||
check_forward_ad=check_forward_ad,
|
||||
check_backward_ad=check_backward_ad,
|
||||
check_undefined_grad=True,
|
||||
check_batched_forward_grad=check_batched_forward_grad))
|
||||
elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check
|
||||
self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
|
||||
for gen_non_contig_grad_outputs in (False, True):
|
||||
kwargs = {
|
||||
"gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
|
||||
"check_batched_grad": op.check_batched_gradgrad,
|
||||
"check_grad_dtypes": True,
|
||||
"nondet_tol": op.gradcheck_nondet_tol,
|
||||
"fast_mode": op.gradcheck_fast_mode
|
||||
}
|
||||
if check == "fwgrad_bwgrad":
|
||||
kwargs["check_fwd_over_rev"] = True
|
||||
kwargs["check_rev_over_rev"] = False
|
||||
kwargs["check_batched_grad"] = False
|
||||
kwargs["check_undefined_grad"] = False
|
||||
|
||||
self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
|
||||
else:
|
||||
self.assertTrue(False, msg="Unknown check requested!")
|
||||
|
||||
def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
|
||||
check_batched_grad=None, check_batched_forward_grad=False):
|
||||
return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
|
||||
check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
|
||||
check_batched_forward_grad=check_batched_forward_grad)
|
||||
|
||||
def _skip_helper(self, op, device, dtype):
|
||||
if not op.supports_autograd and not op.supports_forward_ad:
|
||||
self.skipTest("Skipped! autograd not supported.")
|
||||
if not op.supports_complex_autograd(torch.device(device).type) and dtype.is_complex:
|
||||
self.skipTest("Skipped! Complex autograd not supported.")
|
||||
|
||||
# Tests that gradients are computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_grad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
self._grad_test_helper(device, dtype, op, op.get_op())
|
||||
|
||||
# Method grad (and gradgrad, see below) tests are disabled since they're
|
||||
# costly and redundant with function grad (and gradgad) tests
|
||||
# @_gradcheck_ops(op_db)
|
||||
# def test_method_grad(self, device, dtype, op):
|
||||
# self._skip_helper(op, device, dtype)
|
||||
# self._grad_test_helper(device, dtype, op, op.get_method())
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_inplace_grad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.inplace_variant or not op.supports_inplace_autograd:
|
||||
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
||||
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
||||
|
||||
# Test that gradients of gradients are computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.supports_gradgrad:
|
||||
self.skipTest("Skipped! Operation does not support gradgrad")
|
||||
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
||||
|
||||
# Test that forward-over-reverse gradgrad is computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_fwgrad_bwgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
|
||||
if op.supports_fwgrad_bwgrad:
|
||||
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
||||
else:
|
||||
err_msg = r"Trying to use forward AD with .* that does not support it\."
|
||||
hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not "
|
||||
"raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.")
|
||||
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
||||
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
||||
|
||||
# Test that gradients of gradients are properly raising
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_fail_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if op.supports_gradgrad:
|
||||
self.skipTest("Skipped! Operation does support gradgrad")
|
||||
|
||||
err_msg = r"derivative for .* is not implemented"
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
||||
|
||||
# Method gradgrad (and grad, see above) tests are disabled since they're
|
||||
# costly and redundant with function gradgrad (and grad) tests
|
||||
# @_gradcheck_ops(op_db)
|
||||
# def test_method_gradgrad(self, device, dtype, op):
|
||||
# self._skip_helper(op, device, dtype)
|
||||
# self._gradgrad_test_helper(device, dtype, op, op.get_method())
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_inplace_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.inplace_variant or not op.supports_inplace_autograd:
|
||||
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
||||
self._check_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad")
|
||||
|
||||
def _forward_grad_helper(self, device, dtype, op, variant, is_inplace):
|
||||
# TODO: clean up how attributes are passed to gradcheck from OpInfos
|
||||
def call_grad_test_helper():
|
||||
check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or
|
||||
(op.check_inplace_batched_forward_grad and is_inplace))
|
||||
self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad)
|
||||
if op.supports_forward_ad:
|
||||
call_grad_test_helper()
|
||||
else:
|
||||
err_msg = r"Trying to use forward AD with .* that does not support it\."
|
||||
hint_msg = ("Running forward AD for an OP that has does not support it did not "
|
||||
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
|
||||
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
||||
call_grad_test_helper()
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_forward_mode_AD(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
|
||||
self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False)
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_inplace_forward_mode_AD(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
|
||||
if not op.inplace_variant or not op.supports_inplace_autograd:
|
||||
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
||||
|
||||
self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True)
|
||||
|
||||
# Functions that do not support autograd should not fail in forward mode
|
||||
# Inplace functions (such as "resize_") are expected to fail in forward mode and should be skipped
|
||||
# Test only when supports_autograd=False and for double dtype
|
||||
@ops(filter(lambda op: not op.supports_autograd, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_nondifferentiable(self, device, dtype, op):
|
||||
# Expecting no errors
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
sample = first_sample(self, samples)
|
||||
result = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
|
||||
# Tests operators for consistency between JIT and eager, also checks
|
||||
# correctness of JIT specific alias schemas and intended
|
||||
# autodifferentiation behavior.
|
||||
# Inherits from JitCommonTestCase instead of TestCase directly to share
|
||||
# functionality with original test_jit.py method operator tests
|
||||
class TestJit(JitCommonTestCase):
|
||||
exact_dtype = True
|
||||
|
||||
# Tests that the forward and backward passes of operations produce the
|
||||
# same values for the cross-product of op variants (function, method, inplace)
|
||||
# and runtimes (eager, traced, scripted).
|
||||
# TODO WARNING: inplace x {traced, scripted} not currently tested
|
||||
@_variant_ops(op_db)
|
||||
def test_variant_consistency_jit(self, device, dtype, op):
|
||||
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
|
||||
op.supports_complex_autograd(torch.device(device).type))
|
||||
|
||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
|
||||
|
||||
# Acquires variants to test
|
||||
func = op.get_op()
|
||||
method = op.get_method()
|
||||
variants = {
|
||||
# TODO: inplace tests currently fail, fix and add inplace variant
|
||||
'function': func, 'method': method,
|
||||
}
|
||||
|
||||
# TODO: find better way to standardize on op registration itself..
|
||||
has_fake_function = op.name in ["resize_", 'resize_as_']
|
||||
|
||||
if has_fake_function:
|
||||
variants = {'method': getattr(torch.Tensor, op.name)}
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
|
||||
support_script = op.supports_scripting
|
||||
|
||||
tested = False
|
||||
for sample in samples:
|
||||
# Test traced and scripted consistency
|
||||
for func_type, variant in variants.items():
|
||||
if variant is None:
|
||||
continue
|
||||
|
||||
# scripting and check_alias_analysis do not work with lambdas
|
||||
# lambdas are typically used as a way to simulate methods without
|
||||
# functional variants, so rely on the other variant for testing
|
||||
# for now
|
||||
if is_lambda(variant):
|
||||
continue
|
||||
|
||||
tested = True
|
||||
|
||||
# Create accessor for script function variant
|
||||
name = op.name + '_' if func_type == 'inplace' else op.name
|
||||
|
||||
# run with disable_autodiff_subgraph_inlining(True) to test
|
||||
# autodiff support. Context manager forces the graph to contain
|
||||
# DifferentiableGraph nodes if they are present
|
||||
with disable_autodiff_subgraph_inlining():
|
||||
# Check scripted forward, grad, and grad grad
|
||||
if support_script:
|
||||
script_fn = create_script_fn(self, name, func_type)
|
||||
|
||||
def out_fn(output):
|
||||
# Processes the output for autograd
|
||||
if sample.output_process_fn_grad is not None:
|
||||
return sample.output_process_fn_grad(output)
|
||||
return output
|
||||
|
||||
def get_sample():
|
||||
return clone_input_helper(sample.input) if op.name[-1] == '_' else sample.input
|
||||
|
||||
if support_script:
|
||||
check_against_reference(self,
|
||||
script_fn,
|
||||
func,
|
||||
out_fn,
|
||||
(get_sample(),) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
|
||||
|
||||
# Check traced forward, grad, and grad grad
|
||||
# TODO: fix tracing here
|
||||
supports_tracing = not has_fake_function
|
||||
if op.assert_jit_shape_analysis:
|
||||
self.assertTrue(supports_tracing)
|
||||
|
||||
if supports_tracing:
|
||||
traced_fn = create_traced_fn(self, variant)
|
||||
check_against_reference(self,
|
||||
traced_fn,
|
||||
func,
|
||||
out_fn,
|
||||
(get_sample(),) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
|
||||
|
||||
# Check alias annotation schema for correctness (make
|
||||
# sure inputs that aren't supposed to be modified aren't)
|
||||
# Note: only runs in float32 because schema isn't affected by dtype,
|
||||
# so running it on all dtypes is would be excessive
|
||||
if dtype == torch.float32:
|
||||
# TODO: no reason why we cant run this with tracing graph
|
||||
if support_script and op.name != "rsub":
|
||||
check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs,
|
||||
func_type=func_type, aten_name=op.aten_name)
|
||||
|
||||
# TODO: use script graph as well
|
||||
checked_shape_analysis = False
|
||||
if supports_tracing:
|
||||
out = variant(get_sample(), *sample.args, **sample.kwargs)
|
||||
|
||||
# right now, tuple of outputs and tensor output supported
|
||||
# TODO: list of tensor outputs
|
||||
tuple_of_tensors = isinstance(out, tuple) and all([isinstance(elem, torch.Tensor) for elem in out])
|
||||
|
||||
if isinstance(out, torch.Tensor) or tuple_of_tensors:
|
||||
if tuple_of_tensors:
|
||||
sizes = [elem.size() for elem in out]
|
||||
else:
|
||||
sizes = out.size()
|
||||
self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis)
|
||||
checked_shape_analysis = True
|
||||
if op.assert_jit_shape_analysis:
|
||||
self.assertTrue(checked_shape_analysis)
|
||||
|
||||
# Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
|
||||
if dtype is torch.float32:
|
||||
# Sandcastle doesn't fuse nodes
|
||||
if IS_SANDCASTLE:
|
||||
# fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
|
||||
nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
|
||||
fusible_nodes = []
|
||||
else:
|
||||
nonfusible_nodes = op.autodiff_nonfusible_nodes
|
||||
fusible_nodes = op.autodiff_fusible_nodes
|
||||
|
||||
if supports_tracing:
|
||||
self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
|
||||
if support_script:
|
||||
self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
|
||||
assert tested, "JIT Test does not execute any logic"
|
||||
|
||||
# alias testing is only done with torch.float for the same reason
|
||||
_alias_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=(torch.float,))
|
||||
|
||||
@_alias_ops((op for op in op_db if op.aliases))
|
||||
def test_jit_alias_remapping(self, device, dtype, op):
|
||||
# Required to avoid undefined value: tensor error in JIT compilation of the function template
|
||||
tensor = torch.tensor
|
||||
|
||||
# NOTE: only tests on first sample
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
sample = first_sample(self, samples)
|
||||
|
||||
# [Scripting Data Preparation]
|
||||
# Prepare data for test scripting
|
||||
# Below we prepare strings of args/kwargs with and without type annotations.
|
||||
# These strings are inserted into function template strings which is then torch scripted.
|
||||
# - args string is ["t0"] corresponding to the "input" tensor required by the op
|
||||
# - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example,
|
||||
# ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0))
|
||||
args = ["t0"]
|
||||
|
||||
def quote_strs(v):
|
||||
if isinstance(v, str):
|
||||
return f"'{v}'"
|
||||
|
||||
return str(v)
|
||||
|
||||
args_kw = args + \
|
||||
[f"{v}" for v in sample.args] + \
|
||||
[f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
|
||||
|
||||
# Prepare data for test tracing
|
||||
sample_args_kwargs = ()
|
||||
if len(sample.args) > 0:
|
||||
sample_args_kwargs += (sample.args, )
|
||||
if len(sample.kwargs) > 0:
|
||||
sample_args_kwargs += (sample.kwargs, )
|
||||
|
||||
original_name = op.aten_name
|
||||
original_name_inplace = original_name + "_"
|
||||
expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype
|
||||
|
||||
for a_op in op.aliases:
|
||||
inplace = a_op.inplace_variant
|
||||
method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
|
||||
variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
|
||||
|
||||
# Test scripting:
|
||||
for variant in variants:
|
||||
variant_name = variant.__name__
|
||||
op_name = original_name_inplace if variant is inplace else original_name
|
||||
|
||||
if variant in method_or_inplace:
|
||||
fn_template = '''
|
||||
def _fn(t0{c}):
|
||||
return t0.{alias_name}({args_kw})
|
||||
'''
|
||||
# remove the first input tensor
|
||||
script = fn_template.format(
|
||||
c=", " if len(args_kw[1:]) > 1 else "",
|
||||
args_kw=", ".join(args_kw[1:]),
|
||||
alias_name=variant_name,
|
||||
)
|
||||
else:
|
||||
fn_template = '''
|
||||
def _fn({args}):
|
||||
return variant({args_kw})
|
||||
'''
|
||||
script = fn_template.format(
|
||||
args=", ".join(args),
|
||||
args_kw=", ".join(args_kw),
|
||||
)
|
||||
scripted = torch.jit.CompilationUnit(script)._fn
|
||||
|
||||
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
|
||||
try:
|
||||
inp = clone_input_helper(sample.input)
|
||||
scripted(inp)
|
||||
except Exception as e:
|
||||
continue
|
||||
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
|
||||
|
||||
inp = clone_input_helper(sample.input)
|
||||
scripted(inp)
|
||||
inp = clone_input_helper(sample.input)
|
||||
graph = scripted.graph_for(inp)
|
||||
FileCheck().check(op.aten_name).check_not(variant_name).run(graph)
|
||||
|
||||
# Test tracing:
|
||||
for variant in variants:
|
||||
variant_name = variant.__name__
|
||||
op_name = original_name_inplace if variant is inplace else original_name
|
||||
|
||||
def _fn(*sample_args, **sample_kwargs):
|
||||
return variant(*sample_args, **sample_kwargs)
|
||||
|
||||
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
|
||||
traced = torch.jit.trace(_fn, *inp)
|
||||
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
|
||||
traced(*inp)
|
||||
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
|
||||
graph = traced.graph_for(*inp)
|
||||
FileCheck().check(op_name).check_not(variant_name).run(graph)
|
||||
|
||||
class TestMathBits(TestCase):
|
||||
# Tests that
|
||||
# 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors
|
||||
|
|
@ -1313,8 +850,6 @@ class TestMathBits(TestCase):
|
|||
|
||||
|
||||
instantiate_device_type_tests(TestCommon, globals())
|
||||
instantiate_device_type_tests(TestGradients, globals())
|
||||
instantiate_device_type_tests(TestJit, globals())
|
||||
instantiate_device_type_tests(TestMathBits, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
228
test/test_ops_gradients.py
Normal file
228
test/test_ops_gradients.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
# Owner(s): ["high priority"]
|
||||
|
||||
from functools import partial, wraps
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck, first_sample)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, OpDTypes)
|
||||
|
||||
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
# gradcheck requires double precision
|
||||
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=[torch.double, torch.cdouble])
|
||||
|
||||
class TestGradients(TestCase):
|
||||
exact_dtype = True
|
||||
|
||||
# Copies inputs to inplace operations to avoid inplace modifications
|
||||
# to leaves requiring gradient
|
||||
def _get_safe_inplace(self, inplace_variant):
|
||||
@wraps(inplace_variant)
|
||||
def _fn(t, *args, **kwargs):
|
||||
return inplace_variant(t.clone(), *args, **kwargs)
|
||||
|
||||
return _fn
|
||||
|
||||
def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
|
||||
check_batched_grad=None, check_batched_forward_grad=False):
|
||||
assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
|
||||
# NB: check_backward_ad does not affect gradgradcheck (always True)
|
||||
if variant is None:
|
||||
self.skipTest("Skipped! Variant not implemented.")
|
||||
if not op.supports_dtype(dtype, torch.device(device).type):
|
||||
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
|
||||
|
||||
def is_inplace(variant):
|
||||
if hasattr(variant, "__wrapped__"):
|
||||
return variant.__wrapped__ is op.get_inplace()
|
||||
return variant is op.get_inplace()
|
||||
|
||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs)
|
||||
|
||||
for sample in samples:
|
||||
if sample.broadcasts_input and is_inplace(variant):
|
||||
continue
|
||||
|
||||
# Note on TensorList inputs
|
||||
#
|
||||
# gradcheck does not support TensorList inputs so here we pass TensorList
|
||||
# inputs of size n as n single Tensor inputs to gradcheck and wrap the op
|
||||
# in a function that puts the n Tensor inputs back into a TensorList
|
||||
def fn(*inputs):
|
||||
# Put tensors back into TensorList since we splat them when passing to gradcheck
|
||||
if is_iterable_of_tensors(sample.input):
|
||||
n = len(sample.input)
|
||||
inputs = (inputs[:n], *inputs[n:])
|
||||
output = op.gradcheck_wrapper(variant, *inputs, **sample.kwargs)
|
||||
if sample.output_process_fn_grad is not None:
|
||||
return sample.output_process_fn_grad(output)
|
||||
return output
|
||||
|
||||
# Splat TensorList inputs into single Tensor inputs
|
||||
gradcheck_args = (sample.input,) if isinstance(sample.input, torch.Tensor) else tuple(sample.input)
|
||||
gradcheck_args += sample.args
|
||||
|
||||
if check == 'gradcheck':
|
||||
if check_batched_grad is None:
|
||||
check_batched_grad = op.check_batched_grad
|
||||
self.assertTrue(gradcheck(fn, gradcheck_args,
|
||||
check_batched_grad=check_batched_grad,
|
||||
check_grad_dtypes=True,
|
||||
nondet_tol=op.gradcheck_nondet_tol,
|
||||
fast_mode=op.gradcheck_fast_mode,
|
||||
check_forward_ad=check_forward_ad,
|
||||
check_backward_ad=check_backward_ad,
|
||||
check_undefined_grad=True,
|
||||
check_batched_forward_grad=check_batched_forward_grad))
|
||||
elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check
|
||||
self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
|
||||
for gen_non_contig_grad_outputs in (False, True):
|
||||
kwargs = {
|
||||
"gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
|
||||
"check_batched_grad": op.check_batched_gradgrad,
|
||||
"check_grad_dtypes": True,
|
||||
"nondet_tol": op.gradcheck_nondet_tol,
|
||||
"fast_mode": op.gradcheck_fast_mode
|
||||
}
|
||||
if check == "fwgrad_bwgrad":
|
||||
kwargs["check_fwd_over_rev"] = True
|
||||
kwargs["check_rev_over_rev"] = False
|
||||
kwargs["check_batched_grad"] = False
|
||||
kwargs["check_undefined_grad"] = False
|
||||
|
||||
self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
|
||||
else:
|
||||
self.assertTrue(False, msg="Unknown check requested!")
|
||||
|
||||
def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
|
||||
check_batched_grad=None, check_batched_forward_grad=False):
|
||||
return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
|
||||
check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
|
||||
check_batched_forward_grad=check_batched_forward_grad)
|
||||
|
||||
def _skip_helper(self, op, device, dtype):
|
||||
if not op.supports_autograd and not op.supports_forward_ad:
|
||||
self.skipTest("Skipped! autograd not supported.")
|
||||
if not op.supports_complex_autograd(torch.device(device).type) and dtype.is_complex:
|
||||
self.skipTest("Skipped! Complex autograd not supported.")
|
||||
|
||||
# Tests that gradients are computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_grad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
self._grad_test_helper(device, dtype, op, op.get_op())
|
||||
|
||||
# Method grad (and gradgrad, see below) tests are disabled since they're
|
||||
# costly and redundant with function grad (and gradgad) tests
|
||||
# @_gradcheck_ops(op_db)
|
||||
# def test_method_grad(self, device, dtype, op):
|
||||
# self._skip_helper(op, device, dtype)
|
||||
# self._grad_test_helper(device, dtype, op, op.get_method())
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_inplace_grad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.inplace_variant or not op.supports_inplace_autograd:
|
||||
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
||||
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
||||
|
||||
# Test that gradients of gradients are computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.supports_gradgrad:
|
||||
self.skipTest("Skipped! Operation does not support gradgrad")
|
||||
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
||||
|
||||
# Test that forward-over-reverse gradgrad is computed correctly
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_fwgrad_bwgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
|
||||
if op.supports_fwgrad_bwgrad:
|
||||
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
||||
else:
|
||||
err_msg = r"Trying to use forward AD with .* that does not support it\."
|
||||
hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not "
|
||||
"raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.")
|
||||
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
||||
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
||||
|
||||
# Test that gradients of gradients are properly raising
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_fn_fail_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if op.supports_gradgrad:
|
||||
self.skipTest("Skipped! Operation does support gradgrad")
|
||||
|
||||
err_msg = r"derivative for .* is not implemented"
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
||||
|
||||
# Method gradgrad (and grad, see above) tests are disabled since they're
|
||||
# costly and redundant with function gradgrad (and grad) tests
|
||||
# @_gradcheck_ops(op_db)
|
||||
# def test_method_gradgrad(self, device, dtype, op):
|
||||
# self._skip_helper(op, device, dtype)
|
||||
# self._gradgrad_test_helper(device, dtype, op, op.get_method())
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_inplace_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.inplace_variant or not op.supports_inplace_autograd:
|
||||
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
||||
self._check_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad")
|
||||
|
||||
def _forward_grad_helper(self, device, dtype, op, variant, is_inplace):
|
||||
# TODO: clean up how attributes are passed to gradcheck from OpInfos
|
||||
def call_grad_test_helper():
|
||||
check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or
|
||||
(op.check_inplace_batched_forward_grad and is_inplace))
|
||||
self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad)
|
||||
if op.supports_forward_ad:
|
||||
call_grad_test_helper()
|
||||
else:
|
||||
err_msg = r"Trying to use forward AD with .* that does not support it\."
|
||||
hint_msg = ("Running forward AD for an OP that has does not support it did not "
|
||||
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
|
||||
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
||||
call_grad_test_helper()
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_forward_mode_AD(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
|
||||
self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False)
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
def test_inplace_forward_mode_AD(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
|
||||
if not op.inplace_variant or not op.supports_inplace_autograd:
|
||||
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
||||
|
||||
self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True)
|
||||
|
||||
# Functions that do not support autograd should not fail in forward mode
|
||||
# Inplace functions (such as "resize_") are expected to fail in forward mode and should be skipped
|
||||
# Test only when supports_autograd=False and for double dtype
|
||||
@ops(filter(lambda op: not op.supports_autograd, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_nondifferentiable(self, device, dtype, op):
|
||||
# Expecting no errors
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
sample = first_sample(self, samples)
|
||||
result = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestGradients, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
280
test/test_ops_jit.py
Normal file
280
test/test_ops_jit.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
# Owner(s): ["high priority"]
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import \
|
||||
(run_tests, IS_SANDCASTLE, clone_input_helper, first_sample)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
|
||||
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
|
||||
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, check_alias_annotation
|
||||
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
|
||||
|
||||
|
||||
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
# variant testing is only done with torch.float and torch.cfloat to avoid
|
||||
# excessive test times and maximize signal to noise ratio
|
||||
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=(torch.float, torch.cfloat))
|
||||
|
||||
|
||||
|
||||
# Tests operators for consistency between JIT and eager, also checks
|
||||
# correctness of JIT specific alias schemas and intended
|
||||
# autodifferentiation behavior.
|
||||
# Inherits from JitCommonTestCase instead of TestCase directly to share
|
||||
# functionality with original test_jit.py method operator tests
|
||||
class TestJit(JitCommonTestCase):
|
||||
exact_dtype = True
|
||||
|
||||
# Tests that the forward and backward passes of operations produce the
|
||||
# same values for the cross-product of op variants (function, method, inplace)
|
||||
# and runtimes (eager, traced, scripted).
|
||||
# TODO WARNING: inplace x {traced, scripted} not currently tested
|
||||
@_variant_ops(op_db)
|
||||
def test_variant_consistency_jit(self, device, dtype, op):
|
||||
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
|
||||
op.supports_complex_autograd(torch.device(device).type))
|
||||
|
||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
|
||||
|
||||
# Acquires variants to test
|
||||
func = op.get_op()
|
||||
method = op.get_method()
|
||||
variants = {
|
||||
# TODO: inplace tests currently fail, fix and add inplace variant
|
||||
'function': func, 'method': method,
|
||||
}
|
||||
|
||||
# TODO: find better way to standardize on op registration itself..
|
||||
has_fake_function = op.name in ["resize_", 'resize_as_']
|
||||
|
||||
if has_fake_function:
|
||||
variants = {'method': getattr(torch.Tensor, op.name)}
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
|
||||
support_script = op.supports_scripting
|
||||
|
||||
tested = False
|
||||
for sample in samples:
|
||||
# Test traced and scripted consistency
|
||||
for func_type, variant in variants.items():
|
||||
if variant is None:
|
||||
continue
|
||||
|
||||
# scripting and check_alias_analysis do not work with lambdas
|
||||
# lambdas are typically used as a way to simulate methods without
|
||||
# functional variants, so rely on the other variant for testing
|
||||
# for now
|
||||
if is_lambda(variant):
|
||||
continue
|
||||
|
||||
tested = True
|
||||
|
||||
# Create accessor for script function variant
|
||||
name = op.name + '_' if func_type == 'inplace' else op.name
|
||||
|
||||
# run with disable_autodiff_subgraph_inlining(True) to test
|
||||
# autodiff support. Context manager forces the graph to contain
|
||||
# DifferentiableGraph nodes if they are present
|
||||
with disable_autodiff_subgraph_inlining():
|
||||
# Check scripted forward, grad, and grad grad
|
||||
if support_script:
|
||||
script_fn = create_script_fn(self, name, func_type)
|
||||
|
||||
def out_fn(output):
|
||||
# Processes the output for autograd
|
||||
if sample.output_process_fn_grad is not None:
|
||||
return sample.output_process_fn_grad(output)
|
||||
return output
|
||||
|
||||
def get_sample():
|
||||
return clone_input_helper(sample.input) if op.name[-1] == '_' else sample.input
|
||||
|
||||
if support_script:
|
||||
check_against_reference(self,
|
||||
script_fn,
|
||||
func,
|
||||
out_fn,
|
||||
(get_sample(),) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
|
||||
|
||||
# Check traced forward, grad, and grad grad
|
||||
# TODO: fix tracing here
|
||||
supports_tracing = not has_fake_function
|
||||
if op.assert_jit_shape_analysis:
|
||||
self.assertTrue(supports_tracing)
|
||||
|
||||
if supports_tracing:
|
||||
traced_fn = create_traced_fn(self, variant)
|
||||
check_against_reference(self,
|
||||
traced_fn,
|
||||
func,
|
||||
out_fn,
|
||||
(get_sample(),) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
|
||||
|
||||
# Check alias annotation schema for correctness (make
|
||||
# sure inputs that aren't supposed to be modified aren't)
|
||||
# Note: only runs in float32 because schema isn't affected by dtype,
|
||||
# so running it on all dtypes is would be excessive
|
||||
if dtype == torch.float32:
|
||||
# TODO: no reason why we cant run this with tracing graph
|
||||
if support_script and op.name != "rsub":
|
||||
check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs,
|
||||
func_type=func_type, aten_name=op.aten_name)
|
||||
|
||||
# TODO: use script graph as well
|
||||
checked_shape_analysis = False
|
||||
if supports_tracing:
|
||||
out = variant(get_sample(), *sample.args, **sample.kwargs)
|
||||
|
||||
# right now, tuple of outputs and tensor output supported
|
||||
# TODO: list of tensor outputs
|
||||
tuple_of_tensors = isinstance(out, tuple) and all([isinstance(elem, torch.Tensor) for elem in out])
|
||||
|
||||
if isinstance(out, torch.Tensor) or tuple_of_tensors:
|
||||
if tuple_of_tensors:
|
||||
sizes = [elem.size() for elem in out]
|
||||
else:
|
||||
sizes = out.size()
|
||||
self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis)
|
||||
checked_shape_analysis = True
|
||||
if op.assert_jit_shape_analysis:
|
||||
self.assertTrue(checked_shape_analysis)
|
||||
|
||||
# Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
|
||||
if dtype is torch.float32:
|
||||
# Sandcastle doesn't fuse nodes
|
||||
if IS_SANDCASTLE:
|
||||
# fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
|
||||
nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
|
||||
fusible_nodes = []
|
||||
else:
|
||||
nonfusible_nodes = op.autodiff_nonfusible_nodes
|
||||
fusible_nodes = op.autodiff_fusible_nodes
|
||||
|
||||
if supports_tracing:
|
||||
self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
|
||||
if support_script:
|
||||
self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
|
||||
assert tested, "JIT Test does not execute any logic"
|
||||
|
||||
# alias testing is only done with torch.float for the same reason
|
||||
_alias_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=(torch.float,))
|
||||
|
||||
@_alias_ops((op for op in op_db if op.aliases))
|
||||
def test_jit_alias_remapping(self, device, dtype, op):
|
||||
# Required to avoid undefined value: tensor error in JIT compilation of the function template
|
||||
tensor = torch.tensor
|
||||
|
||||
# NOTE: only tests on first sample
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
sample = first_sample(self, samples)
|
||||
|
||||
# [Scripting Data Preparation]
|
||||
# Prepare data for test scripting
|
||||
# Below we prepare strings of args/kwargs with and without type annotations.
|
||||
# These strings are inserted into function template strings which is then torch scripted.
|
||||
# - args string is ["t0"] corresponding to the "input" tensor required by the op
|
||||
# - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example,
|
||||
# ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0))
|
||||
args = ["t0"]
|
||||
|
||||
def quote_strs(v):
|
||||
if isinstance(v, str):
|
||||
return f"'{v}'"
|
||||
|
||||
return str(v)
|
||||
|
||||
args_kw = args + \
|
||||
[f"{v}" for v in sample.args] + \
|
||||
[f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
|
||||
|
||||
# Prepare data for test tracing
|
||||
sample_args_kwargs = ()
|
||||
if len(sample.args) > 0:
|
||||
sample_args_kwargs += (sample.args, )
|
||||
if len(sample.kwargs) > 0:
|
||||
sample_args_kwargs += (sample.kwargs, )
|
||||
|
||||
original_name = op.aten_name
|
||||
original_name_inplace = original_name + "_"
|
||||
expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype
|
||||
|
||||
for a_op in op.aliases:
|
||||
inplace = a_op.inplace_variant
|
||||
method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
|
||||
variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
|
||||
|
||||
# Test scripting:
|
||||
for variant in variants:
|
||||
variant_name = variant.__name__
|
||||
op_name = original_name_inplace if variant is inplace else original_name
|
||||
|
||||
if variant in method_or_inplace:
|
||||
fn_template = '''
|
||||
def _fn(t0{c}):
|
||||
return t0.{alias_name}({args_kw})
|
||||
'''
|
||||
# remove the first input tensor
|
||||
script = fn_template.format(
|
||||
c=", " if len(args_kw[1:]) > 1 else "",
|
||||
args_kw=", ".join(args_kw[1:]),
|
||||
alias_name=variant_name,
|
||||
)
|
||||
else:
|
||||
fn_template = '''
|
||||
def _fn({args}):
|
||||
return variant({args_kw})
|
||||
'''
|
||||
script = fn_template.format(
|
||||
args=", ".join(args),
|
||||
args_kw=", ".join(args_kw),
|
||||
)
|
||||
scripted = torch.jit.CompilationUnit(script)._fn
|
||||
|
||||
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
|
||||
try:
|
||||
inp = clone_input_helper(sample.input)
|
||||
scripted(inp)
|
||||
except Exception as e:
|
||||
continue
|
||||
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
|
||||
|
||||
inp = clone_input_helper(sample.input)
|
||||
scripted(inp)
|
||||
inp = clone_input_helper(sample.input)
|
||||
graph = scripted.graph_for(inp)
|
||||
FileCheck().check(op.aten_name).check_not(variant_name).run(graph)
|
||||
|
||||
# Test tracing:
|
||||
for variant in variants:
|
||||
variant_name = variant.__name__
|
||||
op_name = original_name_inplace if variant is inplace else original_name
|
||||
|
||||
def _fn(*sample_args, **sample_kwargs):
|
||||
return variant(*sample_args, **sample_kwargs)
|
||||
|
||||
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
|
||||
traced = torch.jit.trace(_fn, *inp)
|
||||
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
|
||||
traced(*inp)
|
||||
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
|
||||
graph = traced.graph_for(*inp)
|
||||
FileCheck().check(op_name).check_not(variant_name).run(graph)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestJit, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -504,7 +504,7 @@ workarounds. The workaround depends on how your test invokes gradcheck/gradgradc
|
|||
If the test
|
||||
- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
|
||||
with `nondet_tol=<tol>` as a keyword argument.
|
||||
- is OpInfo-based (e.g., in test_ops.py), then modify the OpInfo for the test
|
||||
- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
|
||||
to have `gradcheck_nondet_tol=<tol>`.
|
||||
- is a Module test (e.g., in common_nn.py), then modify the corresponding
|
||||
module_test entry to have `gradcheck_nondet_tol=<tol>`
|
||||
|
|
@ -717,7 +717,7 @@ workarounds. The workaround depends on how your test invokes gradcheck/gradgradc
|
|||
If the test
|
||||
- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
|
||||
with `check_batched_grad=False` as a keyword argument.
|
||||
- is OpInfo-based (e.g., in test_ops.py), then modify the OpInfo for the test
|
||||
- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
|
||||
to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`.
|
||||
|
||||
If you're modifying an existing operator that supports batched grad computation,
|
||||
|
|
@ -743,7 +743,7 @@ workarounds. The workaround depends on how your test invokes gradcheck/gradgradc
|
|||
If the test
|
||||
- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
|
||||
with `check_batched_forward_grad=False` as a keyword argument.
|
||||
- is OpInfo-based (e.g., in test_ops.py), then modify the OpInfo for the test
|
||||
- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
|
||||
to have `check_batched_forward_grad=False`
|
||||
"""
|
||||
|
||||
|
|
@ -1196,7 +1196,7 @@ workarounds. The workaround depends on how your test invokes gradcheck/gradgradc
|
|||
If the test
|
||||
- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
|
||||
with `fast_mode=False` as a keyword argument.
|
||||
- is OpInfo-based (e.g., in test_ops.py), then modify the OpInfo for the test
|
||||
- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
|
||||
to have `gradcheck_fast_mode=False`
|
||||
- is a Module test (e.g., in common_nn.py), then modify the corresponding
|
||||
module_test entry to have `gradcheck_fast_mode=False`
|
||||
|
|
|
|||
Loading…
Reference in a new issue