Eager opgen support for in-place operations with variadic args (#12125)

* use torch library binding frontend for tensorlist

* fix test

* allow in-place modification of variadic args

* fix lint issues

* update ORT eager readme

Co-authored-by: Juan Paez <juanpaez@microsoft.com>
This commit is contained in:
Juan Paez 2022-07-19 21:01:00 -07:00 committed by GitHub
parent 5e2109f7ef
commit 9b6ef17c5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 45 deletions

View file

@ -20,12 +20,18 @@ To run the eager tests:
PYTHONPATH=~/{onnxruntime repo}/build/Linux/Debug python3 ~/{onnxruntime repo}/orttraining/orttraining/eager/test/ort_ops.py
## Mapping aten cpp namespace functions to onnx ops
ORT Eager uses opgen to generate C++ shim code that intercepts PyTorch operator calls and executes them with the ORT
backend. The opgen script is executed as part of the build process, and it can be found in
`orttraining/orttraining/eager/opgen/opgen.py.` The shim code generated by opgen is responsible for mapping aten inputs
to the ORT values passed into the onnx ops, invoking the onnx op, and converting outputs to the correct return
type. Opgen also creates Python bindings for the C++ shim code to allow it to be called directly from PyTorch when
using the ORT eager backend.
Useful links
- [Onnx Op Schema](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
- [Aten native ops](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml)
For mapping existing aten ops to onnx files start in `orttraining/orttraining/eager/opgen/opgen/atenops.py.` This file
For mapping existing aten ops to onnx files, start in `orttraining/orttraining/eager/opgen/opgen/atenops.py.` This file
drives the generator which the build runs and produces `orttraining/orttraining/eager/ort_aten.g.cpp`. Looking at the
generated code will be helpful to understand the aten op signature and how onnx ops are invoked.
@ -35,16 +41,29 @@ an example `"aten::t": Transpose("self"),` maps the aten transpose function `t`
`self` input param of `t` to the first argument of `Transpose`.
Often mapping ten to onnx is not one to one, so composing is supported. Example
Often mapping aten to onnx is not one to one, so composing is supported. Example
`"aten::zeros_like": ConstantOfShape(Shape("self")),`.
In some case more data shaping is required, so only the signature should be created such as `"aten::equal": SignatureOnly(),`.
The implementation is then added to `orttraining/orttraining/eager/ort_aten.cpp`
It may also be necessary to expose an onnx op in eager mode that does not exist as a native aten op. For this scenario,
start in `orttraining/orttraining/eager/opgen/opgen/custom_ops.py`, which will generate code found in
`orttraining/orttraining/eager/ort_customops.g.cpp`. Custom ops can then be invoked from PyTorch eager like
`torch.ops.ort.<op_name>`. Note that opgen for custom ops will also need header/schema declarations, found in
`orttraining/orttraining/eager/opgen/CustomOpDeclarations.h` (equivalent to RegistrationDeclarations.h for aten ops).
Please add tests for all ops. Tests are defined in `orttraining/orttraining/eager/test/ort_ops.py`
## Decisions worth noting
- Resizing the output tensor. Aten supports resizing any out tensor but prints a warning this is depracated and support
- Resizing the output tensor: Aten supports resizing any out tensor but prints a warning this is depracated and support
will end. With that in mind, we have decided to error in `resize_output` if the output tensor is not empty or already
the right shape.
- Python bindings for ONNX ops: opgen will use the TORCH_LIBRARY binding API instead of PyBind11 to bind ATen and custom
ops to generated C++ code. Using TORCH_LIBRARY and TORCH_LIBRARY_IMPL allows us to introduce backend-specific bindings
for native and custom ops while using the familiar PyTorch interface. More information:
[TORCH_LIBRARY](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html),
[TORCH_LIBRARY_IMPL](https://pytorch.org/tutorials/advanced/extend_dispatcher.html).

View file

@ -1,2 +1,3 @@
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, float alpha, float beta, int transA, int transB);
std::tuple<Tensor&, Tensor&, Tensor&> batchnorm_inplace(Tensor& X, const Tensor& scale, const Tensor& B, Tensor& input_mean, Tensor& input_var, const float epsilon, const float momentum); // {"schema": "batchnorm_inplace(Tensor(a!) X, Tensor scale, Tensor b, Tensor(b!) input_mean, Tensor(c!) input_var, float epsilon, float momentum) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "False", "default": "True"}
Tensor gemm(const Tensor& A, const Tensor& B, const Tensor& C, double alpha, double beta, int64_t transA, int64_t transB);
std::tuple<Tensor&, Tensor&, Tensor&> batchnorm_inplace(Tensor& X, const Tensor& scale, const Tensor& B, Tensor& input_mean, Tensor& input_var, const double epsilon, const double momentum); // {"schema": "batchnorm_inplace(Tensor(a!) X, Tensor scale, Tensor b, Tensor(b!) input_mean, Tensor(c!) input_var, float epsilon, float momentum) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "False", "default": "True"}
Tensor my_cat(TensorList tensors, int64_t dim); // {"schema": "my_cat(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "False"}

View file

@ -1,8 +1,9 @@
from opgen.onnxops import BatchNormalization, Gemm
from opgen.onnxops import BatchNormalization, Gemm, Concat
ops = {
"gemm": Gemm("A", "B", "C", "alpha", "beta", "transA", "transB"),
"batchnorm_inplace": BatchNormalization("X", "scale", "B", "input_mean", "input_var", "epsilon", "momentum", 1),
"my_cat": Concat("tensors", "dim"),
}
type_promotion_ops = {}

View file

@ -439,24 +439,25 @@ class ORTGen:
assert isinstance(
output_param.member, ast.TupleMemberType
), "output_param.member must be of TupleMemberType"
output_alias = self._get_alias_info(output_param.member.element_type)
if (
output_alias
and self._get_alias_info(torch_p) == output_alias
and output_alias.is_writable
):
if self._is_inplace(output_param.member.element_type, torch_p):
writer.writeline(
f"{onnx_op.outputs}[{output_index}] = ort_input_{onnx_op_index}_{onnx_op.inputs[input_index]};"
)
in_place_params[output_index] = cpp_param.identifier.value
break
elif isinstance(return_info, ast.ArrayType):
if self._is_inplace(return_info, torch_p):
writer.writeline(f"for (int i = 0; i < {onnx_op.outputs.count}; i++) {{")
writer.push_indent()
writer.writeline(
f"{onnx_op.outputs}[i] = ort_input_{onnx_op_index}_{onnx_op.inputs[input_index]}[i];"
)
writer.pop_indent()
writer.writeline("}")
in_place_params[0] = cpp_param.identifier.value
break
else:
output_alias = self._get_alias_info(return_info)
if (
output_alias
and self._get_alias_info(torch_p) == output_alias
and output_alias.is_writable
):
if self._is_inplace(return_info, torch_p):
writer.writeline(
f"{onnx_op.outputs}[0] = ort_input_{onnx_op_index}_{onnx_op.inputs[input_index]};"
)
@ -502,10 +503,12 @@ class ORTGen:
return_outputs = onnx_op.outputs
# TODO: Pick the right "out" Torch parameter; do not assume the first one
# TODO: Handle mutliple results
# TODO: Handle multiple results
# TODO: Assert return type
if len(in_place_params) == 0:
if cpp_func.return_type.desugar().identifier_tokens[0].value == "void":
pass
elif len(in_place_params) == 0:
# tensor options
if set_out_tensor:
writer.writeline(f"return {last_param.identifier.value};")
@ -538,23 +541,22 @@ class ORTGen:
writer.writeline("tensor_options);")
writer.pop_indent()
return
elif len(in_place_params) == 1:
writer.writeline(f"return {in_place_params[0]};")
else:
if len(in_place_params) == 1:
writer.writeline(f"return {in_place_params[0]};")
else:
if not (
isinstance(cpp_func.return_type, ast.TemplateType)
and cpp_func.return_type.identifier_tokens[-1].value == "std::tuple"
):
raise Exception(f"")
tensorRef = "Tensor&," * len(in_place_params)
tensorRef = tensorRef[: len(tensorRef) - 1]
writer.write(f"return std::tuple<{tensorRef}>(")
for index, key in enumerate(sorted(in_place_params)):
if index > 0:
writer.write(", ")
writer.write(in_place_params[key])
writer.writeline(");")
if not (
isinstance(cpp_func.return_type, ast.TemplateType)
and cpp_func.return_type.identifier_tokens[-1].value == "std::tuple"
):
raise Exception(f"")
tensorRef = "Tensor&," * len(in_place_params)
tensorRef = tensorRef[: len(tensorRef) - 1]
writer.write(f"return std::tuple<{tensorRef}>(")
for index, key in enumerate(sorted(in_place_params)):
if index > 0:
writer.write(", ")
writer.write(in_place_params[key])
writer.writeline(");")
def _write_function_registrations(self, writer: writer.SourceWriter, generated_funcs: List[MappedOpFunction]):
writer.writeline()
@ -581,9 +583,8 @@ class ORTGen:
def _write_custom_ops_registrations(self, writer: writer.SourceWriter, generated_funcs: List[MappedOpFunction]):
writer.writeline()
writer.writeline("void GenerateCustomOpsBindings(pybind11::module_ m) {")
writer.writeline("TORCH_LIBRARY(ort, m) {")
writer.push_indent()
writer.writeline('ORT_LOG_INFO << "GenerateCustomOpsBindings init";')
for mapped_func in generated_funcs:
cpp_func = mapped_func.cpp_func
@ -692,3 +693,7 @@ class ORTGen:
cpp_param.torch_param.append(torch_param)
return cpp_func
def _is_inplace(self, element_type, torch_p):
output_alias = self._get_alias_info(element_type)
return output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable

View file

@ -68,7 +68,7 @@ void addObjectMethodsForEager(py::module& m){
THPDevice_New(at::Device(at::DeviceType::ORT, device_index)));
},
py::arg("device_index") = 0);
m.def("aten_ort_tensor_to_ort_value", [](at::Tensor data) {
return ORTTensor_toORTValue(data);
});
@ -76,7 +76,7 @@ void addObjectMethodsForEager(py::module& m){
return OrtValue_To_ATen_Tensor(ortvalue);
});
m.def("set_device", [](size_t device_index,
m.def("set_device", [](size_t device_index,
const std::string& provider_type,
const std::unordered_map<std::string, std::string>& arguments){
auto status = GetORTBackendsManager().set_device(device_index, provider_type, arguments);
@ -89,9 +89,6 @@ void addObjectMethodsForEager(py::module& m){
m.def("get_ort_device_provider_info", [](size_t torch_device_index){
return GetORTBackendsManager().GetOrtDeviceProviderInfo(torch_device_index);
});
auto customop_module = m.def_submodule("custom_ops");
torch_ort::eager::GenerateCustomOpsBindings(customop_module);
}
}

View file

@ -138,7 +138,7 @@ class OrtOpTests(unittest.TestCase):
cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
ort_ones = cpu_ones.to(device)
cpu_ans = cpu_ones * 4
ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
ort_ans = torch.ops.ort.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
assert torch.allclose(cpu_ans, ort_ans.cpu())
def test_batchnormalization_inplace(self):
@ -148,11 +148,18 @@ class OrtOpTests(unittest.TestCase):
bias = torch.Tensor([0.0, 1.0]).to(device)
mean = torch.Tensor([0.0, 3.0]).to(device)
var = torch.Tensor([1.0, 1.5]).to(device)
y, mean_out, var_out = torch_ort.custom_ops.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9)
y, mean_out, var_out = torch.ops.ort.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9)
assert torch.allclose(x.cpu(), y.cpu()), "x != y"
assert torch.allclose(mean.cpu(), mean_out.cpu()), "mean != mean_out"
assert torch.allclose(var.cpu(), var_out.cpu()), "var != var_out"
def test_variadic_inputs(self):
device = self.get_device()
tensor = torch.ones(2, 2).to(device)
expected = torch.ones(2, 6)
out = torch.ops.ort.my_cat([tensor, tensor, tensor], 1)
assert torch.allclose(expected, out.cpu())
def test_max(self):
cpu_tensor = torch.rand(10, 10)
ort_tensor = cpu_tensor.to("ort")