mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Add warning to non pickable models (#5037)
This commit is contained in:
parent
9d1bdef195
commit
9388d49c0d
3 changed files with 126 additions and 79 deletions
|
|
@ -453,10 +453,16 @@ class ORTTrainer(object):
|
|||
with torch.no_grad():
|
||||
# Deepcopy inputs, since input values may change after model run.
|
||||
sample_inputs_copy = copy.deepcopy(sample_inputs)
|
||||
# Deepcopy model, in case model is stateful and changes after model run.
|
||||
model_copy = copy.deepcopy(model)
|
||||
try:
|
||||
# Deepcopy model, in case model is stateful and changes after model run.
|
||||
model_copy = copy.deepcopy(model)
|
||||
except Exception:
|
||||
model_copy = model
|
||||
warnings.warn("This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX."
|
||||
" Compute will continue, but unexpected results may occur!")
|
||||
sample_outputs = model_copy(*sample_inputs_copy)
|
||||
model.train()
|
||||
|
||||
if isinstance(sample_outputs, torch.Tensor):
|
||||
sample_outputs = [sample_outputs]
|
||||
|
||||
|
|
@ -472,6 +478,7 @@ class ORTTrainer(object):
|
|||
|
||||
# Export the model to ONNX
|
||||
f = io.BytesIO()
|
||||
|
||||
# Deepcopy inputs, since input values may change after model run.
|
||||
sample_inputs_copy = copy.deepcopy(sample_inputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import onnx
|
|||
import os
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
|
||||
|
|
@ -454,80 +453,6 @@ def testToyBertCheckpointLoadZero():
|
|||
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("loss_scaler, optimizer_config, gradient_accumulation_steps", [
|
||||
(None, optim.AdamConfig(), 1),
|
||||
(None, optim.LambConfig(), 1),
|
||||
(None, optim.SGDConfig(), 1),
|
||||
(amp.DynamicLossScaler(), optim.AdamConfig(), 1),
|
||||
(amp.DynamicLossScaler(), optim.LambConfig(), 5),
|
||||
#(amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16
|
||||
])
|
||||
def testToyBertStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps):
|
||||
# Common setup
|
||||
seed = 1
|
||||
class LinearModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 4)
|
||||
def forward(self, y=None, x=None):
|
||||
if y is not None:
|
||||
return self.linear(x) + y
|
||||
else:
|
||||
return self.linear(x) + torch.ones(2, 4)
|
||||
model_desc = {'inputs' : [('x', [2, 2]),
|
||||
('label', [2, ])],
|
||||
'outputs' : [('loss', [], True),
|
||||
('output', [2, 4])]}
|
||||
|
||||
# Dummy data
|
||||
data1 = torch.randn(2, 2)
|
||||
label1 = torch.tensor([0, 1], dtype=torch.int64)
|
||||
data2 = torch.randn(2, 2)
|
||||
label2 = torch.tensor([0, 1], dtype=torch.int64)
|
||||
|
||||
# Setup training based on test parameters
|
||||
opts = {'debug' : {'deterministic_compute': True},
|
||||
'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps}}
|
||||
if loss_scaler:
|
||||
opts['mixed_precision'] = { 'enabled': True, 'loss_scaler': loss_scaler}
|
||||
opts = orttrainer.ORTTrainerOptions(opts)
|
||||
|
||||
# Training session 1
|
||||
torch.manual_seed(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
pt_model = LinearModel()
|
||||
def loss_fn(x, label):
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), label)
|
||||
trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts)
|
||||
|
||||
# Check state_dict keys before train. Must be empty
|
||||
state_dict = checkpoint.experimental_state_dict(trainer)
|
||||
assert state_dict == {}
|
||||
|
||||
# Train once and check initial state
|
||||
trainer.train_step(x=data1, label=label1)
|
||||
state_dict = checkpoint.experimental_state_dict(trainer)
|
||||
assert all([weight in state_dict.keys() for weight in ['linear.bias', 'linear.weight']])
|
||||
|
||||
# Initialize training session 2 from state of Training 1
|
||||
torch.manual_seed(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts)
|
||||
checkpoint.experimental_load_state_dict(trainer2, state_dict)
|
||||
|
||||
# Verify state was loaded properly
|
||||
for k,v in state_dict.items():
|
||||
assert_allclose(v, trainer2._state_dict[k])
|
||||
|
||||
# Perform a second step in both training session 1 and 2 and verify they match
|
||||
trainer.train_step(x=data2, label=label2)
|
||||
state_dict = checkpoint.experimental_state_dict(trainer)
|
||||
trainer2.train_step(x=data2, label=label2)
|
||||
state_dict2 = checkpoint.experimental_state_dict(trainer2)
|
||||
for k,v in state_dict.items():
|
||||
assert_allclose(v, state_dict2[k])
|
||||
|
||||
|
||||
def testToyBertCheckpointFrozenWeights():
|
||||
# Common setup
|
||||
seed = 1
|
||||
|
|
|
|||
|
|
@ -6,14 +6,14 @@ import onnx
|
|||
import os
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from onnxruntime import set_seed
|
||||
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
|
||||
ModelDescription as Legacy_ModelDescription,\
|
||||
LossScaler as Legacy_LossScaler,\
|
||||
ORTTrainer as Legacy_ORTTrainer
|
||||
from onnxruntime.experimental import _utils, amp, optim, orttrainer, TrainStepInfo,\
|
||||
from onnxruntime.experimental import _utils, amp, checkpoint, optim, orttrainer, TrainStepInfo,\
|
||||
model_desc_validation as md_val,\
|
||||
orttrainer_options as orttrainer_options
|
||||
import _test_commons,_test_helpers
|
||||
|
|
@ -850,6 +850,121 @@ def testORTTrainerFrozenWeights(model_params):
|
|||
assert not all([param in session_state for param in model_params])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("loss_scaler, optimizer_config, gradient_accumulation_steps", [
|
||||
(None, optim.AdamConfig(), 1),
|
||||
(None, optim.LambConfig(), 1),
|
||||
(None, optim.SGDConfig(), 1),
|
||||
(amp.DynamicLossScaler(), optim.AdamConfig(), 1),
|
||||
(amp.DynamicLossScaler(), optim.LambConfig(), 5),
|
||||
#(amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16
|
||||
])
|
||||
def testORTTrainerStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps):
|
||||
# Common setup
|
||||
seed = 1
|
||||
class LinearModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 4)
|
||||
def forward(self, y=None, x=None):
|
||||
if y is not None:
|
||||
return self.linear(x) + y
|
||||
else:
|
||||
return self.linear(x) + torch.ones(2, 4)
|
||||
model_desc = {'inputs' : [('x', [2, 2]),
|
||||
('label', [2, ])],
|
||||
'outputs' : [('loss', [], True),
|
||||
('output', [2, 4])]}
|
||||
|
||||
# Dummy data
|
||||
data1 = torch.randn(2, 2)
|
||||
label1 = torch.tensor([0, 1], dtype=torch.int64)
|
||||
data2 = torch.randn(2, 2)
|
||||
label2 = torch.tensor([0, 1], dtype=torch.int64)
|
||||
|
||||
# Setup training based on test parameters
|
||||
opts = {'debug' : {'deterministic_compute': True},
|
||||
'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps}}
|
||||
if loss_scaler:
|
||||
opts['mixed_precision'] = { 'enabled': True, 'loss_scaler': loss_scaler}
|
||||
opts = orttrainer.ORTTrainerOptions(opts)
|
||||
|
||||
# Training session 1
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
pt_model = LinearModel()
|
||||
def loss_fn(x, label):
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), label)
|
||||
trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts)
|
||||
|
||||
# Check state_dict keys before train. Must be empty
|
||||
state_dict = checkpoint.experimental_state_dict(trainer)
|
||||
assert state_dict == {}
|
||||
|
||||
# Train once and check initial state
|
||||
trainer.train_step(x=data1, label=label1)
|
||||
state_dict = checkpoint.experimental_state_dict(trainer)
|
||||
assert all([weight in state_dict.keys() for weight in ['linear.bias', 'linear.weight']])
|
||||
|
||||
# Initialize training session 2 from state of Training 1
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts)
|
||||
checkpoint.experimental_load_state_dict(trainer2, state_dict)
|
||||
|
||||
# Verify state was loaded properly
|
||||
for k,v in state_dict.items():
|
||||
assert_allclose(v, trainer2._state_dict[k])
|
||||
|
||||
# Perform a second step in both training session 1 and 2 and verify they match
|
||||
trainer.train_step(x=data2, label=label2)
|
||||
state_dict = checkpoint.experimental_state_dict(trainer)
|
||||
trainer2.train_step(x=data2, label=label2)
|
||||
state_dict2 = checkpoint.experimental_state_dict(trainer2)
|
||||
for k,v in state_dict.items():
|
||||
assert_allclose(v, state_dict2[k])
|
||||
|
||||
|
||||
def testORTTrainerNonPickableModel():
|
||||
# Common setup
|
||||
import threading
|
||||
seed = 1
|
||||
class UnpickableModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 4)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def forward(self, y=None, x=None):
|
||||
with self._lock:
|
||||
if y is not None:
|
||||
return self.linear(x) + y
|
||||
else:
|
||||
return self.linear(x) + torch.ones(2, 4)
|
||||
|
||||
model_desc = {'inputs' : [('x', [2, 2]),
|
||||
('label', [2, ])],
|
||||
'outputs' : [('loss', [], True),
|
||||
('output', [2, 4])]}
|
||||
|
||||
# Dummy data
|
||||
data = torch.randn(2, 2)
|
||||
label = torch.tensor([0, 1], dtype=torch.int64)
|
||||
|
||||
# Setup training based on test parameters
|
||||
opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}})
|
||||
|
||||
# Training session
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
pt_model = UnpickableModel()
|
||||
def loss_fn(x, label):
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), label)
|
||||
optim_config = optim.AdamConfig()
|
||||
trainer = orttrainer.ORTTrainer(pt_model, model_desc, optim_config, loss_fn=loss_fn, options=opts)
|
||||
|
||||
# Train must succeed despite warning
|
||||
_, _ = trainer.train_step(data, label)
|
||||
|
||||
###############################################################################
|
||||
# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############
|
||||
###############################################################################
|
||||
|
|
|
|||
Loading…
Reference in a new issue