reinstate size and shape returning symints (#79560)

This PR redirects `size` and `.shape` to call `sym_sizes`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79560
Approved by: https://github.com/Chillee
This commit is contained in:
Nikolay Korovaiko 2022-07-08 01:17:33 +00:00 committed by PyTorch MergeBot
parent 1263247395
commit 8389ccbcd8
10 changed files with 105 additions and 118 deletions

View file

@ -68,6 +68,9 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
int64_t size_custom(int64_t d) const override {
return this->size(d);
}
c10::SymInt sym_size_custom(int64_t d) const override {
return c10::SymInt{this->size(d)};
}
IntArrayRef sizes_custom() const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymIntArrayRef sym_sizes() const override;

View file

@ -157,11 +157,7 @@ class TORCH_API TensorBase {
}
c10::SymInt sym_size(int64_t dim) const {
const auto sizes = this->sym_sizes();
const auto ndim = static_cast<int64_t>(sizes.size());
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
return impl_->sym_size(dim);
}
int64_t size(int64_t dim) const {

View file

@ -594,6 +594,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return sizes_and_strides_.size_at_unchecked(d).as_int_unchecked();
}
c10::SymInt sym_size(int64_t d) const {
if (C10_UNLIKELY(
sizes_strides_policy_ >=
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
return sym_size_custom(d);
}
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
const auto sizes = this->sym_sizes();
return sizes[d];
}
/**
* Return the stride of a tensor at some dimension, wrapping the dimension
* if necessary.
@ -697,6 +708,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
return sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
}
virtual c10::SymInt sym_size_custom(int64_t d) const {
// TODO: We could add support to Python dispatch here.
// TODO: We could call into aten::size.int instead of
// sym_sizes_custom()[d] and enable use of the dispatcher.
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
return sym_sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
}
virtual IntArrayRef sizes_custom() const;
virtual Device device_custom() const;
virtual Layout layout_custom() const;

View file

@ -95,6 +95,7 @@ TEST(LazyDynamicOpsTest, NarrowCopy) {
}
TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
FLAGS_ltc_enable_symbolic_shapes = true;
auto xc = torch::rand({10});
auto x = xc.to(kLazy);
const size_t Y_DIM = 3;
@ -105,6 +106,7 @@ TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc
// shape inference assumes narrow_copy can copy the whole tensor
AllClose(z.cpu(), zc);
FLAGS_ltc_enable_symbolic_shapes = false;
}
TEST_F(LazyOpsTest, TestScalarTensor) {

View file

@ -32,7 +32,7 @@ def register_meta(op):
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.sym_size())
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
@ -53,7 +53,7 @@ def cat_meta(tensors, dim=0):
@register_meta([aten.narrow_copy.SymInt])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.sym_size()):
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
@ -165,6 +165,14 @@ class FakeSymbolicTensor(torch.Tensor):
self = args[0]
return self.sym_shape
# some calls can be redirected to `sym_size` rather than
# `sym_sizes`. `sym_size` uses `dim` to canonicalize an index
# so we need to implement both `sym_size` and `dim` for python
# tensors
if func_overload == torch.ops.aten.dim.default:
self = args[0]
return len(self.sym_shape)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
@ -174,7 +182,7 @@ class FakeSymbolicTensor(torch.Tensor):
def create_symbolic_tensor(name, arg, shape_env):
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.sym_size())])
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.size())])
sym_strides = tuple([shape_env.create_symint(f"{name}_{idx}_stride", val) for idx, val in enumerate(arg.stride())])
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device)
@ -188,22 +196,22 @@ class TestPySymInt(TestCase):
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.sym_size(0), PySymInt))
self.assertTrue(isinstance(x.sym_size(0), CPP_SYMINT_CLASS))
self.assertTrue(not isinstance(x.shape[0], PySymInt))
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
self.assertEqual(int(x.sym_size(0)), 5)
self.assertEqual(int(x.sym_size(1)), 4)
self.assertEqual(int(x.sym_size(2)), 3)
self.assertEqual(int(x.shape[0]), 5)
self.assertEqual(int(x.shape[1]), 4)
self.assertEqual(int(x.shape[2]), 3)
self.assertEqual(int(x.sym_size()[0]), 5)
self.assertEqual(int(x.sym_size()[1]), 4)
self.assertTrue(isinstance(x.sym_size()[1], CPP_SYMINT_CLASS))
self.assertEqual(int(x.sym_size()[2]), 3)
self.assertEqual(int(x.size()[0]), 5)
self.assertEqual(int(x.size()[1]), 4)
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
self.assertEqual(int(x.size()[2]), 3)
self.assertEqual(int(x.sym_size(0)), 5)
self.assertEqual(int(x.sym_size(1)), 4)
self.assertEqual(int(x.sym_size(2)), 3)
self.assertTrue(isinstance(x.sym_size(2), CPP_SYMINT_CLASS))
self.assertEqual(int(x.size(0)), 5)
self.assertEqual(int(x.size(1)), 4)
self.assertEqual(int(x.size(2)), 3)
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
@skipIfNoSympy
def test_binary(self):
@ -212,16 +220,16 @@ class TestPySymInt(TestCase):
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)
# broadcasting
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)
@skipIfNoSympy
def test_symint_args(self):
@ -229,16 +237,16 @@ class TestPySymInt(TestCase):
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.sym_size(LAST_DIM))
self.assertEqual(int(z.sym_size(2)), int(y.sym_size(2)))
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertEqual(int(z.shape[2]), int(y.shape[2]))
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - y.sym_size(LAST_DIM))
self.assertEqual(int(z.sym_size(2)), 2)
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertEqual(int(z.shape[2]), 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - 1)
self.assertEqual(int(z.sym_size(2)), 2)
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertEqual(int(z.shape[2]), 2)
@skipIfNoSympy
def test_symint_vargs(self):
@ -247,67 +255,67 @@ class TestPySymInt(TestCase):
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.sym_size(0), y.sym_size(1), x.sym_size(2))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)
# shape list
z = y.expand((x.sym_size(0), y.sym_size(1), x.sym_size(2)))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)
# mixed python symints and ints
z = y.expand(x.sym_size(0), y.sym_size(1), 3)
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)
# mixed python symints and ints in a list
z = y.expand((x.sym_size(0), y.sym_size(1), 3))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)
# mixed python symints and ints
z = y.expand(5, y.sym_size(1), x.sym_size(2))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand(5, y.shape[1], x.shape[2])
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)
# mixed python ints and symints in a list
z = y.expand((5, y.sym_size(1), x.sym_size(2)))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)
@skipIfNoSympy
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.sym_size(0), x.sym_size(0))
if expand_x.sym_size(0) > 3:
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op = shape_env.guards[0][0]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.sym_size(0)), str(gt_op.args[0]))
self.assertTrue(str(expand_x.sym_size(1)), str(x.sym_size(0)))
self.assertTrue(str(expand_x.sym_size(1)), str(result.sym_size(0)))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
@skipIfNoSympy
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.sym_size(0))
torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.SymInt(x, [x.sym_size(0), x.sym_size(1), x.sym_size(2)])
torch.ops.aten.expand.SymInt(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
@ -327,7 +335,7 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv()
a0 = shape_env.create_symint("a0", 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.sym_size(0), CPP_SYMINT_CLASS)
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
if __name__ == '__main__':

View file

@ -1794,7 +1794,7 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.size:
if func.overloadpacket == torch.ops.aten.sym_size:
return (5, 3)
return NotImplemented
@ -1807,13 +1807,13 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.size:
if func.overloadpacket == torch.ops.aten.sym_size:
return None
return NotImplemented
err_msg = "no implementation found for 'torch.ops.aten.size'"
err_msg = "no implementation found for 'torch.ops.aten.sym_size'"
e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
with self.assertRaisesRegex(RuntimeError, err_msg):
e.size()
e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)

View file

@ -95,43 +95,6 @@ static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg)
END_HANDLE_TH_ERRORS
}
// TODO: FIXME This should be super temprorary until we fix the XLA issue.
static PyObject * THPVariable_sym_size(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"sym_size(int64_t dim)",
"sym_size()",
"sym_size(Dimname dim)",
});
auto& self_ = THPVariable_Unpack(self);
ParsedArgs<3> parsed_args;
auto r = parser.parse(self, args, kwargs, parsed_args);
if(r.has_torch_function()){
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
}
if (r.idx == 0) {
if (jit::tracer::isTracing()) {
// will error out if a tensor has symints
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
} else {
return torch::toPyObject(self_.sym_size(r.toInt64(0)));
}
} else if (r.idx == 1) {
return THPSize_NewFromSymSizes(self_);
}
else if (r.idx == 2) {
if (jit::tracer::isTracing()) {
TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT");
}
return wrap(self_.size(r.dimname(0)));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
@ -152,14 +115,10 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa
// will error out if a tensor has symints
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
} else {
return wrap(self_.size(r.toInt64(0)));
//return torch::toPyObject(self_.sym_size(r.toInt64(0)));
return torch::toPyObject(self_.sym_size(r.toInt64(0)));
}
} else if (r.idx == 1) {
// we can't do the normal wrapping here because IntArrayRef maps to both
// torch.Size and tuple in python.
return THPSize_New(self_);
//return THPSize_NewFromSymSizes(self_);
return THPSize_NewFromSymSizes(self_);
}
else if (r.idx == 2) {
if (jit::tracer::isTracing()) {
@ -1322,7 +1281,6 @@ PyMethodDef variable_methods[] = {
{"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL},
{"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL},
{"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL},
{"sym_size", castPyCFunctionWithKeywords(THPVariable_sym_size), METH_VARARGS | METH_KEYWORDS, NULL},
{"_storage", THPVariable_storage, METH_NOARGS, NULL},
{"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL},
{"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL},

View file

@ -1220,8 +1220,7 @@ PyObject* THPVariable_get_shape(THPVariable* self, void* unused) {
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "shape");
}
// return THPSize_NewFromSymSizes(THPVariable_Unpack(self));
return THPSize_New(THPVariable_Unpack(self));
return THPSize_NewFromSymSizes(THPVariable_Unpack(self));
END_HANDLE_TH_ERRORS
}

View file

@ -143,7 +143,9 @@ void LTCTensorImpl::shallow_copy_from(
}
c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size());
return FLAGS_ltc_enable_symbolic_shapes
? c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size())
: TensorImpl::sym_sizes_default();
}
c10::SymIntArrayRef LTCTensorImpl::sym_sizes() const {

View file

@ -279,8 +279,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._is_zerotensor,
Tensor._addmm_activation,
Tensor._nested_tensor_layer_norm,
Tensor.to_padded_tensor,
Tensor.sym_size
Tensor.to_padded_tensor
}