Add warning to non pickable models (#5037)

This commit is contained in:
Thiago Crepaldi 2020-09-03 11:53:56 -07:00 committed by GitHub
parent 9d1bdef195
commit 9388d49c0d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 126 additions and 79 deletions

View file

@ -453,10 +453,16 @@ class ORTTrainer(object):
with torch.no_grad():
# Deepcopy inputs, since input values may change after model run.
sample_inputs_copy = copy.deepcopy(sample_inputs)
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(model)
try:
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(model)
except Exception:
model_copy = model
warnings.warn("This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX."
" Compute will continue, but unexpected results may occur!")
sample_outputs = model_copy(*sample_inputs_copy)
model.train()
if isinstance(sample_outputs, torch.Tensor):
sample_outputs = [sample_outputs]
@ -472,6 +478,7 @@ class ORTTrainer(object):
# Export the model to ONNX
f = io.BytesIO()
# Deepcopy inputs, since input values may change after model run.
sample_inputs_copy = copy.deepcopy(sample_inputs)

View file

@ -7,7 +7,6 @@ import onnx
import os
import pytest
import torch
import torch.nn.functional as F
import onnxruntime
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
@ -454,80 +453,6 @@ def testToyBertCheckpointLoadZero():
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
@pytest.mark.parametrize("loss_scaler, optimizer_config, gradient_accumulation_steps", [
(None, optim.AdamConfig(), 1),
(None, optim.LambConfig(), 1),
(None, optim.SGDConfig(), 1),
(amp.DynamicLossScaler(), optim.AdamConfig(), 1),
(amp.DynamicLossScaler(), optim.LambConfig(), 5),
#(amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16
])
def testToyBertStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps):
# Common setup
seed = 1
class LinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 4)
def forward(self, y=None, x=None):
if y is not None:
return self.linear(x) + y
else:
return self.linear(x) + torch.ones(2, 4)
model_desc = {'inputs' : [('x', [2, 2]),
('label', [2, ])],
'outputs' : [('loss', [], True),
('output', [2, 4])]}
# Dummy data
data1 = torch.randn(2, 2)
label1 = torch.tensor([0, 1], dtype=torch.int64)
data2 = torch.randn(2, 2)
label2 = torch.tensor([0, 1], dtype=torch.int64)
# Setup training based on test parameters
opts = {'debug' : {'deterministic_compute': True},
'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps}}
if loss_scaler:
opts['mixed_precision'] = { 'enabled': True, 'loss_scaler': loss_scaler}
opts = orttrainer.ORTTrainerOptions(opts)
# Training session 1
torch.manual_seed(seed)
onnxruntime.set_seed(seed)
pt_model = LinearModel()
def loss_fn(x, label):
return F.nll_loss(F.log_softmax(x, dim=1), label)
trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts)
# Check state_dict keys before train. Must be empty
state_dict = checkpoint.experimental_state_dict(trainer)
assert state_dict == {}
# Train once and check initial state
trainer.train_step(x=data1, label=label1)
state_dict = checkpoint.experimental_state_dict(trainer)
assert all([weight in state_dict.keys() for weight in ['linear.bias', 'linear.weight']])
# Initialize training session 2 from state of Training 1
torch.manual_seed(seed)
onnxruntime.set_seed(seed)
trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts)
checkpoint.experimental_load_state_dict(trainer2, state_dict)
# Verify state was loaded properly
for k,v in state_dict.items():
assert_allclose(v, trainer2._state_dict[k])
# Perform a second step in both training session 1 and 2 and verify they match
trainer.train_step(x=data2, label=label2)
state_dict = checkpoint.experimental_state_dict(trainer)
trainer2.train_step(x=data2, label=label2)
state_dict2 = checkpoint.experimental_state_dict(trainer2)
for k,v in state_dict.items():
assert_allclose(v, state_dict2[k])
def testToyBertCheckpointFrozenWeights():
# Common setup
seed = 1

View file

