Fixed Error message for tensor.align_to (#27221)

Summary:
Fixing this [issue1](https://github.com/pytorch/pytorch/issues/27074) and [issue2](https://github.com/pytorch/pytorch/issues/27073)
Tested via unit tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27221

Differential Revision: D17716235

Pulled By: izdeby

fbshipit-source-id: c7bafd16b469c91924ebc3dba77ca56424d4c93c
This commit is contained in:
Iurii Zdebskyi 2019-10-02 14:17:59 -07:00 committed by Facebook Github Bot
parent 162ef02db6
commit 293e35a87c
3 changed files with 5 additions and 6 deletions

View file

@ -152,7 +152,7 @@ Tensor align_to(const Tensor& tensor, DimnameList names) {
const auto& dim = tensor_names[idx];
TORCH_CHECK(dim.isBasic(),
"align_to: All input dims must be named. Found unnamed dim at index ",
dim, " of Tensor", tensor_names);
idx, " of Tensor", tensor_names);
auto it = std::find(names.begin(), names.end(), dim);
TORCH_CHECK(it != names.end(),
"align_to: Cannot find dim ", dim, " from Tensor", names,

View file

@ -1341,7 +1341,7 @@ class TestNamedTensor(TestCase):
self.assertEqual(output.shape, [3, 5, 1, 2])
# All input dimensions must be named
with self.assertRaisesRegex(RuntimeError, "All input dims must be named"):
with self.assertRaisesRegex(RuntimeError, "All input dims must be named. Found unnamed dim at index 0"):
create('None:2,C:3').align_to('N', 'C')
# not enough names
@ -1397,6 +1397,7 @@ class TestNamedTensor(TestCase):
self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
self.assertEqual(output.shape, [3, 5, 1, 2])
@unittest.skip("Not implemented yet")
def test_align_tensors_two_inputs(self):
def _test(tensor_namedshape, align_names, expected_sizes, expected_error):
tensor_names, tensor_sizes = tensor_namedshape
@ -1507,6 +1508,7 @@ class TestNamedTensor(TestCase):
for test in tests:
_test(*test)
@unittest.skip("Not implemented yet")
def test_align_tensors(self):
def reference_fn(*tensors):
longest_names = tensors[0].names

View file

@ -811,7 +811,4 @@ def lu(A, pivot=True, get_infos=False, out=None):
def align_tensors(*tensors):
if not torch._C._BUILD_NAMEDTENSOR:
raise RuntimeError('NYI: torch.align_tensors is experimental and a part '
'of our named tensors project.')
return torch._C._VariableFunctions.align_tensors(tensors)
raise RuntimeError('`align_tensors` not yet implemented.')