From 8389ccbcd809e2596bd4f8a20d069fc919beefdf Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Fri, 8 Jul 2022 01:17:33 +0000 Subject: [PATCH] reinstate size and shape returning symints (#79560) This PR redirects `size` and `.shape` to call `sym_sizes` Pull Request resolved: https://github.com/pytorch/pytorch/pull/79560 Approved by: https://github.com/Chillee --- aten/src/ATen/NestedTensorImpl.h | 3 + aten/src/ATen/core/TensorBase.h | 6 +- c10/core/TensorImpl.h | 20 +++ test/cpp/lazy/test_lazy_ops.cpp | 2 + test/test_dynamic_shapes.py | 128 ++++++++++-------- test/test_python_dispatch.py | 8 +- .../templates/python_variable_methods.cpp | 46 +------ torch/csrc/autograd/python_variable.cpp | 3 +- torch/csrc/lazy/core/tensor_impl.cpp | 4 +- torch/overrides.py | 3 +- 10 files changed, 105 insertions(+), 118 deletions(-) diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h index 7c8ff7f3aa6..b0f4a33c991 100644 --- a/aten/src/ATen/NestedTensorImpl.h +++ b/aten/src/ATen/NestedTensorImpl.h @@ -68,6 +68,9 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { int64_t size_custom(int64_t d) const override { return this->size(d); } + c10::SymInt sym_size_custom(int64_t d) const override { + return c10::SymInt{this->size(d)}; + } IntArrayRef sizes_custom() const override; c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymIntArrayRef sym_sizes() const override; diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 80d70b4f850..a8452e79229 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -157,11 +157,7 @@ class TORCH_API TensorBase { } c10::SymInt sym_size(int64_t dim) const { - const auto sizes = this->sym_sizes(); - const auto ndim = static_cast(sizes.size()); - // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) - return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)]; - + return impl_->sym_size(dim); } int64_t size(int64_t dim) const { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 8039a065939..ec6615addb0 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -594,6 +594,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return sizes_and_strides_.size_at_unchecked(d).as_int_unchecked(); } + c10::SymInt sym_size(int64_t d) const { + if (C10_UNLIKELY( + sizes_strides_policy_ >= + static_cast(SizesStridesPolicy::CustomSizes))) { + return sym_size_custom(d); + } + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + const auto sizes = this->sym_sizes(); + return sizes[d]; + } + /** * Return the stride of a tensor at some dimension, wrapping the dimension * if necessary. @@ -697,6 +708,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); return sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) } + + virtual c10::SymInt sym_size_custom(int64_t d) const { + // TODO: We could add support to Python dispatch here. + // TODO: We could call into aten::size.int instead of + // sym_sizes_custom()[d] and enable use of the dispatcher. + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + return sym_sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + } + virtual IntArrayRef sizes_custom() const; virtual Device device_custom() const; virtual Layout layout_custom() const; diff --git a/test/cpp/lazy/test_lazy_ops.cpp b/test/cpp/lazy/test_lazy_ops.cpp index e0ad18068bd..1b15b2b6ced 100644 --- a/test/cpp/lazy/test_lazy_ops.cpp +++ b/test/cpp/lazy/test_lazy_ops.cpp @@ -95,6 +95,7 @@ TEST(LazyDynamicOpsTest, NarrowCopy) { } TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) { + FLAGS_ltc_enable_symbolic_shapes = true; auto xc = torch::rand({10}); auto x = xc.to(kLazy); const size_t Y_DIM = 3; @@ -105,6 +106,7 @@ TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) { ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc // shape inference assumes narrow_copy can copy the whole tensor AllClose(z.cpu(), zc); + FLAGS_ltc_enable_symbolic_shapes = false; } TEST_F(LazyOpsTest, TestScalarTensor) { diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d9aad3e4fe9..2acf423d9f5 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -32,7 +32,7 @@ def register_meta(op): @register_meta([aten.add.Tensor, aten.sub.Tensor]) def binary_meta(a, b): - return a.new_empty(a.sym_size()) + return a.new_empty(a.shape) @register_meta(aten.cat.default) @@ -53,7 +53,7 @@ def cat_meta(tensors, dim=0): @register_meta([aten.narrow_copy.SymInt]) def narrow_copy_symint_meta(a, dim, start, length, **kwargs): shape = [] - for i, x in enumerate(a.sym_size()): + for i, x in enumerate(a.shape): if i == dim: shape.append(length) else: @@ -165,6 +165,14 @@ class FakeSymbolicTensor(torch.Tensor): self = args[0] return self.sym_shape + # some calls can be redirected to `sym_size` rather than + # `sym_sizes`. `sym_size` uses `dim` to canonicalize an index + # so we need to implement both `sym_size` and `dim` for python + # tensors + if func_overload == torch.ops.aten.dim.default: + self = args[0] + return len(self.sym_shape) + if func_overload == torch.ops.aten.new_empty.default: self = args[0] shape = args[1] @@ -174,7 +182,7 @@ class FakeSymbolicTensor(torch.Tensor): def create_symbolic_tensor(name, arg, shape_env): - sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.sym_size())]) + sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.size())]) sym_strides = tuple([shape_env.create_symint(f"{name}_{idx}_stride", val) for idx, val in enumerate(arg.stride())]) return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device) @@ -188,22 +196,22 @@ class TestPySymInt(TestCase): def test_roundtrip(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - self.assertTrue(not isinstance(x.sym_size(0), PySymInt)) - self.assertTrue(isinstance(x.sym_size(0), CPP_SYMINT_CLASS)) + self.assertTrue(not isinstance(x.shape[0], PySymInt)) + self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS)) - self.assertEqual(int(x.sym_size(0)), 5) - self.assertEqual(int(x.sym_size(1)), 4) - self.assertEqual(int(x.sym_size(2)), 3) + self.assertEqual(int(x.shape[0]), 5) + self.assertEqual(int(x.shape[1]), 4) + self.assertEqual(int(x.shape[2]), 3) - self.assertEqual(int(x.sym_size()[0]), 5) - self.assertEqual(int(x.sym_size()[1]), 4) - self.assertTrue(isinstance(x.sym_size()[1], CPP_SYMINT_CLASS)) - self.assertEqual(int(x.sym_size()[2]), 3) + self.assertEqual(int(x.size()[0]), 5) + self.assertEqual(int(x.size()[1]), 4) + self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS)) + self.assertEqual(int(x.size()[2]), 3) - self.assertEqual(int(x.sym_size(0)), 5) - self.assertEqual(int(x.sym_size(1)), 4) - self.assertEqual(int(x.sym_size(2)), 3) - self.assertTrue(isinstance(x.sym_size(2), CPP_SYMINT_CLASS)) + self.assertEqual(int(x.size(0)), 5) + self.assertEqual(int(x.size(1)), 4) + self.assertEqual(int(x.size(2)), 3) + self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS)) @skipIfNoSympy def test_binary(self): @@ -212,16 +220,16 @@ class TestPySymInt(TestCase): y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env) z = x + y - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + self.assertEqual(int(z.shape[0]), 5) + self.assertEqual(int(z.shape[1]), 4) + self.assertEqual(int(z.shape[2]), 3) # broadcasting y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) z = x + y - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + self.assertEqual(int(z.shape[0]), 5) + self.assertEqual(int(z.shape[1]), 4) + self.assertEqual(int(z.shape[2]), 3) @skipIfNoSympy def test_symint_args(self): @@ -229,16 +237,16 @@ class TestPySymInt(TestCase): x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env) LAST_DIM = 2 - z = x.narrow_copy(LAST_DIM, 0, y.sym_size(LAST_DIM)) - self.assertEqual(int(z.sym_size(2)), int(y.sym_size(2))) + z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM]) + self.assertEqual(int(z.shape[2]), int(y.shape[2])) # arithmetic expr with two symints - z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - y.sym_size(LAST_DIM)) - self.assertEqual(int(z.sym_size(2)), 2) + z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM]) + self.assertEqual(int(z.shape[2]), 2) # arithmetic expr with a symint and python int - z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - 1) - self.assertEqual(int(z.sym_size(2)), 2) + z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) + self.assertEqual(int(z.shape[2]), 2) @skipIfNoSympy def test_symint_vargs(self): @@ -247,67 +255,67 @@ class TestPySymInt(TestCase): y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) # varargs - z = y.expand(x.sym_size(0), y.sym_size(1), x.sym_size(2)) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand(x.shape[0], y.shape[1], x.shape[2]) + self.assertEqual(int(z.shape[0]), 5) + self.assertEqual(int(z.shape[1]), 4) + self.assertEqual(int(z.shape[2]), 3) # shape list - z = y.expand((x.sym_size(0), y.sym_size(1), x.sym_size(2))) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand((x.shape[0], y.shape[1], x.shape[2])) + self.assertEqual(int(z.shape[0]), 5) + self.assertEqual(int(z.shape[1]), 4) + self.assertEqual(int(z.shape[2]), 3) # mixed python symints and ints - z = y.expand(x.sym_size(0), y.sym_size(1), 3) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand(x.shape[0], y.shape[1], 3) + self.assertEqual(int(z.shape[0]), 5) + self.assertEqual(int(z.shape[1]), 4) + self.assertEqual(int(z.shape[2]), 3) # mixed python symints and ints in a list - z = y.expand((x.sym_size(0), y.sym_size(1), 3)) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand((x.shape[0], y.shape[1], 3)) + self.assertEqual(int(z.shape[0]), 5) + self.assertEqual(int(z.shape[1]), 4) + self.assertEqual(int(z.shape[2]), 3) # mixed python symints and ints - z = y.expand(5, y.sym_size(1), x.sym_size(2)) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand(5, y.shape[1], x.shape[2]) + self.assertEqual(int(z.shape[0]), 5) + self.assertEqual(int(z.shape[1]), 4) + self.assertEqual(int(z.shape[2]), 3) # mixed python ints and symints in a list - z = y.expand((5, y.sym_size(1), x.sym_size(2))) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand((5, y.shape[1], x.shape[2])) + self.assertEqual(int(z.shape[0]), 5) + self.assertEqual(int(z.shape[1]), 4) + self.assertEqual(int(z.shape[2]), 3) @skipIfNoSympy def test_size_expressions(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) - expand_x = x.expand(x.sym_size(0), x.sym_size(0)) - if expand_x.sym_size(0) > 3: + expand_x = x.expand(x.shape[0], x.shape[0]) + if expand_x.shape[0] > 3: result = expand_x + expand_x else: result = expand_x + expand_x gt_op = shape_env.guards[0][0] self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) - self.assertTrue(str(x.sym_size(0)), str(gt_op.args[0])) - self.assertTrue(str(expand_x.sym_size(1)), str(x.sym_size(0))) - self.assertTrue(str(expand_x.sym_size(1)), str(result.sym_size(0))) + self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) + self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) + self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) @skipIfNoSympy def test_aten_ops(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) - torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.sym_size(0)) + torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.shape[0]) shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - torch.ops.aten.expand.SymInt(x, [x.sym_size(0), x.sym_size(1), x.sym_size(2)]) + torch.ops.aten.expand.SymInt(x, [x.shape[0], x.shape[1], x.shape[2]]) def test_fx_trace_intlist(self): class CustomModule(torch.nn.Module): @@ -327,7 +335,7 @@ class TestPySymInt(TestCase): shape_env = ShapeEnv() a0 = shape_env.create_symint("a0", 2) r = torch.empty(a0, device='meta') - self.assertIsInstance(r.sym_size(0), CPP_SYMINT_CLASS) + self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS) if __name__ == '__main__': diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 46f4edb1847..109e7b0ab04 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -1794,7 +1794,7 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""") def __torch_dispatch__(cls, func, types, args, kwargs): if func.overloadpacket == torch.ops.aten.dim: return data.dim() - if func.overloadpacket == torch.ops.aten.size: + if func.overloadpacket == torch.ops.aten.sym_size: return (5, 3) return NotImplemented @@ -1807,13 +1807,13 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""") def __torch_dispatch__(cls, func, types, args, kwargs): if func.overloadpacket == torch.ops.aten.dim: return data.dim() - if func.overloadpacket == torch.ops.aten.size: + if func.overloadpacket == torch.ops.aten.sym_size: return None return NotImplemented - err_msg = "no implementation found for 'torch.ops.aten.size'" + err_msg = "no implementation found for 'torch.ops.aten.sym_size'" e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass) - with self.assertRaisesRegex(TypeError, err_msg): + with self.assertRaisesRegex(RuntimeError, err_msg): e.size() e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass) diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index fdbecf062b4..4e075d0b31e 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -95,43 +95,6 @@ static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg) END_HANDLE_TH_ERRORS } -// TODO: FIXME This should be super temprorary until we fix the XLA issue. -static PyObject * THPVariable_sym_size(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "sym_size(int64_t dim)", - "sym_size()", - "sym_size(Dimname dim)", - }); - auto& self_ = THPVariable_Unpack(self); - ParsedArgs<3> parsed_args; - auto r = parser.parse(self, args, kwargs, parsed_args); - - if(r.has_torch_function()){ - return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); - } - if (r.idx == 0) { - if (jit::tracer::isTracing()) { - // will error out if a tensor has symints - return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0))); - } else { - return torch::toPyObject(self_.sym_size(r.toInt64(0))); - } - } else if (r.idx == 1) { - return THPSize_NewFromSymSizes(self_); - } - else if (r.idx == 2) { - if (jit::tracer::isTracing()) { - TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT"); - } - return wrap(self_.size(r.dimname(0))); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - - static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS @@ -152,14 +115,10 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa // will error out if a tensor has symints return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0))); } else { - return wrap(self_.size(r.toInt64(0))); - //return torch::toPyObject(self_.sym_size(r.toInt64(0))); + return torch::toPyObject(self_.sym_size(r.toInt64(0))); } } else if (r.idx == 1) { - // we can't do the normal wrapping here because IntArrayRef maps to both - // torch.Size and tuple in python. - return THPSize_New(self_); - //return THPSize_NewFromSymSizes(self_); + return THPSize_NewFromSymSizes(self_); } else if (r.idx == 2) { if (jit::tracer::isTracing()) { @@ -1322,7 +1281,6 @@ PyMethodDef variable_methods[] = { {"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL}, {"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL}, {"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL}, - {"sym_size", castPyCFunctionWithKeywords(THPVariable_sym_size), METH_VARARGS | METH_KEYWORDS, NULL}, {"_storage", THPVariable_storage, METH_NOARGS, NULL}, {"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL}, {"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 2f3bc0bde9f..771add371a4 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1220,8 +1220,7 @@ PyObject* THPVariable_get_shape(THPVariable* self, void* unused) { if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "shape"); } - // return THPSize_NewFromSymSizes(THPVariable_Unpack(self)); - return THPSize_New(THPVariable_Unpack(self)); + return THPSize_NewFromSymSizes(THPVariable_Unpack(self)); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/lazy/core/tensor_impl.cpp b/torch/csrc/lazy/core/tensor_impl.cpp index 1434084e502..85a5e9f2125 100644 --- a/torch/csrc/lazy/core/tensor_impl.cpp +++ b/torch/csrc/lazy/core/tensor_impl.cpp @@ -143,7 +143,9 @@ void LTCTensorImpl::shallow_copy_from( } c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const { - return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size()); + return FLAGS_ltc_enable_symbolic_shapes + ? c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size()) + : TensorImpl::sym_sizes_default(); } c10::SymIntArrayRef LTCTensorImpl::sym_sizes() const { diff --git a/torch/overrides.py b/torch/overrides.py index 28d6f661913..45654d1b003 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -279,8 +279,7 @@ def get_ignored_functions() -> Set[Callable]: Tensor._is_zerotensor, Tensor._addmm_activation, Tensor._nested_tensor_layer_norm, - Tensor.to_padded_tensor, - Tensor.sym_size + Tensor.to_padded_tensor }