mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Fix minor python and cpp warnings from previous PR. (#12140)
Description: In the PR 12018 a few fixable python and cpp warning were introduced that this PR cleans up. Also adding a comment on the intent of test_mul_bool and out testing on test_ones. Motivation and Context When iterating in Python, use a list instead of a set and don't use reserved words Fix long line in cpp Clarify test_mul_bool intent for future developers. fill_ implements torch.ones under the covers but in previous pr verification on the out param was not added so adding it here.
This commit is contained in:
parent
99a370dd02
commit
f1047e0456
2 changed files with 13 additions and 7 deletions
|
|
@ -917,7 +917,9 @@ at::Tensor& fill__Scalar(
|
|||
ORT_LOG_FN(self, value);
|
||||
|
||||
if (
|
||||
!IsSupportedType(self, {at::kHalf,at::kFloat,at::kInt,at::kDouble,at::kByte,at::kShort,at::kLong,at::kBFloat16,at::kBool})) {
|
||||
std::vector<at::ScalarType> supportedTypes =
|
||||
{at::kHalf, at::kFloat, at::kInt, at::kDouble, at::kByte, at::kShort, at::kLong, at::kBFloat16, at::kBool};
|
||||
!IsSupportedType(self, supportedTypes)) {
|
||||
std::cout << "fill__Scalar - Fell back to cpu!\n";
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ class OrtOpTests(unittest.TestCase):
|
|||
ort_ones = cpu_ones.to(device)
|
||||
assert torch.allclose(torch.add(cpu_ones, cpu_ones, alpha=2.5), torch.add(ort_ones, ort_ones, alpha=2.5).cpu())
|
||||
|
||||
# the onnx operator Mul does not support type bool. The following test verifies cpu fall back works.
|
||||
def test_mul_bool(self):
|
||||
device = self.get_device()
|
||||
cpu_ones = torch.ones(3, 3, dtype=bool)
|
||||
|
|
@ -144,13 +145,16 @@ class OrtOpTests(unittest.TestCase):
|
|||
assert not torch.equal(cpu_a, cpu_e)
|
||||
assert not torch.equal(ort_a, ort_e)
|
||||
|
||||
def test_torch_ones(self):
|
||||
def test_ones(self):
|
||||
device = self.get_device()
|
||||
cpu_ones = torch.ones((10, 10))
|
||||
cpu_out_tensor = torch.tensor([])
|
||||
ort_out_tensor = cpu_out_tensor.to(device)
|
||||
cpu_ones = torch.ones((10, 10), out=cpu_out_tensor)
|
||||
ort_ones = cpu_ones.to(device)
|
||||
ort_ones_device = torch.ones((10, 10), device=device)
|
||||
ort_ones_device = torch.ones((10, 10), out=ort_out_tensor, device=device)
|
||||
assert torch.allclose(cpu_ones, ort_ones.cpu())
|
||||
assert torch.allclose(cpu_ones, ort_ones_device.cpu())
|
||||
assert torch.allclose(cpu_out_tensor, ort_out_tensor.cpu())
|
||||
|
||||
def test_narrow(self):
|
||||
cpu_tensor = torch.rand(10, 10)
|
||||
|
|
@ -346,10 +350,10 @@ class OrtOpTests(unittest.TestCase):
|
|||
|
||||
def test_fill(self):
|
||||
device = self.get_device()
|
||||
for type in {torch.int, torch.float}:
|
||||
cpu_tensor = torch.zeros(2, 2, dtype=type)
|
||||
for torch_type in [torch.int, torch.float]:
|
||||
cpu_tensor = torch.zeros(2, 2, dtype=torch_type)
|
||||
ort_tensor = cpu_tensor.to(device)
|
||||
for value in {True, 1.1, -1, 0}:
|
||||
for value in [True, 1.1, -1, 0]:
|
||||
cpu_tensor.fill_(value)
|
||||
ort_tensor.fill_(value)
|
||||
assert cpu_tensor.dtype == ort_tensor.dtype
|
||||
|
|
|
|||
Loading…
Reference in a new issue