From f13c2a61d5efc3e5cf7cdf529b45cd574487a7c2 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Mon, 16 Nov 2020 16:30:00 -0800 Subject: [PATCH] Refactor IObinding --- .../orttraining/python/training/ortmodule.py | 245 ++++++++++-------- .../orttraining_test_ortmodule_iobinding.py | 161 ------------ ...py => orttraining_test_ortmodule_mnist.py} | 14 +- run_ortmodule_mvp_bert_finetuning.sh | 3 +- run_ortmodule_mvp_mnist.sh | 3 +- run_ortmodule_mvp_mnist_iobinding.sh | 18 -- 6 files changed, 141 insertions(+), 303 deletions(-) delete mode 100644 orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py rename orttraining/orttraining/test/python/{orttraining_test_ortmodule_basic.py => orttraining_test_ortmodule_mnist.py} (96%) delete mode 100644 run_ortmodule_mvp_mnist_iobinding.sh diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 7b29d3af11..2f8654e3a2 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -9,50 +9,59 @@ import warnings import numpy as np from inspect import signature +# Needed to re-implement PyTorch's cpu,cuda,to methods +from torch import Tensor, device, dtype +from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict + from onnxruntime.capi import _pybind_state as C from . import _utils ONNX_OPSET_VERSION = 12 +# Needed to re-implement PyTorch's cpu,cuda,to methods +T = TypeVar('T', bound='Module') -def get_device_index(device): - if type(device) == str: - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - device = torch.device(device) - return 0 if device.index is None else device.index +def _create_iobinding(io_binding, inputs, model, output_buffers, device): + '''Creates IO binding for a `model` inputs and output''' + def _get_device_index(device): + if type(device) == str: + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + device = torch.device(device) + return 0 if device.index is None else device.index - -def prepare_io_binding(io_binding, inputs, model, output_buffers, device): - idx = 0 - for value_info in model.graph.input: - io_binding.bind_input(value_info.name, inputs[idx].device.type, get_device_index(inputs[idx].device), - _utils.dtype_torch_to_numpy(inputs[idx].dtype), list(inputs[idx].size()), + for idx, value_info in enumerate(model.graph.input): + io_binding.bind_input(value_info.name, inputs[idx].device.type, + _get_device_index(inputs[idx].device), + _utils.dtype_torch_to_numpy(inputs[idx].dtype), + list(inputs[idx].size()), inputs[idx].data_ptr()) - idx += 1 for value_info in model.graph.output: name = value_info.name output_tensor = output_buffers[name] - io_binding.bind_output(name, output_tensor.device.type, get_device_index(device), - _utils.dtype_torch_to_numpy(output_tensor.dtype), list(output_tensor.size()), + io_binding.bind_output(name, output_tensor.device.type, + _get_device_index(device), + _utils.dtype_torch_to_numpy(output_tensor.dtype), + list(output_tensor.size()), output_tensor.data_ptr()) +def _onnx_value_info_to_buffer_tensor(value_info, device): + '''Create a torch zeroed tensor with the same shape and type of `value_info`''' -def value_info_to_buffer_tensor(value_info, device): shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim] dtype = _utils.dtype_onnx_to_torch(value_info.type.tensor_type.elem_type) return torch.zeros(shape, device=device, dtype=dtype) - class ORTModule(torch.nn.Module): - def __init__(self, module, device="cpu", use_iobinding=False): + def __init__(self, module): assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module" super(ORTModule, self).__init__() - self._device = device - self._use_iobinding = use_iobinding + self._export_again = False + # TODO: This is incorrect when different layers may be in different devices + self._device = next(module.parameters()).device # User module is wrapped to use its initializers and save computed gradients self._original_module = module @@ -62,18 +71,78 @@ class ORTModule(torch.nn.Module): # Forward pass self._onnx_forward = None self._forward_session = None + self._forward_io_binding = None + self._forward_output_buffers = {} # Backward pass self._onnx_backward = None self._backward_session = None + self._backward_io_binding = None + self._backward_output_buffers = {} # Log level self._loglevel = getattr(logging, 'WARNING') - # TODO: debug flags + # Debug flags self._save_onnx = False self._save_onnx_prefix = '' + def cpu(self: T) -> T: + '''Thin layer to capture device for ORTModule IO bindings''' + if self._device != 'cpu': + self._require_export = True + self._device = 'cpu' + return super(ORTModule, self).cpu() + + def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: + '''Thin layer to capture device for ORTModule IO bindings''' + if device: + device = str(device) + else: + device = 'cuda' + if self._device != str(device): + self._require_export = True + self._device = device + return super(ORTModule, self).cuda(device) + + @overload + def to(self: T, device: Optional[Union[int, device]] = ..., + dtype: Optional[Union[dtype, str]] = ..., + non_blocking: bool = ...) -> T: + '''Thin layer to capture device for ORTModule IO bindings''' + if device: + device = str(device) + else: + device = None + if self._device != str(device) and device is not None: + self._require_export = True + self._device = device + return super(ORTModule, self).to(device, dtype, non_blocking) + + @overload + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + '''Thin layer to capture device for ORTModule IO bindings''' + # TODO: Should we do anything? + self._require_export = False + return super(ORTModule, self).to(dtype, non_blocking) + + @overload + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + '''Thin layer to capture device for ORTModule IO bindings''' + # TODO: Long shot by sending model to tensor's device + device = None + if tensor: + device = str(tensor.device) + if self._device != str(device) and device is not None: + self._require_export = True + self._device = device + return super(ORTModule, self).to(tensor, non_blocking) + + def to(self, *args, **kwargs): + '''Thin layer to capture device for ORTModule IO bindings''' + # TODO: Should we do anything? + self._require_export = False + return super(ORTModule, self).to(args, kwargs) def forward(self, *inputs, **kwargs): '''Forward pass starts here and continues at `_ORTModuleFunction.forward` @@ -82,13 +151,15 @@ class ORTModule(torch.nn.Module): Next, a full training graph is splitted in forward and backward graph which are used to instantiate ONNX Runtime InferenceSession`s ''' - if not self._onnx_forward: + if not self._onnx_forward or self._require_export: + self._require_export = False + self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs) grad_builder_config = C.ModuleGradientGraphBuilderConfiguration() # TODO: PyTorch exporter bug: changes the initializer order - # Use the order in original module initializer_names = [p[0] for p in self._original_module.named_parameters()] - self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, initializer_names) + self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \ + ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, initializer_names) if self._save_onnx: onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx') @@ -99,19 +170,18 @@ class ORTModule(torch.nn.Module): # TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info) - execution_providers = ['CPUExecutionProvider'] if self._device == 'cpu' else ['CUDAExecutionProvider', 'CPUExecutionProvider'] - self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=execution_providers) - self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=execution_providers) + self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString()) + self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString()) - if self._use_iobinding: - self._forward_io_binding = self._forward_session.io_binding() - self._forward_output_buffers = {} - for output in self._onnx_forward.graph.output: - self._forward_output_buffers[output.name] = value_info_to_buffer_tensor(output, self._device) - self._backward_io_binding = self._backward_session.io_binding() - self._backward_output_buffers = {} - for output in self._onnx_backward.graph.output: - self._backward_output_buffers[output.name] = value_info_to_buffer_tensor(output, self._device) + # IO binding + self._forward_io_binding = self._forward_session.io_binding() + self._forward_output_buffers = {} + for output in self._onnx_forward.graph.output: + self._forward_output_buffers[output.name] = _onnx_value_info_to_buffer_tensor(output, str(self._device)) + self._backward_io_binding = self._backward_session.io_binding() + self._backward_output_buffers = {} + for output in self._onnx_backward.graph.output: + self._backward_output_buffers[output.name] = _onnx_value_info_to_buffer_tensor(output, str(self._device)) # Use a custom torch.autograd.Function to associate self.backward_graph as the # gradient implementation for self.forward_graph. @@ -129,42 +199,26 @@ class ORTModule(torch.nn.Module): * Intermediate tensors ''' - if not self._use_iobinding: - # Convert input to dict of torch tensors - data_dict = self._convert_forward_input_list_to_dict(*inputs) + # Use IO binding + _create_iobinding(self._forward_io_binding, inputs, + self._onnx_forward, + self._forward_output_buffers, + str(self._device)) - # Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement) - data_dict_numpy = self._convert_dict_torch_to_numpy(data_dict) - - # Feed forward - outputs, intermediate = self._run_forward_graph(data_dict_numpy) - outputs = tuple(torch.from_numpy(item) for item in outputs) - - # Save input, initializers and intermediate tensors to be used during backward - user_input = self._onnx_graphs_info.user_input_names - backward_user_input = self._onnx_graphs_info.backward_user_input_names - ctx_input = tuple(data_dict[name] for name in user_input if name in backward_user_input) - forward_initializer = self._onnx_graphs_info.initializer_names_to_train - backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input - ctx_initializer = tuple(data_dict[name] for name in forward_initializer if name in backward_intializer) - intermediate = tuple(torch.from_numpy(item) for item in intermediate) - ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate]) - - # TODO: Support original module output (currently dict is not supported) - if len(outputs) == 1: - return outputs[0] - return outputs - - # Use IO binding. - prepare_io_binding(self._forward_io_binding, inputs, self._onnx_forward, self._forward_output_buffers, self._device) + # Run self._forward_session.run_with_iobinding(self._forward_io_binding) + # Stash tensors needed by backward forward_input_dict = self._convert_forward_input_list_to_dict(*inputs) - ctx_inputs = tuple(forward_input_dict[name] for name in self._onnx_graphs_info.backward_user_input_names) - ctx_initializers = tuple(forward_input_dict[name] for name in self._onnx_graphs_info.backward_intializer_names_as_input) - ctx_intermediates = tuple(self._forward_output_buffers[name] for name in self._onnx_graphs_info.intermediate_tensor_names) + ctx_inputs = tuple(forward_input_dict[name] \ + for name in self._onnx_graphs_info.backward_user_input_names) + ctx_initializers = tuple(forward_input_dict[name] \ + for name in self._onnx_graphs_info.backward_intializer_names_as_input) + ctx_intermediates = tuple(self._forward_output_buffers[name] \ + for name in self._onnx_graphs_info.intermediate_tensor_names) ctx.save_for_backward(*[*ctx_inputs, *ctx_initializers, *ctx_intermediates]) + # Return model output outputs = tuple(self._forward_output_buffers[name] for name in self._onnx_graphs_info.user_output_names) return outputs[0] if len(outputs) == 1 else outputs @@ -178,22 +232,22 @@ class ORTModule(torch.nn.Module): TODO: Input gradient is hard-coded to torch.tensor([1.]) ''' - if not self._use_iobinding: - saved_tensors = ctx.saved_tensors - grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output]) - result = [torch.tensor([1])]* len(self._onnx_graphs_info.user_input_names) - result += [torch.from_numpy(grad) for grad in grad_weights] - return tuple(result) - - # Use IO binding. + # Use IO binding grad_output_dict = dict(zip(self._onnx_graphs_info.user_output_grad_names, grad_output)) backward_grad_output = tuple(grad_output_dict[name] for name in self._onnx_graphs_info.backward_output_grad_names) - prepare_io_binding(self._backward_io_binding, [*ctx.saved_tensors, *backward_grad_output], self._onnx_backward, self._backward_output_buffers, self._device) + _create_iobinding(self._backward_io_binding, [*ctx.saved_tensors, *backward_grad_output], + self._onnx_backward, + self._backward_output_buffers, + str(self._device)) + + # Run self._backward_session.run_with_iobinding(self._backward_io_binding) + # Return input and initializer gradients results = [torch.tensor([1])] * len(self._onnx_graphs_info.user_input_names) - results += [self._backward_output_buffers[name] for name in self._onnx_graphs_info.initializer_grad_names_to_train] + results += [self._backward_output_buffers[name] \ + for name in self._onnx_graphs_info.initializer_grad_names_to_train] return tuple(results) proc_inputs = [data for data in inputs if data is not None] @@ -203,6 +257,7 @@ class ORTModule(torch.nn.Module): '''Creates forward `*inputs` list from user input and PyTorch initializers TODO: **kwargs is not supported + TODO: How IO binding model inputs and outputs affects initializer copies? ONNX Runtime forward requires an order list of: * User input: computed from forward InferenceSession @@ -224,18 +279,6 @@ class ORTModule(torch.nn.Module): return result - def _convert_dict_torch_to_numpy(self, tensor_dict): - '''Convert `tensor_dict` PyTorch tensors to numpy tensors - - This is a ONNX Runtime requirement - - TODO: #UseIOBinding - ''' - result = {} - for k,v in tensor_dict.items(): - result.update({k : v.detach().cpu().numpy()}) - return result - def _convert_forward_input_list_to_dict(self, *inputs): '''Convert forward `*inputs` list to dict @@ -302,34 +345,6 @@ class ORTModule(torch.nn.Module): return result - def _run_forward_graph(self, inputs): - '''Execute forward pass on ONNX Runtime - - Output order has to be specified to ONNX Runtime backend - to distinguish intermediate from output tensors - ''' - - forward_output = self._forward_session.run([*self._onnx_graphs_info.user_output_names, - *self._onnx_graphs_info.intermediate_tensor_names], inputs) - output = forward_output[:len(self._onnx_graphs_info.user_output_names)] - intermediates = forward_output[len(self._onnx_graphs_info.user_output_names):] - return output, intermediates - - def _run_backward_graph(self, *inputs, **kwargs): - '''Execute backward pass on ONNX Runtime - - `*inputs` is converted from list to a list of detached numpy tensors before - being fed to an ONNX Runtime InferenceSession - - TODO: **kwargs are not supported - ''' - - # Convert input to dict of torch tensors - data = self._convert_backward_input_list_to_dict(*inputs) - - # Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement) - data = self._convert_dict_torch_to_numpy(data) - return self._backward_session.run(self._onnx_graphs_info.initializer_grad_names_to_train, data) @staticmethod def _get_forward_graph(module, *inputs, **kwargs): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py deleted file mode 100644 index 81ebdcf36a..0000000000 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py +++ /dev/null @@ -1,161 +0,0 @@ -import argparse -import logging -import torch -from torchvision import datasets, transforms - -import onnxruntime -from onnxruntime.training import ORTModule - - -class NeuralNet(torch.nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super(NeuralNet, self).__init__() - - self.fc1 = torch.nn.Linear(input_size, hidden_size) - self.relu = torch.nn.ReLU() - self.fc2 = torch.nn.Linear(hidden_size, num_classes) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = self.fc2(out) - return out - - -def train(args, model, device, optimizer, loss_fn, train_loader, epoch): - model.train() - for iteration, (data, target) in enumerate(train_loader): - if iteration == args.train_steps: - break - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - optimizer.zero_grad() - if args.pytorch_only: - probability = model(data) - else: - probability = model(data) - - if args.view_graphs: - import torchviz - pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters()))) - pytorch_backward_graph.view() - - loss = loss_fn(probability, target) - loss.backward() - optimizer.step() - - # Stats - if iteration % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, iteration * len(data), len(train_loader.dataset), - 100. * iteration / len(train_loader), loss)) - - -def test(args, model, device, loss_fn, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - output = model(data) - - # Stats - test_loss += loss_fn(output, target, False).item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) - -def my_loss(x, target, is_train=True): - if is_train: - return torch.nn.CrossEntropyLoss()(x, target) - else: - return torch.nn.CrossEntropyLoss(reduction='sum')(x, target) - -def main(): - # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--train-steps', type=int, default=-1, metavar='N', - help='number of steps to train. Set -1 to run through whole dataset (default: -1)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--batch-size', type=int, default=20, metavar='N', - help='input batch size for training (default: 20)') - parser.add_argument('--test-batch-size', type=int, default=20, metavar='N', - help='input batch size for testing (default: 20)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--use_iobinding', action='store_true', default=False, - help='use IO binding') - parser.add_argument('--seed', type=int, default=42, metavar='S', - help='random seed (default: 42)') - parser.add_argument('--pytorch-only', action='store_true', default=False, - help='disables ONNX Runtime training') - parser.add_argument('--log-interval', type=int, default=100, metavar='N', - help='how many batches to wait before logging training status (default: 100)') - parser.add_argument('--view-graphs', action='store_true', default=False, - help='views forward and backward graphs') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', - help='Log level (default: WARNING)') - - args = parser.parse_args() - - - # Common setup - torch.manual_seed(args.seed) - onnxruntime.set_seed(args.seed) - - # TODO: CUDA support is broken due to copying from PyTorch into ORT - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - # device = 'cpu' - - ## Data loader - train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.batch_size, - shuffle=True) - if args.test_batch_size > 0: - test_loader = torch.utils.data.DataLoader( - datasets.MNIST('./data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.test_batch_size, shuffle=True) - - # Model architecture - model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device) - if not args.pytorch_only: - print('Training MNIST on ORTModule....') - model = ORTModule(model, device, args.use_iobinding) - - # TODO: change it to False to stop saving ONNX models - model._save_onnx = True - model._save_onnx_prefix = 'MNIST' - - # Set log level - numeric_level = getattr(logging, args.log_level.upper(), None) - if not isinstance(numeric_level, int): - raise ValueError('Invalid log level: %s' % args.log_level) - logging.basicConfig(level=numeric_level) - else: - print('Training MNIST on vanilla PyTorch....') - optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) - - # Train loop - for epoch in range(1, args.epochs + 1): - train(args, model, device, optimizer, my_loss, train_loader, epoch) - if args.test_batch_size > 0: - test(args, model, device, my_loss, test_loader) - - -if __name__ == '__main__': - main() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py similarity index 96% rename from orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py index a1e37f502e..9147d981d7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py @@ -1,3 +1,5 @@ +import pdb + import argparse import logging import torch @@ -111,17 +113,15 @@ def main(): onnxruntime.set_seed(args.seed) # TODO: CUDA support is broken due to copying from PyTorch into ORT - # if not args.no_cuda and torch.cuda.is_available(): - # device = "cuda" - # else: - # device = "cpu" - device = 'cpu' + if not args.no_cuda and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" ## Data loader train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), - - transforms.Normalize((0.1307,), (0.3081,))])), + transforms.Normalize((0.1307,), (0.3081,))])), batch_size=args.batch_size, shuffle=True) if args.test_batch_size > 0: diff --git a/run_ortmodule_mvp_bert_finetuning.sh b/run_ortmodule_mvp_bert_finetuning.sh index 52c78c875e..5631a5d041 100755 --- a/run_ortmodule_mvp_bert_finetuning.sh +++ b/run_ortmodule_mvp_bert_finetuning.sh @@ -15,4 +15,5 @@ echo "Copying PyTorch frontend source-code to build folder" cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/ echo "Running Flexible API (ORTModule)" -python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py --no-cuda --epochs 4 --log-interval 20 --log-level=DEBUG +python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py --help +python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py $@ diff --git a/run_ortmodule_mvp_mnist.sh b/run_ortmodule_mvp_mnist.sh index 22e37ad2b9..fefb1800a7 100755 --- a/run_ortmodule_mvp_mnist.sh +++ b/run_ortmodule_mvp_mnist.sh @@ -15,4 +15,5 @@ echo "Copying PyTorch frontend source-code to build folder" cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/ echo "Running Flexible API (ORTModule)" -python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py --epochs 10 --log-interval 100 --log-level=DEBUG +python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py --help +python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist.py $@ diff --git a/run_ortmodule_mvp_mnist_iobinding.sh b/run_ortmodule_mvp_mnist_iobinding.sh deleted file mode 100644 index 45b25f9630..0000000000 --- a/run_ortmodule_mvp_mnist_iobinding.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -cur_dir=$(basename `pwd`) - -if [[ ${cur_dir} != "RelWithDebInfo" ]] -then - echo "Going to build folder (aka build/Linux/RelWithDebInfo)" - cd build/Linux/RelWithDebInfo -fi - -echo "Exporting PYTHONPATH to use build dir as onnxruntime package" -export PYTHONPATH=$(pwd) - -echo "Copying PyTorch frontend source-code to build folder" -cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/ - -echo "Running Flexible API (ORTModule)" -python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py --epochs 10 --log-interval 100 --log-level=DEBUG --use_iobinding