Fix minor python and cpp warnings from previous PR. (#12140)

Description: In the PR 12018 a few fixable python and cpp warning were introduced that this PR cleans up. Also adding a comment on the intent of test_mul_bool and out testing on test_ones.

Motivation and Context

When iterating in Python, use a list instead of a set and don't use reserved words
Fix long line in cpp
Clarify test_mul_bool intent for future developers.
fill_ implements torch.ones under the covers but in previous pr verification on the out param was not added so adding it here.
This commit is contained in:
Wil Brady 2022-07-11 16:18:40 -04:00 committed by GitHub
parent 99a370dd02
commit f1047e0456
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 7 deletions

View file

@ -917,7 +917,9 @@ at::Tensor& fill__Scalar(
ORT_LOG_FN(self, value);
if (
!IsSupportedType(self, {at::kHalf,at::kFloat,at::kInt,at::kDouble,at::kByte,at::kShort,at::kLong,at::kBFloat16,at::kBool})) {
std::vector<at::ScalarType> supportedTypes =
{at::kHalf, at::kFloat, at::kInt, at::kDouble, at::kByte, at::kShort, at::kLong, at::kBFloat16, at::kBool};
!IsSupportedType(self, supportedTypes)) {
std::cout << "fill__Scalar - Fell back to cpu!\n";
return at::native::call_fallback_fn<
&at::native::cpu_fallback,

View file

@ -40,6 +40,7 @@ class OrtOpTests(unittest.TestCase):
ort_ones = cpu_ones.to(device)
assert torch.allclose(torch.add(cpu_ones, cpu_ones, alpha=2.5), torch.add(ort_ones, ort_ones, alpha=2.5).cpu())
# the onnx operator Mul does not support type bool. The following test verifies cpu fall back works.
def test_mul_bool(self):
device = self.get_device()
cpu_ones = torch.ones(3, 3, dtype=bool)
@ -144,13 +145,16 @@ class OrtOpTests(unittest.TestCase):
assert not torch.equal(cpu_a, cpu_e)
assert not torch.equal(ort_a, ort_e)
def test_torch_ones(self):
def test_ones(self):
device = self.get_device()
cpu_ones = torch.ones((10, 10))
cpu_out_tensor = torch.tensor([])
ort_out_tensor = cpu_out_tensor.to(device)
cpu_ones = torch.ones((10, 10), out=cpu_out_tensor)
ort_ones = cpu_ones.to(device)
ort_ones_device = torch.ones((10, 10), device=device)
ort_ones_device = torch.ones((10, 10), out=ort_out_tensor, device=device)
assert torch.allclose(cpu_ones, ort_ones.cpu())
assert torch.allclose(cpu_ones, ort_ones_device.cpu())
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())
def test_narrow(self):
cpu_tensor = torch.rand(10, 10)
@ -346,10 +350,10 @@ class OrtOpTests(unittest.TestCase):
def test_fill(self):
device = self.get_device()
for type in {torch.int, torch.float}:
cpu_tensor = torch.zeros(2, 2, dtype=type)
for torch_type in [torch.int, torch.float]:
cpu_tensor = torch.zeros(2, 2, dtype=torch_type)
ort_tensor = cpu_tensor.to(device)
for value in {True, 1.1, -1, 0}:
for value in [True, 1.1, -1, 0]:
cpu_tensor.fill_(value)
ort_tensor.fill_(value)
assert cpu_tensor.dtype == ort_tensor.dtype