Fix checkpoint API and copy samples into build dir (#4887)

* Fix state_dict APIs
* Copy samples to build folder and fix CI
This commit is contained in:
Thiago Crepaldi 2020-08-22 00:09:48 -07:00 committed by GitHub
parent 6260d073b3
commit dce2ce7a4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 33 additions and 34 deletions

View file

@ -570,6 +570,9 @@ set(all_dependencies ${onnxruntime_test_providers_dependencies} )
set(TEST_DATA_SRC ${TEST_SRC_DIR}/testdata)
set(TEST_DATA_DES $<TARGET_FILE_DIR:${test_data_target}>/testdata)
set(TEST_SAMPLES_SRC ${REPO_ROOT}/samples)
set(TEST_SAMPLES_DES $<TARGET_FILE_DIR:${test_data_target}>/samples)
# Copy test data from source to destination.
add_custom_command(
TARGET ${test_data_target} POST_BUILD
@ -577,6 +580,13 @@ add_custom_command(
${TEST_DATA_SRC}
${TEST_DATA_DES})
# Copy test samples from source to destination.
add_custom_command(
TARGET ${test_data_target} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
${TEST_SAMPLES_SRC}
${TEST_SAMPLES_DES})
if (onnxruntime_USE_DNNL)
list(APPEND onnx_test_libs dnnl)
add_custom_command(

View file

@ -38,8 +38,8 @@ def experimental_load_state_dict(ort_trainer, state_dict, strict=False):
# In this case we cache a reference to desired state and delay the restore until after initialization
# Unexpected behavior will result if the user changes the reference before initialization
if not ort_trainer._training_session:
ort_trainer.state_dict_ = state_dict
ort_trainer.strict_ = strict
ort_trainer._state_dict = state_dict
ort_trainer._load_state_dict_strict = strict
return
# update onnx model from loaded state dict
@ -55,7 +55,7 @@ def experimental_load_state_dict(ort_trainer, state_dict, strict=False):
ort_trainer._update_onnx_model_initializers(new_initializers)
# create new session based on updated onnx model
ort_trainer.state_dict_ = None
ort_trainer._state_dict = None
ort_trainer._init_session()
# load training state

View file

@ -7,7 +7,7 @@ from inspect import signature
import warnings
import onnxruntime as ort
from . import _utils, amp, optim, postprocess, ORTTrainerOptions
from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions
from .model_desc_validation import _ORTTrainerModelDesc
class TrainStepInfo(object):
@ -276,8 +276,8 @@ class ORTTrainer(object):
assert isinstance(path, str), "'path' must be a valid path string"
dir_name = os.path.dirname(path)
file_name = os.path.basename(path)
if not dir_name or not os.path.exists(dir_name) or not file_name:
warnings.warn("'path' is not valid. It must contain an existing folder + filename")
if (dir_name and not os.path.exists(dir_name)) or not file_name:
warnings.warn("'path' is not valid or does not exist")
return
with open(path, "wb") as f:
@ -393,13 +393,6 @@ class ORTTrainer(object):
ordered_input_list = [*ordered_input_list,
list(sig_loss.parameters.keys())[1]]
# Check whether input names from model match inputs from ModelDescription
match = True
for ordered_list_key, input_name in zip(ordered_input_list, input_names):
if ordered_list_key != input_name:
match = False
break
class CombineTorchModelLossFnWrapInput(torch.nn.Module):
def __init__(self, model, loss_fn, input_names):
super().__init__()
@ -409,7 +402,6 @@ class ORTTrainer(object):
def forward(self, *inputs):
sig = signature(self.model.forward)
ordered_list_keys = list(sig.parameters.keys())
input_dict = {}
for key in sig.parameters.keys():
@ -642,7 +634,7 @@ class ORTTrainer(object):
# TODO: thiagofc: Checkpoint related for redesign
if self._state_dict:
self.load_state_dict(self._state_dict, self._load_state_dict_strict)
checkpoint.load_state_dict(self, self._state_dict, self._load_state_dict_strict)
self._state_dict = None
def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs):

View file

@ -43,7 +43,7 @@ def generate_random_input_from_model_desc(desc, seed=1, device = "cuda:0"):
size.append(s)
else:
size.append(dims[s] if s in dims else 1)
sample_input.append(torch.randint(0, num_classes[index], tuple(size), dtype=torch.int64).to(device))
sample_input.append(torch.randint(0, num_classes[index], tuple(size), dtype=dtype).to(device))
return sample_input
# EXPERIMENTAL HELPER FUNCTIONS
@ -84,7 +84,7 @@ def optimizer_parameters(model):
def load_bert_onnx_model():
bert_onnx_model_path = os.path.join('..', '..', '..', 'onnxruntime', 'test', 'testdata', "bert_toy_postprocessed.onnx")
bert_onnx_model_path = os.path.join('testdata', "bert_toy_postprocessed.onnx")
model = onnx.load(bert_onnx_model_path)
return model
@ -564,7 +564,7 @@ def testORTTrainerFrozenWeights(model_params):
def testToyBERTSaveAsONNX():
device = 'cuda'
onnx_file_name = os.path.join('..','..','..','temp_toy_bert_onnx_model.onnx')
onnx_file_name = '_____temp_toy_bert_onnx_model.onnx'
if os.path.exists(onnx_file_name):
os.remove(onnx_file_name)
assert not os.path.exists(onnx_file_name)
@ -583,7 +583,7 @@ def testToyBERTSaveAsONNX():
},
})
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config)#, options=opts)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
trainer.save_as_onnx(onnx_file_name)
assert os.path.exists(onnx_file_name)

