diff --git a/orttraining/orttraining/eager/opgen/opgen.py b/orttraining/orttraining/eager/opgen/opgen.py index 77192035df..9f90c8e611 100755 --- a/orttraining/orttraining/eager/opgen/opgen.py +++ b/orttraining/orttraining/eager/opgen/opgen.py @@ -3,13 +3,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from opgen.writer import SourceWriter as SourceWriter -from opgen.parser import cpp_create_from_file as CPPParser -import sys -import os -from opgen.generator import ORTGen as ORTGen -from importlib.machinery import SourceFileLoader import argparse +import os +import sys +from importlib.machinery import SourceFileLoader + +from opgen.generator import ORTGen as ORTGen +from opgen.parser import cpp_create_from_file as CPPParser +from opgen.writer import SourceWriter as SourceWriter parser = argparse.ArgumentParser(description="Generate ORT ATen operations") parser.add_argument( @@ -24,7 +25,12 @@ parser.add_argument( args = parser.parse_args() ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module() -ortgen = ORTGen(ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops) +ortgen = ORTGen( + ops_module.ops, + type_promotion_ops=ops_module.type_promotion_ops, + custom_ops=args.custom_ops, + aten_output_type=ops_module.aten_output_type, +) regdecs_path = args.header_file print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}") diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 7114827144..477631ffa4 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -45,6 +45,7 @@ class GeluGrad(ONNXOp): ops = {} type_promotion_ops = [] +aten_output_type = {} # the following op list is for ops that have a .out version. Often this is the only op needing to be implemented # and the regular and inplace(_) version derive from the .out. @@ -68,7 +69,6 @@ unary_ops_with_out = [ "hardsigmoid", "log", "neg", - "nonzero", "reciprocal", "round", "sigmoid", @@ -92,7 +92,6 @@ unary_ops = [ "det", "isinf", "isnan", - "nonzero", "relu", "selu", ] @@ -169,8 +168,14 @@ hand_implemented = { "aten::equal": SignatureOnly(), "aten::_softmax": Softmax("self", axis="dim"), "aten::argmax.out": SignatureOnly(), + "aten::nonzero": Transpose(NonZero("self")), + "aten::nonzero.out": SignatureOnly(), } +# If the aten op expects a specific output type that differs from self +# add the op and type to aten_output_type +aten_output_type["aten::nonzero"] = "at::ScalarType::Long" + # Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439 # This is done to make sure it is backward and future compatible if version.parse(torch.__version__) < version.parse(TORCH_API_CHANGE_VERSION): diff --git a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py index 4ba9d0af6b..61815f930f 100644 --- a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py +++ b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py @@ -1,14 +1,11 @@ from copy import deepcopy -from opgen.generator import AttrType, ONNXAttr - -from opgen.generator import ( - ORTGen as ORTGen, - ONNXOp as ONNXOp, - SignatureOnly as SignatureOnly, - MakeTorchFallback as MakeTorchFallback, -) - +from opgen.generator import AttrType +from opgen.generator import MakeTorchFallback as MakeTorchFallback +from opgen.generator import ONNXAttr +from opgen.generator import ONNXOp as ONNXOp +from opgen.generator import ORTGen as ORTGen +from opgen.generator import SignatureOnly as SignatureOnly from opgen.onnxops import * ops = { @@ -17,3 +14,4 @@ ops = { } type_promotion_ops = {} +aten_output_type = {} diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index 196d9bf655..5741017c98 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -114,13 +114,18 @@ class ORTGen: _custom_ops: bool def __init__( - self, ops: Optional[Dict[str, ONNXOp]] = None, custom_ops: bool = False, type_promotion_ops: List = () + self, + ops: Optional[Dict[str, ONNXOp]] = None, + custom_ops: bool = False, + type_promotion_ops: List = (), + aten_output_type: Dict = (), ): self._mapped_ops = {} if ops: self.register_many(ops) self._custom_ops = custom_ops self.type_promotion_ops = type_promotion_ops + self.aten_output_type = aten_output_type def register(self, aten_name: str, onnx_op: ONNXOp): self._mapped_ops[aten_name] = onnx_op @@ -517,6 +522,11 @@ class ORTGen: writer.write(".options()") if need_type_promotion: writer.write(".dtype(*promoted_type)") + + # do we need to set type on the returned value + if mapped_func.mapped_op_name in self.aten_output_type: + writer.write(f".dtype({self.aten_output_type[mapped_func.mapped_op_name]})") + writer.writeline(";") writer.writeline("return aten_tensor_from_ort(") diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 1ac4251ebd..e66f968312 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -957,6 +957,24 @@ at::Tensor& fill__Scalar( return self; } +// aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +at::Tensor& nonzero_out( + const at::Tensor& self, + // *, + at::Tensor& out) { + ORT_LOG_FN(self, out); + + auto temp = eager::aten::nonzero(self); + + // resize out, then copy nonzero result into it. + auto& invoker = GetORTInvoker(self.device()); + resize_output(invoker, dynamic_cast(out.unsafeGetTensorImpl()), temp.sizes()); + auto ort_input_out = create_ort_value(invoker, out); + auto ort_temp = create_ort_value(invoker, temp); + copy(invoker, ort_temp, ort_input_out); + + return out; +} } // namespace aten diff --git a/orttraining/orttraining/eager/ort_aten.h b/orttraining/orttraining/eager/ort_aten.h index a29b272ec6..88ddb15a70 100644 --- a/orttraining/orttraining/eager/ort_aten.h +++ b/orttraining/orttraining/eager/ort_aten.h @@ -133,5 +133,12 @@ void resize_impl_ort_( ORTTensorImpl* self, at::IntArrayRef size); +namespace aten { + +// aten::nonzero(Tensor self) -> Tensor +at::Tensor nonzero( + const at::Tensor& self); + +} // namespace aten } // namespace eager } // namespace torch_ort diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index cf15b53dae..597911efa9 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -42,7 +42,6 @@ ops = [ # the following unary ops not been tested: # ["isnan", torch.tensor([1, float('nan'), 2])]] # ["selu", torch.randn(10)]] -# ["nonzero", torch.tensor([0, 2, 1, 3])]] # ["sign", ]] # ["hardsigmoid", ], # ["isinf", ], @@ -476,6 +475,33 @@ class OrtOpTests(unittest.TestCase): assert cpu_tensor.dtype == ort_tensor.dtype assert torch.equal(cpu_tensor, ort_tensor.to("cpu")) + # tests both nonzero and nonzero.out + def test_nonzero(self): + device = self.get_device() + + for cpu_tensor in [ + torch.tensor([[[-1, 0, 1], [0, 1, 0]], [[0, 1, 0], [-1, 0, 1]]], dtype=torch.long), + torch.tensor([[[-1, 0, 1], [0, 1, 0]], [[0, 1, 0], [-1, 0, 1]]], dtype=torch.float), + ]: + ort_tensor = cpu_tensor.to(device) + + cpu_out_tensor = torch.tensor([], dtype=torch.long) + ort_out_tensor = cpu_out_tensor.to(device) + + # nonzero.out + cpu_result = torch.nonzero(cpu_tensor, out=cpu_out_tensor) + ort_result = torch.nonzero(ort_tensor, out=ort_out_tensor) + assert torch.equal(cpu_out_tensor, ort_out_tensor.to("cpu")) + assert torch.equal(cpu_result, ort_result.to("cpu")) + + # nonzero + cpu_result = torch.nonzero(cpu_tensor) + ort_result = torch.nonzero(ort_tensor) + assert torch.equal(cpu_result, ort_result.to("cpu")) + + # check result between nonzero.out and nonzero + assert torch.equal(ort_result.to("cpu"), ort_out_tensor.to("cpu")) + if __name__ == "__main__": unittest.main()