diff --git a/orttraining/orttraining/eager/README.md b/orttraining/orttraining/eager/README.md index 25acbf3e78..684a8e8bbb 100644 --- a/orttraining/orttraining/eager/README.md +++ b/orttraining/orttraining/eager/README.md @@ -20,12 +20,18 @@ To run the eager tests: PYTHONPATH=~/{onnxruntime repo}/build/Linux/Debug python3 ~/{onnxruntime repo}/orttraining/orttraining/eager/test/ort_ops.py ## Mapping aten cpp namespace functions to onnx ops +ORT Eager uses opgen to generate C++ shim code that intercepts PyTorch operator calls and executes them with the ORT +backend. The opgen script is executed as part of the build process, and it can be found in +`orttraining/orttraining/eager/opgen/opgen.py.` The shim code generated by opgen is responsible for mapping aten inputs +to the ORT values passed into the onnx ops, invoking the onnx op, and converting outputs to the correct return +type. Opgen also creates Python bindings for the C++ shim code to allow it to be called directly from PyTorch when +using the ORT eager backend. Useful links - [Onnx Op Schema](https://github.com/onnx/onnx/blob/main/docs/Operators.md) - [Aten native ops](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml) -For mapping existing aten ops to onnx files start in `orttraining/orttraining/eager/opgen/opgen/atenops.py.` This file +For mapping existing aten ops to onnx files, start in `orttraining/orttraining/eager/opgen/opgen/atenops.py.` This file drives the generator which the build runs and produces `orttraining/orttraining/eager/ort_aten.g.cpp`. Looking at the generated code will be helpful to understand the aten op signature and how onnx ops are invoked. @@ -35,16 +41,29 @@ an example `"aten::t": Transpose("self"),` maps the aten transpose function `t` `self` input param of `t` to the first argument of `Transpose`. -Often mapping ten to onnx is not one to one, so composing is supported. Example +Often mapping aten to onnx is not one to one, so composing is supported. Example `"aten::zeros_like": ConstantOfShape(Shape("self")),`. In some case more data shaping is required, so only the signature should be created such as `"aten::equal": SignatureOnly(),`. The implementation is then added to `orttraining/orttraining/eager/ort_aten.cpp` + +It may also be necessary to expose an onnx op in eager mode that does not exist as a native aten op. For this scenario, +start in `orttraining/orttraining/eager/opgen/opgen/custom_ops.py`, which will generate code found in +`orttraining/orttraining/eager/ort_customops.g.cpp`. Custom ops can then be invoked from PyTorch eager like +`torch.ops.ort.`. Note that opgen for custom ops will also need header/schema declarations, found in +`orttraining/orttraining/eager/opgen/CustomOpDeclarations.h` (equivalent to RegistrationDeclarations.h for aten ops). + + Please add tests for all ops. Tests are defined in `orttraining/orttraining/eager/test/ort_ops.py` ## Decisions worth noting -- Resizing the output tensor. Aten supports resizing any out tensor but prints a warning this is depracated and support +- Resizing the output tensor: Aten supports resizing any out tensor but prints a warning this is depracated and support will end. With that in mind, we have decided to error in `resize_output` if the output tensor is not empty or already the right shape. +- Python bindings for ONNX ops: opgen will use the TORCH_LIBRARY binding API instead of PyBind11 to bind ATen and custom +ops to generated C++ code. Using TORCH_LIBRARY and TORCH_LIBRARY_IMPL allows us to introduce backend-specific bindings +for native and custom ops while using the familiar PyTorch interface. More information: +[TORCH_LIBRARY](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html), +[TORCH_LIBRARY_IMPL](https://pytorch.org/tutorials/advanced/extend_dispatcher.html). diff --git a/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h b/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h index a7bf492f74..f8c4893afa 100644 --- a/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h +++ b/orttraining/orttraining/eager/opgen/CustomOpDeclarations.h @@ -1,2 +1,3 @@ -Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB); -std::tuple batchnorm_inplace(Tensor& X, const Tensor& scale, const Tensor& B, Tensor& input_mean, Tensor& input_var, const float epsilon, const float momentum); // {"schema": "batchnorm_inplace(Tensor(a!) X, Tensor scale, Tensor b, Tensor(b!) input_mean, Tensor(c!) input_var, float epsilon, float momentum) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "False", "default": "True"} \ No newline at end of file +Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, double alpha, double beta, int64_t transA, int64_t transB); +std::tuple batchnorm_inplace(Tensor& X, const Tensor& scale, const Tensor& B, Tensor& input_mean, Tensor& input_var, const double epsilon, const double momentum); // {"schema": "batchnorm_inplace(Tensor(a!) X, Tensor scale, Tensor b, Tensor(b!) input_mean, Tensor(c!) input_var, float epsilon, float momentum) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "False", "default": "True"} +Tensor my_cat(TensorList tensors, int64_t dim); // {"schema": "my_cat(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "False"} diff --git a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py index ab0b11cebf..a8031fe7d8 100644 --- a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py +++ b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py @@ -1,8 +1,9 @@ -from opgen.onnxops import BatchNormalization, Gemm +from opgen.onnxops import BatchNormalization, Gemm, Concat ops = { "gemm": Gemm("A", "B", "C", "alpha", "beta", "transA", "transB"), "batchnorm_inplace": BatchNormalization("X", "scale", "B", "input_mean", "input_var", "epsilon", "momentum", 1), + "my_cat": Concat("tensors", "dim"), } type_promotion_ops = {} diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index 9b4941ed6f..6767158e75 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -439,24 +439,25 @@ class ORTGen: assert isinstance( output_param.member, ast.TupleMemberType ), "output_param.member must be of TupleMemberType" - output_alias = self._get_alias_info(output_param.member.element_type) - if ( - output_alias - and self._get_alias_info(torch_p) == output_alias - and output_alias.is_writable - ): + if self._is_inplace(output_param.member.element_type, torch_p): writer.writeline( f"{onnx_op.outputs}[{output_index}] = ort_input_{onnx_op_index}_{onnx_op.inputs[input_index]};" ) in_place_params[output_index] = cpp_param.identifier.value break + elif isinstance(return_info, ast.ArrayType): + if self._is_inplace(return_info, torch_p): + writer.writeline(f"for (int i = 0; i < {onnx_op.outputs.count}; i++) {{") + writer.push_indent() + writer.writeline( + f"{onnx_op.outputs}[i] = ort_input_{onnx_op_index}_{onnx_op.inputs[input_index]}[i];" + ) + writer.pop_indent() + writer.writeline("}") + in_place_params[0] = cpp_param.identifier.value + break else: - output_alias = self._get_alias_info(return_info) - if ( - output_alias - and self._get_alias_info(torch_p) == output_alias - and output_alias.is_writable - ): + if self._is_inplace(return_info, torch_p): writer.writeline( f"{onnx_op.outputs}[0] = ort_input_{onnx_op_index}_{onnx_op.inputs[input_index]};" ) @@ -502,10 +503,12 @@ class ORTGen: return_outputs = onnx_op.outputs # TODO: Pick the right "out" Torch parameter; do not assume the first one - # TODO: Handle mutliple results + # TODO: Handle multiple results # TODO: Assert return type - if len(in_place_params) == 0: + if cpp_func.return_type.desugar().identifier_tokens[0].value == "void": + pass + elif len(in_place_params) == 0: # tensor options if set_out_tensor: writer.writeline(f"return {last_param.identifier.value};") @@ -538,23 +541,22 @@ class ORTGen: writer.writeline("tensor_options);") writer.pop_indent() return + elif len(in_place_params) == 1: + writer.writeline(f"return {in_place_params[0]};") else: - if len(in_place_params) == 1: - writer.writeline(f"return {in_place_params[0]};") - else: - if not ( - isinstance(cpp_func.return_type, ast.TemplateType) - and cpp_func.return_type.identifier_tokens[-1].value == "std::tuple" - ): - raise Exception(f"") - tensorRef = "Tensor&," * len(in_place_params) - tensorRef = tensorRef[: len(tensorRef) - 1] - writer.write(f"return std::tuple<{tensorRef}>(") - for index, key in enumerate(sorted(in_place_params)): - if index > 0: - writer.write(", ") - writer.write(in_place_params[key]) - writer.writeline(");") + if not ( + isinstance(cpp_func.return_type, ast.TemplateType) + and cpp_func.return_type.identifier_tokens[-1].value == "std::tuple" + ): + raise Exception(f"") + tensorRef = "Tensor&," * len(in_place_params) + tensorRef = tensorRef[: len(tensorRef) - 1] + writer.write(f"return std::tuple<{tensorRef}>(") + for index, key in enumerate(sorted(in_place_params)): + if index > 0: + writer.write(", ") + writer.write(in_place_params[key]) + writer.writeline(");") def _write_function_registrations(self, writer: writer.SourceWriter, generated_funcs: List[MappedOpFunction]): writer.writeline() @@ -581,9 +583,8 @@ class ORTGen: def _write_custom_ops_registrations(self, writer: writer.SourceWriter, generated_funcs: List[MappedOpFunction]): writer.writeline() - writer.writeline("void GenerateCustomOpsBindings(pybind11::module_ m) {") + writer.writeline("TORCH_LIBRARY(ort, m) {") writer.push_indent() - writer.writeline('ORT_LOG_INFO << "GenerateCustomOpsBindings init";') for mapped_func in generated_funcs: cpp_func = mapped_func.cpp_func @@ -692,3 +693,7 @@ class ORTGen: cpp_param.torch_param.append(torch_param) return cpp_func + + def _is_inplace(self, element_type, torch_p): + output_alias = self._get_alias_info(element_type) + return output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable diff --git a/orttraining/orttraining/eager/ort_eager.cpp b/orttraining/orttraining/eager/ort_eager.cpp index f7d3d730c3..3937cbfd7d 100644 --- a/orttraining/orttraining/eager/ort_eager.cpp +++ b/orttraining/orttraining/eager/ort_eager.cpp @@ -68,7 +68,7 @@ void addObjectMethodsForEager(py::module& m){ THPDevice_New(at::Device(at::DeviceType::ORT, device_index))); }, py::arg("device_index") = 0); - + m.def("aten_ort_tensor_to_ort_value", [](at::Tensor data) { return ORTTensor_toORTValue(data); }); @@ -76,7 +76,7 @@ void addObjectMethodsForEager(py::module& m){ return OrtValue_To_ATen_Tensor(ortvalue); }); - m.def("set_device", [](size_t device_index, + m.def("set_device", [](size_t device_index, const std::string& provider_type, const std::unordered_map& arguments){ auto status = GetORTBackendsManager().set_device(device_index, provider_type, arguments); @@ -89,9 +89,6 @@ void addObjectMethodsForEager(py::module& m){ m.def("get_ort_device_provider_info", [](size_t torch_device_index){ return GetORTBackendsManager().GetOrtDeviceProviderInfo(torch_device_index); }); - - auto customop_module = m.def_submodule("custom_ops"); - torch_ort::eager::GenerateCustomOpsBindings(customop_module); } } diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index a93c364335..7c7aa73d37 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -138,7 +138,7 @@ class OrtOpTests(unittest.TestCase): cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) ort_ones = cpu_ones.to(device) cpu_ans = cpu_ones * 4 - ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0) + ort_ans = torch.ops.ort.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0) assert torch.allclose(cpu_ans, ort_ans.cpu()) def test_batchnormalization_inplace(self): @@ -148,11 +148,18 @@ class OrtOpTests(unittest.TestCase): bias = torch.Tensor([0.0, 1.0]).to(device) mean = torch.Tensor([0.0, 3.0]).to(device) var = torch.Tensor([1.0, 1.5]).to(device) - y, mean_out, var_out = torch_ort.custom_ops.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9) + y, mean_out, var_out = torch.ops.ort.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9) assert torch.allclose(x.cpu(), y.cpu()), "x != y" assert torch.allclose(mean.cpu(), mean_out.cpu()), "mean != mean_out" assert torch.allclose(var.cpu(), var_out.cpu()), "var != var_out" + def test_variadic_inputs(self): + device = self.get_device() + tensor = torch.ones(2, 2).to(device) + expected = torch.ones(2, 6) + out = torch.ops.ort.my_cat([tensor, tensor, tensor], 1) + assert torch.allclose(expected, out.cpu()) + def test_max(self): cpu_tensor = torch.rand(10, 10) ort_tensor = cpu_tensor.to("ort")