diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 944d309df2..a0a59ce42f 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -18,7 +18,7 @@ from . import _utils ONNX_OPSET_VERSION = 12 -__TEMP_ENABLE_METHOD_TIMING__ = True +__TEMP_ENABLE_METHOD_TIMING__ = False # Needed to re-implement PyTorch's cpu,cuda,to methods T = TypeVar('T', bound='Module') @@ -56,7 +56,7 @@ def _onnx_value_info_to_buffer_tensor(value_info, device): class ORTModule(torch.nn.Module): - def __init__(self, module, dynamic_axes=None): + def __init__(self, module): assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module" super(ORTModule, self).__init__() @@ -66,10 +66,10 @@ class ORTModule(torch.nn.Module): # User module is wrapped to use its initializers and save computed gradients self._original_module = module - self._dynamic_axes = dynamic_axes self._onnx_training = None - self._curr_inputs_size = None + # Related to training graph split/shape inference + self._current_input_shape = None self._module_gradient_graph_builder = None # Forward pass @@ -158,11 +158,12 @@ class ORTModule(torch.nn.Module): if not self._onnx_forward or self._require_export: self._require_export = False - self._onnx_training = ORTModule._get_forward_graph(self._original_module, self._dynamic_axes, *inputs, **kwargs) - grad_builder_config = C.ModuleGradientGraphBuilderConfiguration() - + self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs) # TODO: PyTorch exporter bug: changes the initializer order initializer_names = [p[0] for p in self._original_module.named_parameters()] + + # Build full training graph and split in forward/backward + grad_builder_config = C.ModuleGradientGraphBuilderConfiguration() grad_builder_config.initializer_names_to_train = initializer_names grad_builder_config.input_names_require_grad = [] self._module_gradient_graph_builder = C.ModuleGradientGraphBuilder() @@ -171,18 +172,14 @@ class ORTModule(torch.nn.Module): if self._save_onnx: onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx') - inputs_size = [list(input.size()) for input in inputs if input is not None] - if self._curr_inputs_size is None or self._curr_inputs_size != inputs_size: - self._curr_inputs_size = inputs_size - self._module_gradient_graph_builder.build_and_split(self._curr_inputs_size) + # Perform shape inference and re-split forward/backward graph for bacthes with different shapes + new_input_shape = [list(input.size()) for input in inputs if input is not None] + if self._current_input_shape is None or self._current_input_shape != new_input_shape: + self._current_input_shape = new_input_shape + self._module_gradient_graph_builder.build_and_split(self._current_input_shape) self._onnx_forward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_forward_model()) self._onnx_backward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_backward_model()) self._onnx_graphs_info = self._module_gradient_graph_builder.get_split_graphs_info() - - if self._save_onnx: - onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx') - onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx') - self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString()) self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString()) @@ -197,6 +194,10 @@ class ORTModule(torch.nn.Module): for output in self._onnx_backward.graph.output: self._backward_output_buffers[output.name] = _onnx_value_info_to_buffer_tensor(output, str(self._device)) + if self._save_onnx: + onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx') + onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx') + # Use a custom torch.autograd.Function to associate self.backward_graph as the # gradient implementation for self.forward_graph. class _ORTModuleFunction(torch.autograd.Function): @@ -275,7 +276,7 @@ class ORTModule(torch.nn.Module): TODO: How IO binding model inputs and outputs affects initializer copies? ONNX Runtime forward requires an order list of: - * User input: computed from ONNX forward graph, excluding initializers as input + * User input: computed from forward InferenceSession * Initializers: computed from original PyTorch model parameters This codes assumes the exported model's inputs and initializers @@ -349,7 +350,7 @@ class ORTModule(torch.nn.Module): @staticmethod - def _get_forward_graph(module, dynamic_axes, *inputs, **kwargs): + def _get_forward_graph(module, *inputs, **kwargs): '''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input TODO: How to support dynamic axes? Dimensions are determined by samples @@ -364,7 +365,15 @@ class ORTModule(torch.nn.Module): # Ignore optional *inputs explicitly specified as None sig = signature(module.forward) all_input_names = sig.parameters.keys() - input_names = [name for idx, name in enumerate(all_input_names) if inputs[idx] is not None] + # input_names = [name for idx, name in enumerate(all_input_names) if inputs[idx] is not None] + input_names = [] + dynamic_axes = {} + for input_idx, name in enumerate(all_input_names): + if inputs[input_idx] is not None: + input_names.append(name) + dynamic_axes[name] = {} + for dim_idx in range(len(inputs[input_idx].shape)): + dynamic_axes[name].update({dim_idx : f'input{input_idx}_dim{dim_idx}'}) # TODO: Support contrib OPs support? user model has no hint # from onnxruntime.training import register_custom_ops_pytorch_exporter diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index 2be6206d23..003f0b604b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -28,7 +28,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): # https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128 # Perform one full pass over the training set. - print('\n======== Epoch {:} / {:} ========'.format(epoch + 1, args.epochs)) + print('\n======== Epoch {:} / {:} with batch size {:} ========'.format(epoch + 1, args.epochs, args.batch_size)) # Measure how long the training epoch takes. t0 = time.time() @@ -140,7 +140,7 @@ def test(model, validation_dataloader, device, args): # ======================================== # After the completion of each training epoch, measure our performance on # our validation set. - print("\nRunning Validation...") + print("\nRunning Validation with batch size {:} ...".format(args.test_batch_size)) # Put the model in evaluation mode--the dropout layers behave differently # during evaluation. @@ -380,11 +380,7 @@ def main(): ) if not args.pytorch_only: - dynamic_axes = {'input_ids': {0: 'batch_size', 1: 'seq_len'}, - 'attention_mask': {0: 'batch_size', 1: 'seq_len'}, - 'labels': {0: 'batch_size'}, - '210': {0: 'batch'}} - model = ORTModule(model, dynamic_axes) + model = ORTModule(model) # TODO: change it to False to stop saving ONNX models model._save_onnx = True diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py index 0a996b08a8..fdc74f6bf8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py @@ -24,7 +24,7 @@ class NeuralNet(torch.nn.Module): def train(args, model, device, optimizer, loss_fn, train_loader, epoch): - print('\n======== Epoch {:} / {:} ========'.format(epoch+1, args.epochs)) + print('\n======== Epoch {:} / {:} with batch size {:} ========'.format(epoch+1, args.epochs, args.batch_size)) model.train() # Measure how long the training epoch takes. t0 = time.time() @@ -96,8 +96,8 @@ def test(args, model, device, loss_fn, test_loader): pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), + print('\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + args.test_batch_size, test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) # Report the final accuracy for this validation run. @@ -119,10 +119,10 @@ def main(): help='number of steps to train. Set -1 to run through whole dataset (default: -1)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') - parser.add_argument('--batch-size', type=int, default=20, metavar='N', - help='input batch size for training (default: 20)') - parser.add_argument('--test-batch-size', type=int, default=20, metavar='N', - help='input batch size for testing (default: 20)') + parser.add_argument('--batch-size', type=int, default=32, metavar='N', + help='input batch size for training (default: 32)') + parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', + help='input batch size for testing (default: 64)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=42, metavar='S', @@ -157,6 +157,7 @@ def main(): transforms.Normalize((0.1307,), (0.3081,))])), batch_size=args.batch_size, shuffle=True) + test_loader = None if args.test_batch_size > 0: test_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=False, transform=transforms.Compose([