Enable fallback when forward fails due to non contiguous tensor (#9369)

This commit is contained in:
Xavier Dupré 2021-10-28 22:04:54 +02:00 committed by GitHub
parent a01a3f2552
commit 9c15c68ed4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 124 additions and 3 deletions

View file

@ -538,7 +538,7 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
schema=schema,
num_positionals=len(inputs),
num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs,
keyword_names=kwargs.keys())
keyword_names=list(kwargs.keys()))
def parse_outputs_for_onnx_export_and_extract_schema(module, inputs, kwargs):

View file

@ -47,7 +47,8 @@ class TrainingManager(GraphExecutionManager):
forward_inputs = C.OrtValueVector()
forward_inputs.reserve(len(inputs))
for input in inputs:
forward_inputs.push_back(_utils._torch_tensor_to_dlpack(input), input.dtype == torch.bool)
dlp = _utils._torch_tensor_to_dlpack(input)
forward_inputs.push_back(dlp, input.dtype == torch.bool)
forward_outputs = C.OrtValueVector()
# Run and return module outputs.

View file

@ -5,7 +5,8 @@
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
from onnxruntime.capi import _pybind_state as C
from ._fallback_exceptions import ORTModuleDeviceException, wrap_exception
from ._fallback_exceptions import (
ORTModuleDeviceException, wrap_exception, ORTModuleIOError)
from ._torch_module_pytorch import TorchModulePytorch
import os
@ -52,6 +53,9 @@ def _torch_tensor_to_dlpack(tensor):
# We need to convert bool tensor to unit8 tensor to workaround this.
# DLPack is discussing how to support bool type, we can remove this workaround once both DLPack
# and PyTorch support bool type.
if not tensor.is_contiguous():
raise ORTModuleIOError(
"Only contiguous tensors are supported.")
if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'):
tensor = tensor.to(torch.uint8)
return to_dlpack(tensor)

View file

@ -5,6 +5,8 @@
import copy
import itertools
import os
import math
import numpy as np
import torch
import pytest
import warnings
@ -572,3 +574,117 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback):
assert "Fallback to PyTorch due to exception" in str(warning_record[0].message.args[0])
del os.environ['ORTMODULE_SKIPCHECK_POLICY']
@pytest.mark.parametrize("is_training,persist_fallback",
list(itertools.product([True, False], repeat=2)))
def test_ortmodule_fallback_non_contiguous_tensors(is_training, persist_fallback):
# is_training: True for torch.nn.Module training model, eval mode otherwise
# Validate fix for issue: https://github.com/pytorch/ort/issues/92
policy = 'FALLBACK_UNSUPPORTED_DEVICE'
os.environ['ORTMODULE_FALLBACK_POLICY'] = policy
os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback)
os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED'
class PositionalEncoding(torch.nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = (torch.exp(torch.arange(0, d_model, 2) *
(-math.log(10000.0) / d_model)))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class TransformerModel(torch.nn.Module):
def __init__(self, ntoken, d_model, nhead, d_hid,
nlayers, dropout=0.5):
super().__init__()
self.model_type = 'Transformer'
encoder_layers = torch.nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers)
self.pos_encoder = PositionalEncoding(d_model, dropout)
self.encoder = torch.nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.decoder = torch.nn.Linear(d_model, ntoken)
self.init_weights()
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src, src_mask):
src = self.encoder(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output
def generate_square_subsequent_mask(sz):
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
def get_batch(source, i):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].reshape(-1)
return data, target
criterion = torch.nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_data = np.random.randint(1, 12455, 1000)
ends = np.random.randint(2, 20, 100).cumsum()
ends = ends[ends < train_data.shape[0] - 2]
train_data[ends] = 0
train_data[-1] = 0
train_data = torch.tensor(np.array(train_data, dtype=np.int64))
train_data = train_data.to(torch.int64).to(device)
bptt = 35
src_mask = generate_square_subsequent_mask(bptt).to(device)
ntokens, emsize, nhead, d_hid, nlayers, dropout = 12455, 200, 2, 200, 2, 0.2
pt_model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout)
model = ORTModule(pt_model).to(device)
pt_model.train(is_training)
model.train(is_training)
optimizer = torch.optim.SGD(model.parameters(), lr=5.0)
n_iter = 0
for epoch in range(1, 2):
model.train() # turn on train mode
num_batches = len(train_data) // bptt
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i)
batch_size = data.size(0)
if batch_size != bptt: # only on last batch
src_mask = src_mask[:batch_size, :batch_size]
output = model(data, src_mask)
nrows = min(ntokens, targets.shape[0])
loss = criterion(output.view(nrows, -1), targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
n_iter += 1
break
assert n_iter > 0
del os.environ['ORTMODULE_SKIPCHECK_POLICY']