[Training/Python] Add option to enable symbolic shape inference (#5107)

This change adds symbolic shape inference to ORT training which helps static memory planning for model like BART.
This commit is contained in:
KeDengMS 2020-09-22 10:49:07 -07:00 committed by GitHub
parent 14f250a4d0
commit 8dceebda0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 140 additions and 74 deletions

View file

@ -297,7 +297,7 @@
"metadata": {},
"outputs": [],
"source": [
"SymbolicShapeInference.infer_shapes(input_model=lstm_model, output_model=lstm_model)"
"onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(lstm_model)), lstm_model)"
]
},
{
@ -559,7 +559,7 @@
"bert_model_with_shape_inference = os.path.join(bert_model_dir, 'bertsquad10_shaped.onnx')\n",
"\n",
"# run symbolic shape inference\n",
"SymbolicShapeInference.infer_shapes(bert_model, bert_model_with_shape_inference, auto_merge=True, int_max=100000)"
"onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(bert_model), auto_merge=True, int_max=100000), bert_model_with_shape_inference)"
]
},
{
@ -692,7 +692,7 @@
"gpt2_model_with_shape_inference = os.path.join(gpt2_model_dir, 'model_shaped.onnx')\n",
"\n",
"# run symbolic shape inference\n",
"SymbolicShapeInference.infer_shapes(gpt2_model, gpt2_model_with_shape_inference, auto_merge=True)"
"onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(gpt2_model), auto_merge=True), gpt2_model_with_shape_inference)"
]
},
{
@ -892,7 +892,7 @@
"source": [
"# editing\n",
"bidaf_converted = 'bidaf_mod.onnx'\n",
"SymbolicShapeInference.infer_shapes(bidaf, bidaf_converted)\n",
"onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(bidaf)), bidaf_converted)\n",
"convert_to_scan_model(bidaf_converted, bidaf_converted)\n",
"# When quantizing, there's an only_for_scan option to quantize only the GEMV inside Scan ops.\n",
"# This is useful when the input dims of LSTM being much bigger than hidden dims.\n",

View file

@ -831,5 +831,7 @@ if __name__ == '__main__':
else:
raise NotImplementedError('Unknown mode')
print('Running symbolic shape inference on output model')
SymbolicShapeInference.infer_shapes(args.output, args.output, auto_merge=True)
mp = onnx.load(args.output)
mp = SymbolicShapeInference.infer_shapes(mp, auto_merge=True)
onnx.save(mp, args.output)
print('Done!')

View file

@ -9,7 +9,6 @@ import numpy as np
import onnx
from onnx import helper, numpy_helper
from .node_factory import NodeFactory, ensure_opset
from ..tools.symbolic_shape_infer import SymbolicShapeInference
class QuantizeConfig:
def __init__(self, signed, reserved_bits, type_bits):

View file

@ -120,7 +120,7 @@ def perf_test(rnn_type, num_threads, input_dim, hidden_dim, bidirectional, layer
scan_model_name = os.path.splitext(model_name)[0] + '_scan.onnx'
convert_to_scan_model(model_name, scan_model_name)
# note that symbolic shape inference is needed because model has symbolic batch dim, thus init_state is ConstantOfShape
SymbolicShapeInference.infer_shapes(scan_model_name, scan_model_name)
onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(scan_model_name)), scan_model_name)
sess = onnxruntime.InferenceSession(scan_model_name)
count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds)
avg_scan = top_n_avg(per_iter_cost, top_n)
@ -130,7 +130,7 @@ def perf_test(rnn_type, num_threads, input_dim, hidden_dim, bidirectional, layer
from .model_quantizer import convert_matmul_model
int8_model_name = os.path.splitext(model_name)[0] + '_int8.onnx'
convert_matmul_model(scan_model_name, int8_model_name)
SymbolicShapeInference.infer_shapes(int8_model_name, int8_model_name)
onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(int8_model_name)), int8_model_name)
sess = onnxruntime.InferenceSession(int8_model_name)
count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds)
avg_int8 = top_n_avg(per_iter_cost, top_n)

View file

