From 4bfe2f0900a602017a31c0aaa8f00750412a2d83 Mon Sep 17 00:00:00 2001 From: Vitaly Fedyunin Date: Thu, 31 Oct 2019 13:19:31 -0700 Subject: [PATCH] Fix jit outplace tracing and reapply changes to *_like operators. (#28839) Summary: Reapply reverted and fix files `gen_variable_type.py` `test_jit.py` https://github.com/pytorch/pytorch/issues/27891 Cleanup testing of _like operators https://github.com/pytorch/pytorch/issues/27890 Add memory format support to randn_like operator https://github.com/pytorch/pytorch/issues/27889 Add memory format support to randint_like operator https://github.com/pytorch/pytorch/issues/27562 Add memory format support to zeros_like operator https://github.com/pytorch/pytorch/issues/27561 Add memory format support to rand_like operator https://github.com/pytorch/pytorch/issues/27270 Add memory format support to ones_like operator https://github.com/pytorch/pytorch/issues/27262 Add memory format support to full_like operator Pull Request resolved: https://github.com/pytorch/pytorch/pull/28839 Test Plan: Imported from GitHub, without a `Test Plan:` line. buck test mode/dev //language_technology/neural_mt/os/pytorch_translate/test:test_onnx -- 'test_forced_decoder_export_vocab_reduction \(language_technology\.neural_mt\.os\.pytorch_translate\.test\.test_onnx\.TestONNX\)' Differential Revision: D18203397 Pulled By: VitalyFedyunin fbshipit-source-id: eea41cbd4c232cf5a54172b1e1b16b173798f298 --- aten/src/ATen/native/TensorFactories.cpp | 106 +++++++++++++++------ aten/src/ATen/native/native_functions.yaml | 35 +++---- test/test_jit.py | 14 +++ test/test_torch.py | 38 +++++--- tools/autograd/gen_variable_factories.py | 7 +- tools/autograd/gen_variable_type.py | 54 ++++++++++- torch/csrc/jit/fuser/codegen.cpp | 3 + torch/csrc/jit/ir.cpp | 16 ++-- torch/csrc/jit/passes/graph_fuser.cpp | 4 +- torch/csrc/jit/passes/shape_analysis.cpp | 30 +++--- torch/csrc/jit/symbolic_script.cpp | 6 +- torch/onnx/symbolic_opset8.py | 12 +-- torch/onnx/symbolic_opset9.py | 12 +-- 13 files changed, 224 insertions(+), 113 deletions(-) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 07cbca3e573..8cd6dacb299 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -323,12 +323,21 @@ Tensor& full_out(Tensor& result, IntArrayRef size, Scalar fill_value) { return result.fill_(fill_value); } -Tensor full_like(const Tensor& self, Scalar fill_value) { - return native::full_like(self, fill_value, self.options()); +Tensor full_like( + const Tensor& self, + Scalar fill_value, + c10::optional optional_memory_format) { + return native::full_like( + self, fill_value, self.options(), optional_memory_format); } -Tensor full_like(const Tensor& self, Scalar fill_value, const TensorOptions& options) { - return native::full(self.sizes(), fill_value, options); +Tensor full_like( + const Tensor& self, + Scalar fill_value, + const TensorOptions& options, + c10::optional optional_memory_format) { + auto result = at::empty_like(self, options, optional_memory_format); + return result.fill_(fill_value); } Tensor new_full( @@ -375,14 +384,20 @@ Tensor& ones_out(Tensor& result, IntArrayRef size) { return native::full_out(result, size, /*fill_value=*/1); } -Tensor ones_like(const Tensor& self) { - return native::ones(self.sizes(), self.options()); +Tensor ones_like( + const Tensor& self, + const TensorOptions& options, + c10::optional optional_memory_format) { + auto result = at::empty_like(self, options, optional_memory_format); + return result.fill_(1); } -Tensor ones_like(const Tensor& self, const TensorOptions& options) { - return native::ones(self.sizes(), options); +Tensor ones_like( + const Tensor& self, + c10::optional optional_memory_format) { + return native::ones_like( + self, self.options(), optional_memory_format); } - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ scalar_tensor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tensor scalar_tensor(Scalar s, const TensorOptions& options) { @@ -409,12 +424,18 @@ Tensor& rand_out(Tensor& result, IntArrayRef size, Generator* generator) { return result.uniform_(0, 1, generator); } -Tensor rand_like(const Tensor& self) { - return native::rand_like(self, self.options()); +Tensor rand_like( + const Tensor& self, + c10::optional optional_memory_format) { + return native::rand_like(self, self.options(), optional_memory_format); } -Tensor rand_like(const Tensor& self, const TensorOptions& options) { - return native::rand(self.sizes(), options); +Tensor rand_like( + const Tensor& self, + const TensorOptions& options, + c10::optional optional_memory_format) { + auto result = at::empty_like(self, options, optional_memory_format); + return result.uniform_(0, 1, nullptr); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -476,18 +497,30 @@ Tensor& randint_out( return result.random_(low, high, generator); } -Tensor randint_like(const Tensor& self, int64_t high) { - return native::randint_like(self, high, self.options()); +Tensor randint_like( + const Tensor& self, + int64_t high, + c10::optional optional_memory_format) { + return native::randint_like( + self, high, self.options(), optional_memory_format); } -Tensor randint_like(const Tensor& self, int64_t low, int64_t high) { - return native::randint_like(self, low, high, self.options()); +Tensor randint_like( + const Tensor& self, + int64_t low, + int64_t high, + c10::optional optional_memory_format) { + return native::randint_like( + self, low, high, self.options(), optional_memory_format); } Tensor randint_like( const Tensor& self, int64_t high, - const TensorOptions& options) { + const TensorOptions& options, + c10::optional optional_memory_format) { + auto result = at::empty_like(self, options, optional_memory_format); + return result.random_(0, high, nullptr); return native::randint(high, self.sizes(), nullptr, options); } @@ -495,8 +528,10 @@ Tensor randint_like( const Tensor& self, int64_t low, int64_t high, - const TensorOptions& options) { - return native::randint(low, high, self.sizes(), nullptr, options); + const TensorOptions& options, + c10::optional optional_memory_format) { + auto result = at::empty_like(self, options, optional_memory_format); + return result.random_(low, high, nullptr); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -531,12 +566,18 @@ Tensor& normal_out(Tensor& result, double mean, double std, return result.normal_(mean, std, generator); } -Tensor randn_like(const Tensor& self) { - return native::randn_like(self, self.options()); +Tensor randn_like( + const Tensor& self, + c10::optional optional_memory_format) { + return native::randn_like(self, self.options(), optional_memory_format); } -Tensor randn_like(const Tensor& self, const TensorOptions& options) { - return native::randn(self.sizes(), nullptr, options); +Tensor randn_like( + const Tensor& self, + const TensorOptions& options, + c10::optional optional_memory_format) { + auto result = at::empty_like(self, options, optional_memory_format); + return result.normal_(0, 1, nullptr); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randperm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -710,17 +751,24 @@ Tensor& zeros_out(Tensor& result, IntArrayRef size) { return result.zero_(); } -Tensor zeros_like(const Tensor& self) { - return native::zeros_like(self, self.options()); +Tensor zeros_like( + const Tensor& self, + c10::optional optional_memory_format) { + return native::zeros_like(self, self.options(), optional_memory_format); } -Tensor zeros_like(const Tensor& self, const TensorOptions& options) { +Tensor zeros_like( + const Tensor& self, + const TensorOptions& options, + c10::optional optional_memory_format) { if (options.layout() == kSparse && self.is_sparse()) { auto res = at::empty({0}, options); // to be resized - res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim()); + res.sparse_resize_and_clear_( + self.sizes(), self.sparse_dim(), self.dense_dim()); return res; } - return native::zeros(self.sizes(), options); + auto result = at::empty_like(self, options, optional_memory_format); + return result.zero_(); } Tensor new_zeros( diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 60bca427379..00ee3122b46 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1214,10 +1214,9 @@ - func: full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) -- func: full_like(Tensor self, Scalar fill_value) -> Tensor - use_c10_dispatcher: full +- func: full_like(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None) -> Tensor -- func: full_like.dtype(Tensor self, Scalar fill_value, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor +- func: full_like.dtype(Tensor self, Scalar fill_value, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor - func: from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: @@ -1983,10 +1982,9 @@ - func: ones.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) -- func: ones_like(Tensor self) -> Tensor - use_c10_dispatcher: full +- func: ones_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor -- func: ones_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor +- func: ones_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor use_c10_dispatcher: full @@ -2059,10 +2057,9 @@ - func: rand.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) -- func: rand_like(Tensor self) -> Tensor - use_c10_dispatcher: full +- func: rand_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor -- func: rand_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor +- func: rand_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor - func: randint(int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -2080,15 +2077,13 @@ - func: randint.low_generator_out(int low, int high, int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) -- func: randint_like(Tensor self, int high) -> Tensor - use_c10_dispatcher: full +- func: randint_like(Tensor self, int high, *, MemoryFormat? memory_format=None) -> Tensor -- func: randint_like.low(Tensor self, int low, int high) -> Tensor - use_c10_dispatcher: full +- func: randint_like.low(Tensor self, int low, int high, *, MemoryFormat? memory_format=None) -> Tensor -- func: randint_like.dtype(Tensor self, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor +- func: randint_like.dtype(Tensor self, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor -- func: randint_like.low_dtype(Tensor self, int low, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor +- func: randint_like.low_dtype(Tensor self, int low, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor - func: randn(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -2104,10 +2099,9 @@ - func: randn.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) -- func: randn_like(Tensor self) -> Tensor - use_c10_dispatcher: full +- func: randn_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor -- func: randn_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor +- func: randn_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor - func: randperm(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -2843,10 +2837,9 @@ - func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) -- func: zeros_like(Tensor self) -> Tensor - use_c10_dispatcher: full +- func: zeros_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor -- func: zeros_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor +- func: zeros_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor use_c10_dispatcher: full diff --git a/test/test_jit.py b/test/test_jit.py index 0195d3b7bcd..dd533e0f502 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1795,6 +1795,20 @@ graph(%Ra, %Rb): with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'): ge(x) + def test_force_outplace_check_fill(self): + def f(x): + return torch.empty(x.shape).fill_(7) + x = torch.randn(10, 15) + ft = torch.jit.trace(f, x, _force_outplace=True) + self.assertEqual(f(x), ft(x)) + + def test_force_outplace_check_zero(self): + def f(x): + return torch.empty(x.shape).zero_() + x = torch.randn(10, 15) + ft = torch.jit.trace(f, x, _force_outplace=True) + self.assertEqual(f(x), ft(x)) + def do_trace_size(self, requires_grad): def fn(x): return x.view(x.shape[1] * 2, x.size(0), 2) diff --git a/test/test_torch.py b/test/test_torch.py index ff43b80159b..03445872e80 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5565,14 +5565,14 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], torch.testing.assert_allclose(expected_norm, actual_norm) def test_memory_format(self): - x = torch.randn(10, 3, 32, 32) + x = torch.randn(4, 3, 8, 8) nhwc = x.contiguous(memory_format=torch.channels_last) self.assertFalse(nhwc.is_contiguous()) self.assertTrue(nhwc.is_contiguous(memory_format=torch.channels_last)) self.assertEqual(nhwc, x) def test_memory_format_contiguous_returns_same_tensor_if_already_satisfies(self): - x = torch.randn(10, 32, 32, 3).permute(0, 3, 1, 2) + x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2) alias = x.contiguous(memory_format=torch.channels_last) alias.fill_(7) self.assertEqual(x, alias) @@ -11857,13 +11857,13 @@ class TestTorchDeviceType(TestCase): _test_serialization(BytesIOContext) def test_memory_format_preserved_after_permute(self, device): - x = torch.randn(10, 3, 32, 32, device=device) + x = torch.randn(4, 3, 8, 8, device=device) nhwc = x.contiguous(memory_format=torch.channels_last) y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2) self.assertTrue(y.is_contiguous(memory_format=torch.channels_last)) def test_memory_format_empty_like(self, device): - x = torch.randn(10, 3, 32, 32, device=device) + x = torch.randn(4, 3, 8, 8, device=device) nhwc = x.contiguous(memory_format=torch.channels_last) like = torch.empty_like(nhwc, memory_format=torch.preserve_format) @@ -12536,7 +12536,7 @@ class TestTorchDeviceType(TestCase): def test_memory_format_to(self, device): def input_generator_fn(device): - return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) + return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) def transformation_fn(tensor, **kwargs): return tensor.to(dtype=torch.float64, **kwargs) @@ -12545,7 +12545,7 @@ class TestTorchDeviceType(TestCase): def test_memory_format_type(self, device): def input_generator_fn(device): - return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) + return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) def transformation_fn(tensor, **kwargs): return tensor.type(torch.float64, **kwargs) @@ -12554,7 +12554,7 @@ class TestTorchDeviceType(TestCase): def test_memory_format_clone(self, device): def input_generator_fn(device): - return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) + return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) def transformation_fn(tensor, **kwargs): return tensor.clone(**kwargs) @@ -12575,18 +12575,26 @@ class TestTorchDeviceType(TestCase): torch.sum(x, (2, 1), out=res2) self.assertEqual(res1, res2) - def test_memory_format_empty_like_strides(self, device): + def test_memory_format_factory_like_functions_preserve_strides(self, device): def input_generator_fn(device): - return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) + return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) - def transformation_fn(tensor, **kwargs): - return torch.empty_like(tensor, **kwargs) + transformation_fns = [ + lambda t, **kwargs: torch.zeros_like(t, **kwargs), + lambda t, **kwargs: torch.ones_like(t, **kwargs), + lambda t, **kwargs: torch.randint_like(t, 10, 100, **kwargs), + lambda t, **kwargs: torch.randint_like(t, 100, **kwargs), + lambda t, **kwargs: torch.randn_like(t, **kwargs), + lambda t, **kwargs: torch.rand_like(t, **kwargs), + lambda t, **kwargs: torch.full_like(t, 7, **kwargs), + lambda t, **kwargs: torch.empty_like(t, **kwargs)] - self._test_memory_format_transformations(device, input_generator_fn, transformation_fn, compare_data=False) + for transformation_fn in transformation_fns: + self._test_memory_format_transformations(device, input_generator_fn, transformation_fn, compare_data=False) def test_memory_format_type_shortcuts(self, device): def input_generator_fn(device): - return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).clamp(0, 1).round().contiguous(memory_format=torch.channels_last) + return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).clamp(0, 1).round().contiguous(memory_format=torch.channels_last) def get_fn(fn_name): def transformation_fn(tensor, **kwargs): @@ -12603,14 +12611,14 @@ class TestTorchDeviceType(TestCase): # Test 'float' separately to avoid float->float no-op. def input_generator_fn_double(device): - return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float64).clamp(0, 1).round().contiguous(memory_format=torch.channels_last) + return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float64).clamp(0, 1).round().contiguous(memory_format=torch.channels_last) self._test_memory_format_transformations(device, input_generator_fn_double, get_fn('float')) @onlyCUDA def test_memory_format_cpu_and_cuda_ops(self, device): def input_generator_fn(device): - return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) + return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last) def transformation_cpu_fn(tensor, **kwargs): return tensor.cpu(**kwargs) diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py index eaa793738b3..11a4fd1a925 100644 --- a/tools/autograd/gen_variable_factories.py +++ b/tools/autograd/gen_variable_factories.py @@ -63,11 +63,8 @@ def process_function(decl, has_tensor_options, disable_autograd): actuals.append(actual) requires_grad = "options.requires_grad()" if has_tensor_options else "false" if decl['name'].endswith('_like') and not has_tensor_options: - # it's a tensor - if decl['name'] == 'empty_like': - actuals.insert(-1, '{}.options().is_variable(false)'.format(actuals[0])) - else: - actuals.append('{}.options().is_variable(false)'.format(actuals[0])) + # Insert TensorOptions before MemoryFormat + actuals.insert(-1, '{}.options().is_variable(false)'.format(actuals[0])) if not disable_autograd: pre_record_trace, post_record_trace = format_trace(decl) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index aff2704c41a..a382b827a44 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -48,8 +48,41 @@ DONT_RECORD_TRACE = { # These functions have their names recorded under trace renamed, RENAME_TRACE = { - 'zero': 'zeros_like', - 'fill': 'full_like', + 'zero': 'zeros_like', # replacing aten::zero_ with aten::zeros_like + 'fill': 'full_like', # replacing aten::fill_ with aten::full_like +} + +# `torch.jit.trace` have undocumented keyword argument `_force_outplace`, +# which force jit to replace functions with outplace variants (for +# example `aten::add_` becomes `aten::add`). +# +# This replacement implemented in-place with minimum modifications of +# arguments stack (as it assumes that outplace call has the same arguments +# as inplace version). +# +# However there are no such substitutions available for `aten::fill_` +# and `aten::zero_` operators, as we never implemented `aten::fill` +# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with +# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`. +# +# But as they potentially can have different arguments, we also have +# to hack into the stack and add missing ones. +# +# A possible alternative would be: +# +# - Add `aten::fill` and `aten::zero` +# +# - Or keep `aten::zeros_like` arguments aligned with `aten::zero_` +# arguments (inside of the `native_functions.yaml`) +RENAME_TRACE_ADD_ARGS = { + 'fill': '''\ + c10::optional memory_format = c10::nullopt; + jit::tracer::addInputs(node, "memory_format", memory_format); +''', + 'zero': '''\ + c10::optional memory_format = c10::nullopt; + jit::tracer::addInputs(node, "memory_format", memory_format); +''', } # (declaration name, argument name) -> attribute name @@ -232,6 +265,7 @@ RECORD_FUNCTION("${name}", std::vector({${input_names}}), Node::pee """) SELECT = CodeTemplate("""\ + if (${cond}) { ${true} } else { @@ -408,7 +442,21 @@ def format_prerecord_trace(declaration): is_inplace = declaration['api_name'] != uninplace_api_name(declaration['api_name']) local['set_op_name'] = format_trace_op_name(declaration) - local['add_trace_inputs'] = format_trace_inputs(declaration) + + is_inplace = declaration['api_name'] != uninplace_api_name(declaration['api_name']) + add_args = '' + if is_inplace: + api_name = uninplace_api_name(declaration['api_name']) + add_args = RENAME_TRACE_ADD_ARGS.get(api_name, '') + if add_args: + select_params = {} + select_params['cond'] = 'tracer_state->force_outplace' + select_params['true'] = add_args + select_params['false'] = '' + additional_inputs = SELECT.substitute(select_params) + else: + additional_inputs = '' + local['add_trace_inputs'] = format_trace_inputs(declaration) + additional_inputs local['inplace_guard'] = '' if is_inplace: diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp index 04853827c43..4aff18ea0f6 100644 --- a/torch/csrc/jit/fuser/codegen.cpp +++ b/torch/csrc/jit/fuser/codegen.cpp @@ -115,6 +115,9 @@ static std::string typeCastedValueName( // cast here, which may end up being a no-op if the tensor's scalar type // is `double`. return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; + } else if (t->kind() == TypeKind::NoneType) { + // Support None value for optional arguments like memory format + return vn; } else if (auto scalar_type = t->expect()->scalarType()) { if (*scalar_type != outtype) { return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index e3fc9fa900c..21230e49c22 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -898,17 +898,17 @@ bool Node::isNondeterministic() const { "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", - "aten::rand_like(Tensor self) -> Tensor", - "aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::rand_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", - "aten::randint_like(Tensor self, int high) -> Tensor", - "aten::randint_like(Tensor self, int low, int high) -> Tensor", - "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::randint_like(Tensor self, int high, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like(Tensor self, int low, int high, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", - "aten::randn_like(Tensor self) -> Tensor", - "aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::randn_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"}; if (nondeterministic_ops.find(this) == nullptr) { diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index b07464e4bc8..c332ba292ad 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -27,7 +27,7 @@ namespace { // or all tensor inputs have the same scalar type and // output is identified in PropagateInputShapes // - Output and all tensor inputs should be on the same device -// - Produces contiguous outputs +// - Produces dense non-overlapping outputs // Some of these restrictions may be relaxable, but you should // carefully read the code first, as we rely on these assumptions. bool isSimpleMap(Node* node) { @@ -66,7 +66,6 @@ bool isSimpleMap(Node* node) { "aten::pow(Tensor self, Tensor exponent) -> Tensor", "aten::pow(Tensor self, Scalar exponent) -> Tensor", "aten::pow(Scalar self, Tensor exponent) -> Tensor", - "aten::rand_like(Tensor self) -> Tensor", "aten::reciprocal(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", @@ -79,6 +78,7 @@ bool isSimpleMap(Node* node) { "aten::sqrt(Tensor self) -> Tensor", "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::tan(Tensor self) -> Tensor", + "aten::rand_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", "aten::tanh(Tensor self) -> Tensor", "aten::trunc(Tensor self) -> Tensor", "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 5898fb12fb3..5cbed529e3b 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -857,13 +857,13 @@ class ShapePropagator { "aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor", "aten::alias(Tensor self) -> Tensor", "aten::empty_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", - "aten::full_like(Tensor self, Scalar fill_value) -> Tensor", - "aten::ones_like(Tensor self) -> Tensor", - "aten::rand_like(Tensor self) -> Tensor", - "aten::randint_like(Tensor self, int high) -> Tensor", - "aten::randint_like(Tensor self, int low, int high) -> Tensor", - "aten::randn_like(Tensor self) -> Tensor", - "aten::zeros_like(Tensor self) -> Tensor", + "aten::full_like(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::ones_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::rand_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like(Tensor self, int high, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like(Tensor self, int low, int high, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::zeros_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", }, [](Node* node) -> type_vec_t { auto input_type = node->input(0)->type()->cast(); @@ -1411,14 +1411,14 @@ class ShapePropagator { // - has ScalarType dtype, Layeout layout and Device device arguments static const register_formula_for like_factories_with_options{ { - "aten::empty_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=contiguous_format) -> Tensor", - "aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", - "aten::ones_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", - "aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", - "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", - "aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", - "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::empty_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", + "aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", + "aten::ones_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", + "aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", + "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor", }, [](Node* node) -> type_vec_t { if (auto type = diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 8e55bddedc3..9e05f2462a6 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -923,11 +923,11 @@ const std::vector functions = { return torch.log2(self), backward - def rand_like(self): + def rand_like(self, *, memory_format: Optional[int]): def backward(grad_output): return None - return torch.rand_like(self), backward + return torch.rand_like(self, memory_format=memory_format), backward def reciprocal(self): result = torch.reciprocal(self) @@ -1309,7 +1309,7 @@ const std::vector functions = { return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward )", - R"( + R"( def AD_sizes_if_not_equal_multi_1(t1, t2, res): return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size()) diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 9471abcf557..8a5531d33f3 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -245,8 +245,8 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False): return _constant_fill(g, sizes, dtype, 0) -@parse_args('v', 'i', 'v', 'v', 'v') -def zeros_like(g, input, dtype, layout, device, pin_memory=False): +@parse_args('v', 'i', 'v', 'v', 'v', 'v') +def zeros_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None): shape = g.op("Shape", input) return _constant_fill(g, shape, dtype, 0) @@ -256,8 +256,8 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False): return _constant_fill(g, sizes, dtype, 1) -@parse_args('v', 'i', 'v', 'v', 'v') -def ones_like(g, input, dtype, layout, device, pin_memory=False): +@parse_args('v', 'i', 'v', 'v', 'v', 'v') +def ones_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None): shape = g.op("Shape", input) return _constant_fill(g, shape, dtype, 1) @@ -272,7 +272,7 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False): return _constant_fill(g, sizes, dtype, const_value) -@parse_args('v', 'f', 'i', 'v', 'v', 'v') -def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False): +@parse_args('v', 'f', 'i', 'v', 'v', 'v', 'v') +def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False, memory_format=None): shape = g.op("Shape", input) return _constant_fill(g, shape, dtype, fill_value) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 0ac74b53a5c..c801464aca1 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1212,8 +1212,8 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False): value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) -@parse_args('v', 'i', 'v', 'v', 'v') -def zeros_like(g, input, dtype, layout, device, pin_memory=False): +@parse_args('v', 'i', 'v', 'v', 'v', 'v') +def zeros_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None): shape = g.op("Shape", input) if dtype is None: dtype = 6 # float @@ -1229,8 +1229,8 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False): value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) -@parse_args('v', 'i', 'v', 'v', 'v') -def ones_like(g, input, dtype, layout, device, pin_memory=False): +@parse_args('v', 'i', 'v', 'v', 'v', 'v') +def ones_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None): shape = g.op("Shape", input) if dtype is None: dtype = 6 # float @@ -1251,8 +1251,8 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False): value_t=torch.tensor([const_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype])) -@parse_args('v', 'f', 'i', 'v', 'v', 'v') -def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False): +@parse_args('v', 'f', 'i', 'v', 'v', 'v', 'v') +def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False, memory_format=None): shape = g.op("Shape", input) if dtype is None: dtype = 6 # float