Fix NonZero eager impl. (#12143)

This commit is contained in:
Wil Brady 2022-07-13 05:50:33 -04:00 committed by GitHub
parent 3b0aaa9e0e
commit 48647bc7d7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 90 additions and 20 deletions

View file

@ -3,13 +3,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from opgen.writer import SourceWriter as SourceWriter
from opgen.parser import cpp_create_from_file as CPPParser
import sys
import os
from opgen.generator import ORTGen as ORTGen
from importlib.machinery import SourceFileLoader
import argparse
import os
import sys
from importlib.machinery import SourceFileLoader
from opgen.generator import ORTGen as ORTGen
from opgen.parser import cpp_create_from_file as CPPParser
from opgen.writer import SourceWriter as SourceWriter
parser = argparse.ArgumentParser(description="Generate ORT ATen operations")
parser.add_argument(
@ -24,7 +25,12 @@ parser.add_argument(
args = parser.parse_args()
ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module()
ortgen = ORTGen(ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops)
ortgen = ORTGen(
ops_module.ops,
type_promotion_ops=ops_module.type_promotion_ops,
custom_ops=args.custom_ops,
aten_output_type=ops_module.aten_output_type,
)
regdecs_path = args.header_file
print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}")

View file

@ -45,6 +45,7 @@ class GeluGrad(ONNXOp):
ops = {}
type_promotion_ops = []
aten_output_type = {}
# the following op list is for ops that have a .out version. Often this is the only op needing to be implemented
# and the regular and inplace(_) version derive from the .out.
@ -68,7 +69,6 @@ unary_ops_with_out = [
"hardsigmoid",
"log",
"neg",
"nonzero",
"reciprocal",
"round",
"sigmoid",
@ -92,7 +92,6 @@ unary_ops = [
"det",
"isinf",
"isnan",
"nonzero",
"relu",
"selu",
]
@ -169,8 +168,14 @@ hand_implemented = {
"aten::equal": SignatureOnly(),
"aten::_softmax": Softmax("self", axis="dim"),
"aten::argmax.out": SignatureOnly(),
"aten::nonzero": Transpose(NonZero("self")),
"aten::nonzero.out": SignatureOnly(),
}
# If the aten op expects a specific output type that differs from self
# add the op and type to aten_output_type
aten_output_type["aten::nonzero"] = "at::ScalarType::Long"
# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439
# This is done to make sure it is backward and future compatible
if version.parse(torch.__version__) < version.parse(TORCH_API_CHANGE_VERSION):

View file

@ -1,14 +1,11 @@
from copy import deepcopy
from opgen.generator import AttrType, ONNXAttr
from opgen.generator import (
ORTGen as ORTGen,
ONNXOp as ONNXOp,
SignatureOnly as SignatureOnly,
MakeTorchFallback as MakeTorchFallback,
)
from opgen.generator import AttrType
from opgen.generator import MakeTorchFallback as MakeTorchFallback
from opgen.generator import ONNXAttr
from opgen.generator import ONNXOp as ONNXOp
from opgen.generator import ORTGen as ORTGen
from opgen.generator import SignatureOnly as SignatureOnly
from opgen.onnxops import *
ops = {
@ -17,3 +14,4 @@ ops = {
}
type_promotion_ops = {}
aten_output_type = {}

View file

@ -114,13 +114,18 @@ class ORTGen:
_custom_ops: bool
def __init__(
self, ops: Optional[Dict[str, ONNXOp]] = None, custom_ops: bool = False, type_promotion_ops: List = ()
self,
ops: Optional[Dict[str, ONNXOp]] = None,
custom_ops: bool = False,
type_promotion_ops: List = (),
aten_output_type: Dict = (),
):
self._mapped_ops = {}
if ops:
self.register_many(ops)
self._custom_ops = custom_ops
self.type_promotion_ops = type_promotion_ops
self.aten_output_type = aten_output_type
def register(self, aten_name: str, onnx_op: ONNXOp):
self._mapped_ops[aten_name] = onnx_op
@ -517,6 +522,11 @@ class ORTGen:
writer.write(".options()")
if need_type_promotion:
writer.write(".dtype(*promoted_type)")
# do we need to set type on the returned value
if mapped_func.mapped_op_name in self.aten_output_type:
writer.write(f".dtype({self.aten_output_type[mapped_func.mapped_op_name]})")
writer.writeline(";")
writer.writeline("return aten_tensor_from_ort(")

View file

@ -957,6 +957,24 @@ at::Tensor& fill__Scalar(
return self;
}
// aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
at::Tensor& nonzero_out(
const at::Tensor& self,
// *,
at::Tensor& out) {
ORT_LOG_FN(self, out);
auto temp = eager::aten::nonzero(self);
// resize out, then copy nonzero result into it.
auto& invoker = GetORTInvoker(self.device());
resize_output(invoker, dynamic_cast<ORTTensorImpl*>(out.unsafeGetTensorImpl()), temp.sizes());
auto ort_input_out = create_ort_value(invoker, out);
auto ort_temp = create_ort_value(invoker, temp);
copy(invoker, ort_temp, ort_input_out);
return out;
}
} // namespace aten

View file

@ -133,5 +133,12 @@ void resize_impl_ort_(
ORTTensorImpl* self,
at::IntArrayRef size);
namespace aten {
// aten::nonzero(Tensor self) -> Tensor
at::Tensor nonzero(
const at::Tensor& self);
} // namespace aten
} // namespace eager
} // namespace torch_ort

View file

@ -42,7 +42,6 @@ ops = [
# the following unary ops not been tested:
# ["isnan", torch.tensor([1, float('nan'), 2])]]
# ["selu", torch.randn(10)]]
# ["nonzero", torch.tensor([0, 2, 1, 3])]]
# ["sign", ]]
# ["hardsigmoid", ],
# ["isinf", ],
@ -476,6 +475,33 @@ class OrtOpTests(unittest.TestCase):
assert cpu_tensor.dtype == ort_tensor.dtype
assert torch.equal(cpu_tensor, ort_tensor.to("cpu"))
# tests both nonzero and nonzero.out
def test_nonzero(self):
device = self.get_device()
for cpu_tensor in [
torch.tensor([[[-1, 0, 1], [0, 1, 0]], [[0, 1, 0], [-1, 0, 1]]], dtype=torch.long),
torch.tensor([[[-1, 0, 1], [0, 1, 0]], [[0, 1, 0], [-1, 0, 1]]], dtype=torch.float),
]:
ort_tensor = cpu_tensor.to(device)
cpu_out_tensor = torch.tensor([], dtype=torch.long)
ort_out_tensor = cpu_out_tensor.to(device)
# nonzero.out
cpu_result = torch.nonzero(cpu_tensor, out=cpu_out_tensor)
ort_result = torch.nonzero(ort_tensor, out=ort_out_tensor)
assert torch.equal(cpu_out_tensor, ort_out_tensor.to("cpu"))
assert torch.equal(cpu_result, ort_result.to("cpu"))
# nonzero
cpu_result = torch.nonzero(cpu_tensor)
ort_result = torch.nonzero(ort_tensor)
assert torch.equal(cpu_result, ort_result.to("cpu"))
# check result between nonzero.out and nonzero
assert torch.equal(ort_result.to("cpu"), ort_out_tensor.to("cpu"))
if __name__ == "__main__":
unittest.main()