mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[Training/Python] Add option to enable symbolic shape inference (#5107)
This change adds symbolic shape inference to ORT training which helps static memory planning for model like BART.
This commit is contained in:
parent
14f250a4d0
commit
8dceebda0e
10 changed files with 140 additions and 74 deletions
|
|
@ -297,7 +297,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SymbolicShapeInference.infer_shapes(input_model=lstm_model, output_model=lstm_model)"
|
||||
"onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(lstm_model)), lstm_model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -559,7 +559,7 @@
|
|||
"bert_model_with_shape_inference = os.path.join(bert_model_dir, 'bertsquad10_shaped.onnx')\n",
|
||||
"\n",
|
||||
"# run symbolic shape inference\n",
|
||||
"SymbolicShapeInference.infer_shapes(bert_model, bert_model_with_shape_inference, auto_merge=True, int_max=100000)"
|
||||
"onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(bert_model), auto_merge=True, int_max=100000), bert_model_with_shape_inference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -692,7 +692,7 @@
|
|||
"gpt2_model_with_shape_inference = os.path.join(gpt2_model_dir, 'model_shaped.onnx')\n",
|
||||
"\n",
|
||||
"# run symbolic shape inference\n",
|
||||
"SymbolicShapeInference.infer_shapes(gpt2_model, gpt2_model_with_shape_inference, auto_merge=True)"
|
||||
"onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(gpt2_model), auto_merge=True), gpt2_model_with_shape_inference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -892,7 +892,7 @@
|
|||
"source": [
|
||||
"# editing\n",
|
||||
"bidaf_converted = 'bidaf_mod.onnx'\n",
|
||||
"SymbolicShapeInference.infer_shapes(bidaf, bidaf_converted)\n",
|
||||
"onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(bidaf)), bidaf_converted)\n",
|
||||
"convert_to_scan_model(bidaf_converted, bidaf_converted)\n",
|
||||
"# When quantizing, there's an only_for_scan option to quantize only the GEMV inside Scan ops.\n",
|
||||
"# This is useful when the input dims of LSTM being much bigger than hidden dims.\n",
|
||||
|
|
|
|||
|
|
@ -831,5 +831,7 @@ if __name__ == '__main__':
|
|||
else:
|
||||
raise NotImplementedError('Unknown mode')
|
||||
print('Running symbolic shape inference on output model')
|
||||
SymbolicShapeInference.infer_shapes(args.output, args.output, auto_merge=True)
|
||||
mp = onnx.load(args.output)
|
||||
mp = SymbolicShapeInference.infer_shapes(mp, auto_merge=True)
|
||||
onnx.save(mp, args.output)
|
||||
print('Done!')
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import numpy as np
|
|||
import onnx
|
||||
from onnx import helper, numpy_helper
|
||||
from .node_factory import NodeFactory, ensure_opset
|
||||
from ..tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
|
||||
class QuantizeConfig:
|
||||
def __init__(self, signed, reserved_bits, type_bits):
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ def perf_test(rnn_type, num_threads, input_dim, hidden_dim, bidirectional, layer
|
|||
scan_model_name = os.path.splitext(model_name)[0] + '_scan.onnx'
|
||||
convert_to_scan_model(model_name, scan_model_name)
|
||||
# note that symbolic shape inference is needed because model has symbolic batch dim, thus init_state is ConstantOfShape
|
||||
SymbolicShapeInference.infer_shapes(scan_model_name, scan_model_name)
|
||||
onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(scan_model_name)), scan_model_name)
|
||||
sess = onnxruntime.InferenceSession(scan_model_name)
|
||||
count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds)
|
||||
avg_scan = top_n_avg(per_iter_cost, top_n)
|
||||
|
|
@ -130,7 +130,7 @@ def perf_test(rnn_type, num_threads, input_dim, hidden_dim, bidirectional, layer
|
|||
from .model_quantizer import convert_matmul_model
|
||||
int8_model_name = os.path.splitext(model_name)[0] + '_int8.onnx'
|
||||
convert_matmul_model(scan_model_name, int8_model_name)
|
||||
SymbolicShapeInference.infer_shapes(int8_model_name, int8_model_name)
|
||||
onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(int8_model_name)), int8_model_name)
|
||||
sess = onnxruntime.InferenceSession(int8_model_name)
|
||||
count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds)
|
||||
avg_int8 = top_n_avg(per_iter_cost, top_n)
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ class SymbolicShapeInference:
|
|||
'CategoryMapper' : self._infer_CategoryMapper,
|
||||
'Compress' : self._infer_Compress,
|
||||
'Concat' : self._infer_Concat,
|
||||
'Constant' : self._infer_Constant,
|
||||
'ConstantOfShape' : self._infer_ConstantOfShape,
|
||||
'Conv' : self._infer_Conv,
|
||||
'CumSum' : self._pass_on_shape_and_type,
|
||||
|
|
@ -91,6 +92,7 @@ class SymbolicShapeInference:
|
|||
'Gather' : self._infer_Gather,
|
||||
'GatherElements' : self._infer_GatherElements,
|
||||
'GatherND' : self._infer_GatherND,
|
||||
'Gelu' : self._pass_on_shape_and_type,
|
||||
'If' : self._infer_If,
|
||||
'Loop' : self._infer_Loop,
|
||||
'MatMul' : self._infer_MatMul,
|
||||
|
|
@ -113,6 +115,7 @@ class SymbolicShapeInference:
|
|||
'Shape' : self._infer_Shape,
|
||||
'Size' : self._infer_Size,
|
||||
'Slice' : self._infer_Slice,
|
||||
'SoftmaxCrossEntropyLoss':self._infer_SoftmaxCrossEntropyLoss,
|
||||
'Split' : self._infer_Split,
|
||||
'SplitToSequence' : self._infer_SplitToSequence,
|
||||
'Squeeze' : self._infer_Squeeze,
|
||||
|
|
@ -189,43 +192,8 @@ class SymbolicShapeInference:
|
|||
d.dim_param = v
|
||||
|
||||
def _preprocess(self, in_mp):
|
||||
out_mp = onnx.ModelProto()
|
||||
out_mp.CopyFrom(in_mp)
|
||||
out_mp.graph.ClearField('node')
|
||||
self.out_mp_ = out_mp
|
||||
|
||||
defined = set([i.name for i in list(in_mp.graph.input) + list(in_mp.graph.initializer)])
|
||||
pending_nodes = []
|
||||
|
||||
# returns True if no more ready nodes
|
||||
def _insert_ready_nodes():
|
||||
ready_nodes = [pn for pn in pending_nodes if all([i in defined for i in pn.input if i])]
|
||||
for rn in ready_nodes:
|
||||
self.out_mp_.graph.node.add().CopyFrom(rn)
|
||||
for o in rn.output:
|
||||
defined.add(o)
|
||||
pending_nodes.remove(rn)
|
||||
return not ready_nodes
|
||||
|
||||
# constant op -> initializer, topological sort
|
||||
for in_n in in_mp.graph.node:
|
||||
if in_n.op_type == 'Constant':
|
||||
t = get_attribute(in_n, 'value')
|
||||
t.name = in_n.output[0]
|
||||
self.out_mp_.graph.initializer.add().CopyFrom(t)
|
||||
defined.add(t.name)
|
||||
else:
|
||||
pending_nodes.append(in_n)
|
||||
_insert_ready_nodes()
|
||||
|
||||
while pending_nodes:
|
||||
if _insert_ready_nodes():
|
||||
break
|
||||
|
||||
if pending_nodes and self.verbose_ > 0:
|
||||
print('SymbolicShapeInference: orphaned nodes discarded: ')
|
||||
print(*[n.op_type + ': ' + n.output[0] for n in pending_nodes], sep='\n')
|
||||
|
||||
self.out_mp_ = onnx.ModelProto()
|
||||
self.out_mp_.CopyFrom(in_mp)
|
||||
self.initializers_ = dict([(i.name, i) for i in self.out_mp_.graph.initializer])
|
||||
self.known_vi_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)])
|
||||
self.known_vi_.update(dict([(i.name, helper.make_tensor_value_info(i.name, i.data_type, list(i.dims))) for i in self.out_mp_.graph.initializer]))
|
||||
|
|
@ -370,8 +338,6 @@ class SymbolicShapeInference:
|
|||
symbolic_shape_inference = SymbolicShapeInference(self.int_max_, self.auto_merge_, self.guess_output_rank_, self.verbose_)
|
||||
all_shapes_inferred = False
|
||||
symbolic_shape_inference._preprocess(self.tmp_mp_)
|
||||
# note that after _preprocess, Constant node will be converted to initializer and should be appended to subgraph.initializer
|
||||
subgraph.initializer.extend([i for i in symbolic_shape_inference.out_mp_.graph.initializer if i.name not in subgraph_implicit_input and i.name not in subgraph_inputs])
|
||||
symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
|
||||
while symbolic_shape_inference.run_:
|
||||
all_shapes_inferred = symbolic_shape_inference._infer_impl(self.tmp_mp_, self.sympy_data_.copy())
|
||||
|
|
@ -638,11 +604,9 @@ class SymbolicShapeInference:
|
|||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape)))
|
||||
|
||||
def _infer_Conv(self, node):
|
||||
sympy_shape = self._compute_conv_pool_shape(node)
|
||||
self._update_computed_dims(sympy_shape)
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape)))
|
||||
def _infer_Constant(self, node):
|
||||
t = get_attribute(node, 'value')
|
||||
self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
|
||||
|
||||
def _infer_ConstantOfShape(self, node):
|
||||
sympy_shape = self._get_int_values(node)[0]
|
||||
|
|
@ -662,6 +626,12 @@ class SymbolicShapeInference:
|
|||
vi.type.tensor_type.elem_type,
|
||||
get_shape_from_sympy_shape(sympy_shape)))
|
||||
|
||||
def _infer_Conv(self, node):
|
||||
sympy_shape = self._compute_conv_pool_shape(node)
|
||||
self._update_computed_dims(sympy_shape)
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape)))
|
||||
|
||||
def _infer_Expand(self, node):
|
||||
expand_to_shape = self._try_get_value(node, 1)
|
||||
if expand_to_shape is not None:
|
||||
|
|
@ -680,8 +650,8 @@ class SymbolicShapeInference:
|
|||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
|
||||
vi.type.tensor_type.elem_type,
|
||||
data_shape[:axis] + indices_shape + data_shape[axis+1:]))
|
||||
if node.input[0] in self.sympy_data_:
|
||||
assert 0 == get_attribute(node, 'axis', 0) # only handle 1D sympy compute
|
||||
# for 1D input, do some sympy compute
|
||||
if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and 0 == get_attribute(node, 'axis', 0):
|
||||
idx = self._get_value(node, 1)
|
||||
data = self.sympy_data_[node.input[0]]
|
||||
if type(data) == list:
|
||||
|
|
@ -1037,11 +1007,20 @@ class SymbolicShapeInference:
|
|||
get_shape_from_sympy_shape(new_sympy_shape)))
|
||||
|
||||
# handle sympy_data if needed, for slice in shape computation
|
||||
if node.input[0] in self.sympy_data_:
|
||||
assert [0] == axes
|
||||
assert len(starts) == 1
|
||||
assert len(ends) == 1
|
||||
self.sympy_data_[node.output[0]] = self.sympy_data_[node.input[0]][starts[0]:ends[0]]
|
||||
if node.input[0] in self.sympy_data_ and [0] == axes and len(starts) == 1 and len(ends) == 1:
|
||||
input_sympy_data = self.sympy_data_[node.input[0]]
|
||||
if type(input_sympy_data) == list or (type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1):
|
||||
self.sympy_data_[node.output[0]] = input_sympy_data[starts[0]:ends[0]]
|
||||
|
||||
def _infer_SoftmaxCrossEntropyLoss(self, node):
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
vi.type.tensor_type.elem_type = elem_type
|
||||
|
||||
if len(node.output) > 1:
|
||||
data_shape = self._get_shape(node, 0)
|
||||
vi = self.known_vi_[node.output[1]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape))
|
||||
|
||||
def _infer_Split_Common(self, node, make_value_info_func):
|
||||
input_sympy_shape = self._get_sympy_shape(node, 0)
|
||||
|
|
@ -1276,22 +1255,20 @@ class SymbolicShapeInference:
|
|||
output.CopyFrom(self.known_vi_[output.name])
|
||||
|
||||
@staticmethod
|
||||
def infer_shapes(input_model, output_model, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
|
||||
in_mp = onnx.load(input_model)
|
||||
def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
|
||||
onnx_opset = get_opset(in_mp)
|
||||
if not onnx_opset or onnx_opset < 7:
|
||||
print('Only support models of onnx opset 7 and above.')
|
||||
return
|
||||
return None
|
||||
symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose)
|
||||
all_shapes_inferred = False
|
||||
symbolic_shape_inference._preprocess(in_mp)
|
||||
while symbolic_shape_inference.run_:
|
||||
all_shapes_inferred = symbolic_shape_inference._infer_impl(in_mp)
|
||||
symbolic_shape_inference._update_output_from_vi()
|
||||
if output_model:
|
||||
onnx.save(symbolic_shape_inference.out_mp_, output_model)
|
||||
if not all_shapes_inferred:
|
||||
sys.exit(1)
|
||||
raise Exception("Incomplete symbolic shape inference")
|
||||
return symbolic_shape_inference.out_mp_
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
@ -1309,5 +1286,7 @@ if __name__ == '__main__':
|
|||
if args.output:
|
||||
print('output model ' + args.output)
|
||||
print('Doing symbolic shape inference...')
|
||||
out_mp = SymbolicShapeInference.infer_shapes(args.input, args.output, args.int_max, args.auto_merge, args.guess_output_rank, args.verbose)
|
||||
print('Done!')
|
||||
out_mp = SymbolicShapeInference.infer_shapes(onnx.load(args.input), args.int_max, args.auto_merge, args.guess_output_rank, args.verbose)
|
||||
if args.output and out_mp:
|
||||
onnx.save(out_mp, args.output)
|
||||
print('Done!')
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
# -*- coding: UTF-8 -*-
|
||||
import unittest
|
||||
import onnx
|
||||
import os
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
class TestSymbolicShapeInference(unittest.TestCase):
|
||||
def test_symbolic_shape_infer(self):
|
||||
|
|
@ -17,8 +18,7 @@ class TestSymbolicShapeInference(unittest.TestCase):
|
|||
continue # skip some bad model files
|
||||
print("Running symbolic shape inference on : " + str(filename))
|
||||
SymbolicShapeInference.infer_shapes(
|
||||
input_model=str(filename),
|
||||
output_model=None,
|
||||
in_mp=onnx.load(str(filename)),
|
||||
auto_merge=True,
|
||||
int_max=100000,
|
||||
guess_output_rank=True)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import warnings
|
|||
from .checkpointing_utils import list_checkpoint_files, get_checkpoint_name, CombineZeroCheckpoint
|
||||
import onnxruntime.capi.pt_patch
|
||||
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
|
||||
DEFAULT_OPSET_VERSION = 12
|
||||
|
||||
class IODescription():
|
||||
|
|
@ -320,8 +322,14 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
|
|||
import copy
|
||||
# Deepcopy inputs, since input values may change after model run.
|
||||
sample_inputs_copy = copy.deepcopy(sample_inputs)
|
||||
# Deepcopy model, in case model is stateful and changes after model run.
|
||||
model_copy = copy.deepcopy(model)
|
||||
try:
|
||||
# Deepcopy model, in case model is stateful and changes after model run.
|
||||
model_copy = copy.deepcopy(model)
|
||||
except Exception:
|
||||
model_copy = model
|
||||
warnings.warn("This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX."
|
||||
" Compute will continue, but unexpected results may occur!")
|
||||
|
||||
sample_outputs = model_copy(*sample_inputs_copy)
|
||||
if isinstance(sample_outputs, torch.Tensor):
|
||||
sample_outputs = [sample_outputs]
|
||||
|
|
@ -539,7 +547,7 @@ class ORTTrainer():
|
|||
global_step=0, get_lr_this_step=None, loss_scaler=None, deepspeed_zero_stage=0,
|
||||
enable_grad_norm_clip=True, frozen_weights=[], _opset_version=DEFAULT_OPSET_VERSION,
|
||||
_enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False,
|
||||
use_invertible_layernorm_grad=False):
|
||||
use_invertible_layernorm_grad=False, run_symbolic_shape_infer=False):
|
||||
super(ORTTrainer, self).__init__()
|
||||
"""
|
||||
Initialize ORTTrainer.
|
||||
|
|
@ -607,6 +615,8 @@ class ORTTrainer():
|
|||
Defaults to None
|
||||
use_invertible_layernorm_grad: use invertible layernorm grad
|
||||
Defaults to False
|
||||
run_symbolic_shape_infer: run symbolic shape inference
|
||||
Defaults to False
|
||||
"""
|
||||
warnings.warn('DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it')
|
||||
self.is_train = True
|
||||
|
|
@ -669,6 +679,7 @@ class ORTTrainer():
|
|||
self.state_dict_ = None
|
||||
self._use_deterministic_compute = _use_deterministic_compute
|
||||
self.use_invertible_layernorm_grad = use_invertible_layernorm_grad
|
||||
self.run_symbolic_shape_infer = run_symbolic_shape_infer
|
||||
|
||||
# use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs.
|
||||
# see prepare_input_and_fetches for more details.
|
||||
|
|
@ -681,6 +692,10 @@ class ORTTrainer():
|
|||
return
|
||||
|
||||
self._verify_fully_optimized_model(self.onnx_model_)
|
||||
|
||||
if self.run_symbolic_shape_infer:
|
||||
self.onnx_model_ = SymbolicShapeInference.infer_shapes(self.onnx_model_, auto_merge=True, guess_output_rank=True)
|
||||
|
||||
self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \
|
||||
create_ort_training_session_with_optimizer(
|
||||
self.onnx_model_, self.device_,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import onnxruntime as ort
|
|||
from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions
|
||||
from .model_desc_validation import _ORTTrainerModelDesc
|
||||
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
|
||||
class TrainStepInfo(object):
|
||||
r"""Private class used to store runtime information from current train step.
|
||||
|
||||
|
|
@ -671,6 +673,9 @@ class ORTTrainer(object):
|
|||
if self._onnx_model is None:
|
||||
return
|
||||
|
||||
if self.options.utils.run_symbolic_shape_infer:
|
||||
self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True)
|
||||
|
||||
# Create training session used by train_step
|
||||
self._create_ort_training_session()
|
||||
|
||||
|
|
|
|||
|
|
@ -132,6 +132,10 @@ class ORTTrainerOptions(object):
|
|||
'invertible_layer_norm_gradient' : {
|
||||
'type' : 'boolean',
|
||||
'default' : False
|
||||
},
|
||||
'run_symbolic_shape_infer' : {
|
||||
'type' : 'boolean',
|
||||
'default' : False
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
@ -225,6 +229,8 @@ class ORTTrainerOptions(object):
|
|||
enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer'
|
||||
utils.invertible_layer_norm_gradient (bool, default is False):
|
||||
enables use of invertible layer norm gradients
|
||||
utils.run_symbolic_shape_infer (bool, default is False):
|
||||
runs symbolic shape inference on the model
|
||||
debug (dict):
|
||||
debug options
|
||||
debug.deterministic_compute (bool, default is False)
|
||||
|
|
@ -445,6 +451,10 @@ _ORTTRAINER_OPTIONS_SCHEMA = {
|
|||
'invertible_layer_norm_gradient' : {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'run_symbolic_shape_infer' : {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ def testORTTrainerOptionsDefaultValues(test_input):
|
|||
'frozen_weights': [],
|
||||
'grad_norm_clip': True,
|
||||
'invertible_layer_norm_gradient': False,
|
||||
'run_symbolic_shape_infer': False
|
||||
},
|
||||
'debug': {
|
||||
'deterministic_compute': False,
|
||||
|
|
@ -1290,3 +1291,58 @@ def testLossScalerLegacyAndExperimentalRandomAllFinite():
|
|||
assert_allclose(new_loss_scale, old_loss_scale)
|
||||
out.append(new_loss_scale)
|
||||
assert new_loss_scale > 1e-7
|
||||
|
||||
def testORTTrainerRunSymbolicShapeInfer():
|
||||
# Common data
|
||||
seed = 0
|
||||
total_steps = 12
|
||||
device = 'cuda'
|
||||
torch.set_printoptions(precision=10)
|
||||
|
||||
# Setup without symbolic shape inference
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
|
||||
'debug' : {'deterministic_compute' : True}})
|
||||
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
optim_config = optim.LambConfig(lr=0.001)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
|
||||
# Training loop
|
||||
expected_loss = []
|
||||
for i in range(total_steps):
|
||||
data, targets = batcher_fn(train_data, i)
|
||||
loss, _ = trainer.train_step(data, targets)
|
||||
expected_loss.append(loss.cpu())
|
||||
|
||||
# Setup with symbolic shape inference
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
optim_config = optim.LambConfig(lr=0.001)
|
||||
options.utils.run_symbolic_shape_infer = True
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
|
||||
# Training loop
|
||||
new_loss = []
|
||||
for i in range(total_steps):
|
||||
data, targets = batcher_fn(train_data, i)
|
||||
loss, _ = trainer.train_step(data, targets)
|
||||
new_loss.append(loss.cpu())
|
||||
|
||||
# Setup with symbolic shape inference in legacy API
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
|
||||
legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer",
|
||||
None, lr_desc, device=device,
|
||||
run_symbolic_shape_infer=True,
|
||||
_use_deterministic_compute=True)
|
||||
# Training loop
|
||||
legacy_loss = []
|
||||
for i in range(total_steps):
|
||||
data, targets = batcher_fn(train_data, i)
|
||||
loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
|
||||
legacy_loss.append(loss.cpu())
|
||||
|
||||
# Compare losses
|
||||
_test_helpers.assert_model_outputs(new_loss, expected_loss)
|
||||
_test_helpers.assert_model_outputs(legacy_loss, expected_loss)
|
||||
Loading…
Reference in a new issue