@ -6,14 +6,14 @@ import onnx
import os
import pytest
import torch
import torch.nn.functional as F
from onnxruntime import set_seed
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
ModelDescription as Legacy_ModelDescription,\
LossScaler as Legacy_LossScaler,\
ORTTrainer as Legacy_ORTTrainer
from onnxruntime.experimental import _utils, amp, optim, orttrainer, TrainStepInfo,\
from onnxruntime.experimental import _utils, amp, checkpoint, optim, orttrainer, TrainStepInfo,\
model_desc_validation as md_val,\
orttrainer_options as orttrainer_options
import _test_commons,_test_helpers
@ -850,6 +850,121 @@ def testORTTrainerFrozenWeights(model_params):
assert not all([param in session_state for param in model_params])
@pytest.mark.parametrize("loss_scaler, optimizer_config, gradient_accumulation_steps", [
(None, optim.AdamConfig(), 1),
(None, optim.LambConfig(), 1),
(None, optim.SGDConfig(), 1),
(amp.DynamicLossScaler(), optim.AdamConfig(), 1),
(amp.DynamicLossScaler(), optim.LambConfig(), 5),
#(amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16
])
def testORTTrainerStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps):
# Common setup
seed = 1
class LinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 4)
def forward(self, y=None, x=None):
if y is not None:
return self.linear(x) + y
else:
return self.linear(x) + torch.ones(2, 4)
model_desc = {'inputs' : [('x', [2, 2]),
('label', [2, ])],
'outputs' : [('loss', [], True),
('output', [2, 4])]}
# Dummy data
data1 = torch.randn(2, 2)
label1 = torch.tensor([0, 1], dtype=torch.int64)
data2 = torch.randn(2, 2)
label2 = torch.tensor([0, 1], dtype=torch.int64)
# Setup training based on test parameters
opts = {'debug' : {'deterministic_compute': True},
'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps}}
if loss_scaler:
opts['mixed_precision'] = { 'enabled': True, 'loss_scaler': loss_scaler}
opts = orttrainer.ORTTrainerOptions(opts)
# Training session 1
torch.manual_seed(seed)
set_seed(seed)
pt_model = LinearModel()
def loss_fn(x, label):
return F.nll_loss(F.log_softmax(x, dim=1), label)
trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts)
# Check state_dict keys before train. Must be empty
state_dict = checkpoint.experimental_state_dict(trainer)
assert state_dict == {}
# Train once and check initial state
trainer.train_step(x=data1, label=label1)
state_dict = checkpoint.experimental_state_dict(trainer)
assert all([weight in state_dict.keys() for weight in ['linear.bias', 'linear.weight']])
# Initialize training session 2 from state of Training 1
torch.manual_seed(seed)
set_seed(seed)
trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts)
checkpoint.experimental_load_state_dict(trainer2, state_dict)
# Verify state was loaded properly
for k,v in state_dict.items():
assert_allclose(v, trainer2._state_dict[k])
# Perform a second step in both training session 1 and 2 and verify they match
trainer.train_step(x=data2, label=label2)
state_dict = checkpoint.experimental_state_dict(trainer)
trainer2.train_step(x=data2, label=label2)
state_dict2 = checkpoint.experimental_state_dict(trainer2)
for k,v in state_dict.items():
assert_allclose(v, state_dict2[k])
def testORTTrainerNonPickableModel():
# Common setup
import threading
seed = 1
class UnpickableModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 4)
self._lock = threading.Lock()
def forward(self, y=None, x=None):
with self._lock:
if y is not None:
return self.linear(x) + y
else:
return self.linear(x) + torch.ones(2, 4)
model_desc = {'inputs' : [('x', [2, 2]),
('label', [2, ])],
'outputs' : [('loss', [], True),
('output', [2, 4])]}
# Dummy data
data = torch.randn(2, 2)
label = torch.tensor([0, 1], dtype=torch.int64)
# Setup training based on test parameters
opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}})
# Training session
torch.manual_seed(seed)
set_seed(seed)
pt_model = UnpickableModel()
def loss_fn(x, label):
return F.nll_loss(F.log_softmax(x, dim=1), label)
optim_config = optim.AdamConfig()
trainer = orttrainer.ORTTrainer(pt_model, model_desc, optim_config, loss_fn=loss_fn, options=opts)
# Train must succeed despite warning
_, _ = trainer.train_step(data, label)
###############################################################################
# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############
###############################################################################