Fix jit outplace tracing and reapply changes to *_like operators. (#28839)

Summary:
Reapply reverted and fix files `gen_variable_type.py` `test_jit.py`

https://github.com/pytorch/pytorch/issues/27891 Cleanup testing of _like operators
https://github.com/pytorch/pytorch/issues/27890 Add memory format support to randn_like operator
https://github.com/pytorch/pytorch/issues/27889 Add memory format support to randint_like operator
https://github.com/pytorch/pytorch/issues/27562 Add memory format support to zeros_like operator
https://github.com/pytorch/pytorch/issues/27561 Add memory format support to rand_like operator
https://github.com/pytorch/pytorch/issues/27270 Add memory format support to ones_like operator
https://github.com/pytorch/pytorch/issues/27262 Add memory format support to full_like operator
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28839

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

buck test mode/dev //language_technology/neural_mt/os/pytorch_translate/test:test_onnx -- 'test_forced_decoder_export_vocab_reduction \(language_technology\.neural_mt\.os\.pytorch_translate\.test\.test_onnx\.TestONNX\)'

Differential Revision: D18203397

Pulled By: VitalyFedyunin

fbshipit-source-id: eea41cbd4c232cf5a54172b1e1b16b173798f298
This commit is contained in:
Vitaly Fedyunin 2019-10-31 13:19:31 -07:00 committed by Facebook Github Bot
parent 0e441dd386
commit 4bfe2f0900
13 changed files with 224 additions and 113 deletions

View file

@ -323,12 +323,21 @@ Tensor& full_out(Tensor& result, IntArrayRef size, Scalar fill_value) {
return result.fill_(fill_value);
}
Tensor full_like(const Tensor& self, Scalar fill_value) {
return native::full_like(self, fill_value, self.options());
Tensor full_like(
const Tensor& self,
Scalar fill_value,
c10::optional<c10::MemoryFormat> optional_memory_format) {
return native::full_like(
self, fill_value, self.options(), optional_memory_format);
}
Tensor full_like(const Tensor& self, Scalar fill_value, const TensorOptions& options) {
return native::full(self.sizes(), fill_value, options);
Tensor full_like(
const Tensor& self,
Scalar fill_value,
const TensorOptions& options,
c10::optional<c10::MemoryFormat> optional_memory_format) {
auto result = at::empty_like(self, options, optional_memory_format);
return result.fill_(fill_value);
}
Tensor new_full(
@ -375,14 +384,20 @@ Tensor& ones_out(Tensor& result, IntArrayRef size) {
return native::full_out(result, size, /*fill_value=*/1);
}
Tensor ones_like(const Tensor& self) {
return native::ones(self.sizes(), self.options());
Tensor ones_like(
const Tensor& self,
const TensorOptions& options,
c10::optional<c10::MemoryFormat> optional_memory_format) {
auto result = at::empty_like(self, options, optional_memory_format);
return result.fill_(1);
}
Tensor ones_like(const Tensor& self, const TensorOptions& options) {
return native::ones(self.sizes(), options);
Tensor ones_like(
const Tensor& self,
c10::optional<c10::MemoryFormat> optional_memory_format) {
return native::ones_like(
self, self.options(), optional_memory_format);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ scalar_tensor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor scalar_tensor(Scalar s, const TensorOptions& options) {
@ -409,12 +424,18 @@ Tensor& rand_out(Tensor& result, IntArrayRef size, Generator* generator) {
return result.uniform_(0, 1, generator);
}
Tensor rand_like(const Tensor& self) {
return native::rand_like(self, self.options());
Tensor rand_like(
const Tensor& self,
c10::optional<c10::MemoryFormat> optional_memory_format) {
return native::rand_like(self, self.options(), optional_memory_format);
}
Tensor rand_like(const Tensor& self, const TensorOptions& options) {
return native::rand(self.sizes(), options);
Tensor rand_like(
const Tensor& self,
const TensorOptions& options,
c10::optional<c10::MemoryFormat> optional_memory_format) {
auto result = at::empty_like(self, options, optional_memory_format);
return result.uniform_(0, 1, nullptr);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -476,18 +497,30 @@ Tensor& randint_out(
return result.random_(low, high, generator);
}
Tensor randint_like(const Tensor& self, int64_t high) {
return native::randint_like(self, high, self.options());
Tensor randint_like(
const Tensor& self,
int64_t high,
c10::optional<c10::MemoryFormat> optional_memory_format) {
return native::randint_like(
self, high, self.options(), optional_memory_format);
}
Tensor randint_like(const Tensor& self, int64_t low, int64_t high) {
return native::randint_like(self, low, high, self.options());
Tensor randint_like(
const Tensor& self,
int64_t low,
int64_t high,
c10::optional<c10::MemoryFormat> optional_memory_format) {
return native::randint_like(
self, low, high, self.options(), optional_memory_format);
}
Tensor randint_like(
const Tensor& self,
int64_t high,
const TensorOptions& options) {
const TensorOptions& options,
c10::optional<c10::MemoryFormat> optional_memory_format) {
auto result = at::empty_like(self, options, optional_memory_format);
return result.random_(0, high, nullptr);
return native::randint(high, self.sizes(), nullptr, options);
}
@ -495,8 +528,10 @@ Tensor randint_like(
const Tensor& self,
int64_t low,
int64_t high,
const TensorOptions& options) {
return native::randint(low, high, self.sizes(), nullptr, options);
const TensorOptions& options,
c10::optional<c10::MemoryFormat> optional_memory_format) {
auto result = at::empty_like(self, options, optional_memory_format);
return result.random_(low, high, nullptr);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -531,12 +566,18 @@ Tensor& normal_out(Tensor& result, double mean, double std,
return result.normal_(mean, std, generator);
}
Tensor randn_like(const Tensor& self) {
return native::randn_like(self, self.options());
Tensor randn_like(
const Tensor& self,
c10::optional<c10::MemoryFormat> optional_memory_format) {
return native::randn_like(self, self.options(), optional_memory_format);
}
Tensor randn_like(const Tensor& self, const TensorOptions& options) {
return native::randn(self.sizes(), nullptr, options);
Tensor randn_like(
const Tensor& self,
const TensorOptions& options,
c10::optional<c10::MemoryFormat> optional_memory_format) {
auto result = at::empty_like(self, options, optional_memory_format);
return result.normal_(0, 1, nullptr);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randperm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -710,17 +751,24 @@ Tensor& zeros_out(Tensor& result, IntArrayRef size) {
return result.zero_();
}
Tensor zeros_like(const Tensor& self) {
return native::zeros_like(self, self.options());
Tensor zeros_like(
const Tensor& self,
c10::optional<c10::MemoryFormat> optional_memory_format) {
return native::zeros_like(self, self.options(), optional_memory_format);
}
Tensor zeros_like(const Tensor& self, const TensorOptions& options) {
Tensor zeros_like(
const Tensor& self,
const TensorOptions& options,
c10::optional<c10::MemoryFormat> optional_memory_format) {
if (options.layout() == kSparse && self.is_sparse()) {
auto res = at::empty({0}, options); // to be resized
res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim());
res.sparse_resize_and_clear_(
self.sizes(), self.sparse_dim(), self.dense_dim());
return res;
}
return native::zeros(self.sizes(), options);
auto result = at::empty_like(self, options, optional_memory_format);
return result.zero_();
}
Tensor new_zeros(

View file

@ -1214,10 +1214,9 @@
- func: full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
- func: full_like(Tensor self, Scalar fill_value) -> Tensor
use_c10_dispatcher: full
- func: full_like(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None) -> Tensor
- func: full_like.dtype(Tensor self, Scalar fill_value, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
- func: full_like.dtype(Tensor self, Scalar fill_value, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor
- func: from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch:
@ -1983,10 +1982,9 @@
- func: ones.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
- func: ones_like(Tensor self) -> Tensor
use_c10_dispatcher: full
- func: ones_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
- func: ones_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
- func: ones_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor
- func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor
use_c10_dispatcher: full
@ -2059,10 +2057,9 @@
- func: rand.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
- func: rand_like(Tensor self) -> Tensor
use_c10_dispatcher: full
- func: rand_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
- func: rand_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
- func: rand_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor
- func: randint(int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
@ -2080,15 +2077,13 @@
- func: randint.low_generator_out(int low, int high, int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
- func: randint_like(Tensor self, int high) -> Tensor
use_c10_dispatcher: full
- func: randint_like(Tensor self, int high, *, MemoryFormat? memory_format=None) -> Tensor
- func: randint_like.low(Tensor self, int low, int high) -> Tensor
use_c10_dispatcher: full
- func: randint_like.low(Tensor self, int low, int high, *, MemoryFormat? memory_format=None) -> Tensor
- func: randint_like.dtype(Tensor self, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
- func: randint_like.dtype(Tensor self, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor
- func: randint_like.low_dtype(Tensor self, int low, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
- func: randint_like.low_dtype(Tensor self, int low, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor
- func: randn(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
@ -2104,10 +2099,9 @@
- func: randn.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
- func: randn_like(Tensor self) -> Tensor
use_c10_dispatcher: full
- func: randn_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
- func: randn_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
- func: randn_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor
- func: randperm(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
@ -2843,10 +2837,9 @@
- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
- func: zeros_like(Tensor self) -> Tensor
use_c10_dispatcher: full
- func: zeros_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
- func: zeros_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
- func: zeros_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=None) -> Tensor
- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
use_c10_dispatcher: full

View file

@ -1795,6 +1795,20 @@ graph(%Ra, %Rb):
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
ge(x)
def test_force_outplace_check_fill(self):
def f(x):
return torch.empty(x.shape).fill_(7)
x = torch.randn(10, 15)
ft = torch.jit.trace(f, x, _force_outplace=True)
self.assertEqual(f(x), ft(x))
def test_force_outplace_check_zero(self):
def f(x):
return torch.empty(x.shape).zero_()
x = torch.randn(10, 15)
ft = torch.jit.trace(f, x, _force_outplace=True)
self.assertEqual(f(x), ft(x))
def do_trace_size(self, requires_grad):
def fn(x):
return x.view(x.shape[1] * 2, x.size(0), 2)

View file

@ -5565,14 +5565,14 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
torch.testing.assert_allclose(expected_norm, actual_norm)
def test_memory_format(self):
x = torch.randn(10, 3, 32, 32)
x = torch.randn(4, 3, 8, 8)
nhwc = x.contiguous(memory_format=torch.channels_last)
self.assertFalse(nhwc.is_contiguous())
self.assertTrue(nhwc.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(nhwc, x)
def test_memory_format_contiguous_returns_same_tensor_if_already_satisfies(self):
x = torch.randn(10, 32, 32, 3).permute(0, 3, 1, 2)
x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
alias = x.contiguous(memory_format=torch.channels_last)
alias.fill_(7)
self.assertEqual(x, alias)
@ -11857,13 +11857,13 @@ class TestTorchDeviceType(TestCase):
_test_serialization(BytesIOContext)
def test_memory_format_preserved_after_permute(self, device):
x = torch.randn(10, 3, 32, 32, device=device)
x = torch.randn(4, 3, 8, 8, device=device)
nhwc = x.contiguous(memory_format=torch.channels_last)
y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2)
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
def test_memory_format_empty_like(self, device):
x = torch.randn(10, 3, 32, 32, device=device)
x = torch.randn(4, 3, 8, 8, device=device)
nhwc = x.contiguous(memory_format=torch.channels_last)
like = torch.empty_like(nhwc, memory_format=torch.preserve_format)
@ -12536,7 +12536,7 @@ class TestTorchDeviceType(TestCase):
def test_memory_format_to(self, device):
def input_generator_fn(device):
return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
def transformation_fn(tensor, **kwargs):
return tensor.to(dtype=torch.float64, **kwargs)
@ -12545,7 +12545,7 @@ class TestTorchDeviceType(TestCase):
def test_memory_format_type(self, device):
def input_generator_fn(device):
return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
def transformation_fn(tensor, **kwargs):
return tensor.type(torch.float64, **kwargs)
@ -12554,7 +12554,7 @@ class TestTorchDeviceType(TestCase):
def test_memory_format_clone(self, device):
def input_generator_fn(device):
return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
def transformation_fn(tensor, **kwargs):
return tensor.clone(**kwargs)
@ -12575,18 +12575,26 @@ class TestTorchDeviceType(TestCase):
torch.sum(x, (2, 1), out=res2)
self.assertEqual(res1, res2)
def test_memory_format_empty_like_strides(self, device):
def test_memory_format_factory_like_functions_preserve_strides(self, device):
def input_generator_fn(device):
return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
def transformation_fn(tensor, **kwargs):
return torch.empty_like(tensor, **kwargs)
transformation_fns = [
lambda t, **kwargs: torch.zeros_like(t, **kwargs),
lambda t, **kwargs: torch.ones_like(t, **kwargs),
lambda t, **kwargs: torch.randint_like(t, 10, 100, **kwargs),
lambda t, **kwargs: torch.randint_like(t, 100, **kwargs),
lambda t, **kwargs: torch.randn_like(t, **kwargs),
lambda t, **kwargs: torch.rand_like(t, **kwargs),
lambda t, **kwargs: torch.full_like(t, 7, **kwargs),
lambda t, **kwargs: torch.empty_like(t, **kwargs)]
self._test_memory_format_transformations(device, input_generator_fn, transformation_fn, compare_data=False)
for transformation_fn in transformation_fns:
self._test_memory_format_transformations(device, input_generator_fn, transformation_fn, compare_data=False)
def test_memory_format_type_shortcuts(self, device):
def input_generator_fn(device):
return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).clamp(0, 1).round().contiguous(memory_format=torch.channels_last)
return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).clamp(0, 1).round().contiguous(memory_format=torch.channels_last)
def get_fn(fn_name):
def transformation_fn(tensor, **kwargs):
@ -12603,14 +12611,14 @@ class TestTorchDeviceType(TestCase):
# Test 'float' separately to avoid float->float no-op.
def input_generator_fn_double(device):
return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float64).clamp(0, 1).round().contiguous(memory_format=torch.channels_last)
return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float64).clamp(0, 1).round().contiguous(memory_format=torch.channels_last)
self._test_memory_format_transformations(device, input_generator_fn_double, get_fn('float'))
@onlyCUDA
def test_memory_format_cpu_and_cuda_ops(self, device):
def input_generator_fn(device):
return torch.randn((10, 3, 32, 32), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
return torch.randn((4, 3, 8, 8), device=device, dtype=torch.float32).contiguous(memory_format=torch.channels_last)
def transformation_cpu_fn(tensor, **kwargs):
return tensor.cpu(**kwargs)

View file

@ -63,11 +63,8 @@ def process_function(decl, has_tensor_options, disable_autograd):
actuals.append(actual)
requires_grad = "options.requires_grad()" if has_tensor_options else "false"
if decl['name'].endswith('_like') and not has_tensor_options:
# it's a tensor
if decl['name'] == 'empty_like':
actuals.insert(-1, '{}.options().is_variable(false)'.format(actuals[0]))
else:
actuals.append('{}.options().is_variable(false)'.format(actuals[0]))
# Insert TensorOptions before MemoryFormat
actuals.insert(-1, '{}.options().is_variable(false)'.format(actuals[0]))
if not disable_autograd:
pre_record_trace, post_record_trace = format_trace(decl)

View file

@ -48,8 +48,41 @@ DONT_RECORD_TRACE = {
# These functions have their names recorded under trace renamed,
RENAME_TRACE = {
'zero': 'zeros_like',
'fill': 'full_like',
'zero': 'zeros_like', # replacing aten::zero_ with aten::zeros_like
'fill': 'full_like', # replacing aten::fill_ with aten::full_like
}
# `torch.jit.trace` have undocumented keyword argument `_force_outplace`,
# which force jit to replace functions with outplace variants (for
# example `aten::add_` becomes `aten::add`).
#
# This replacement implemented in-place with minimum modifications of
# arguments stack (as it assumes that outplace call has the same arguments
# as inplace version).
#
# However there are no such substitutions available for `aten::fill_`
# and `aten::zero_` operators, as we never implemented `aten::fill`
# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with
# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`.
#
# But as they potentially can have different arguments, we also have
# to hack into the stack and add missing ones.
#
# A possible alternative would be:
#
# - Add `aten::fill` and `aten::zero`
#
# - Or keep `aten::zeros_like` arguments aligned with `aten::zero_`
# arguments (inside of the `native_functions.yaml`)
RENAME_TRACE_ADD_ARGS = {
'fill': '''\
c10::optional<MemoryFormat> memory_format = c10::nullopt;
jit::tracer::addInputs(node, "memory_format", memory_format);
''',
'zero': '''\
c10::optional<MemoryFormat> memory_format = c10::nullopt;
jit::tracer::addInputs(node, "memory_format", memory_format);
''',
}
# (declaration name, argument name) -> attribute name
@ -232,6 +265,7 @@ RECORD_FUNCTION("${name}", std::vector<c10::IValue>({${input_names}}), Node::pee
""")
SELECT = CodeTemplate("""\
if (${cond}) {
${true}
} else {
@ -408,7 +442,21 @@ def format_prerecord_trace(declaration):
is_inplace = declaration['api_name'] != uninplace_api_name(declaration['api_name'])
local['set_op_name'] = format_trace_op_name(declaration)
local['add_trace_inputs'] = format_trace_inputs(declaration)
is_inplace = declaration['api_name'] != uninplace_api_name(declaration['api_name'])
add_args = ''
if is_inplace:
api_name = uninplace_api_name(declaration['api_name'])
add_args = RENAME_TRACE_ADD_ARGS.get(api_name, '')
if add_args:
select_params = {}
select_params['cond'] = 'tracer_state->force_outplace'
select_params['true'] = add_args
select_params['false'] = ''
additional_inputs = SELECT.substitute(select_params)
else:
additional_inputs = ''
local['add_trace_inputs'] = format_trace_inputs(declaration) + additional_inputs
local['inplace_guard'] = ''
if is_inplace:

View file

@ -115,6 +115,9 @@ static std::string typeCastedValueName(
// cast here, which may end up being a no-op if the tensor's scalar type
// is `double`.
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
} else if (t->kind() == TypeKind::NoneType) {
// Support None value for optional arguments like memory format
return vn;
} else if (auto scalar_type = t->expect<TensorType>()->scalarType()) {
if (*scalar_type != outtype) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";

View file

@ -898,17 +898,17 @@ bool Node::isNondeterministic() const {
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::rand_like(Tensor self) -> Tensor",
"aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::rand_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randint_like(Tensor self, int high) -> Tensor",
"aten::randint_like(Tensor self, int low, int high) -> Tensor",
"aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::randint_like(Tensor self, int high, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int low, int high, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randn_like(Tensor self) -> Tensor",
"aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::randn_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
if (nondeterministic_ops.find(this) == nullptr) {

View file

@ -27,7 +27,7 @@ namespace {
// or all tensor inputs have the same scalar type and
// output is identified in PropagateInputShapes
// - Output and all tensor inputs should be on the same device
// - Produces contiguous outputs
// - Produces dense non-overlapping outputs
// Some of these restrictions may be relaxable, but you should
// carefully read the code first, as we rely on these assumptions.
bool isSimpleMap(Node* node) {
@ -66,7 +66,6 @@ bool isSimpleMap(Node* node) {
"aten::pow(Tensor self, Tensor exponent) -> Tensor",
"aten::pow(Tensor self, Scalar exponent) -> Tensor",
"aten::pow(Scalar self, Tensor exponent) -> Tensor",
"aten::rand_like(Tensor self) -> Tensor",
"aten::reciprocal(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
@ -79,6 +78,7 @@ bool isSimpleMap(Node* node) {
"aten::sqrt(Tensor self) -> Tensor",
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
"aten::tan(Tensor self) -> Tensor",
"aten::rand_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::tanh(Tensor self) -> Tensor",
"aten::trunc(Tensor self) -> Tensor",
"aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",

View file

@ -857,13 +857,13 @@ class ShapePropagator {
"aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor",
"aten::alias(Tensor self) -> Tensor",
"aten::empty_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::full_like(Tensor self, Scalar fill_value) -> Tensor",
"aten::ones_like(Tensor self) -> Tensor",
"aten::rand_like(Tensor self) -> Tensor",
"aten::randint_like(Tensor self, int high) -> Tensor",
"aten::randint_like(Tensor self, int low, int high) -> Tensor",
"aten::randn_like(Tensor self) -> Tensor",
"aten::zeros_like(Tensor self) -> Tensor",
"aten::full_like(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::ones_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::rand_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int high, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int low, int high, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::randn_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::zeros_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
},
[](Node* node) -> type_vec_t {
auto input_type = node->input(0)->type()->cast<TensorType>();
@ -1411,14 +1411,14 @@ class ShapePropagator {
// - has ScalarType dtype, Layeout layout and Device device arguments
static const register_formula_for like_factories_with_options{
{
"aten::empty_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=contiguous_format) -> Tensor",
"aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::ones_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::zeros_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
"aten::empty_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::ones_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
"aten::zeros_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory, MemoryFormat? memory_format=None) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto type =

View file

@ -923,11 +923,11 @@ const std::vector<std::string> functions = {
return torch.log2(self), backward
def rand_like(self):
def rand_like(self, *, memory_format: Optional[int]):
def backward(grad_output):
return None
return torch.rand_like(self), backward
return torch.rand_like(self, memory_format=memory_format), backward
def reciprocal(self):
result = torch.reciprocal(self)
@ -1309,7 +1309,7 @@ const std::vector<std::string> functions = {
return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
)",
R"(
R"(
def AD_sizes_if_not_equal_multi_1(t1, t2, res):
return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())

View file

@ -245,8 +245,8 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False):
return _constant_fill(g, sizes, dtype, 0)
@parse_args('v', 'i', 'v', 'v', 'v')
def zeros_like(g, input, dtype, layout, device, pin_memory=False):
@parse_args('v', 'i', 'v', 'v', 'v', 'v')
def zeros_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, 0)
@ -256,8 +256,8 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False):
return _constant_fill(g, sizes, dtype, 1)
@parse_args('v', 'i', 'v', 'v', 'v')
def ones_like(g, input, dtype, layout, device, pin_memory=False):
@parse_args('v', 'i', 'v', 'v', 'v', 'v')
def ones_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, 1)
@ -272,7 +272,7 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False):
return _constant_fill(g, sizes, dtype, const_value)
@parse_args('v', 'f', 'i', 'v', 'v', 'v')
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False):
@parse_args('v', 'f', 'i', 'v', 'v', 'v', 'v')
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
return _constant_fill(g, shape, dtype, fill_value)

View file

@ -1212,8 +1212,8 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False):
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'i', 'v', 'v', 'v')
def zeros_like(g, input, dtype, layout, device, pin_memory=False):
@parse_args('v', 'i', 'v', 'v', 'v', 'v')
def zeros_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float
@ -1229,8 +1229,8 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False):
value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'i', 'v', 'v', 'v')
def ones_like(g, input, dtype, layout, device, pin_memory=False):
@parse_args('v', 'i', 'v', 'v', 'v', 'v')
def ones_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float
@ -1251,8 +1251,8 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False):
value_t=torch.tensor([const_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'f', 'i', 'v', 'v', 'v')
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False):
@parse_args('v', 'f', 'i', 'v', 'v', 'v', 'v')
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float