mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Fix NonZero eager impl. (#12143)
This commit is contained in:
parent
3b0aaa9e0e
commit
48647bc7d7
7 changed files with 90 additions and 20 deletions
|
|
@ -3,13 +3,14 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from opgen.writer import SourceWriter as SourceWriter
|
||||
from opgen.parser import cpp_create_from_file as CPPParser
|
||||
import sys
|
||||
import os
|
||||
from opgen.generator import ORTGen as ORTGen
|
||||
from importlib.machinery import SourceFileLoader
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
from opgen.generator import ORTGen as ORTGen
|
||||
from opgen.parser import cpp_create_from_file as CPPParser
|
||||
from opgen.writer import SourceWriter as SourceWriter
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate ORT ATen operations")
|
||||
parser.add_argument(
|
||||
|
|
@ -24,7 +25,12 @@ parser.add_argument(
|
|||
args = parser.parse_args()
|
||||
ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module()
|
||||
|
||||
ortgen = ORTGen(ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops)
|
||||
ortgen = ORTGen(
|
||||
ops_module.ops,
|
||||
type_promotion_ops=ops_module.type_promotion_ops,
|
||||
custom_ops=args.custom_ops,
|
||||
aten_output_type=ops_module.aten_output_type,
|
||||
)
|
||||
|
||||
regdecs_path = args.header_file
|
||||
print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}")
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ class GeluGrad(ONNXOp):
|
|||
|
||||
ops = {}
|
||||
type_promotion_ops = []
|
||||
aten_output_type = {}
|
||||
|
||||
# the following op list is for ops that have a .out version. Often this is the only op needing to be implemented
|
||||
# and the regular and inplace(_) version derive from the .out.
|
||||
|
|
@ -68,7 +69,6 @@ unary_ops_with_out = [
|
|||
"hardsigmoid",
|
||||
"log",
|
||||
"neg",
|
||||
"nonzero",
|
||||
"reciprocal",
|
||||
"round",
|
||||
"sigmoid",
|
||||
|
|
@ -92,7 +92,6 @@ unary_ops = [
|
|||
"det",
|
||||
"isinf",
|
||||
"isnan",
|
||||
"nonzero",
|
||||
"relu",
|
||||
"selu",
|
||||
]
|
||||
|
|
@ -169,8 +168,14 @@ hand_implemented = {
|
|||
"aten::equal": SignatureOnly(),
|
||||
"aten::_softmax": Softmax("self", axis="dim"),
|
||||
"aten::argmax.out": SignatureOnly(),
|
||||
"aten::nonzero": Transpose(NonZero("self")),
|
||||
"aten::nonzero.out": SignatureOnly(),
|
||||
}
|
||||
|
||||
# If the aten op expects a specific output type that differs from self
|
||||
# add the op and type to aten_output_type
|
||||
aten_output_type["aten::nonzero"] = "at::ScalarType::Long"
|
||||
|
||||
# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439
|
||||
# This is done to make sure it is backward and future compatible
|
||||
if version.parse(torch.__version__) < version.parse(TORCH_API_CHANGE_VERSION):
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
from copy import deepcopy
|
||||
|
||||
from opgen.generator import AttrType, ONNXAttr
|
||||
|
||||
from opgen.generator import (
|
||||
ORTGen as ORTGen,
|
||||
ONNXOp as ONNXOp,
|
||||
SignatureOnly as SignatureOnly,
|
||||
MakeTorchFallback as MakeTorchFallback,
|
||||
)
|
||||
|
||||
from opgen.generator import AttrType
|
||||
from opgen.generator import MakeTorchFallback as MakeTorchFallback
|
||||
from opgen.generator import ONNXAttr
|
||||
from opgen.generator import ONNXOp as ONNXOp
|
||||
from opgen.generator import ORTGen as ORTGen
|
||||
from opgen.generator import SignatureOnly as SignatureOnly
|
||||
from opgen.onnxops import *
|
||||
|
||||
ops = {
|
||||
|
|
@ -17,3 +14,4 @@ ops = {
|
|||
}
|
||||
|
||||
type_promotion_ops = {}
|
||||
aten_output_type = {}
|
||||
|
|
|
|||
|
|
@ -114,13 +114,18 @@ class ORTGen:
|
|||
_custom_ops: bool
|
||||
|
||||
def __init__(
|
||||
self, ops: Optional[Dict[str, ONNXOp]] = None, custom_ops: bool = False, type_promotion_ops: List = ()
|
||||
self,
|
||||
ops: Optional[Dict[str, ONNXOp]] = None,
|
||||
custom_ops: bool = False,
|
||||
type_promotion_ops: List = (),
|
||||
aten_output_type: Dict = (),
|
||||
):
|
||||
self._mapped_ops = {}
|
||||
if ops:
|
||||
self.register_many(ops)
|
||||
self._custom_ops = custom_ops
|
||||
self.type_promotion_ops = type_promotion_ops
|
||||
self.aten_output_type = aten_output_type
|
||||
|
||||
def register(self, aten_name: str, onnx_op: ONNXOp):
|
||||
self._mapped_ops[aten_name] = onnx_op
|
||||
|
|
@ -517,6 +522,11 @@ class ORTGen:
|
|||
writer.write(".options()")
|
||||
if need_type_promotion:
|
||||
writer.write(".dtype(*promoted_type)")
|
||||
|
||||
# do we need to set type on the returned value
|
||||
if mapped_func.mapped_op_name in self.aten_output_type:
|
||||
writer.write(f".dtype({self.aten_output_type[mapped_func.mapped_op_name]})")
|
||||
|
||||
writer.writeline(";")
|
||||
|
||||
writer.writeline("return aten_tensor_from_ort(")
|
||||
|
|
|
|||
|
|
@ -957,6 +957,24 @@ at::Tensor& fill__Scalar(
|
|||
return self;
|
||||
}
|
||||
|
||||
// aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
at::Tensor& nonzero_out(
|
||||
const at::Tensor& self,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
ORT_LOG_FN(self, out);
|
||||
|
||||
auto temp = eager::aten::nonzero(self);
|
||||
|
||||
// resize out, then copy nonzero result into it.
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
resize_output(invoker, dynamic_cast<ORTTensorImpl*>(out.unsafeGetTensorImpl()), temp.sizes());
|
||||
auto ort_input_out = create_ort_value(invoker, out);
|
||||
auto ort_temp = create_ort_value(invoker, temp);
|
||||
copy(invoker, ort_temp, ort_input_out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace aten
|
||||
|
||||
|
|
|
|||
|
|
@ -133,5 +133,12 @@ void resize_impl_ort_(
|
|||
ORTTensorImpl* self,
|
||||
at::IntArrayRef size);
|
||||
|
||||
namespace aten {
|
||||
|
||||
// aten::nonzero(Tensor self) -> Tensor
|
||||
at::Tensor nonzero(
|
||||
const at::Tensor& self);
|
||||
|
||||
} // namespace aten
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ ops = [
|
|||
# the following unary ops not been tested:
|
||||
# ["isnan", torch.tensor([1, float('nan'), 2])]]
|
||||
# ["selu", torch.randn(10)]]
|
||||
# ["nonzero", torch.tensor([0, 2, 1, 3])]]
|
||||
# ["sign", ]]
|
||||
# ["hardsigmoid", ],
|
||||
# ["isinf", ],
|
||||
|
|
@ -476,6 +475,33 @@ class OrtOpTests(unittest.TestCase):
|
|||
assert cpu_tensor.dtype == ort_tensor.dtype
|
||||
assert torch.equal(cpu_tensor, ort_tensor.to("cpu"))
|
||||
|
||||
# tests both nonzero and nonzero.out
|
||||
def test_nonzero(self):
|
||||
device = self.get_device()
|
||||
|
||||
for cpu_tensor in [
|
||||
torch.tensor([[[-1, 0, 1], [0, 1, 0]], [[0, 1, 0], [-1, 0, 1]]], dtype=torch.long),
|
||||
torch.tensor([[[-1, 0, 1], [0, 1, 0]], [[0, 1, 0], [-1, 0, 1]]], dtype=torch.float),
|
||||
]:
|
||||
ort_tensor = cpu_tensor.to(device)
|
||||
|
||||
cpu_out_tensor = torch.tensor([], dtype=torch.long)
|
||||
ort_out_tensor = cpu_out_tensor.to(device)
|
||||
|
||||
# nonzero.out
|
||||
cpu_result = torch.nonzero(cpu_tensor, out=cpu_out_tensor)
|
||||
ort_result = torch.nonzero(ort_tensor, out=ort_out_tensor)
|
||||
assert torch.equal(cpu_out_tensor, ort_out_tensor.to("cpu"))
|
||||
assert torch.equal(cpu_result, ort_result.to("cpu"))
|
||||
|
||||
# nonzero
|
||||
cpu_result = torch.nonzero(cpu_tensor)
|
||||
ort_result = torch.nonzero(ort_tensor)
|
||||
assert torch.equal(cpu_result, ort_result.to("cpu"))
|
||||
|
||||
# check result between nonzero.out and nonzero
|
||||
assert torch.equal(ort_result.to("cpu"), ort_out_tensor.to("cpu"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue