Enable custom ops on ORTModule (#6740)

This commit is contained in:
Thiago Crepaldi 2021-02-18 09:08:10 -08:00 committed by GitHub
parent b7b5612159
commit fb3f1f5cc1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 4 deletions

View file

@ -16,6 +16,7 @@ from collections import abc
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
from onnxruntime.capi import _pybind_state as C
from onnxruntime.training import register_custom_ops_pytorch_exporter
from . import _utils
@ -162,6 +163,9 @@ class ORTModule(torch.nn.Module):
assert isinstance(module, torch.nn.Module), "'module' must be a torch.nn.Module"
super(ORTModule, self).__init__()
# Support contrib OPs
register_custom_ops_pytorch_exporter.register_custom_op()
# TODO: Single device support for now
self._device = _utils.get_device_from_module(module)
self._device_changed = False
@ -435,10 +439,6 @@ class ORTModule(torch.nn.Module):
output_names, output_dynamic_axes, self._original_module_output_type = _parse_outputs_for_onnx_export(self._original_module, inputs)
dynamic_axes.update(output_dynamic_axes)
# TODO: Support contrib OPs support? user model has no hint
# from onnxruntime.training import register_custom_ops_pytorch_exporter
# register_custom_ops_pytorch_exporter.register_custom_op()
# Export torch.nn.Module to ONNX
f = io.BytesIO()

View file

@ -613,3 +613,41 @@ def test_model_with_different_cuda_devices(device):
model.to(device)
x = torch.randn(N, D_in, device=device)
model(x)
def test_register_custom_ops_pytorch_exporter_tensor_triu():
class SimpleTensorTriuModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 10)
def forward(self, x):
x = self.fc1(x)
mask = torch.ones_like(x).triu(diagonal=1)
x = x * mask
return x
model = SimpleTensorTriuModel()
model = ORTModule(model)
user_input = torch.ones(1, 10, 10)
output = model(user_input)
assert list(output.shape) == [1, 10, 10]
def test_register_custom_ops_pytorch_exporter_torch_triu():
class SimpleTorchTriuModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 10)
def forward(self, x):
x = self.fc1(x)
mask = torch.triu(torch.ones_like(x))
x = x * mask
return x
model = SimpleTorchTriuModel()
model = ORTModule(model)
user_input = torch.ones(1, 10, 10)
output = model(user_input)
assert list(output.shape) == [1, 10, 10]