View file

@ -27,7 +27,7 @@ import _test_commons,_test_helpers
def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False):
# Loads external Pytorch TransformerModel into utils
pytorch_transformer_path = os.path.join('..', '..', '..', 'samples', 'python', 'pytorch_transformer')
pytorch_transformer_path = os.path.join('samples', 'python', 'pytorch_transformer')
pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py')
pt_model = _utils.import_module_from_file(pt_model_path)
ort_utils_path = os.path.join(pytorch_transformer_path, 'ort_utils.py')
@ -622,7 +622,7 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device)
trainer._onnx_model.graph.output[i].type.tensor_type.elem_type)
# Save current model as ONNX as a file
file_name = os.path.join('..','..','..','temp_onnx_model.onnx')
file_name = os.path.join('_____temp_onnx_model.onnx')
trainer.save_as_onnx(file_name)
assert os.path.exists(file_name)
with open(file_name, "rb") as f:
@ -660,7 +660,7 @@ def testORTDeterministicCompute(seed, device):
# Setup for the first ORTTRainer run
torch.manual_seed(seed)
set_seed(seed)
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
data, targets = batcher_fn(train_data, 0)
_ = first_trainer.train_step(data, targets)
@ -687,7 +687,6 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches)
total_steps = len(expected_loss)
torch.manual_seed(seed)
set_seed(seed)
bptt=35
# Setup ORTTrainer
loss_scaler = amp.DynamicLossScaler()
@ -717,7 +716,7 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches)
trainer._train_step_info.fetches=['loss']
loss = trainer.eval_step(val_data, val_targets)
trainer._train_step_info.fetches=[]
loss, preds = trainer.eval_step(val_data, val_targets)
loss, _ = trainer.eval_step(val_data, val_targets)
# Compare loss to ground truth computed from current ORTTrainer API
_test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=1e-4)
@ -742,7 +741,7 @@ def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps},
'debug' : {'deterministic_compute' : True}})
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
@ -768,7 +767,7 @@ def testORTTrainerDynamicShape(dynamic_axes):
# Setup ORTTrainer
options = orttrainer.ORTTrainerOptions({})
model, model_desc, my_loss, batcher_fn,\
train_data, val_data, _ = _load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes)
train_data, _, _ = _load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes)
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
@ -796,7 +795,7 @@ def testORTTrainerFrozenWeights(model_params):
# Setup ORTTrainer WITHOUT frozen weights
options = orttrainer.ORTTrainerOptions({})
model, model_desc, my_loss, batcher_fn,\
train_data, val_data, _ = _load_pytorch_transformer_model(device)
train_data, _, _ = _load_pytorch_transformer_model(device)
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
for i in range(total_steps):
@ -834,7 +833,6 @@ def testORTTrainerFrozenWeights(model_params):
def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
# Common data
total_steps = 5
bptt = 35
# Setup for the experimental ORTTRainer run
torch.manual_seed(seed)
@ -848,7 +846,7 @@ def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
'deterministic_compute': True
},
})
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
# Training loop
for i in range(total_steps):
@ -876,7 +874,6 @@ def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device):
def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device):
# Common data
total_steps = 5
bptt=35
# Setup experimental API
torch.manual_seed(seed)
@ -887,7 +884,7 @@ def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device):
'enabled' : True,
'loss_scaler' : loss_scaler},
'debug' : {'deterministic_compute' : True,}})
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
# Training loop
@ -940,14 +937,14 @@ def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradie
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps},
'debug' : {'deterministic_compute' : True}})
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
# Training loop
experimental_loss = []
for i in range(total_steps):
data, targets = batcher_fn(train_data, i)
exp_loss, exp_preds = trainer.train_step(data, targets)
exp_loss, _ = trainer.train_step(data, targets)
experimental_loss.append(exp_loss.cpu())
# Setup legacy API
@ -962,7 +959,7 @@ def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradie
legacy_loss = []
for i in range(total_steps):
data, targets = batcher_fn(train_data, i)
leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
leg_loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
legacy_loss.append(leg_loss.cpu())
# Compare legacy vs experimental APIs