diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 5215762619..1ac4251ebd 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -917,7 +917,9 @@ at::Tensor& fill__Scalar( ORT_LOG_FN(self, value); if ( - !IsSupportedType(self, {at::kHalf,at::kFloat,at::kInt,at::kDouble,at::kByte,at::kShort,at::kLong,at::kBFloat16,at::kBool})) { + std::vector supportedTypes = + {at::kHalf, at::kFloat, at::kInt, at::kDouble, at::kByte, at::kShort, at::kLong, at::kBFloat16, at::kBool}; + !IsSupportedType(self, supportedTypes)) { std::cout << "fill__Scalar - Fell back to cpu!\n"; return at::native::call_fallback_fn< &at::native::cpu_fallback, diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index f73abf5d32..ef0ba1984c 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -40,6 +40,7 @@ class OrtOpTests(unittest.TestCase): ort_ones = cpu_ones.to(device) assert torch.allclose(torch.add(cpu_ones, cpu_ones, alpha=2.5), torch.add(ort_ones, ort_ones, alpha=2.5).cpu()) + # the onnx operator Mul does not support type bool. The following test verifies cpu fall back works. def test_mul_bool(self): device = self.get_device() cpu_ones = torch.ones(3, 3, dtype=bool) @@ -144,13 +145,16 @@ class OrtOpTests(unittest.TestCase): assert not torch.equal(cpu_a, cpu_e) assert not torch.equal(ort_a, ort_e) - def test_torch_ones(self): + def test_ones(self): device = self.get_device() - cpu_ones = torch.ones((10, 10)) + cpu_out_tensor = torch.tensor([]) + ort_out_tensor = cpu_out_tensor.to(device) + cpu_ones = torch.ones((10, 10), out=cpu_out_tensor) ort_ones = cpu_ones.to(device) - ort_ones_device = torch.ones((10, 10), device=device) + ort_ones_device = torch.ones((10, 10), out=ort_out_tensor, device=device) assert torch.allclose(cpu_ones, ort_ones.cpu()) assert torch.allclose(cpu_ones, ort_ones_device.cpu()) + assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu()) def test_narrow(self): cpu_tensor = torch.rand(10, 10) @@ -346,10 +350,10 @@ class OrtOpTests(unittest.TestCase): def test_fill(self): device = self.get_device() - for type in {torch.int, torch.float}: - cpu_tensor = torch.zeros(2, 2, dtype=type) + for torch_type in [torch.int, torch.float]: + cpu_tensor = torch.zeros(2, 2, dtype=torch_type) ort_tensor = cpu_tensor.to(device) - for value in {True, 1.1, -1, 0}: + for value in [True, 1.1, -1, 0]: cpu_tensor.fill_(value) ort_tensor.fill_(value) assert cpu_tensor.dtype == ort_tensor.dtype