mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Support non tuple return values from torch.nn.module (#6660)
* Support dictionary, namedtuples and huffingface ModelOutput type for model return values
This commit is contained in:
parent
7f33671ade
commit
01dfa8e125
9 changed files with 149 additions and 20 deletions
|
|
@ -74,9 +74,20 @@ def _parse_inputs_for_onnx_export(module, *inputs, **kwargs):
|
|||
|
||||
def _parse_outputs_for_onnx_export(module, inputs):
|
||||
|
||||
def _create_output_dim_names_from_mapping(output):
|
||||
output_names, dynamic_axes = [], {}
|
||||
for name, value in output.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
raise TypeError('ORTModule does not support the following model output type {} within a Mapping'.format(type(value)))
|
||||
output_names.append(name)
|
||||
dynamic_axes[name] = {}
|
||||
for dim_idx in range(len(value.shape)):
|
||||
dynamic_axes[name].update({dim_idx: '{}_dim{}'.format(name, dim_idx)})
|
||||
return output_names, dynamic_axes
|
||||
|
||||
def _create_output_dim_names(output, output_idx, from_sequence):
|
||||
if from_sequence and not isinstance(output, torch.Tensor):
|
||||
raise TypeError('ORTModule does not support the following model output type {} within a Sequence'.format(type(sample_outputs)))
|
||||
raise TypeError('ORTModule does not support the following model output type {} within a Sequence'.format(type(output)))
|
||||
output_names, dynamic_axes = [], {}
|
||||
name = 'output{}'.format(output_idx)
|
||||
output_names.append(name)
|
||||
|
|
@ -88,6 +99,9 @@ def _parse_outputs_for_onnx_export(module, inputs):
|
|||
# Do an inference to grab outputs
|
||||
is_train_mode = module.training
|
||||
module.eval()
|
||||
output_names = []
|
||||
output_dynamic_axes = {}
|
||||
sample_output_type = None
|
||||
with torch.no_grad():
|
||||
# Deepcopy inputs, since input values may change after model run.
|
||||
sample_inputs_copy = _deepcopy_model_input(*inputs)
|
||||
|
|
@ -100,12 +114,11 @@ def _parse_outputs_for_onnx_export(module, inputs):
|
|||
" Compute will continue, but unexpected results may occur!")
|
||||
|
||||
sample_outputs = model_copy(*sample_inputs_copy)
|
||||
output_names = []
|
||||
output_dynamic_axes = {}
|
||||
sample_output_type = type(sample_outputs)
|
||||
if isinstance(sample_outputs, torch.Tensor):
|
||||
output_names, output_dynamic_axes = _create_output_dim_names(sample_outputs, 0, False)
|
||||
elif isinstance(sample_outputs, abc.Mapping):
|
||||
raise NotImplementedError('Dictionaries are not supported as output yet')
|
||||
output_names, output_dynamic_axes = _create_output_dim_names_from_mapping(sample_outputs)
|
||||
elif isinstance(sample_outputs, abc.Sequence):
|
||||
for idx, out in enumerate(sample_outputs):
|
||||
tmp_output_names, tmp_output_dynamic_axes = _create_output_dim_names(out, idx, True)
|
||||
|
|
@ -115,7 +128,21 @@ def _parse_outputs_for_onnx_export(module, inputs):
|
|||
raise TypeError('ORTModule does not support the following model output type {}'.format(type(sample_outputs)))
|
||||
if is_train_mode:
|
||||
module.train()
|
||||
return output_names, output_dynamic_axes
|
||||
return output_names, output_dynamic_axes, sample_output_type
|
||||
|
||||
def _populate_user_output(user_output_type, user_output_names, user_outputs):
|
||||
if issubclass(user_output_type, Mapping):
|
||||
key_value_pairs = [(user_output_names[i], user_outputs[i]) for i in range(len(user_output_names))]
|
||||
return user_output_type(key_value_pairs)
|
||||
elif issubclass(user_output_type, tuple):
|
||||
try:
|
||||
# Try constructing the user named tuple from the output tuple
|
||||
return user_output_type(*user_outputs)
|
||||
except TypeError:
|
||||
# The expected output type is not a namedtuple, but is a regular tuple type
|
||||
pass
|
||||
|
||||
return user_outputs
|
||||
|
||||
# TODO: PyTorch's to_dlpack() uses same config for both torch.bool and torch.uint8,
|
||||
# and convert the config to torch.uint8 tensor duing from_dlpack(). So a boolean tensor
|
||||
|
|
@ -148,6 +175,7 @@ class ORTModule(torch.nn.Module):
|
|||
self._current_input_shape = None
|
||||
self._module_gradient_graph_builder = None
|
||||
self._input_names_require_grad = None
|
||||
self._original_module_output_type = None
|
||||
|
||||
# Training model
|
||||
self._onnx_training = None
|
||||
|
|
@ -369,7 +397,8 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
proc_inputs = [data for data in inputs if data is not None]
|
||||
|
||||
return _ORTModuleFunction.apply(*self._convert_training_graph_input_to_list(*proc_inputs, **kwargs))
|
||||
return _populate_user_output(self._original_module_output_type, self._onnx_graphs_info.user_output_names,
|
||||
_ORTModuleFunction.apply(*self._convert_training_graph_input_to_list(*proc_inputs, **kwargs)))
|
||||
|
||||
@_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__)
|
||||
def _convert_training_graph_input_to_list(self, *inputs, **kwargs):
|
||||
|
|
@ -403,7 +432,7 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# Setup dynamic axes for onnx model
|
||||
input_names, dynamic_axes, self._input_names_require_grad = _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs)
|
||||
output_names, output_dynamic_axes = _parse_outputs_for_onnx_export(self._original_module, inputs)
|
||||
output_names, output_dynamic_axes, self._original_module_output_type = _parse_outputs_for_onnx_export(self._original_module, inputs)
|
||||
dynamic_axes.update(output_dynamic_axes)
|
||||
|
||||
# TODO: Support contrib OPs support? user model has no hint
|
||||
|
|
|
|||
|
|
@ -9,3 +9,7 @@ Due to a [bug on DeepSpeed](https://github.com/microsoft/DeepSpeed/issues/663),
|
|||
|
||||
```sh
|
||||
pip install -r tools/ci_build/github/linux/docker/scripts/training/requirements.txt
|
||||
|
||||
2. Install second set of dependencies for ortmodule:
|
||||
```sh
|
||||
pip install -r tools/ci_build/github/linux/docker/scripts/training/ortmodule/requirements.txt
|
||||
|
|
|
|||
|
|
@ -4,9 +4,12 @@
|
|||
|
||||
import torch
|
||||
from transformers import AutoConfig, BertForSequenceClassification
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
import pytest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
from collections import OrderedDict
|
||||
from collections import namedtuple
|
||||
|
||||
from onnxruntime.training import _utils, ORTModule
|
||||
import _test_helpers
|
||||
|
|
@ -351,8 +354,15 @@ def test_gpu_reserved_memory_with_torch_no_grad():
|
|||
assert mem_reserved_after_export_with_torch_no_grad < mem_reserved_after_export_without_torch_no_grad
|
||||
assert mem_reserved_before_export == mem_reserved_after_export_with_torch_no_grad
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
|
||||
def test_exception_raised_for_dict_return_value_module(device):
|
||||
@pytest.mark.parametrize("return_type, device", [
|
||||
(dict, 'cpu'),
|
||||
(dict, 'cuda'),
|
||||
(OrderedDict, 'cpu'),
|
||||
(OrderedDict, 'cuda'),
|
||||
(SequenceClassifierOutput, 'cpu'),
|
||||
(SequenceClassifierOutput, 'cuda')
|
||||
])
|
||||
def test_dict_return_value_module(return_type, device):
|
||||
class NeuralNetDictOutput(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNetDictOutput, self).__init__()
|
||||
|
|
@ -372,8 +382,8 @@ def test_exception_raised_for_dict_return_value_module(device):
|
|||
def forward(self, input1, input2, input3):
|
||||
out1 = self.fc1_2(self.relu1(self.fc1_1(input1)))
|
||||
out2 = self.fc2_2(self.relu2(self.fc2_1(input2)))
|
||||
out3 = self.fc3_2(self.relu3(self.fc3_1(input2)))
|
||||
return {'a': out1, 'b': out2, 'c': out3}
|
||||
out3 = self.fc3_2(self.relu3(self.fc3_1(input3)))
|
||||
return return_type([('loss', out1), ('logits', out2), ('hidden_states', out3)])
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetDictOutput(D_in, H, D_out).to(device)
|
||||
|
|
@ -382,9 +392,81 @@ def test_exception_raised_for_dict_return_value_module(device):
|
|||
y = torch.randn(N, D_in, device=device)
|
||||
z = torch.randn(N, D_in, device=device)
|
||||
|
||||
with pytest.raises(NotImplementedError) as not_implemented_error:
|
||||
output = model(x, y, z)
|
||||
assert isinstance(output, return_type)
|
||||
assert 'loss' in output and 'logits' in output and 'hidden_states' in output
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
|
||||
def test_dict_of_tuple_return_value_module(device):
|
||||
class NeuralNetDictOfTuplesOutput(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNetDictOfTuplesOutput, self).__init__()
|
||||
|
||||
self.fc1_1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu1 = torch.nn.ReLU()
|
||||
self.fc1_2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
self.fc2_1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu2 = torch.nn.ReLU()
|
||||
self.fc2_2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
self.fc3_1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu3 = torch.nn.ReLU()
|
||||
self.fc3_2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1, input2, input3):
|
||||
out1 = self.fc1_2(self.relu1(self.fc1_1(input1)))
|
||||
out2 = self.fc2_2(self.relu2(self.fc2_1(input2)))
|
||||
out3 = self.fc3_2(self.relu3(self.fc3_1(input3)))
|
||||
return {'loss': (out1, out2, out3)}
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetDictOfTuplesOutput(D_in, H, D_out).to(device)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
y = torch.randn(N, D_in, device=device)
|
||||
z = torch.randn(N, D_in, device=device)
|
||||
|
||||
with pytest.raises(TypeError) as type_error:
|
||||
model(x, y, z)
|
||||
assert str(not_implemented_error.value) == 'Dictionaries are not supported as output yet'
|
||||
assert 'ORTModule does not support the following model output type' in str(type_error.value)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
|
||||
def test_named_tuple_return_value_module(device):
|
||||
ReturnValue = namedtuple('NamedTupleReturnValue', 'loss logits hidden_states')
|
||||
class NeuralNetNamedTupleOutput(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNetNamedTupleOutput, self).__init__()
|
||||
|
||||
self.fc1_1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu1 = torch.nn.ReLU()
|
||||
self.fc1_2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
self.fc2_1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu2 = torch.nn.ReLU()
|
||||
self.fc2_2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
self.fc3_1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu3 = torch.nn.ReLU()
|
||||
self.fc3_2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1, input2, input3):
|
||||
out1 = self.fc1_2(self.relu1(self.fc1_1(input1)))
|
||||
out2 = self.fc2_2(self.relu2(self.fc2_1(input2)))
|
||||
out3 = self.fc3_2(self.relu3(self.fc3_1(input3)))
|
||||
|
||||
return ReturnValue(out1, out2, out3)
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetNamedTupleOutput(D_in, H, D_out).to(device)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
y = torch.randn(N, D_in, device=device)
|
||||
z = torch.randn(N, D_in, device=device)
|
||||
|
||||
output = model(x, y, z)
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, ReturnValue)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
|
||||
def test_exception_raised_for_custom_class_return_value_module(device):
|
||||
|
|
@ -413,7 +495,7 @@ def test_exception_raised_for_custom_class_return_value_module(device):
|
|||
def forward(self, input1, input2, input3):
|
||||
out1 = self.fc1_2(self.relu1(self.fc1_1(input1)))
|
||||
out2 = self.fc2_2(self.relu2(self.fc2_1(input2)))
|
||||
out3 = self.fc3_2(self.relu3(self.fc3_1(input2)))
|
||||
out3 = self.fc3_2(self.relu3(self.fc3_1(input3)))
|
||||
return CustomClass(out1, out2, out3)
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ jobs:
|
|||
--update --build \
|
||||
--build_wheel \
|
||||
" \
|
||||
-m
|
||||
-m \
|
||||
-u
|
||||
DisplayName: 'Build'
|
||||
|
||||
- bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdata-storage-key) -s "//orttrainingtestdata.file.core.windows.net/mnist" -d "/mnist"
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ jobs:
|
|||
--update --build \
|
||||
--build_wheel \
|
||||
" \
|
||||
-u
|
||||
DisplayName: 'Build'
|
||||
|
||||
- bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdata-storage-key) -s "//orttrainingtestdata.file.core.windows.net/mnist" -d "/mnist"
|
||||
|
|
|
|||
|
|
@ -4,14 +4,16 @@ set -e -x
|
|||
SCRIPT_DIR="$( dirname "${BASH_SOURCE[0]}" )"
|
||||
INSTALL_DEPS_TRAINING=false
|
||||
INSTALL_DEPS_DISTRIBUTED_SETUP=false
|
||||
ORTMODULE_BUILD=false
|
||||
|
||||
while getopts p:d:tm parameter_Option
|
||||
while getopts p:d:tmu parameter_Option
|
||||
do case "${parameter_Option}"
|
||||
in
|
||||
p) PYTHON_VER=${OPTARG};;
|
||||
d) DEVICE_TYPE=${OPTARG};;
|
||||
t) INSTALL_DEPS_TRAINING=true;;
|
||||
m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;;
|
||||
u) ORTMODULE_BUILD=true;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
|
@ -117,8 +119,10 @@ ${PYTHON_EXE} -m pip install -r ${0/%install_deps\.sh/requirements\.txt}
|
|||
if [ $DEVICE_TYPE = "gpu" ]; then
|
||||
if [[ $INSTALL_DEPS_TRAINING = true ]]; then
|
||||
${PYTHON_EXE} -m pip install -r ${0/%install_deps.sh/training\/requirements.txt}
|
||||
# Due to a [bug on DeepSpeed](https://github.com/microsoft/DeepSpeed/issues/663), we install it separately through secondary/requirements.txt
|
||||
${PYTHON_EXE} -m pip install -r ${0/%install_deps.sh/training\/secondary\/requirements.txt}
|
||||
if [[ $ORTMODULE_BUILD = true ]]; then
|
||||
# Due to a [bug on DeepSpeed](https://github.com/microsoft/DeepSpeed/issues/663), we install it separately through ortmodule/requirements.txt
|
||||
${PYTHON_EXE} -m pip install -r ${0/%install_deps.sh/training\/ortmodule\/requirements.txt}
|
||||
fi
|
||||
fi
|
||||
if [[ $INSTALL_DEPS_DISTRIBUTED_SETUP = true ]]; then
|
||||
source ${0/%install_deps.sh/install_openmpi.sh}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
deepspeed
|
||||
numpy==1.19.5
|
||||
transformers==v4.3.2
|
||||
|
|
@ -1 +0,0 @@
|
|||
deepspeed
|
||||
|
|
@ -6,10 +6,11 @@ SOURCE_ROOT=$(realpath $SCRIPT_DIR/../../../../)
|
|||
CUDA_VER=cuda10.1-cudnn7.6
|
||||
YOCTO_VERSION="4.19"
|
||||
INSTALL_DEPS_DISTRIBUTED_SETUP=false
|
||||
ORTMODULE_BUILD=false
|
||||
ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV="ALLOW_RELEASED_ONNX_OPSET_ONLY="$ALLOW_RELEASED_ONNX_OPSET_ONLY
|
||||
echo "ALLOW_RELEASED_ONNX_OPSET_ONLY environment variable is set as "$ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV
|
||||
|
||||
while getopts c:o:d:r:p:x:a:v:y:t:i:m parameter_Option
|
||||
while getopts c:o:d:r:p:x:a:v:y:t:i:mu parameter_Option
|
||||
do case "${parameter_Option}"
|
||||
in
|
||||
#android, ubuntu16.04, ubuntu18.04, CentOS7
|
||||
|
|
@ -36,6 +37,8 @@ t) EXTRA_IMAGE_TAG=${OPTARG};;
|
|||
i) IMAGE_CACHE_CONTAINER_REGISTRY_NAME=${OPTARG};;
|
||||
# install distributed setup dependencies
|
||||
m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;;
|
||||
# install ortmodule specific dependencies
|
||||
u) ORTMODULE_BUILD=true;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
|
@ -86,6 +89,9 @@ else
|
|||
if [[ $INSTALL_DEPS_DISTRIBUTED_SETUP = true ]]; then
|
||||
INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -m"
|
||||
fi
|
||||
if [[ $ORTMODULE_BUILD = true ]]; then
|
||||
INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -u"
|
||||
fi
|
||||
$GET_DOCKER_IMAGE_CMD --repository "onnxruntime-$IMAGE" \
|
||||
--docker-build-args="--build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg INSTALL_DEPS_EXTRA_ARGS=\"${INSTALL_DEPS_EXTRA_ARGS}\"" \
|
||||
--dockerfile $DOCKER_FILE --context .
|
||||
|
|
|
|||
Loading…
Reference in a new issue