From fb3f1f5cc133d3b86d5cc412abb2e8a6fe963207 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 18 Feb 2021 09:08:10 -0800 Subject: [PATCH] Enable custom ops on ORTModule (#6740) --- .../orttraining/python/training/ortmodule.py | 8 ++-- .../python/orttraining_test_ortmodule_api.py | 38 +++++++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 7bcff50122..0d4784614c 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -16,6 +16,7 @@ from collections import abc from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict from onnxruntime.capi import _pybind_state as C +from onnxruntime.training import register_custom_ops_pytorch_exporter from . import _utils @@ -162,6 +163,9 @@ class ORTModule(torch.nn.Module): assert isinstance(module, torch.nn.Module), "'module' must be a torch.nn.Module" super(ORTModule, self).__init__() + # Support contrib OPs + register_custom_ops_pytorch_exporter.register_custom_op() + # TODO: Single device support for now self._device = _utils.get_device_from_module(module) self._device_changed = False @@ -435,10 +439,6 @@ class ORTModule(torch.nn.Module): output_names, output_dynamic_axes, self._original_module_output_type = _parse_outputs_for_onnx_export(self._original_module, inputs) dynamic_axes.update(output_dynamic_axes) - # TODO: Support contrib OPs support? user model has no hint - # from onnxruntime.training import register_custom_ops_pytorch_exporter - # register_custom_ops_pytorch_exporter.register_custom_op() - # Export torch.nn.Module to ONNX f = io.BytesIO() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 847195e554..850097df8a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -613,3 +613,41 @@ def test_model_with_different_cuda_devices(device): model.to(device) x = torch.randn(N, D_in, device=device) model(x) + +def test_register_custom_ops_pytorch_exporter_tensor_triu(): + class SimpleTensorTriuModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10) + + def forward(self, x): + x = self.fc1(x) + mask = torch.ones_like(x).triu(diagonal=1) + x = x * mask + return x + + model = SimpleTensorTriuModel() + model = ORTModule(model) + user_input = torch.ones(1, 10, 10) + + output = model(user_input) + assert list(output.shape) == [1, 10, 10] + +def test_register_custom_ops_pytorch_exporter_torch_triu(): + class SimpleTorchTriuModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10) + + def forward(self, x): + x = self.fc1(x) + mask = torch.triu(torch.ones_like(x)) + x = x * mask + return x + + model = SimpleTorchTriuModel() + model = ORTModule(model) + user_input = torch.ones(1, 10, 10) + + output = model(user_input) + assert list(output.shape) == [1, 10, 10]