mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Eager opgen support for in-place operations with variadic args (#12125)
* use torch library binding frontend for tensorlist * fix test * allow in-place modification of variadic args * fix lint issues * update ORT eager readme Co-authored-by: Juan Paez <juanpaez@microsoft.com>
This commit is contained in:
parent
5e2109f7ef
commit
9b6ef17c5f
6 changed files with 75 additions and 45 deletions
|
|
@ -20,12 +20,18 @@ To run the eager tests:
|
|||
PYTHONPATH=~/{onnxruntime repo}/build/Linux/Debug python3 ~/{onnxruntime repo}/orttraining/orttraining/eager/test/ort_ops.py
|
||||
|
||||
## Mapping aten cpp namespace functions to onnx ops
|
||||
ORT Eager uses opgen to generate C++ shim code that intercepts PyTorch operator calls and executes them with the ORT
|
||||
backend. The opgen script is executed as part of the build process, and it can be found in
|
||||
`orttraining/orttraining/eager/opgen/opgen.py.` The shim code generated by opgen is responsible for mapping aten inputs
|
||||
to the ORT values passed into the onnx ops, invoking the onnx op, and converting outputs to the correct return
|
||||
type. Opgen also creates Python bindings for the C++ shim code to allow it to be called directly from PyTorch when
|
||||
using the ORT eager backend.
|
||||
|
||||
Useful links
|
||||
- [Onnx Op Schema](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
|
||||
- [Aten native ops](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml)
|
||||
|
||||
For mapping existing aten ops to onnx files start in `orttraining/orttraining/eager/opgen/opgen/atenops.py.` This file
|
||||
For mapping existing aten ops to onnx files, start in `orttraining/orttraining/eager/opgen/opgen/atenops.py.` This file
|
||||
drives the generator which the build runs and produces `orttraining/orttraining/eager/ort_aten.g.cpp`. Looking at the
|
||||
generated code will be helpful to understand the aten op signature and how onnx ops are invoked.
|
||||
|
||||
|
|
@ -35,16 +41,29 @@ an example `"aten::t": Transpose("self"),` maps the aten transpose function `t`
|
|||
`self` input param of `t` to the first argument of `Transpose`.
|
||||
|
||||
|
||||
Often mapping ten to onnx is not one to one, so composing is supported. Example
|
||||
Often mapping aten to onnx is not one to one, so composing is supported. Example
|
||||
`"aten::zeros_like": ConstantOfShape(Shape("self")),`.
|
||||
|
||||
|
||||
In some case more data shaping is required, so only the signature should be created such as `"aten::equal": SignatureOnly(),`.
|
||||
The implementation is then added to `orttraining/orttraining/eager/ort_aten.cpp`
|
||||
|
||||
|
||||
It may also be necessary to expose an onnx op in eager mode that does not exist as a native aten op. For this scenario,
|
||||
start in `orttraining/orttraining/eager/opgen/opgen/custom_ops.py`, which will generate code found in
|
||||
`orttraining/orttraining/eager/ort_customops.g.cpp`. Custom ops can then be invoked from PyTorch eager like
|
||||
`torch.ops.ort.<op_name>`. Note that opgen for custom ops will also need header/schema declarations, found in
|
||||
`orttraining/orttraining/eager/opgen/CustomOpDeclarations.h` (equivalent to RegistrationDeclarations.h for aten ops).
|
||||
|
||||
|
||||
Please add tests for all ops. Tests are defined in `orttraining/orttraining/eager/test/ort_ops.py`
|
||||
|
||||
## Decisions worth noting
|
||||
- Resizing the output tensor. Aten supports resizing any out tensor but prints a warning this is depracated and support
|
||||
- Resizing the output tensor: Aten supports resizing any out tensor but prints a warning this is depracated and support
|
||||
will end. With that in mind, we have decided to error in `resize_output` if the output tensor is not empty or already
|
||||
the right shape.
|
||||
- Python bindings for ONNX ops: opgen will use the TORCH_LIBRARY binding API instead of PyBind11 to bind ATen and custom
|
||||
ops to generated C++ code. Using TORCH_LIBRARY and TORCH_LIBRARY_IMPL allows us to introduce backend-specific bindings
|
||||
for native and custom ops while using the familiar PyTorch interface. More information:
|
||||
[TORCH_LIBRARY](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html),
|
||||
[TORCH_LIBRARY_IMPL](https://pytorch.org/tutorials/advanced/extend_dispatcher.html).
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB);
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> batchnorm_inplace(Tensor& X, const Tensor& scale, const Tensor& B, Tensor& input_mean, Tensor& input_var, const float epsilon, const float momentum); // {"schema": "batchnorm_inplace(Tensor(a!) X, Tensor scale, Tensor b, Tensor(b!) input_mean, Tensor(c!) input_var, float epsilon, float momentum) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "False", "default": "True"}
|
||||
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, double alpha, double beta, int64_t transA, int64_t transB);
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> batchnorm_inplace(Tensor& X, const Tensor& scale, const Tensor& B, Tensor& input_mean, Tensor& input_var, const double epsilon, const double momentum); // {"schema": "batchnorm_inplace(Tensor(a!) X, Tensor scale, Tensor b, Tensor(b!) input_mean, Tensor(c!) input_var, float epsilon, float momentum) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "False", "default": "True"}
|
||||
Tensor my_cat(TensorList tensors, int64_t dim); // {"schema": "my_cat(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "False"}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from opgen.onnxops import BatchNormalization, Gemm
|
||||
from opgen.onnxops import BatchNormalization, Gemm, Concat
|
||||
|
||||
ops = {
|
||||
"gemm": Gemm("A", "B", "C", "alpha", "beta", "transA", "transB"),
|
||||
"batchnorm_inplace": BatchNormalization("X", "scale", "B", "input_mean", "input_var", "epsilon", "momentum", 1),
|
||||
"my_cat": Concat("tensors", "dim"),
|
||||
}
|
||||
|
||||
type_promotion_ops = {}
|
||||
|
|
|
|||
|
|
@ -439,24 +439,25 @@ class ORTGen:
|
|||
assert isinstance(
|
||||
output_param.member, ast.TupleMemberType
|
||||
), "output_param.member must be of TupleMemberType"
|
||||
output_alias = self._get_alias_info(output_param.member.element_type)
|
||||
if (
|
||||
output_alias
|
||||
and self._get_alias_info(torch_p) == output_alias
|
||||
and output_alias.is_writable
|
||||
):
|
||||
if self._is_inplace(output_param.member.element_type, torch_p):
|
||||
writer.writeline(
|
||||
f"{onnx_op.outputs}[{output_index}] = ort_input_{onnx_op_index}_{onnx_op.inputs[input_index]};"
|
||||
)
|
||||
in_place_params[output_index] = cpp_param.identifier.value
|
||||
break
|
||||
elif isinstance(return_info, ast.ArrayType):
|
||||
if self._is_inplace(return_info, torch_p):
|
||||
writer.writeline(f"for (int i = 0; i < {onnx_op.outputs.count}; i++) {{")
|
||||
writer.push_indent()
|
||||
writer.writeline(
|
||||
f"{onnx_op.outputs}[i] = ort_input_{onnx_op_index}_{onnx_op.inputs[input_index]}[i];"
|
||||
)
|
||||
writer.pop_indent()
|
||||
writer.writeline("}")
|
||||
in_place_params[0] = cpp_param.identifier.value
|
||||
break
|
||||
else:
|
||||
output_alias = self._get_alias_info(return_info)
|
||||
if (
|
||||
output_alias
|
||||
and self._get_alias_info(torch_p) == output_alias
|
||||
and output_alias.is_writable
|
||||
):
|
||||
if self._is_inplace(return_info, torch_p):
|
||||
writer.writeline(
|
||||
f"{onnx_op.outputs}[0] = ort_input_{onnx_op_index}_{onnx_op.inputs[input_index]};"
|
||||
)
|
||||
|
|
@ -502,10 +503,12 @@ class ORTGen:
|
|||
return_outputs = onnx_op.outputs
|
||||
|
||||
# TODO: Pick the right "out" Torch parameter; do not assume the first one
|
||||
# TODO: Handle mutliple results
|
||||
# TODO: Handle multiple results
|
||||
# TODO: Assert return type
|
||||
|
||||
if len(in_place_params) == 0:
|
||||
if cpp_func.return_type.desugar().identifier_tokens[0].value == "void":
|
||||
pass
|
||||
elif len(in_place_params) == 0:
|
||||
# tensor options
|
||||
if set_out_tensor:
|
||||
writer.writeline(f"return {last_param.identifier.value};")
|
||||
|
|
@ -538,23 +541,22 @@ class ORTGen:
|
|||
writer.writeline("tensor_options);")
|
||||
writer.pop_indent()
|
||||
return
|
||||
elif len(in_place_params) == 1:
|
||||
writer.writeline(f"return {in_place_params[0]};")
|
||||
else:
|
||||
if len(in_place_params) == 1:
|
||||
writer.writeline(f"return {in_place_params[0]};")
|
||||
else:
|
||||
if not (
|
||||
isinstance(cpp_func.return_type, ast.TemplateType)
|
||||
and cpp_func.return_type.identifier_tokens[-1].value == "std::tuple"
|
||||
):
|
||||
raise Exception(f"")
|
||||
tensorRef = "Tensor&," * len(in_place_params)
|
||||
tensorRef = tensorRef[: len(tensorRef) - 1]
|
||||
writer.write(f"return std::tuple<{tensorRef}>(")
|
||||
for index, key in enumerate(sorted(in_place_params)):
|
||||
if index > 0:
|
||||
writer.write(", ")
|
||||
writer.write(in_place_params[key])
|
||||
writer.writeline(");")
|
||||
if not (
|
||||
isinstance(cpp_func.return_type, ast.TemplateType)
|
||||
and cpp_func.return_type.identifier_tokens[-1].value == "std::tuple"
|
||||
):
|
||||
raise Exception(f"")
|
||||
tensorRef = "Tensor&," * len(in_place_params)
|
||||
tensorRef = tensorRef[: len(tensorRef) - 1]
|
||||
writer.write(f"return std::tuple<{tensorRef}>(")
|
||||
for index, key in enumerate(sorted(in_place_params)):
|
||||
if index > 0:
|
||||
writer.write(", ")
|
||||
writer.write(in_place_params[key])
|
||||
writer.writeline(");")
|
||||
|
||||
def _write_function_registrations(self, writer: writer.SourceWriter, generated_funcs: List[MappedOpFunction]):
|
||||
writer.writeline()
|
||||
|
|
@ -581,9 +583,8 @@ class ORTGen:
|
|||
|
||||
def _write_custom_ops_registrations(self, writer: writer.SourceWriter, generated_funcs: List[MappedOpFunction]):
|
||||
writer.writeline()
|
||||
writer.writeline("void GenerateCustomOpsBindings(pybind11::module_ m) {")
|
||||
writer.writeline("TORCH_LIBRARY(ort, m) {")
|
||||
writer.push_indent()
|
||||
writer.writeline('ORT_LOG_INFO << "GenerateCustomOpsBindings init";')
|
||||
|
||||
for mapped_func in generated_funcs:
|
||||
cpp_func = mapped_func.cpp_func
|
||||
|
|
@ -692,3 +693,7 @@ class ORTGen:
|
|||
cpp_param.torch_param.append(torch_param)
|
||||
|
||||
return cpp_func
|
||||
|
||||
def _is_inplace(self, element_type, torch_p):
|
||||
output_alias = self._get_alias_info(element_type)
|
||||
return output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ void addObjectMethodsForEager(py::module& m){
|
|||
THPDevice_New(at::Device(at::DeviceType::ORT, device_index)));
|
||||
},
|
||||
py::arg("device_index") = 0);
|
||||
|
||||
|
||||
m.def("aten_ort_tensor_to_ort_value", [](at::Tensor data) {
|
||||
return ORTTensor_toORTValue(data);
|
||||
});
|
||||
|
|
@ -76,7 +76,7 @@ void addObjectMethodsForEager(py::module& m){
|
|||
return OrtValue_To_ATen_Tensor(ortvalue);
|
||||
});
|
||||
|
||||
m.def("set_device", [](size_t device_index,
|
||||
m.def("set_device", [](size_t device_index,
|
||||
const std::string& provider_type,
|
||||
const std::unordered_map<std::string, std::string>& arguments){
|
||||
auto status = GetORTBackendsManager().set_device(device_index, provider_type, arguments);
|
||||
|
|
@ -89,9 +89,6 @@ void addObjectMethodsForEager(py::module& m){
|
|||
m.def("get_ort_device_provider_info", [](size_t torch_device_index){
|
||||
return GetORTBackendsManager().GetOrtDeviceProviderInfo(torch_device_index);
|
||||
});
|
||||
|
||||
auto customop_module = m.def_submodule("custom_ops");
|
||||
torch_ort::eager::GenerateCustomOpsBindings(customop_module);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class OrtOpTests(unittest.TestCase):
|
|||
cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
||||
ort_ones = cpu_ones.to(device)
|
||||
cpu_ans = cpu_ones * 4
|
||||
ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
|
||||
ort_ans = torch.ops.ort.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
|
||||
assert torch.allclose(cpu_ans, ort_ans.cpu())
|
||||
|
||||
def test_batchnormalization_inplace(self):
|
||||
|
|
@ -148,11 +148,18 @@ class OrtOpTests(unittest.TestCase):
|
|||
bias = torch.Tensor([0.0, 1.0]).to(device)
|
||||
mean = torch.Tensor([0.0, 3.0]).to(device)
|
||||
var = torch.Tensor([1.0, 1.5]).to(device)
|
||||
y, mean_out, var_out = torch_ort.custom_ops.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9)
|
||||
y, mean_out, var_out = torch.ops.ort.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9)
|
||||
assert torch.allclose(x.cpu(), y.cpu()), "x != y"
|
||||
assert torch.allclose(mean.cpu(), mean_out.cpu()), "mean != mean_out"
|
||||
assert torch.allclose(var.cpu(), var_out.cpu()), "var != var_out"
|
||||
|
||||
def test_variadic_inputs(self):
|
||||
device = self.get_device()
|
||||
tensor = torch.ones(2, 2).to(device)
|
||||
expected = torch.ones(2, 6)
|
||||
out = torch.ops.ort.my_cat([tensor, tensor, tensor], 1)
|
||||
assert torch.allclose(expected, out.cpu())
|
||||
|
||||
def test_max(self):
|
||||
cpu_tensor = torch.rand(10, 10)
|
||||
ort_tensor = cpu_tensor.to("ort")
|
||||
|
|
|
|||
Loading…
Reference in a new issue