Support non tuple return values from torch.nn.module (#6660)

* Support dictionary, namedtuples and huffingface ModelOutput type for model return values
This commit is contained in:
baijumeswani 2021-02-16 20:48:32 -08:00 committed by GitHub
parent 7f33671ade
commit 01dfa8e125
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 149 additions and 20 deletions

View file

@ -74,9 +74,20 @@ def _parse_inputs_for_onnx_export(module, *inputs, **kwargs):
def _parse_outputs_for_onnx_export(module, inputs):
def _create_output_dim_names_from_mapping(output):
output_names, dynamic_axes = [], {}
for name, value in output.items():
if not isinstance(value, torch.Tensor):
raise TypeError('ORTModule does not support the following model output type {} within a Mapping'.format(type(value)))
output_names.append(name)
dynamic_axes[name] = {}
for dim_idx in range(len(value.shape)):
dynamic_axes[name].update({dim_idx: '{}_dim{}'.format(name, dim_idx)})
return output_names, dynamic_axes
def _create_output_dim_names(output, output_idx, from_sequence):
if from_sequence and not isinstance(output, torch.Tensor):
raise TypeError('ORTModule does not support the following model output type {} within a Sequence'.format(type(sample_outputs)))
raise TypeError('ORTModule does not support the following model output type {} within a Sequence'.format(type(output)))
output_names, dynamic_axes = [], {}
name = 'output{}'.format(output_idx)
output_names.append(name)
@ -88,6 +99,9 @@ def _parse_outputs_for_onnx_export(module, inputs):
# Do an inference to grab outputs
is_train_mode = module.training
module.eval()
output_names = []
output_dynamic_axes = {}
sample_output_type = None
with torch.no_grad():
# Deepcopy inputs, since input values may change after model run.
sample_inputs_copy = _deepcopy_model_input(*inputs)
@ -100,12 +114,11 @@ def _parse_outputs_for_onnx_export(module, inputs):
" Compute will continue, but unexpected results may occur!")
sample_outputs = model_copy(*sample_inputs_copy)
output_names = []
output_dynamic_axes = {}
sample_output_type = type(sample_outputs)
if isinstance(sample_outputs, torch.Tensor):
output_names, output_dynamic_axes = _create_output_dim_names(sample_outputs, 0, False)
elif isinstance(sample_outputs, abc.Mapping):
raise NotImplementedError('Dictionaries are not supported as output yet')
output_names, output_dynamic_axes = _create_output_dim_names_from_mapping(sample_outputs)
elif isinstance(sample_outputs, abc.Sequence):
for idx, out in enumerate(sample_outputs):
tmp_output_names, tmp_output_dynamic_axes = _create_output_dim_names(out, idx, True)
@ -115,7 +128,21 @@ def _parse_outputs_for_onnx_export(module, inputs):
raise TypeError('ORTModule does not support the following model output type {}'.format(type(sample_outputs)))
if is_train_mode:
module.train()
return output_names, output_dynamic_axes
return output_names, output_dynamic_axes, sample_output_type
def _populate_user_output(user_output_type, user_output_names, user_outputs):
if issubclass(user_output_type, Mapping):
key_value_pairs = [(user_output_names[i], user_outputs[i]) for i in range(len(user_output_names))]
return user_output_type(key_value_pairs)
elif issubclass(user_output_type, tuple):
try:
# Try constructing the user named tuple from the output tuple
return user_output_type(*user_outputs)
except TypeError:
# The expected output type is not a namedtuple, but is a regular tuple type
pass
return user_outputs
# TODO: PyTorch's to_dlpack() uses same config for both torch.bool and torch.uint8,
# and convert the config to torch.uint8 tensor duing from_dlpack(). So a boolean tensor
@ -148,6 +175,7 @@ class ORTModule(torch.nn.Module):
self._current_input_shape = None
self._module_gradient_graph_builder = None
self._input_names_require_grad = None
self._original_module_output_type = None
# Training model
self._onnx_training = None
@ -369,7 +397,8 @@ class ORTModule(torch.nn.Module):
proc_inputs = [data for data in inputs if data is not None]
return _ORTModuleFunction.apply(*self._convert_training_graph_input_to_list(*proc_inputs, **kwargs))
return _populate_user_output(self._original_module_output_type, self._onnx_graphs_info.user_output_names,
_ORTModuleFunction.apply(*self._convert_training_graph_input_to_list(*proc_inputs, **kwargs)))
@_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__)
def _convert_training_graph_input_to_list(self, *inputs, **kwargs):
@ -403,7 +432,7 @@ class ORTModule(torch.nn.Module):
# Setup dynamic axes for onnx model
input_names, dynamic_axes, self._input_names_require_grad = _parse_inputs_for_onnx_export(self._original_module, *inputs, **kwargs)
output_names, output_dynamic_axes = _parse_outputs_for_onnx_export(self._original_module, inputs)
output_names, output_dynamic_axes, self._original_module_output_type = _parse_outputs_for_onnx_export(self._original_module, inputs)
dynamic_axes.update(output_dynamic_axes)
# TODO: Support contrib OPs support? user model has no hint