@ -81,6 +81,7 @@ class SymbolicShapeInference:
'CategoryMapper' : self._infer_CategoryMapper,
'Compress' : self._infer_Compress,
'Concat' : self._infer_Concat,
'Constant' : self._infer_Constant,
'ConstantOfShape' : self._infer_ConstantOfShape,
'Conv' : self._infer_Conv,
'CumSum' : self._pass_on_shape_and_type,
@ -91,6 +92,7 @@ class SymbolicShapeInference:
'Gather' : self._infer_Gather,
'GatherElements' : self._infer_GatherElements,
'GatherND' : self._infer_GatherND,
'Gelu' : self._pass_on_shape_and_type,
'If' : self._infer_If,
'Loop' : self._infer_Loop,
'MatMul' : self._infer_MatMul,
@ -113,6 +115,7 @@ class SymbolicShapeInference:
'Shape' : self._infer_Shape,
'Size' : self._infer_Size,
'Slice' : self._infer_Slice,
'SoftmaxCrossEntropyLoss':self._infer_SoftmaxCrossEntropyLoss,
'Split' : self._infer_Split,
'SplitToSequence' : self._infer_SplitToSequence,
'Squeeze' : self._infer_Squeeze,
@ -189,43 +192,8 @@ class SymbolicShapeInference:
d.dim_param = v
def _preprocess(self, in_mp):
out_mp = onnx.ModelProto()
out_mp.CopyFrom(in_mp)
out_mp.graph.ClearField('node')
self.out_mp_ = out_mp
defined = set([i.name for i in list(in_mp.graph.input) + list(in_mp.graph.initializer)])
pending_nodes = []
# returns True if no more ready nodes
def _insert_ready_nodes():
ready_nodes = [pn for pn in pending_nodes if all([i in defined for i in pn.input if i])]
for rn in ready_nodes:
self.out_mp_.graph.node.add().CopyFrom(rn)
for o in rn.output:
defined.add(o)
pending_nodes.remove(rn)
return not ready_nodes
# constant op -> initializer, topological sort
for in_n in in_mp.graph.node:
if in_n.op_type == 'Constant':
t = get_attribute(in_n, 'value')
t.name = in_n.output[0]
self.out_mp_.graph.initializer.add().CopyFrom(t)
defined.add(t.name)
else:
pending_nodes.append(in_n)
_insert_ready_nodes()
while pending_nodes:
if _insert_ready_nodes():
break
if pending_nodes and self.verbose_ > 0:
print('SymbolicShapeInference: orphaned nodes discarded: ')
print(*[n.op_type + ': ' + n.output[0] for n in pending_nodes], sep='\n')
self.out_mp_ = onnx.ModelProto()
self.out_mp_.CopyFrom(in_mp)
self.initializers_ = dict([(i.name, i) for i in self.out_mp_.graph.initializer])
self.known_vi_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)])
self.known_vi_.update(dict([(i.name, helper.make_tensor_value_info(i.name, i.data_type, list(i.dims))) for i in self.out_mp_.graph.initializer]))
@ -370,8 +338,6 @@ class SymbolicShapeInference:
symbolic_shape_inference = SymbolicShapeInference(self.int_max_, self.auto_merge_, self.guess_output_rank_, self.verbose_)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(self.tmp_mp_)
# note that after _preprocess, Constant node will be converted to initializer and should be appended to subgraph.initializer
subgraph.initializer.extend([i for i in symbolic_shape_inference.out_mp_.graph.initializer if i.name not in subgraph_implicit_input and i.name not in subgraph_inputs])
symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl(self.tmp_mp_, self.sympy_data_.copy())
@ -638,11 +604,9 @@ class SymbolicShapeInference:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape)))
def _infer_Conv(self, node):
sympy_shape = self._compute_conv_pool_shape(node)
self._update_computed_dims(sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape)))
def _infer_Constant(self, node):
t = get_attribute(node, 'value')
self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
def _infer_ConstantOfShape(self, node):
sympy_shape = self._get_int_values(node)[0]
@ -662,6 +626,12 @@ class SymbolicShapeInference:
vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape)))
def _infer_Conv(self, node):
sympy_shape = self._compute_conv_pool_shape(node)
self._update_computed_dims(sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape)))
def _infer_Expand(self, node):
expand_to_shape = self._try_get_value(node, 1)
if expand_to_shape is not None:
@ -680,8 +650,8 @@ class SymbolicShapeInference:
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
vi.type.tensor_type.elem_type,
data_shape[:axis] + indices_shape + data_shape[axis+1:]))
if node.input[0] in self.sympy_data_:
assert 0 == get_attribute(node, 'axis', 0) # only handle 1D sympy compute
# for 1D input, do some sympy compute
if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and 0 == get_attribute(node, 'axis', 0):
idx = self._get_value(node, 1)
data = self.sympy_data_[node.input[0]]
if type(data) == list:
@ -1037,11 +1007,20 @@ class SymbolicShapeInference:
get_shape_from_sympy_shape(new_sympy_shape)))
# handle sympy_data if needed, for slice in shape computation
if node.input[0] in self.sympy_data_:
assert [0] == axes
assert len(starts) == 1
assert len(ends) == 1
self.sympy_data_[node.output[0]] = self.sympy_data_[node.input[0]][starts[0]:ends[0]]
if node.input[0] in self.sympy_data_ and [0] == axes and len(starts) == 1 and len(ends) == 1:
input_sympy_data = self.sympy_data_[node.input[0]]
if type(input_sympy_data) == list or (type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1):
self.sympy_data_[node.output[0]] = input_sympy_data[starts[0]:ends[0]]
def _infer_SoftmaxCrossEntropyLoss(self, node):
vi = self.known_vi_[node.output[0]]
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi.type.tensor_type.elem_type = elem_type
if len(node.output) > 1:
data_shape = self._get_shape(node, 0)
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape))
def _infer_Split_Common(self, node, make_value_info_func):
input_sympy_shape = self._get_sympy_shape(node, 0)
@ -1276,22 +1255,20 @@ class SymbolicShapeInference:
output.CopyFrom(self.known_vi_[output.name])
@staticmethod
def infer_shapes(input_model, output_model, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
in_mp = onnx.load(input_model)
def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
onnx_opset = get_opset(in_mp)
if not onnx_opset or onnx_opset < 7:
print('Only support models of onnx opset 7 and above.')
return
return None
symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(in_mp)
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl(in_mp)
symbolic_shape_inference._update_output_from_vi()
if output_model:
onnx.save(symbolic_shape_inference.out_mp_, output_model)
if not all_shapes_inferred:
sys.exit(1)
raise Exception("Incomplete symbolic shape inference")
return symbolic_shape_inference.out_mp_
def parse_arguments():
parser = argparse.ArgumentParser()
@ -1309,5 +1286,7 @@ if __name__ == '__main__':
if args.output:
print('output model ' + args.output)
print('Doing symbolic shape inference...')
out_mp = SymbolicShapeInference.infer_shapes(args.input, args.output, args.int_max, args.auto_merge, args.guess_output_rank, args.verbose)
print('Done!')
out_mp = SymbolicShapeInference.infer_shapes(onnx.load(args.input), args.int_max, args.auto_merge, args.guess_output_rank, args.verbose)
if args.output and out_mp:
onnx.save(out_mp, args.output)
print('Done!')

View file

@ -2,11 +2,12 @@
# Licensed under the MIT License.
# -*- coding: UTF-8 -*-
import unittest
import onnx
import os
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
import sys
from pathlib import Path
import sys
import unittest
class TestSymbolicShapeInference(unittest.TestCase):
def test_symbolic_shape_infer(self):
@ -17,8 +18,7 @@ class TestSymbolicShapeInference(unittest.TestCase):
continue # skip some bad model files
print("Running symbolic shape inference on : " + str(filename))
SymbolicShapeInference.infer_shapes(
input_model=str(filename),
output_model=None,
in_mp=onnx.load(str(filename)),
auto_merge=True,
int_max=100000,
guess_output_rank=True)

View file

@ -16,6 +16,8 @@ import warnings
from .checkpointing_utils import list_checkpoint_files, get_checkpoint_name, CombineZeroCheckpoint
import onnxruntime.capi.pt_patch
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
DEFAULT_OPSET_VERSION = 12
class IODescription():
@ -320,8 +322,14 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
import copy
# Deepcopy inputs, since input values may change after model run.
sample_inputs_copy = copy.deepcopy(sample_inputs)
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(model)
try:
# Deepcopy model, in case model is stateful and changes after model run.
model_copy = copy.deepcopy(model)
except Exception:
model_copy = model
warnings.warn("This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX."
" Compute will continue, but unexpected results may occur!")
sample_outputs = model_copy(*sample_inputs_copy)
if isinstance(sample_outputs, torch.Tensor):
sample_outputs = [sample_outputs]
@ -539,7 +547,7 @@ class ORTTrainer():
global_step=0, get_lr_this_step=None, loss_scaler=None, deepspeed_zero_stage=0,
enable_grad_norm_clip=True, frozen_weights=[], _opset_version=DEFAULT_OPSET_VERSION,
_enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False,
use_invertible_layernorm_grad=False):
use_invertible_layernorm_grad=False, run_symbolic_shape_infer=False):
super(ORTTrainer, self).__init__()
"""
Initialize ORTTrainer.
@ -607,6 +615,8 @@ class ORTTrainer():
Defaults to None
use_invertible_layernorm_grad: use invertible layernorm grad
Defaults to False
run_symbolic_shape_infer: run symbolic shape inference
Defaults to False
"""
warnings.warn('DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it')
self.is_train = True
@ -669,6 +679,7 @@ class ORTTrainer():
self.state_dict_ = None
self._use_deterministic_compute = _use_deterministic_compute
self.use_invertible_layernorm_grad = use_invertible_layernorm_grad
self.run_symbolic_shape_infer = run_symbolic_shape_infer
# use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs.
# see prepare_input_and_fetches for more details.
@ -681,6 +692,10 @@ class ORTTrainer():
return
self._verify_fully_optimized_model(self.onnx_model_)
if self.run_symbolic_shape_infer:
self.onnx_model_ = SymbolicShapeInference.infer_shapes(self.onnx_model_, auto_merge=True, guess_output_rank=True)
self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \
create_ort_training_session_with_optimizer(
self.onnx_model_, self.device_,

View file

@ -10,6 +10,8 @@ import onnxruntime as ort
from . import _utils, amp, checkpoint, optim, postprocess, ORTTrainerOptions
from .model_desc_validation import _ORTTrainerModelDesc
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
class TrainStepInfo(object):
r"""Private class used to store runtime information from current train step.
@ -671,6 +673,9 @@ class ORTTrainer(object):
if self._onnx_model is None:
return
if self.options.utils.run_symbolic_shape_infer:
self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True)
# Create training session used by train_step
self._create_ort_training_session()

View file

@ -132,6 +132,10 @@ class ORTTrainerOptions(object):
'invertible_layer_norm_gradient' : {
'type' : 'boolean',
'default' : False
},
'run_symbolic_shape_infer' : {
'type' : 'boolean',
'default' : False
}
}
},
@ -225,6 +229,8 @@ class ORTTrainerOptions(object):
enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer'
utils.invertible_layer_norm_gradient (bool, default is False):
enables use of invertible layer norm gradients
utils.run_symbolic_shape_infer (bool, default is False):
runs symbolic shape inference on the model
debug (dict):
debug options
debug.deterministic_compute (bool, default is False)
@ -445,6 +451,10 @@ _ORTTRAINER_OPTIONS_SCHEMA = {
'invertible_layer_norm_gradient' : {
'type': 'boolean',
'default': False
},
'run_symbolic_shape_infer' : {
'type': 'boolean',
'default': False
}
}
},

View file

@ -98,6 +98,7 @@ def testORTTrainerOptionsDefaultValues(test_input):
'frozen_weights': [],
'grad_norm_clip': True,
'invertible_layer_norm_gradient': False,
'run_symbolic_shape_infer': False
},
'debug': {
'deterministic_compute': False,
@ -1290,3 +1291,58 @@ def testLossScalerLegacyAndExperimentalRandomAllFinite():
assert_allclose(new_loss_scale, old_loss_scale)
out.append(new_loss_scale)
assert new_loss_scale > 1e-7
def testORTTrainerRunSymbolicShapeInfer():
# Common data
seed = 0
total_steps = 12
device = 'cuda'
torch.set_printoptions(precision=10)
# Setup without symbolic shape inference
torch.manual_seed(seed)
set_seed(seed)
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
'debug' : {'deterministic_compute' : True}})
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
# Training loop
expected_loss = []
for i in range(total_steps):
data, targets = batcher_fn(train_data, i)
loss, _ = trainer.train_step(data, targets)
expected_loss.append(loss.cpu())
# Setup with symbolic shape inference
torch.manual_seed(seed)
set_seed(seed)
model, model_desc, my_loss, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
optim_config = optim.LambConfig(lr=0.001)
options.utils.run_symbolic_shape_infer = True
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
# Training loop
new_loss = []
for i in range(total_steps):
data, targets = batcher_fn(train_data, i)
loss, _ = trainer.train_step(data, targets)
new_loss.append(loss.cpu())
# Setup with symbolic shape inference in legacy API
torch.manual_seed(seed)
set_seed(seed)
model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer",
None, lr_desc, device=device,
run_symbolic_shape_infer=True,
_use_deterministic_compute=True)
# Training loop
legacy_loss = []
for i in range(total_steps):
data, targets = batcher_fn(train_data, i)
loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr]))
legacy_loss.append(loss.cpu())
# Compare losses
_test_helpers.assert_model_outputs(new_loss, expected_loss)
_test_helpers.assert_model_outputs(legacy_loss, expected_loss)