From 9c15c68ed41b2090e20bb178b66745d7093dab2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 28 Oct 2021 22:04:54 +0200 Subject: [PATCH] Enable fallback when forward fails due to non contiguous tensor (#9369) --- .../python/training/ortmodule/_io.py | 2 +- .../training/ortmodule/_training_manager.py | 3 +- .../python/training/ortmodule/_utils.py | 6 +- .../orttraining_test_ortmodule_fallback.py | 116 ++++++++++++++++++ 4 files changed, 124 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 8b925ba2f5..1a62d5ab14 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -538,7 +538,7 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input schema=schema, num_positionals=len(inputs), num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs, - keyword_names=kwargs.keys()) + keyword_names=list(kwargs.keys())) def parse_outputs_for_onnx_export_and_extract_schema(module, inputs, kwargs): diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 182c8dbf57..d5e8eeca84 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -47,7 +47,8 @@ class TrainingManager(GraphExecutionManager): forward_inputs = C.OrtValueVector() forward_inputs.reserve(len(inputs)) for input in inputs: - forward_inputs.push_back(_utils._torch_tensor_to_dlpack(input), input.dtype == torch.bool) + dlp = _utils._torch_tensor_to_dlpack(input) + forward_inputs.push_back(dlp, input.dtype == torch.bool) forward_outputs = C.OrtValueVector() # Run and return module outputs. diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 3775c17634..7156f5ece3 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -5,7 +5,8 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtValue from onnxruntime.capi import _pybind_state as C -from ._fallback_exceptions import ORTModuleDeviceException, wrap_exception +from ._fallback_exceptions import ( + ORTModuleDeviceException, wrap_exception, ORTModuleIOError) from ._torch_module_pytorch import TorchModulePytorch import os @@ -52,6 +53,9 @@ def _torch_tensor_to_dlpack(tensor): # We need to convert bool tensor to unit8 tensor to workaround this. # DLPack is discussing how to support bool type, we can remove this workaround once both DLPack # and PyTorch support bool type. + if not tensor.is_contiguous(): + raise ORTModuleIOError( + "Only contiguous tensors are supported.") if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'): tensor = tensor.to(torch.uint8) return to_dlpack(tensor) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index 2183d2123b..e4d714c1c7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -5,6 +5,8 @@ import copy import itertools import os +import math +import numpy as np import torch import pytest import warnings @@ -572,3 +574,117 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback): assert "Fallback to PyTorch due to exception" in str(warning_record[0].message.args[0]) del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + + + +@pytest.mark.parametrize("is_training,persist_fallback", + list(itertools.product([True, False], repeat=2))) +def test_ortmodule_fallback_non_contiguous_tensors(is_training, persist_fallback): + # is_training: True for torch.nn.Module training model, eval mode otherwise + # Validate fix for issue: https://github.com/pytorch/ort/issues/92 + + policy = 'FALLBACK_UNSUPPORTED_DEVICE' + os.environ['ORTMODULE_FALLBACK_POLICY'] = policy + os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + + class PositionalEncoding(torch.nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super().__init__() + self.dropout = torch.nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = (torch.exp(torch.arange(0, d_model, 2) * + (-math.log(10000.0) / d_model))) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:x.size(0)] + return self.dropout(x) + + + class TransformerModel(torch.nn.Module): + + def __init__(self, ntoken, d_model, nhead, d_hid, + nlayers, dropout=0.5): + super().__init__() + self.model_type = 'Transformer' + encoder_layers = torch.nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout) + self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers) + self.pos_encoder = PositionalEncoding(d_model, dropout) + self.encoder = torch.nn.Embedding(ntoken, d_model) + self.d_model = d_model + self.decoder = torch.nn.Linear(d_model, ntoken) + self.init_weights() + + def init_weights(self): + initrange = 0.1 + self.encoder.weight.data.uniform_(-initrange, initrange) + self.decoder.bias.data.zero_() + self.decoder.weight.data.uniform_(-initrange, initrange) + + def forward(self, src, src_mask): + src = self.encoder(src) * math.sqrt(self.d_model) + src = self.pos_encoder(src) + output = self.transformer_encoder(src, src_mask) + output = self.decoder(output) + return output + + + def generate_square_subsequent_mask(sz): + return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) + + + def get_batch(source, i): + seq_len = min(bptt, len(source) - 1 - i) + data = source[i:i+seq_len] + target = source[i+1:i+1+seq_len].reshape(-1) + return data, target + + + criterion = torch.nn.CrossEntropyLoss() + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + train_data = np.random.randint(1, 12455, 1000) + ends = np.random.randint(2, 20, 100).cumsum() + ends = ends[ends < train_data.shape[0] - 2] + train_data[ends] = 0 + train_data[-1] = 0 + + train_data = torch.tensor(np.array(train_data, dtype=np.int64)) + train_data = train_data.to(torch.int64).to(device) + bptt = 35 + src_mask = generate_square_subsequent_mask(bptt).to(device) + ntokens, emsize, nhead, d_hid, nlayers, dropout = 12455, 200, 2, 200, 2, 0.2 + pt_model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout) + model = ORTModule(pt_model).to(device) + pt_model.train(is_training) + model.train(is_training) + optimizer = torch.optim.SGD(model.parameters(), lr=5.0) + + n_iter = 0 + for epoch in range(1, 2): + model.train() # turn on train mode + + num_batches = len(train_data) // bptt + for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): + data, targets = get_batch(train_data, i) + batch_size = data.size(0) + if batch_size != bptt: # only on last batch + src_mask = src_mask[:batch_size, :batch_size] + output = model(data, src_mask) + nrows = min(ntokens, targets.shape[0]) + loss = criterion(output.view(nrows, -1), targets) + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + n_iter += 1 + break + + assert n_iter > 0 + + del os.environ['ORTMODULE_SKIPCHECK_POLICY']