View file

@ -9,3 +9,7 @@ Due to a [bug on DeepSpeed](https://github.com/microsoft/DeepSpeed/issues/663),
```sh
pip install -r tools/ci_build/github/linux/docker/scripts/training/requirements.txt
2. Install second set of dependencies for ortmodule:
```sh
pip install -r tools/ci_build/github/linux/docker/scripts/training/ortmodule/requirements.txt

View file

@ -4,9 +4,12 @@
import torch
from transformers import AutoConfig, BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
import pytest
import warnings
from unittest.mock import patch
from collections import OrderedDict
from collections import namedtuple
from onnxruntime.training import _utils, ORTModule
import _test_helpers
@ -351,8 +354,15 @@ def test_gpu_reserved_memory_with_torch_no_grad():
assert mem_reserved_after_export_with_torch_no_grad < mem_reserved_after_export_without_torch_no_grad
assert mem_reserved_before_export == mem_reserved_after_export_with_torch_no_grad
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
def test_exception_raised_for_dict_return_value_module(device):
@pytest.mark.parametrize("return_type, device", [
(dict, 'cpu'),
(dict, 'cuda'),
(OrderedDict, 'cpu'),
(OrderedDict, 'cuda'),
(SequenceClassifierOutput, 'cpu'),
(SequenceClassifierOutput, 'cuda')
])
def test_dict_return_value_module(return_type, device):
class NeuralNetDictOutput(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetDictOutput, self).__init__()
@ -372,8 +382,8 @@ def test_exception_raised_for_dict_return_value_module(device):
def forward(self, input1, input2, input3):
out1 = self.fc1_2(self.relu1(self.fc1_1(input1)))
out2 = self.fc2_2(self.relu2(self.fc2_1(input2)))
out3 = self.fc3_2(self.relu3(self.fc3_1(input2)))
return {'a': out1, 'b': out2, 'c': out3}
out3 = self.fc3_2(self.relu3(self.fc3_1(input3)))
return return_type([('loss', out1), ('logits', out2), ('hidden_states', out3)])
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetDictOutput(D_in, H, D_out).to(device)
@ -382,9 +392,81 @@ def test_exception_raised_for_dict_return_value_module(device):
y = torch.randn(N, D_in, device=device)
z = torch.randn(N, D_in, device=device)
with pytest.raises(NotImplementedError) as not_implemented_error:
output = model(x, y, z)
assert isinstance(output, return_type)
assert 'loss' in output and 'logits' in output and 'hidden_states' in output
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
def test_dict_of_tuple_return_value_module(device):
class NeuralNetDictOfTuplesOutput(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetDictOfTuplesOutput, self).__init__()
self.fc1_1 = torch.nn.Linear(input_size, hidden_size)
self.relu1 = torch.nn.ReLU()
self.fc1_2 = torch.nn.Linear(hidden_size, num_classes)
self.fc2_1 = torch.nn.Linear(input_size, hidden_size)
self.relu2 = torch.nn.ReLU()
self.fc2_2 = torch.nn.Linear(hidden_size, num_classes)
self.fc3_1 = torch.nn.Linear(input_size, hidden_size)
self.relu3 = torch.nn.ReLU()
self.fc3_2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1, input2, input3):
out1 = self.fc1_2(self.relu1(self.fc1_1(input1)))
out2 = self.fc2_2(self.relu2(self.fc2_1(input2)))
out3 = self.fc3_2(self.relu3(self.fc3_1(input3)))
return {'loss': (out1, out2, out3)}
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetDictOfTuplesOutput(D_in, H, D_out).to(device)
model = ORTModule(model)
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_in, device=device)
z = torch.randn(N, D_in, device=device)
with pytest.raises(TypeError) as type_error:
model(x, y, z)
assert str(not_implemented_error.value) == 'Dictionaries are not supported as output yet'
assert 'ORTModule does not support the following model output type' in str(type_error.value)
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
def test_named_tuple_return_value_module(device):
ReturnValue = namedtuple('NamedTupleReturnValue', 'loss logits hidden_states')
class NeuralNetNamedTupleOutput(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetNamedTupleOutput, self).__init__()
self.fc1_1 = torch.nn.Linear(input_size, hidden_size)
self.relu1 = torch.nn.ReLU()
self.fc1_2 = torch.nn.Linear(hidden_size, num_classes)
self.fc2_1 = torch.nn.Linear(input_size, hidden_size)
self.relu2 = torch.nn.ReLU()
self.fc2_2 = torch.nn.Linear(hidden_size, num_classes)
self.fc3_1 = torch.nn.Linear(input_size, hidden_size)
self.relu3 = torch.nn.ReLU()
self.fc3_2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1, input2, input3):
out1 = self.fc1_2(self.relu1(self.fc1_1(input1)))
out2 = self.fc2_2(self.relu2(self.fc2_1(input2)))
out3 = self.fc3_2(self.relu3(self.fc3_1(input3)))
return ReturnValue(out1, out2, out3)
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetNamedTupleOutput(D_in, H, D_out).to(device)
model = ORTModule(model)
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_in, device=device)
z = torch.randn(N, D_in, device=device)
output = model(x, y, z)
assert isinstance(output, tuple)
assert isinstance(output, ReturnValue)
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
def test_exception_raised_for_custom_class_return_value_module(device):
@ -413,7 +495,7 @@ def test_exception_raised_for_custom_class_return_value_module(device):
def forward(self, input1, input2, input3):
out1 = self.fc1_2(self.relu1(self.fc1_1(input1)))
out2 = self.fc2_2(self.relu2(self.fc2_1(input2)))
out3 = self.fc3_2(self.relu3(self.fc3_1(input2)))
out3 = self.fc3_2(self.relu3(self.fc3_1(input3)))
return CustomClass(out1, out2, out3)
N, D_in, H, D_out = 64, 784, 500, 10

