mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
reinstate size and shape returning symints (#79560)
This PR redirects `size` and `.shape` to call `sym_sizes` Pull Request resolved: https://github.com/pytorch/pytorch/pull/79560 Approved by: https://github.com/Chillee
This commit is contained in:
parent
1263247395
commit
8389ccbcd8
10 changed files with 105 additions and 118 deletions
|
|
@ -68,6 +68,9 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
|
|||
int64_t size_custom(int64_t d) const override {
|
||||
return this->size(d);
|
||||
}
|
||||
c10::SymInt sym_size_custom(int64_t d) const override {
|
||||
return c10::SymInt{this->size(d)};
|
||||
}
|
||||
IntArrayRef sizes_custom() const override;
|
||||
c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
c10::SymIntArrayRef sym_sizes() const override;
|
||||
|
|
|
|||
|
|
@ -157,11 +157,7 @@ class TORCH_API TensorBase {
|
|||
}
|
||||
|
||||
c10::SymInt sym_size(int64_t dim) const {
|
||||
const auto sizes = this->sym_sizes();
|
||||
const auto ndim = static_cast<int64_t>(sizes.size());
|
||||
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
|
||||
return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
|
||||
|
||||
return impl_->sym_size(dim);
|
||||
}
|
||||
|
||||
int64_t size(int64_t dim) const {
|
||||
|
|
|
|||
|
|
@ -594,6 +594,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
return sizes_and_strides_.size_at_unchecked(d).as_int_unchecked();
|
||||
}
|
||||
|
||||
c10::SymInt sym_size(int64_t d) const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
return sym_size_custom(d);
|
||||
}
|
||||
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
|
||||
const auto sizes = this->sym_sizes();
|
||||
return sizes[d];
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the stride of a tensor at some dimension, wrapping the dimension
|
||||
* if necessary.
|
||||
|
|
@ -697,6 +708,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
|
||||
return sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
|
||||
}
|
||||
|
||||
virtual c10::SymInt sym_size_custom(int64_t d) const {
|
||||
// TODO: We could add support to Python dispatch here.
|
||||
// TODO: We could call into aten::size.int instead of
|
||||
// sym_sizes_custom()[d] and enable use of the dispatcher.
|
||||
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
|
||||
return sym_sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
|
||||
}
|
||||
|
||||
virtual IntArrayRef sizes_custom() const;
|
||||
virtual Device device_custom() const;
|
||||
virtual Layout layout_custom() const;
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ TEST(LazyDynamicOpsTest, NarrowCopy) {
|
|||
}
|
||||
|
||||
TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
|
||||
FLAGS_ltc_enable_symbolic_shapes = true;
|
||||
auto xc = torch::rand({10});
|
||||
auto x = xc.to(kLazy);
|
||||
const size_t Y_DIM = 3;
|
||||
|
|
@ -105,6 +106,7 @@ TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
|
|||
ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc
|
||||
// shape inference assumes narrow_copy can copy the whole tensor
|
||||
AllClose(z.cpu(), zc);
|
||||
FLAGS_ltc_enable_symbolic_shapes = false;
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestScalarTensor) {
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def register_meta(op):
|
|||
|
||||
@register_meta([aten.add.Tensor, aten.sub.Tensor])
|
||||
def binary_meta(a, b):
|
||||
return a.new_empty(a.sym_size())
|
||||
return a.new_empty(a.shape)
|
||||
|
||||
|
||||
@register_meta(aten.cat.default)
|
||||
|
|
@ -53,7 +53,7 @@ def cat_meta(tensors, dim=0):
|
|||
@register_meta([aten.narrow_copy.SymInt])
|
||||
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
||||
shape = []
|
||||
for i, x in enumerate(a.sym_size()):
|
||||
for i, x in enumerate(a.shape):
|
||||
if i == dim:
|
||||
shape.append(length)
|
||||
else:
|
||||
|
|
@ -165,6 +165,14 @@ class FakeSymbolicTensor(torch.Tensor):
|
|||
self = args[0]
|
||||
return self.sym_shape
|
||||
|
||||
# some calls can be redirected to `sym_size` rather than
|
||||
# `sym_sizes`. `sym_size` uses `dim` to canonicalize an index
|
||||
# so we need to implement both `sym_size` and `dim` for python
|
||||
# tensors
|
||||
if func_overload == torch.ops.aten.dim.default:
|
||||
self = args[0]
|
||||
return len(self.sym_shape)
|
||||
|
||||
if func_overload == torch.ops.aten.new_empty.default:
|
||||
self = args[0]
|
||||
shape = args[1]
|
||||
|
|
@ -174,7 +182,7 @@ class FakeSymbolicTensor(torch.Tensor):
|
|||
|
||||
|
||||
def create_symbolic_tensor(name, arg, shape_env):
|
||||
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.sym_size())])
|
||||
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.size())])
|
||||
sym_strides = tuple([shape_env.create_symint(f"{name}_{idx}_stride", val) for idx, val in enumerate(arg.stride())])
|
||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device)
|
||||
|
||||
|
|
@ -188,22 +196,22 @@ class TestPySymInt(TestCase):
|
|||
def test_roundtrip(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
self.assertTrue(not isinstance(x.sym_size(0), PySymInt))
|
||||
self.assertTrue(isinstance(x.sym_size(0), CPP_SYMINT_CLASS))
|
||||
self.assertTrue(not isinstance(x.shape[0], PySymInt))
|
||||
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
|
||||
|
||||
self.assertEqual(int(x.sym_size(0)), 5)
|
||||
self.assertEqual(int(x.sym_size(1)), 4)
|
||||
self.assertEqual(int(x.sym_size(2)), 3)
|
||||
self.assertEqual(int(x.shape[0]), 5)
|
||||
self.assertEqual(int(x.shape[1]), 4)
|
||||
self.assertEqual(int(x.shape[2]), 3)
|
||||
|
||||
self.assertEqual(int(x.sym_size()[0]), 5)
|
||||
self.assertEqual(int(x.sym_size()[1]), 4)
|
||||
self.assertTrue(isinstance(x.sym_size()[1], CPP_SYMINT_CLASS))
|
||||
self.assertEqual(int(x.sym_size()[2]), 3)
|
||||
self.assertEqual(int(x.size()[0]), 5)
|
||||
self.assertEqual(int(x.size()[1]), 4)
|
||||
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
|
||||
self.assertEqual(int(x.size()[2]), 3)
|
||||
|
||||
self.assertEqual(int(x.sym_size(0)), 5)
|
||||
self.assertEqual(int(x.sym_size(1)), 4)
|
||||
self.assertEqual(int(x.sym_size(2)), 3)
|
||||
self.assertTrue(isinstance(x.sym_size(2), CPP_SYMINT_CLASS))
|
||||
self.assertEqual(int(x.size(0)), 5)
|
||||
self.assertEqual(int(x.size(1)), 4)
|
||||
self.assertEqual(int(x.size(2)), 3)
|
||||
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_binary(self):
|
||||
|
|
@ -212,16 +220,16 @@ class TestPySymInt(TestCase):
|
|||
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
|
||||
|
||||
z = x + y
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
self.assertEqual(int(z.shape[0]), 5)
|
||||
self.assertEqual(int(z.shape[1]), 4)
|
||||
self.assertEqual(int(z.shape[2]), 3)
|
||||
|
||||
# broadcasting
|
||||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
||||
z = x + y
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
self.assertEqual(int(z.shape[0]), 5)
|
||||
self.assertEqual(int(z.shape[1]), 4)
|
||||
self.assertEqual(int(z.shape[2]), 3)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_args(self):
|
||||
|
|
@ -229,16 +237,16 @@ class TestPySymInt(TestCase):
|
|||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
|
||||
LAST_DIM = 2
|
||||
z = x.narrow_copy(LAST_DIM, 0, y.sym_size(LAST_DIM))
|
||||
self.assertEqual(int(z.sym_size(2)), int(y.sym_size(2)))
|
||||
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
|
||||
self.assertEqual(int(z.shape[2]), int(y.shape[2]))
|
||||
|
||||
# arithmetic expr with two symints
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - y.sym_size(LAST_DIM))
|
||||
self.assertEqual(int(z.sym_size(2)), 2)
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
|
||||
self.assertEqual(int(z.shape[2]), 2)
|
||||
|
||||
# arithmetic expr with a symint and python int
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - 1)
|
||||
self.assertEqual(int(z.sym_size(2)), 2)
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
|
||||
self.assertEqual(int(z.shape[2]), 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_vargs(self):
|
||||
|
|
@ -247,67 +255,67 @@ class TestPySymInt(TestCase):
|
|||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
||||
|
||||
# varargs
|
||||
z = y.expand(x.sym_size(0), y.sym_size(1), x.sym_size(2))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
|
||||
self.assertEqual(int(z.shape[0]), 5)
|
||||
self.assertEqual(int(z.shape[1]), 4)
|
||||
self.assertEqual(int(z.shape[2]), 3)
|
||||
|
||||
# shape list
|
||||
z = y.expand((x.sym_size(0), y.sym_size(1), x.sym_size(2)))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
|
||||
self.assertEqual(int(z.shape[0]), 5)
|
||||
self.assertEqual(int(z.shape[1]), 4)
|
||||
self.assertEqual(int(z.shape[2]), 3)
|
||||
|
||||
# mixed python symints and ints
|
||||
z = y.expand(x.sym_size(0), y.sym_size(1), 3)
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
z = y.expand(x.shape[0], y.shape[1], 3)
|
||||
self.assertEqual(int(z.shape[0]), 5)
|
||||
self.assertEqual(int(z.shape[1]), 4)
|
||||
self.assertEqual(int(z.shape[2]), 3)
|
||||
|
||||
# mixed python symints and ints in a list
|
||||
z = y.expand((x.sym_size(0), y.sym_size(1), 3))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
z = y.expand((x.shape[0], y.shape[1], 3))
|
||||
self.assertEqual(int(z.shape[0]), 5)
|
||||
self.assertEqual(int(z.shape[1]), 4)
|
||||
self.assertEqual(int(z.shape[2]), 3)
|
||||
|
||||
# mixed python symints and ints
|
||||
z = y.expand(5, y.sym_size(1), x.sym_size(2))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
z = y.expand(5, y.shape[1], x.shape[2])
|
||||
self.assertEqual(int(z.shape[0]), 5)
|
||||
self.assertEqual(int(z.shape[1]), 4)
|
||||
self.assertEqual(int(z.shape[2]), 3)
|
||||
|
||||
# mixed python ints and symints in a list
|
||||
z = y.expand((5, y.sym_size(1), x.sym_size(2)))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
z = y.expand((5, y.shape[1], x.shape[2]))
|
||||
self.assertEqual(int(z.shape[0]), 5)
|
||||
self.assertEqual(int(z.shape[1]), 4)
|
||||
self.assertEqual(int(z.shape[2]), 3)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_size_expressions(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
expand_x = x.expand(x.sym_size(0), x.sym_size(0))
|
||||
if expand_x.sym_size(0) > 3:
|
||||
expand_x = x.expand(x.shape[0], x.shape[0])
|
||||
if expand_x.shape[0] > 3:
|
||||
result = expand_x + expand_x
|
||||
else:
|
||||
result = expand_x + expand_x
|
||||
|
||||
gt_op = shape_env.guards[0][0]
|
||||
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
|
||||
self.assertTrue(str(x.sym_size(0)), str(gt_op.args[0]))
|
||||
self.assertTrue(str(expand_x.sym_size(1)), str(x.sym_size(0)))
|
||||
self.assertTrue(str(expand_x.sym_size(1)), str(result.sym_size(0)))
|
||||
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_aten_ops(self):
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.sym_size(0))
|
||||
torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.shape[0])
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
torch.ops.aten.expand.SymInt(x, [x.sym_size(0), x.sym_size(1), x.sym_size(2)])
|
||||
torch.ops.aten.expand.SymInt(x, [x.shape[0], x.shape[1], x.shape[2]])
|
||||
|
||||
def test_fx_trace_intlist(self):
|
||||
class CustomModule(torch.nn.Module):
|
||||
|
|
@ -327,7 +335,7 @@ class TestPySymInt(TestCase):
|
|||
shape_env = ShapeEnv()
|
||||
a0 = shape_env.create_symint("a0", 2)
|
||||
r = torch.empty(a0, device='meta')
|
||||
self.assertIsInstance(r.sym_size(0), CPP_SYMINT_CLASS)
|
||||
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -1794,7 +1794,7 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
|||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.dim:
|
||||
return data.dim()
|
||||
if func.overloadpacket == torch.ops.aten.size:
|
||||
if func.overloadpacket == torch.ops.aten.sym_size:
|
||||
return (5, 3)
|
||||
return NotImplemented
|
||||
|
||||
|
|
@ -1807,13 +1807,13 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
|||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.dim:
|
||||
return data.dim()
|
||||
if func.overloadpacket == torch.ops.aten.size:
|
||||
if func.overloadpacket == torch.ops.aten.sym_size:
|
||||
return None
|
||||
return NotImplemented
|
||||
|
||||
err_msg = "no implementation found for 'torch.ops.aten.size'"
|
||||
err_msg = "no implementation found for 'torch.ops.aten.sym_size'"
|
||||
e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
e.size()
|
||||
|
||||
e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
|
||||
|
|
|
|||
|
|
@ -95,43 +95,6 @@ static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg)
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// TODO: FIXME This should be super temprorary until we fix the XLA issue.
|
||||
static PyObject * THPVariable_sym_size(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"sym_size(int64_t dim)",
|
||||
"sym_size()",
|
||||
"sym_size(Dimname dim)",
|
||||
});
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
ParsedArgs<3> parsed_args;
|
||||
auto r = parser.parse(self, args, kwargs, parsed_args);
|
||||
|
||||
if(r.has_torch_function()){
|
||||
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
|
||||
}
|
||||
if (r.idx == 0) {
|
||||
if (jit::tracer::isTracing()) {
|
||||
// will error out if a tensor has symints
|
||||
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
|
||||
} else {
|
||||
return torch::toPyObject(self_.sym_size(r.toInt64(0)));
|
||||
}
|
||||
} else if (r.idx == 1) {
|
||||
return THPSize_NewFromSymSizes(self_);
|
||||
}
|
||||
else if (r.idx == 2) {
|
||||
if (jit::tracer::isTracing()) {
|
||||
TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT");
|
||||
}
|
||||
return wrap(self_.size(r.dimname(0)));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -152,14 +115,10 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa
|
|||
// will error out if a tensor has symints
|
||||
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
|
||||
} else {
|
||||
return wrap(self_.size(r.toInt64(0)));
|
||||
//return torch::toPyObject(self_.sym_size(r.toInt64(0)));
|
||||
return torch::toPyObject(self_.sym_size(r.toInt64(0)));
|
||||
}
|
||||
} else if (r.idx == 1) {
|
||||
// we can't do the normal wrapping here because IntArrayRef maps to both
|
||||
// torch.Size and tuple in python.
|
||||
return THPSize_New(self_);
|
||||
//return THPSize_NewFromSymSizes(self_);
|
||||
return THPSize_NewFromSymSizes(self_);
|
||||
}
|
||||
else if (r.idx == 2) {
|
||||
if (jit::tracer::isTracing()) {
|
||||
|
|
@ -1322,7 +1281,6 @@ PyMethodDef variable_methods[] = {
|
|||
{"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"sym_size", castPyCFunctionWithKeywords(THPVariable_sym_size), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"_storage", THPVariable_storage, METH_NOARGS, NULL},
|
||||
{"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL},
|
||||
{"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
|
|
|
|||
|
|
@ -1220,8 +1220,7 @@ PyObject* THPVariable_get_shape(THPVariable* self, void* unused) {
|
|||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_getter(self, "shape");
|
||||
}
|
||||
// return THPSize_NewFromSymSizes(THPVariable_Unpack(self));
|
||||
return THPSize_New(THPVariable_Unpack(self));
|
||||
return THPSize_NewFromSymSizes(THPVariable_Unpack(self));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -143,7 +143,9 @@ void LTCTensorImpl::shallow_copy_from(
|
|||
}
|
||||
|
||||
c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
|
||||
return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size());
|
||||
return FLAGS_ltc_enable_symbolic_shapes
|
||||
? c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size())
|
||||
: TensorImpl::sym_sizes_default();
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef LTCTensorImpl::sym_sizes() const {
|
||||
|
|
|
|||
|
|
@ -279,8 +279,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
Tensor._is_zerotensor,
|
||||
Tensor._addmm_activation,
|
||||
Tensor._nested_tensor_layer_norm,
|
||||
Tensor.to_padded_tensor,
|
||||
Tensor.sym_size
|
||||
Tensor.to_padded_tensor
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue