mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Enable custom ops on ORTModule (#6740)
This commit is contained in:
parent
b7b5612159
commit
fb3f1f5cc1
2 changed files with 42 additions and 4 deletions
|
|
@ -16,6 +16,7 @@ from collections import abc
|
|||
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
from . import _utils
|
||||
|
||||
|
||||
|
|
@ -162,6 +163,9 @@ class ORTModule(torch.nn.Module):
|
|||
assert isinstance(module, torch.nn.Module), "'module' must be a torch.nn.Module"
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
# Support contrib OPs
|
||||
register_custom_ops_pytorch_exporter.register_custom_op()
|
||||
|
||||
# TODO: Single device support for now
|
||||
self._device = _utils.get_device_from_module(module)
|
||||
self._device_changed = False
|
||||
|
|
@ -435,10 +439,6 @@ class ORTModule(torch.nn.Module):
|
|||
output_names, output_dynamic_axes, self._original_module_output_type = _parse_outputs_for_onnx_export(self._original_module, inputs)
|
||||
dynamic_axes.update(output_dynamic_axes)
|
||||
|
||||
# TODO: Support contrib OPs support? user model has no hint
|
||||
# from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
# register_custom_ops_pytorch_exporter.register_custom_op()
|
||||
|
||||
# Export torch.nn.Module to ONNX
|
||||
f = io.BytesIO()
|
||||
|
||||
|
|
|
|||
|
|
@ -613,3 +613,41 @@ def test_model_with_different_cuda_devices(device):
|
|||
model.to(device)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
model(x)
|
||||
|
||||
def test_register_custom_ops_pytorch_exporter_tensor_triu():
|
||||
class SimpleTensorTriuModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
mask = torch.ones_like(x).triu(diagonal=1)
|
||||
x = x * mask
|
||||
return x
|
||||
|
||||
model = SimpleTensorTriuModel()
|
||||
model = ORTModule(model)
|
||||
user_input = torch.ones(1, 10, 10)
|
||||
|
||||
output = model(user_input)
|
||||
assert list(output.shape) == [1, 10, 10]
|
||||
|
||||
def test_register_custom_ops_pytorch_exporter_torch_triu():
|
||||
class SimpleTorchTriuModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
mask = torch.triu(torch.ones_like(x))
|
||||
x = x * mask
|
||||
return x
|
||||
|
||||
model = SimpleTorchTriuModel()
|
||||
model = ORTModule(model)
|
||||
user_input = torch.ones(1, 10, 10)
|
||||
|
||||
output = model(user_input)
|
||||
assert list(output.shape) == [1, 10, 10]
|
||||
|
|
|
|||
Loading…
Reference in a new issue