View file

@ -22,7 +22,8 @@ jobs:
--update --build \
--build_wheel \
" \
-m
-m \
-u
DisplayName: 'Build'
- bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdata-storage-key) -s "//orttrainingtestdata.file.core.windows.net/mnist" -d "/mnist"

View file

@ -22,6 +22,7 @@ jobs:
--update --build \
--build_wheel \
" \
-u
DisplayName: 'Build'
- bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdata-storage-key) -s "//orttrainingtestdata.file.core.windows.net/mnist" -d "/mnist"

View file

@ -4,14 +4,16 @@ set -e -x
SCRIPT_DIR="$( dirname "${BASH_SOURCE[0]}" )"
INSTALL_DEPS_TRAINING=false
INSTALL_DEPS_DISTRIBUTED_SETUP=false
ORTMODULE_BUILD=false
while getopts p:d:tm parameter_Option
while getopts p:d:tmu parameter_Option
do case "${parameter_Option}"
in
p) PYTHON_VER=${OPTARG};;
d) DEVICE_TYPE=${OPTARG};;
t) INSTALL_DEPS_TRAINING=true;;
m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;;
u) ORTMODULE_BUILD=true;;
esac
done
@ -117,8 +119,10 @@ ${PYTHON_EXE} -m pip install -r ${0/%install_deps\.sh/requirements\.txt}
if [ $DEVICE_TYPE = "gpu" ]; then
if [[ $INSTALL_DEPS_TRAINING = true ]]; then
${PYTHON_EXE} -m pip install -r ${0/%install_deps.sh/training\/requirements.txt}
# Due to a [bug on DeepSpeed](https://github.com/microsoft/DeepSpeed/issues/663), we install it separately through secondary/requirements.txt
${PYTHON_EXE} -m pip install -r ${0/%install_deps.sh/training\/secondary\/requirements.txt}
if [[ $ORTMODULE_BUILD = true ]]; then
# Due to a [bug on DeepSpeed](https://github.com/microsoft/DeepSpeed/issues/663), we install it separately through ortmodule/requirements.txt
${PYTHON_EXE} -m pip install -r ${0/%install_deps.sh/training\/ortmodule\/requirements.txt}
fi
fi
if [[ $INSTALL_DEPS_DISTRIBUTED_SETUP = true ]]; then
source ${0/%install_deps.sh/install_openmpi.sh}

View file

@ -0,0 +1,3 @@
deepspeed
numpy==1.19.5
transformers==v4.3.2

View file

@ -6,10 +6,11 @@ SOURCE_ROOT=$(realpath $SCRIPT_DIR/../../../../)
CUDA_VER=cuda10.1-cudnn7.6
YOCTO_VERSION="4.19"
INSTALL_DEPS_DISTRIBUTED_SETUP=false
ORTMODULE_BUILD=false
ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV="ALLOW_RELEASED_ONNX_OPSET_ONLY="$ALLOW_RELEASED_ONNX_OPSET_ONLY
echo "ALLOW_RELEASED_ONNX_OPSET_ONLY environment variable is set as "$ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV
while getopts c:o:d:r:p:x:a:v:y:t:i:m parameter_Option
while getopts c:o:d:r:p:x:a:v:y:t:i:mu parameter_Option
do case "${parameter_Option}"
in
#android, ubuntu16.04, ubuntu18.04, CentOS7
@ -36,6 +37,8 @@ t) EXTRA_IMAGE_TAG=${OPTARG};;
i) IMAGE_CACHE_CONTAINER_REGISTRY_NAME=${OPTARG};;
# install distributed setup dependencies
m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;;
# install ortmodule specific dependencies
u) ORTMODULE_BUILD=true;;
esac
done
@ -86,6 +89,9 @@ else
if [[ $INSTALL_DEPS_DISTRIBUTED_SETUP = true ]]; then
INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -m"
fi
if [[ $ORTMODULE_BUILD = true ]]; then
INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -u"
fi
$GET_DOCKER_IMAGE_CMD --repository "onnxruntime-$IMAGE" \
--docker-build-args="--build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg INSTALL_DEPS_EXTRA_ARGS=\"${INSTALL_DEPS_EXTRA_ARGS}\"" \
--dockerfile $DOCKER_FILE --context .