mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Fix state_dict and save_as_onnx for training (#3774)
This commit is contained in:
parent
5dfc91db51
commit
f7ff5a7aa1
2 changed files with 86 additions and 26 deletions
|
|
@ -82,7 +82,7 @@ def create_ort_trainer(gradient_accumulation_steps,
|
|||
use_mixed_precision=use_mixed_precision,
|
||||
allreduce_post_accumulation=allreduce_post_accumulation,
|
||||
partition_optimizer = partition_optimizer)
|
||||
|
||||
|
||||
return model, model_desc, device
|
||||
|
||||
def runBertTrainingTest(gradient_accumulation_steps,
|
||||
|
|
@ -92,7 +92,7 @@ def runBertTrainingTest(gradient_accumulation_steps,
|
|||
use_internel_loss_scale=False):
|
||||
torch.manual_seed(1)
|
||||
onnxruntime.set_seed(1)
|
||||
|
||||
|
||||
loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None
|
||||
|
||||
model, model_desc, device = create_ort_trainer(gradient_accumulation_steps,
|
||||
|
|
@ -237,7 +237,7 @@ class MNISTWrapper():
|
|||
kwargs = {'num_workers': 0, 'pin_memory': True}
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('../data', train=True, download=True,
|
||||
transform=transforms.Compose([transforms.ToTensor(),
|
||||
transform=transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args_batch_size, shuffle=True, **kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
|
|
@ -259,7 +259,7 @@ class MNISTWrapper():
|
|||
return model, model_desc
|
||||
|
||||
def get_trainer(self, model, model_desc, device):
|
||||
return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
|
||||
return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
|
||||
torch.float32), device, _opset_version=12)
|
||||
|
||||
class TestOrtTrainer(unittest.TestCase):
|
||||
|
|
@ -355,6 +355,57 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
assert_allclose(expected_test_losses, actual_test_losses, rtol=rtol, err_msg="test loss mismatch")
|
||||
assert_allclose(expected_test_accuracies, actual_accuracies, rtol=rtol, err_msg="test accuracy mismatch")
|
||||
|
||||
def testMNISTStateDict(self):
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda")
|
||||
|
||||
mnist = MNISTWrapper()
|
||||
train_loader, test_loader = mnist.get_loaders()
|
||||
model, model_desc = mnist.get_model()
|
||||
|
||||
trainer = mnist.get_trainer(model, model_desc, device)
|
||||
state_dict = trainer.state_dict()
|
||||
assert state_dict == {}
|
||||
|
||||
learningRate = 0.02
|
||||
epoch = 0
|
||||
|
||||
data, target = next(iter(train_loader))
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
|
||||
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
|
||||
|
||||
state_dict = trainer.state_dict()
|
||||
assert state_dict.keys() == {'model_.fc1.bias', 'model_.fc1.weight', 'model_.fc2.bias', 'model_.fc2.weight'}
|
||||
|
||||
def testMNISTSaveAsONNX(self):
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda")
|
||||
onnx_file_name = 'mnist.onnx'
|
||||
if os.path.exists(onnx_file_name):
|
||||
os.remove(onnx_file_name)
|
||||
|
||||
mnist = MNISTWrapper()
|
||||
train_loader, test_loader = mnist.get_loaders()
|
||||
model, model_desc = mnist.get_model()
|
||||
|
||||
trainer = mnist.get_trainer(model, model_desc, device)
|
||||
trainer.save_as_onnx(onnx_file_name)
|
||||
assert not os.path.exists(onnx_file_name)
|
||||
|
||||
learningRate = 0.02
|
||||
epoch = 0
|
||||
|
||||
data, target = next(iter(train_loader))
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
|
||||
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
|
||||
|
||||
trainer.save_as_onnx(onnx_file_name)
|
||||
assert os.path.exists(onnx_file_name)
|
||||
|
||||
def testBertTrainingBasic(self):
|
||||
expected_losses = [
|
||||
11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543,
|
||||
|
|
@ -379,7 +430,7 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
11.02906322479248, 11.094074249267578, 11.008995056152344, 11.061283111572266,
|
||||
11.029059410095215, 11.04024887084961, 11.04680347442627, 10.993708610534668]
|
||||
expected_eval_loss = [10.959011]
|
||||
|
||||
|
||||
actual_losses, actual_eval_loss = runBertTrainingTest(
|
||||
gradient_accumulation_steps=4, use_mixed_precision=False, allreduce_post_accumulation=False)
|
||||
|
||||
|
|
@ -436,7 +487,7 @@ class TestOrtTrainer(unittest.TestCase):
|
|||
|
||||
ckpt_dir = get_name("ort_ckpt")
|
||||
load_checkpoint(model, ckpt_dir, 'bert_toy_lamb')
|
||||
|
||||
|
||||
expected_eval_loss = [10.997552871]
|
||||
|
||||
input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[ 635],[27399],[20647],[18519],[15546]], device=device)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import io
|
||||
import os
|
||||
import warnings
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
|
|
@ -92,7 +93,7 @@ def ort_training_session_run_helper(session, iobinding, inputs, input_descs, out
|
|||
device_index = input_get_device_index(input)
|
||||
iobinding.bind_input(input_desc.name_, input.device.type, device_index, dtype_torch_to_numpy(input.dtype),
|
||||
list(input.size()), input.data_ptr())
|
||||
|
||||
|
||||
output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs)
|
||||
torch_outputs = {}
|
||||
for output_desc in output_descs_resolved:
|
||||
|
|
@ -526,11 +527,11 @@ def save_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", c
|
|||
|
||||
assert os.path.exists(checkpoint_dir), "ERROR: Checkpoint directory doesn't exist: {}".format(checkpoint_dir)
|
||||
|
||||
checkpoint_name = get_checkpoint_name(checkpoint_prefix, model.partition_optimizer_, model.world_rank, model.world_size)
|
||||
checkpoint_name = get_checkpoint_name(checkpoint_prefix, model.partition_optimizer_, model.world_rank, model.world_size)
|
||||
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
|
||||
|
||||
if os.path.exists(checkpoint_file):
|
||||
print("WARNING: {} already exists, overwriting.".format(checkpoint_file))
|
||||
warnings.warn("{} already exists, overwriting.".format(checkpoint_file))
|
||||
|
||||
torch.save(checkpoint_state_dict, checkpoint_file)
|
||||
|
||||
|
|
@ -538,7 +539,7 @@ def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partiti
|
|||
checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size)
|
||||
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name)
|
||||
|
||||
if is_partitioned:
|
||||
if is_partitioned:
|
||||
assert_msg = ("Couldn't find checkpoint file {}." +
|
||||
"Optimizer partitioning is enabled using ZeRO. Please make sure that the "+
|
||||
"checkpoint file exists for rank {} of {}.").format(checkpoint_file,model.world_rank, model.world_size)
|
||||
|
|
@ -561,24 +562,24 @@ def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict):
|
|||
|
||||
model.load_state_dict(aggregate_state_dict, strict=strict)
|
||||
|
||||
# aggregate other keys in the state_dict.
|
||||
# aggregate other keys in the state_dict.
|
||||
# Values will be overwritten for matching keys among workers
|
||||
all_checkpoint_states=dict()
|
||||
for checkpoint_file in checkpoint_files:
|
||||
checkpoint_state = torch.load(checkpoint_file, map_location='cpu')
|
||||
del(checkpoint_state['model'])
|
||||
all_checkpoint_states.update(checkpoint_state)
|
||||
all_checkpoint_states.update(checkpoint_state)
|
||||
return all_checkpoint_states
|
||||
|
||||
def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False):
|
||||
checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix)
|
||||
is_partitioned = False
|
||||
if len(checkpoint_files) > 1:
|
||||
print(f"WARNING: Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." +
|
||||
warnings.warn(f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." +
|
||||
"Attempting to load ZeRO checkpoint.")
|
||||
is_partitioned = True
|
||||
if (not model.partition_optimizer_) and is_partitioned:
|
||||
return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict)
|
||||
return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict)
|
||||
else:
|
||||
return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)
|
||||
|
||||
|
|
@ -629,7 +630,7 @@ class ORTTrainer():
|
|||
else:
|
||||
self.onnx_model_ = model
|
||||
if loss_fn is not None:
|
||||
print("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.")
|
||||
warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.")
|
||||
# TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn
|
||||
self.loss_fn_ = None
|
||||
|
||||
|
|
@ -657,7 +658,7 @@ class ORTTrainer():
|
|||
self.loss_scaler_ = loss_scaler
|
||||
|
||||
if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None:
|
||||
print("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.")
|
||||
warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.")
|
||||
self.training_optimizer_name_ = training_optimizer_name
|
||||
self.learning_rate_description_ = learning_rate_description
|
||||
self.map_optimizer_attributes_ = map_optimizer_attributes
|
||||
|
|
@ -678,7 +679,7 @@ class ORTTrainer():
|
|||
if self.onnx_model_ is None:
|
||||
return
|
||||
|
||||
self.verify_fully_optimized_model(self.onnx_model_)
|
||||
self._verify_fully_optimized_model(self.onnx_model_)
|
||||
self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \
|
||||
create_ort_training_session_with_optimizer(
|
||||
self.onnx_model_, self.device_,
|
||||
|
|
@ -719,7 +720,7 @@ class ORTTrainer():
|
|||
IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)]
|
||||
|
||||
if self.state_dict_:
|
||||
self.load_state_dict(self.state_dict_, self.strict_)
|
||||
self.load_state_dict(self.state_dict_, self.strict_)
|
||||
self.state_dict_ = None
|
||||
|
||||
def _init_onnx_model(self, inputs):
|
||||
|
|
@ -742,6 +743,10 @@ class ORTTrainer():
|
|||
self.is_train = False
|
||||
|
||||
def state_dict(self):
|
||||
if not self.session:
|
||||
warnings.warn("ONNXRuntime training session is not initialized yet. "
|
||||
"Please run train_step or eval_step at least once before calling state_dict().")
|
||||
return {}
|
||||
session_state = self.session.get_state()
|
||||
torch_state = {}
|
||||
for name in session_state:
|
||||
|
|
@ -763,6 +768,10 @@ class ORTTrainer():
|
|||
self.session.load_state(session_state, strict)
|
||||
|
||||
def save_as_onnx(self, path):
|
||||
if not self.session:
|
||||
warnings.warn("ONNXRuntime training session is not initialized yet. "
|
||||
"Please run train_step or eval_step at least once before calling save_as_onnx().")
|
||||
return
|
||||
state_tensors = self.session.get_state()
|
||||
# replace the initializers with new value
|
||||
new_weights = []
|
||||
|
|
@ -781,7 +790,7 @@ class ORTTrainer():
|
|||
with open(path, "wb") as f:
|
||||
f.write(self.onnx_model_.SerializeToString())
|
||||
|
||||
def prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs):
|
||||
def _prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs):
|
||||
fetches = None
|
||||
if type(args) == tuple and len(args) == 1 and type(args[0]) == list:
|
||||
input = tuple(args[0])
|
||||
|
|
@ -842,17 +851,17 @@ class ORTTrainer():
|
|||
loss_scale = torch.tensor([self.loss_scaler_.loss_scale_])
|
||||
|
||||
if self.onnx_model_ is None:
|
||||
sample_input, _ = self.prepare_input_and_fetches(self.model_desc_.inputs_,
|
||||
sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_,
|
||||
None, None, *args, **kwargs)
|
||||
self._init_onnx_model(sample_input)
|
||||
|
||||
if self.use_mixed_precision:
|
||||
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale,
|
||||
input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale,
|
||||
learning_rate, loss_scale, *args, **kwargs)
|
||||
assert len(self.input_desc_with_lr_and_loss_scale) == len(input)
|
||||
input_descs = self.input_desc_with_lr_and_loss_scale
|
||||
else:
|
||||
input, fetches = self.prepare_input_and_fetches(self.input_desc_with_lr,
|
||||
input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr,
|
||||
learning_rate, loss_scale, *args, **kwargs)
|
||||
assert len(self.input_desc_with_lr) == len(input)
|
||||
input_descs = self.input_desc_with_lr
|
||||
|
|
@ -925,8 +934,8 @@ class ORTTrainer():
|
|||
"""
|
||||
|
||||
# with model_loss_cls, the last input is label, first output is loss
|
||||
input, fetches = self.prepare_input_and_fetches(self.model_desc_.inputs_,
|
||||
None, None, *args, **kwargs)
|
||||
input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_,
|
||||
None, None, *args, **kwargs)
|
||||
|
||||
if self.onnx_model_ is None:
|
||||
if self.torch_model_ is not None:
|
||||
|
|
@ -958,7 +967,7 @@ class ORTTrainer():
|
|||
else:
|
||||
return [session_run_results[output_desc.name_] for output_desc in output_desc]
|
||||
|
||||
def verify_fully_optimized_model(self, model):
|
||||
def _verify_fully_optimized_model(self, model):
|
||||
assert(len(model.graph.output) > 0)
|
||||
# model's first output must be the loss tensor
|
||||
if model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT and\
|
||||
|
|
@ -983,7 +992,7 @@ class ORTModel():
|
|||
self.world_size = world_size
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.opset_version = _opset_version
|
||||
|
||||
|
||||
# Adding to not break checkpointing functions for ORTModel
|
||||
self.partition_optimizer_ = False
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue