mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Improve dynamic axes to work without data descriptors
This commit is contained in:
parent
7729bb3c8d
commit
f7f435fc27
3 changed files with 39 additions and 33 deletions
|
|
@ -18,7 +18,7 @@ from . import _utils
|
|||
|
||||
|
||||
ONNX_OPSET_VERSION = 12
|
||||
__TEMP_ENABLE_METHOD_TIMING__ = True
|
||||
__TEMP_ENABLE_METHOD_TIMING__ = False
|
||||
|
||||
# Needed to re-implement PyTorch's cpu,cuda,to methods
|
||||
T = TypeVar('T', bound='Module')
|
||||
|
|
@ -56,7 +56,7 @@ def _onnx_value_info_to_buffer_tensor(value_info, device):
|
|||
|
||||
class ORTModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, module, dynamic_axes=None):
|
||||
def __init__(self, module):
|
||||
assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module"
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
|
|
@ -66,10 +66,10 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# User module is wrapped to use its initializers and save computed gradients
|
||||
self._original_module = module
|
||||
self._dynamic_axes = dynamic_axes
|
||||
self._onnx_training = None
|
||||
|
||||
self._curr_inputs_size = None
|
||||
# Related to training graph split/shape inference
|
||||
self._current_input_shape = None
|
||||
self._module_gradient_graph_builder = None
|
||||
|
||||
# Forward pass
|
||||
|
|
@ -158,11 +158,12 @@ class ORTModule(torch.nn.Module):
|
|||
if not self._onnx_forward or self._require_export:
|
||||
self._require_export = False
|
||||
|
||||
self._onnx_training = ORTModule._get_forward_graph(self._original_module, self._dynamic_axes, *inputs, **kwargs)
|
||||
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
|
||||
self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs)
|
||||
# TODO: PyTorch exporter bug: changes the initializer order
|
||||
initializer_names = [p[0] for p in self._original_module.named_parameters()]
|
||||
|
||||
# Build full training graph and split in forward/backward
|
||||
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
grad_builder_config.initializer_names_to_train = initializer_names
|
||||
grad_builder_config.input_names_require_grad = []
|
||||
self._module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
|
||||
|
|
@ -171,18 +172,14 @@ class ORTModule(torch.nn.Module):
|
|||
if self._save_onnx:
|
||||
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
|
||||
|
||||
inputs_size = [list(input.size()) for input in inputs if input is not None]
|
||||
if self._curr_inputs_size is None or self._curr_inputs_size != inputs_size:
|
||||
self._curr_inputs_size = inputs_size
|
||||
self._module_gradient_graph_builder.build_and_split(self._curr_inputs_size)
|
||||
# Perform shape inference and re-split forward/backward graph for bacthes with different shapes
|
||||
new_input_shape = [list(input.size()) for input in inputs if input is not None]
|
||||
if self._current_input_shape is None or self._current_input_shape != new_input_shape:
|
||||
self._current_input_shape = new_input_shape
|
||||
self._module_gradient_graph_builder.build_and_split(self._current_input_shape)
|
||||
self._onnx_forward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_forward_model())
|
||||
self._onnx_backward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_backward_model())
|
||||
self._onnx_graphs_info = self._module_gradient_graph_builder.get_split_graphs_info()
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
|
||||
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
|
||||
|
||||
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString())
|
||||
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString())
|
||||
|
||||
|
|
@ -197,6 +194,10 @@ class ORTModule(torch.nn.Module):
|
|||
for output in self._onnx_backward.graph.output:
|
||||
self._backward_output_buffers[output.name] = _onnx_value_info_to_buffer_tensor(output, str(self._device))
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
|
||||
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
|
||||
|
||||
# Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
# gradient implementation for self.forward_graph.
|
||||
class _ORTModuleFunction(torch.autograd.Function):
|
||||
|
|
@ -275,7 +276,7 @@ class ORTModule(torch.nn.Module):
|
|||
TODO: How IO binding model inputs and outputs affects initializer copies?
|
||||
|
||||
ONNX Runtime forward requires an order list of:
|
||||
* User input: computed from ONNX forward graph, excluding initializers as input
|
||||
* User input: computed from forward InferenceSession
|
||||
* Initializers: computed from original PyTorch model parameters
|
||||
|
||||
This codes assumes the exported model's inputs and initializers
|
||||
|
|
@ -349,7 +350,7 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def _get_forward_graph(module, dynamic_axes, *inputs, **kwargs):
|
||||
def _get_forward_graph(module, *inputs, **kwargs):
|
||||
'''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input
|
||||
|
||||
TODO: How to support dynamic axes? Dimensions are determined by samples
|
||||
|
|
@ -364,7 +365,15 @@ class ORTModule(torch.nn.Module):
|
|||
# Ignore optional *inputs explicitly specified as None
|
||||
sig = signature(module.forward)
|
||||
all_input_names = sig.parameters.keys()
|
||||
input_names = [name for idx, name in enumerate(all_input_names) if inputs[idx] is not None]
|
||||
# input_names = [name for idx, name in enumerate(all_input_names) if inputs[idx] is not None]
|
||||
input_names = []
|
||||
dynamic_axes = {}
|
||||
for input_idx, name in enumerate(all_input_names):
|
||||
if inputs[input_idx] is not None:
|
||||
input_names.append(name)
|
||||
dynamic_axes[name] = {}
|
||||
for dim_idx in range(len(inputs[input_idx].shape)):
|
||||
dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'})
|
||||
|
||||
# TODO: Support contrib OPs support? user model has no hint
|
||||
# from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
# https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128
|
||||
|
||||
# Perform one full pass over the training set.
|
||||
print('\n======== Epoch {:} / {:} ========'.format(epoch + 1, args.epochs))
|
||||
print('\n======== Epoch {:} / {:} with batch size {:} ========'.format(epoch + 1, args.epochs, args.batch_size))
|
||||
|
||||
# Measure how long the training epoch takes.
|
||||
t0 = time.time()
|
||||
|
|
@ -140,7 +140,7 @@ def test(model, validation_dataloader, device, args):
|
|||
# ========================================
|
||||
# After the completion of each training epoch, measure our performance on
|
||||
# our validation set.
|
||||
print("\nRunning Validation...")
|
||||
print("\nRunning Validation with batch size {:} ...".format(args.test_batch_size))
|
||||
|
||||
# Put the model in evaluation mode--the dropout layers behave differently
|
||||
# during evaluation.
|
||||
|
|
@ -380,11 +380,7 @@ def main():
|
|||
)
|
||||
|
||||
if not args.pytorch_only:
|
||||
dynamic_axes = {'input_ids': {0: 'batch_size', 1: 'seq_len'},
|
||||
'attention_mask': {0: 'batch_size', 1: 'seq_len'},
|
||||
'labels': {0: 'batch_size'},
|
||||
'210': {0: 'batch'}}
|
||||
model = ORTModule(model, dynamic_axes)
|
||||
model = ORTModule(model)
|
||||
|
||||
# TODO: change it to False to stop saving ONNX models
|
||||
model._save_onnx = True
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class NeuralNet(torch.nn.Module):
|
|||
|
||||
|
||||
def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
|
||||
print('\n======== Epoch {:} / {:} ========'.format(epoch+1, args.epochs))
|
||||
print('\n======== Epoch {:} / {:} with batch size {:} ========'.format(epoch+1, args.epochs, args.batch_size))
|
||||
model.train()
|
||||
# Measure how long the training epoch takes.
|
||||
t0 = time.time()
|
||||
|
|
@ -96,8 +96,8 @@ def test(args, model, device, loss_fn, test_loader):
|
|||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
test_loss /= len(test_loader.dataset)
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
print('\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
args.test_batch_size, test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
# Report the final accuracy for this validation run.
|
||||
|
|
@ -119,10 +119,10 @@ def main():
|
|||
help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
|
||||
help='input batch size for training (default: 20)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=20, metavar='N',
|
||||
help='input batch size for testing (default: 20)')
|
||||
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for testing (default: 64)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
||||
|
|
@ -157,6 +157,7 @@ def main():
|
|||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True)
|
||||
test_loader = None
|
||||
if args.test_batch_size > 0:
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('./data', train=False, transform=transforms.Compose([
|
||||
|
|
|
|||
Loading…
Reference in a new issue