mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Enable fallback when forward fails due to non contiguous tensor (#9369)
This commit is contained in:
parent
a01a3f2552
commit
9c15c68ed4
4 changed files with 124 additions and 3 deletions
|
|
@ -538,7 +538,7 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
|
|||
schema=schema,
|
||||
num_positionals=len(inputs),
|
||||
num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs,
|
||||
keyword_names=kwargs.keys())
|
||||
keyword_names=list(kwargs.keys()))
|
||||
|
||||
|
||||
def parse_outputs_for_onnx_export_and_extract_schema(module, inputs, kwargs):
|
||||
|
|
|
|||
|
|
@ -47,7 +47,8 @@ class TrainingManager(GraphExecutionManager):
|
|||
forward_inputs = C.OrtValueVector()
|
||||
forward_inputs.reserve(len(inputs))
|
||||
for input in inputs:
|
||||
forward_inputs.push_back(_utils._torch_tensor_to_dlpack(input), input.dtype == torch.bool)
|
||||
dlp = _utils._torch_tensor_to_dlpack(input)
|
||||
forward_inputs.push_back(dlp, input.dtype == torch.bool)
|
||||
|
||||
forward_outputs = C.OrtValueVector()
|
||||
# Run and return module outputs.
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from ._fallback_exceptions import ORTModuleDeviceException, wrap_exception
|
||||
from ._fallback_exceptions import (
|
||||
ORTModuleDeviceException, wrap_exception, ORTModuleIOError)
|
||||
from ._torch_module_pytorch import TorchModulePytorch
|
||||
|
||||
import os
|
||||
|
|
@ -52,6 +53,9 @@ def _torch_tensor_to_dlpack(tensor):
|
|||
# We need to convert bool tensor to unit8 tensor to workaround this.
|
||||
# DLPack is discussing how to support bool type, we can remove this workaround once both DLPack
|
||||
# and PyTorch support bool type.
|
||||
if not tensor.is_contiguous():
|
||||
raise ORTModuleIOError(
|
||||
"Only contiguous tensors are supported.")
|
||||
if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'):
|
||||
tensor = tensor.to(torch.uint8)
|
||||
return to_dlpack(tensor)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
import copy
|
||||
import itertools
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
import warnings
|
||||
|
|
@ -572,3 +574,117 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback):
|
|||
assert "Fallback to PyTorch due to exception" in str(warning_record[0].message.args[0])
|
||||
|
||||
del os.environ['ORTMODULE_SKIPCHECK_POLICY']
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_training,persist_fallback",
|
||||
list(itertools.product([True, False], repeat=2)))
|
||||
def test_ortmodule_fallback_non_contiguous_tensors(is_training, persist_fallback):
|
||||
# is_training: True for torch.nn.Module training model, eval mode otherwise
|
||||
# Validate fix for issue: https://github.com/pytorch/ort/issues/92
|
||||
|
||||
policy = 'FALLBACK_UNSUPPORTED_DEVICE'
|
||||
os.environ['ORTMODULE_FALLBACK_POLICY'] = policy
|
||||
os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback)
|
||||
os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED'
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
position = torch.arange(max_len).unsqueeze(1)
|
||||
div_term = (torch.exp(torch.arange(0, d_model, 2) *
|
||||
(-math.log(10000.0) / d_model)))
|
||||
pe = torch.zeros(max_len, 1, d_model)
|
||||
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:x.size(0)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class TransformerModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, ntoken, d_model, nhead, d_hid,
|
||||
nlayers, dropout=0.5):
|
||||
super().__init__()
|
||||
self.model_type = 'Transformer'
|
||||
encoder_layers = torch.nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
|
||||
self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers)
|
||||
self.pos_encoder = PositionalEncoding(d_model, dropout)
|
||||
self.encoder = torch.nn.Embedding(ntoken, d_model)
|
||||
self.d_model = d_model
|
||||
self.decoder = torch.nn.Linear(d_model, ntoken)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
initrange = 0.1
|
||||
self.encoder.weight.data.uniform_(-initrange, initrange)
|
||||
self.decoder.bias.data.zero_()
|
||||
self.decoder.weight.data.uniform_(-initrange, initrange)
|
||||
|
||||
def forward(self, src, src_mask):
|
||||
src = self.encoder(src) * math.sqrt(self.d_model)
|
||||
src = self.pos_encoder(src)
|
||||
output = self.transformer_encoder(src, src_mask)
|
||||
output = self.decoder(output)
|
||||
return output
|
||||
|
||||
|
||||
def generate_square_subsequent_mask(sz):
|
||||
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
|
||||
|
||||
|
||||
def get_batch(source, i):
|
||||
seq_len = min(bptt, len(source) - 1 - i)
|
||||
data = source[i:i+seq_len]
|
||||
target = source[i+1:i+1+seq_len].reshape(-1)
|
||||
return data, target
|
||||
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
train_data = np.random.randint(1, 12455, 1000)
|
||||
ends = np.random.randint(2, 20, 100).cumsum()
|
||||
ends = ends[ends < train_data.shape[0] - 2]
|
||||
train_data[ends] = 0
|
||||
train_data[-1] = 0
|
||||
|
||||
train_data = torch.tensor(np.array(train_data, dtype=np.int64))
|
||||
train_data = train_data.to(torch.int64).to(device)
|
||||
bptt = 35
|
||||
src_mask = generate_square_subsequent_mask(bptt).to(device)
|
||||
ntokens, emsize, nhead, d_hid, nlayers, dropout = 12455, 200, 2, 200, 2, 0.2
|
||||
pt_model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout)
|
||||
model = ORTModule(pt_model).to(device)
|
||||
pt_model.train(is_training)
|
||||
model.train(is_training)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=5.0)
|
||||
|
||||
n_iter = 0
|
||||
for epoch in range(1, 2):
|
||||
model.train() # turn on train mode
|
||||
|
||||
num_batches = len(train_data) // bptt
|
||||
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
|
||||
data, targets = get_batch(train_data, i)
|
||||
batch_size = data.size(0)
|
||||
if batch_size != bptt: # only on last batch
|
||||
src_mask = src_mask[:batch_size, :batch_size]
|
||||
output = model(data, src_mask)
|
||||
nrows = min(ntokens, targets.shape[0])
|
||||
loss = criterion(output.view(nrows, -1), targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
||||
optimizer.step()
|
||||
n_iter += 1
|
||||
break
|
||||
|
||||
assert n_iter > 0
|
||||
|
||||
del os.environ['ORTMODULE_SKIPCHECK_POLICY']
|
||||
|
|
|
|||
Loading…
Reference in a new issue