mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Fix checkpoint API and copy samples into build dir (#4887)
* Fix state_dict APIs * Copy samples to build folder and fix CI
This commit is contained in:
parent
6260d073b3
commit
dce2ce7a4f
5 changed files with 33 additions and 34 deletions
|
|
@ -570,6 +570,9 @@ set(all_dependencies ${onnxruntime_test_providers_dependencies} )
|
|||
set(TEST_DATA_SRC ${TEST_SRC_DIR}/testdata)
|
||||
set(TEST_DATA_DES $<TARGET_FILE_DIR:${test_data_target}>/testdata)
|
||||
|
||||
set(TEST_SAMPLES_SRC ${REPO_ROOT}/samples)
|
||||
set(TEST_SAMPLES_DES $<TARGET_FILE_DIR:${test_data_target}>/samples)
|
||||
|
||||
# Copy test data from source to destination.
|
||||
add_custom_command(
|
||||
TARGET ${test_data_target} POST_BUILD
|
||||
|
|
@ -577,6 +580,13 @@ add_custom_command(
|
|||
${TEST_DATA_SRC}
|
||||
${TEST_DATA_DES})
|
||||
|
||||
# Copy test samples from source to destination.
|
||||
add_custom_command(
|
||||
TARGET ${test_data_target} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
${TEST_SAMPLES_SRC}
|
||||
${TEST_SAMPLES_DES})
|
||||
|
||||
if (onnxruntime_USE_DNNL)
|
||||
list(APPEND onnx_test_libs dnnl)
|
||||
add_custom_command(
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ def experimental_load_state_dict(ort_trainer, state_dict, strict=False):
|
|||
# In this case we cache a reference to desired state and delay the restore until after initialization
|
||||
# Unexpected behavior will result if the user changes the reference before initialization
|
||||
if not ort_trainer._training_session:
|
||||
ort_trainer.state_dict_ = state_dict
|
||||
ort_trainer.strict_ = strict
|
||||
ort_trainer._state_dict = state_dict
|
||||
ort_trainer._load_state_dict_strict = strict
|
||||
return
|
||||
|
||||
# update onnx model from loaded state dict
|
||||
|
|
@ -55,7 +55,7 @@ def experimental_load_state_dict(ort_trainer, state_dict, strict=False):
|
|||
ort_trainer._update_onnx_model_initializers(new_initializers)
|
||||
|
||||
# create new session based on updated onnx model
|
||||
ort_trainer.state_dict_ = None
|
||||
ort_trainer._state_dict = None
|
||||
ort_trainer._init_session()
|
||||
|
||||
# load training state
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from inspect import signature
|
|||
import warnings
|
||||
|
||||
import onnxruntime as ort
|
||||
from . import _utils, amp, optim, postprocess, ORTTrainerOptions
|
||||
from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions
|
||||
from .model_desc_validation import _ORTTrainerModelDesc
|
||||
|
||||
class TrainStepInfo(object):
|
||||
|
|
@ -276,8 +276,8 @@ class ORTTrainer(object):
|
|||
assert isinstance(path, str), "'path' must be a valid path string"
|
||||
dir_name = os.path.dirname(path)
|
||||
file_name = os.path.basename(path)
|
||||
if not dir_name or not os.path.exists(dir_name) or not file_name:
|
||||
warnings.warn("'path' is not valid. It must contain an existing folder + filename")
|
||||
if (dir_name and not os.path.exists(dir_name)) or not file_name:
|
||||
warnings.warn("'path' is not valid or does not exist")
|
||||
return
|
||||
|
||||
with open(path, "wb") as f:
|
||||
|
|
@ -393,13 +393,6 @@ class ORTTrainer(object):
|
|||
ordered_input_list = [*ordered_input_list,
|
||||
list(sig_loss.parameters.keys())[1]]
|
||||
|
||||
# Check whether input names from model match inputs from ModelDescription
|
||||
match = True
|
||||
for ordered_list_key, input_name in zip(ordered_input_list, input_names):
|
||||
if ordered_list_key != input_name:
|
||||
match = False
|
||||
break
|
||||
|
||||
class CombineTorchModelLossFnWrapInput(torch.nn.Module):
|
||||
def __init__(self, model, loss_fn, input_names):
|
||||
super().__init__()
|
||||
|
|
@ -409,7 +402,6 @@ class ORTTrainer(object):
|
|||
|
||||
def forward(self, *inputs):
|
||||
sig = signature(self.model.forward)
|
||||
ordered_list_keys = list(sig.parameters.keys())
|
||||
|
||||
input_dict = {}
|
||||
for key in sig.parameters.keys():
|
||||
|
|
@ -642,7 +634,7 @@ class ORTTrainer(object):
|
|||
|
||||
# TODO: thiagofc: Checkpoint related for redesign
|
||||
if self._state_dict:
|
||||
self.load_state_dict(self._state_dict, self._load_state_dict_strict)
|
||||
checkpoint.load_state_dict(self, self._state_dict, self._load_state_dict_strict)
|
||||
self._state_dict = None
|
||||
|
||||
def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs):
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def generate_random_input_from_model_desc(desc, seed=1, device = "cuda:0"):
|
|||
size.append(s)
|
||||
else:
|
||||
size.append(dims[s] if s in dims else 1)
|
||||
sample_input.append(torch.randint(0, num_classes[index], tuple(size), dtype=torch.int64).to(device))
|
||||
sample_input.append(torch.randint(0, num_classes[index], tuple(size), dtype=dtype).to(device))
|
||||
return sample_input
|
||||
|
||||
# EXPERIMENTAL HELPER FUNCTIONS
|
||||
|
|
@ -84,7 +84,7 @@ def optimizer_parameters(model):
|
|||
|
||||
|
||||
def load_bert_onnx_model():
|
||||
bert_onnx_model_path = os.path.join('..', '..', '..', 'onnxruntime', 'test', 'testdata', "bert_toy_postprocessed.onnx")
|
||||
bert_onnx_model_path = os.path.join('testdata', "bert_toy_postprocessed.onnx")
|
||||
model = onnx.load(bert_onnx_model_path)
|
||||
return model
|
||||
|
||||
|
|
@ -564,7 +564,7 @@ def testORTTrainerFrozenWeights(model_params):
|
|||
|
||||
def testToyBERTSaveAsONNX():
|
||||
device = 'cuda'
|
||||
onnx_file_name = os.path.join('..','..','..','temp_toy_bert_onnx_model.onnx')
|
||||
onnx_file_name = '_____temp_toy_bert_onnx_model.onnx'
|
||||
if os.path.exists(onnx_file_name):
|
||||
os.remove(onnx_file_name)
|
||||
assert not os.path.exists(onnx_file_name)
|
||||
|
|
@ -583,7 +583,7 @@ def testToyBERTSaveAsONNX():
|
|||
},
|
||||
})
|
||||
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config)#, options=opts)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
|
||||
|
||||
trainer.save_as_onnx(onnx_file_name)
|
||||
assert os.path.exists(onnx_file_name)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import _test_commons,_test_helpers
|
|||
|
||||
def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False):
|
||||
# Loads external Pytorch TransformerModel into utils
|
||||
pytorch_transformer_path = os.path.join('..', '..', '..', 'samples', 'python', 'pytorch_transformer')
|
||||
pytorch_transformer_path = os.path.join('samples', 'python', 'pytorch_transformer')
|
||||
pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py')
|
||||
pt_model = _utils.import_module_from_file(pt_model_path)
|
||||
ort_utils_path = os.path.join(pytorch_transformer_path, 'ort_utils.py')
|
||||
|
|
@ -622,7 +622,7 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device)
|
|||
trainer._onnx_model.graph.output[i].type.tensor_type.elem_type)
|
||||
|
||||
# Save current model as ONNX as a file
|
||||
file_name = os.path.join('..','..','..','temp_onnx_model.onnx')
|
||||
file_name = os.path.join('_____temp_onnx_model.onnx')
|
||||
trainer.save_as_onnx(file_name)
|
||||
assert os.path.exists(file_name)
|
||||
with open(file_name, "rb") as f:
|
||||
|
|
@ -660,7 +660,7 @@ def testORTDeterministicCompute(seed, device):
|
|||
# Setup for the first ORTTRainer run
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
|
||||
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
|
||||
data, targets = batcher_fn(train_data, 0)
|
||||
_ = first_trainer.train_step(data, targets)
|
||||
|
|
@ -687,7 +687,6 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches)
|
|||
total_steps = len(expected_loss)
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
bptt=35
|
||||
|
||||
# Setup ORTTrainer
|
||||
loss_scaler = amp.DynamicLossScaler()
|
||||
|
|
@ -717,7 +716,7 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches)
|
|||
trainer._train_step_info.fetches=['loss']
|
||||
loss = trainer.eval_step(val_data, val_targets)
|
||||
trainer._train_step_info.fetches=[]
|
||||
loss, preds = trainer.eval_step(val_data, val_targets)
|
||||
loss, _ = trainer.eval_step(val_data, val_targets)
|
||||
|
||||
# Compare loss to ground truth computed from current ORTTrainer API
|
||||
_test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=1e-4)
|
||||
|
|
@ -742,7 +741,7 @@ def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps
|
|||
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
|
||||
'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps},
|
||||
'debug' : {'deterministic_compute' : True}})
|
||||
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
|
||||
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
optim_config = optim.LambConfig(lr=0.001)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
|
||||
|
||||
|
|
@ -768,7 +767,7 @@ def testORTTrainerDynamicShape(dynamic_axes):
|
|||
# Setup ORTTrainer
|
||||
options = orttrainer.ORTTrainerOptions({})
|
||||
model, model_desc, my_loss, batcher_fn,\
|
||||
train_data, val_data, _ = _load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes)
|
||||
train_data, _, _ = _load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes)
|
||||
optim_config = optim.LambConfig(lr=0.001)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
|
||||
|
||||
|
|
@ -796,7 +795,7 @@ def testORTTrainerFrozenWeights(model_params):
|
|||
# Setup ORTTrainer WITHOUT frozen weights
|
||||
options = orttrainer.ORTTrainerOptions({})
|
||||
model, model_desc, my_loss, batcher_fn,\
|
||||
train_data, val_data, _ = _load_pytorch_transformer_model(device)
|
||||
train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
optim_config = optim.LambConfig(lr=0.001)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
|
||||
for i in range(total_steps):
|
||||
|
|
@ -834,7 +833,6 @@ def testORTTrainerFrozenWeights(model_params):
|
|||
def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
|
||||
# Common data
|
||||
total_steps = 5
|
||||
bptt = 35
|
||||
|
||||
# Setup for the experimental ORTTRainer run
|
||||
torch.manual_seed(seed)
|
||||
|
|
@ -848,7 +846,7 @@ def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
|
|||
'deterministic_compute': True
|
||||
},
|
||||
})
|
||||
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
|
||||
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
|
||||
# Training loop
|
||||
for i in range(total_steps):
|
||||
|
|
@ -876,7 +874,6 @@ def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
|
|||
def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device):
|
||||
# Common data
|
||||
total_steps = 5
|
||||
bptt=35
|
||||
|
||||
# Setup experimental API
|
||||
torch.manual_seed(seed)
|
||||
|
|
@ -887,7 +884,7 @@ def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device):
|
|||
'enabled' : True,
|
||||
'loss_scaler' : loss_scaler},
|
||||
'debug' : {'deterministic_compute' : True,}})
|
||||
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
|
||||
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
optim_config = optim.LambConfig(lr=0.001)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
|
||||
# Training loop
|
||||
|
|
@ -940,14 +937,14 @@ def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradie
|
|||
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
|
||||
'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps},
|
||||
'debug' : {'deterministic_compute' : True}})
|
||||
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
|
||||
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
optim_config = optim.LambConfig(lr=0.001)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
|
||||
# Training loop
|
||||
experimental_loss = []
|
||||
for i in range(total_steps):
|
||||
data, targets = batcher_fn(train_data, i)
|
||||
exp_loss, exp_preds = trainer.train_step(data, targets)
|
||||
exp_loss, _ = trainer.train_step(data, targets)
|
||||
experimental_loss.append(exp_loss.cpu())
|
||||
|
||||
# Setup legacy API
|
||||
|
|
@ -962,7 +959,7 @@ def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradie
|
|||
legacy_loss = []
|
||||
for i in range(total_steps):
|
||||
data, targets = batcher_fn(train_data, i)
|
||||
leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
|
||||
leg_loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
|
||||
legacy_loss.append(leg_loss.cpu())
|
||||
|
||||
# Compare legacy vs experimental APIs
|
||||
|
|
|
|||
Loading…
Reference in a new issue