From ff79e8743fbf71e7fa54ff618ef29383f19c5338 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 13 Nov 2020 09:41:59 -0800 Subject: [PATCH] Add support to BERT fine tuning (MVP 3) Additional changes include major refactoring to use new backend API --- .../orttraining/python/training/ortmodule.py | 383 +++++------------- .../orttraining_test_ortmodule_basic.py | 2 +- ...training_test_ortmodule_bert_classifier.py | 101 +++-- run_ortmodule_mvp_bert_finetuning.sh | 2 +- run_ortmodule_mvp_mnist.sh | 2 +- 5 files changed, 182 insertions(+), 308 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index d8c1cabe46..cf4ffa6358 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -6,8 +6,9 @@ import onnxruntime import os import torch import warnings -from onnxruntime.capi import _pybind_state as C +from inspect import signature +from onnxruntime.capi import _pybind_state as C from . import _utils @@ -22,29 +23,16 @@ class ORTModule(torch.nn.Module): # User module is wrapped to use its initializers and save computed gradients self._original_module = module - self._original_module_grad_output_len = -1 - self._original_module_forward_input_grads = [] self._onnx_training = None - self._onnx_training_inputs_desc = [] - self._onnx_training_outputs_desc = [] self._onnx_gradient = None - self._grad_builder_config = C.ModuleGradientGraphBuilderConfiguration() # Forward pass self._onnx_forward = None self._forward_session = None - self._onnx_forward_initializers_desc = [] - self._onnx_forward_inputs_desc = [] - self._onnx_forward_outputs_desc = [] - self._onnx_forward_intermediate_outputs_desc = [] # Backward pass self._onnx_backward = None self._backward_session = None - self._onnx_backward_initializers_desc = [] - self._onnx_backward_inputs_desc = [] - self._onnx_backward_gradient_inputs_desc = [] - self._onnx_backward_outputs_desc = [] # Log level self._loglevel = getattr(logging, 'WARNING') @@ -60,14 +48,13 @@ class ORTModule(torch.nn.Module): ONNX model is exported the first time this method is executed. Next, a full training graph is splitted in forward and backward graph which are used to instantiate ONNX Runtime InferenceSession`s - - TODO: #ImproveGraphSplitting - Additionally to that, several descriptor lists are generated to help identify - model input, output, initializer, intermediate and gradient tensors. ''' if not self._onnx_forward: self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs) - self._onnx_gradient, self._onnx_forward, self._onnx_backward = ORTModule._build_gradient_graph(self._onnx_training, self._grad_builder_config) + grad_builder_config = C.ModuleGradientGraphBuilderConfiguration() + self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config) + # TODO: PyTorch exporter bug: changes the initializer order + self._onnx_graphs_info.initializer_grad_names_to_train = [ p[0]+'_grad' for p in self._original_module.named_parameters()] if self._save_onnx: onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx') @@ -75,44 +62,13 @@ class ORTModule(torch.nn.Module): onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx') onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx') + # TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names + self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info) + # TODO: hard-coding to CPU only self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=['CPUExecutionProvider']) self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=['CPUExecutionProvider']) - # Forward I/O description - if not self._onnx_training_inputs_desc: - self._onnx_training_inputs_desc = self._get_input_from_graph(self._onnx_training) - logging.debug(f'Training inputs:\n\t {self._onnx_training_inputs_desc}') - if not self._onnx_training_outputs_desc: - self._onnx_training_outputs_desc = self._get_output_from_graph(self._onnx_training) - logging.debug(f'Training outputs:\n\t {self._onnx_training_outputs_desc}') - if not self._onnx_forward_initializers_desc: - self._onnx_forward_initializers_desc = self._get_initializer_from_graph(self._onnx_forward) - logging.debug(f'Forward initializers:\n\t {self._onnx_forward_initializers_desc}') - if not self._onnx_forward_inputs_desc: - self._onnx_forward_inputs_desc = self._get_input_from_graph(self._onnx_forward) - logging.debug(f'Forward inputs:\n\t {self._onnx_forward_inputs_desc}') - if not self._onnx_forward_outputs_desc: - self._onnx_forward_outputs_desc = self._get_output_from_graph(self._onnx_forward) - logging.debug(f'Forward outputs:\n\t {self._onnx_forward_outputs_desc}') - if not self._onnx_forward_intermediate_outputs_desc: - self._onnx_forward_intermediate_outputs_desc = self._get_intermediate_from_forward_graph(self._onnx_forward) - logging.debug(f'Forward intermediate outputs:\n\t {self._onnx_forward_intermediate_outputs_desc}') - - # Backward I/O description - if not self._onnx_backward_initializers_desc: - self._onnx_backward_initializers_desc = self._get_input_from_graph(self._onnx_backward, True) - logging.debug(f'Backward initializers: {self._onnx_backward_initializers_desc}') - if not self._onnx_backward_inputs_desc: - self._onnx_backward_inputs_desc = self._get_input_from_graph(self._onnx_backward, False, self._onnx_backward_initializers_desc) - logging.debug(f'Backward inputs: {self._onnx_backward_inputs_desc}') - if not self._onnx_backward_gradient_inputs_desc: - self._onnx_backward_gradient_inputs_desc = self._get_gradient_input_from_graph(self._onnx_backward, self._onnx_forward_inputs_desc, self._onnx_forward_initializers_desc, self._onnx_forward_intermediate_outputs_desc) - logging.debug(f'Backward gradient inputs: {self._onnx_backward_gradient_inputs_desc}') - if not self._onnx_backward_outputs_desc: - self._onnx_backward_outputs_desc = self._get_output_from_graph(self._onnx_backward) - logging.debug(f'Backward outputs: {self._onnx_backward_outputs_desc}') - # Use a custom torch.autograd.Function to associate self.backward_graph as the # gradient implementation for self.forward_graph. class _ORTModuleFunction(torch.autograd.Function): @@ -127,9 +83,6 @@ class ORTModule(torch.nn.Module): * (Partial) user input * (Partial) Initializers * Intermediate tensors - - TODO: #ImproveGraphSplitting - String matching to separate user input from initializer ''' # Convert input to dict of torch tensors @@ -143,10 +96,12 @@ class ORTModule(torch.nn.Module): outputs = tuple(torch.from_numpy(item) for item in outputs) # Save input, initializers and intermediate tensors to be used during backward - initializer_names = [item['name'] for item in self._onnx_backward_initializers_desc] - input_names = [item['name'] for item in self._onnx_backward_inputs_desc if item['name'] not in initializer_names] - ctx_input = tuple(v for k,v in data_dict.items() if k in input_names) - ctx_initializer = tuple(v for k,v in data_dict.items() if k in initializer_names) + user_input = self._onnx_graphs_info.user_input_names + backward_user_input = self._onnx_graphs_info.backward_user_input_names + ctx_input = tuple(data_dict[name] for name in user_input if name in backward_user_input) + forward_initializer = self._onnx_graphs_info.initializer_names_to_train + backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input + ctx_initializer = tuple(data_dict[name] for name in forward_initializer if name in backward_intializer) intermediate = tuple(torch.from_numpy(item) for item in intermediate) ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate]) @@ -163,23 +118,17 @@ class ORTModule(torch.nn.Module): * Tensor stashed (in a particular order) during forward: * (partial) user input, (partial) initializers and intermediate tensors - TODO: #ImproveGraphSplitting - Length of `*grad_output` is needed to detect intermediate tensors during backward pass - TODO: Input gradient is hard-coded to torch.tensor([1.]) ''' saved_tensors = ctx.saved_tensors - # Used to create backward input - if self._original_module_grad_output_len == -1: - self._original_module_grad_output_len = len(grad_output) - grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output]) - result = [torch.tensor([1])]* len(self._onnx_training_inputs_desc) + result = [torch.tensor([1])]* len(self._onnx_graphs_info.user_input_names) result += [torch.from_numpy(grad) for grad in grad_weights] return tuple(result) - return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*inputs, **kwargs)) + proc_inputs = [data for data in inputs if data is not None] + return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*proc_inputs, **kwargs)) def _convert_forward_input_to_list(self, *inputs, **kwargs): '''Creates forward `*inputs` list from user input and PyTorch initializers @@ -189,8 +138,10 @@ class ORTModule(torch.nn.Module): ONNX Runtime forward requires an order list of: * User input: computed from forward InferenceSession * Initializers: computed from original PyTorch model parameters - ''' + This codes assumes the exported model's inputs and initializers + are the same as the original PyTorch model + ''' # List containing both user inputs and initializers, in this order result = [] @@ -219,12 +170,8 @@ class ORTModule(torch.nn.Module): def _convert_forward_input_list_to_dict(self, *inputs): '''Convert forward `*inputs` list to dict - TODO: #ImproveGraphSplitting - Additionally, a list of gradient names of initializers are created to be used by backprop - TODO: Input gradient is being ignored for MVP ''' - # Dictionary containing both inputs and initializers result = {} @@ -237,15 +184,6 @@ class ORTModule(torch.nn.Module): # Initializers for param in self._original_module.named_parameters(): result.update({param[0]: inputs[result_len]}) - # TODO: Create order list of input grads to use during backward. - # (for scenarios where gradients of input is required - not covered on MVP) - # if len(self._original_module_forward_input_grads) < len(self._onnx_training_inputs_desc): - # self._original_module_forward_input_grads.append(param[0]+'_grad') - - # TODO: Create order list of initializer grads to use during backward. - # if len(self._original_module_forward_input_grads) < len(self._onnx_backward_outputs_desc) + len(self._onnx_training_inputs_desc): - if len(self._original_module_forward_input_grads) < len(self._onnx_backward_outputs_desc): - self._original_module_forward_input_grads.append(param[0]+'_grad') result_len += 1 return result @@ -253,46 +191,45 @@ class ORTModule(torch.nn.Module): def _convert_backward_input_list_to_dict(self, *inputs): '''Convert backward `*inputs` list to dict - ONNX Runtime backend requires dict as input, which is composed of: + ONNX Runtime backward requires dict as input, which is composed of: * User input Although not necessary, all user inputs are used for simplicity * (Partial) Initializers init_begin = len(user_input) init_count = len(Pre-computed list of initializer) - * Intermediate tensors TODO: #ImproveGraphSplitting - Intermediate tensors are inferred from input position: - interm_begin = len(user_input) + len(initializer) - interm_count = len(all_inputs) - len(user_input) - len(initializer) - len(grad_output) - * Gradient wrt outputs TODO: #ImproveGraphSplitting - Gradient tensors are inferred from input position: - grads_begin = len(user_input) + len(initializer) + len(intermediate) - grads_count = len(all_inputs) - len(user_input) - len(initializer) - len(intermediate) + * Intermediate tensors + * Gradient wrt outputs ''' # Dictionary containing both inputs and initializers result = {} + backward_user_input = self._onnx_graphs_info.backward_user_input_names + backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input + intermediate = self._onnx_graphs_info.intermediate_tensor_names + backward_output_grad_names = self._onnx_graphs_info.backward_output_grad_names + + # Extract info about stashed input and grad output # Inputs - result_len = 0 - for idx, input_data in enumerate(self._forward_session.get_inputs()): - result.update({ input_data.name : inputs[idx]}) - result_len += 1 + inputs_pos = 0 + for idx, name in enumerate(backward_user_input): + result.update({ name : inputs[idx]}) + inputs_pos += 1 # Initializers - for initializer in self._onnx_backward_initializers_desc: - result.update({initializer['name']: inputs[result_len]}) - result_len += 1 + for idx, name in enumerate(backward_intializer, inputs_pos): + result.update({name: inputs[idx]}) + inputs_pos += 1 # Intermediate - intermediate_len = len(inputs) - result_len - self._original_module_grad_output_len - for idx in range(intermediate_len): - result.update({self._onnx_forward_intermediate_outputs_desc[idx]['name']: inputs[result_len]}) - result_len += 1 + for idx, name in enumerate(intermediate, inputs_pos): + result.update({name: inputs[idx]}) + inputs_pos += 1 # Grad outputs - for idx in range(len(inputs)-result_len): - result.update({self._onnx_backward_gradient_inputs_desc[idx]['name']: inputs[result_len]}) - result_len += 1 + for idx, name in enumerate(backward_output_grad_names, inputs_pos): + result.update({name: inputs[idx]}) + inputs_pos += 1 return result @@ -303,10 +240,10 @@ class ORTModule(torch.nn.Module): to distinguish intermediate from output tensors ''' - output_names = [out['name'] for out in self._onnx_forward_outputs_desc] - forward_output = self._forward_session.run(output_names, inputs) - output = forward_output[:len(self._onnx_training_outputs_desc)] - intermediates = forward_output[len(self._onnx_training_outputs_desc):] + forward_output = self._forward_session.run([*self._onnx_graphs_info.user_output_names, + *self._onnx_graphs_info.intermediate_tensor_names], inputs) + output = forward_output[:len(self._onnx_graphs_info.user_output_names)] + intermediates = forward_output[len(self._onnx_graphs_info.user_output_names):] return output, intermediates def _run_backward_graph(self, *inputs, **kwargs): @@ -323,7 +260,7 @@ class ORTModule(torch.nn.Module): # Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement) data = self._convert_dict_torch_to_numpy(data) - return self._backward_session.run(self._original_module_forward_input_grads, data) + return self._backward_session.run(self._onnx_graphs_info.initializer_grad_names_to_train, data) @staticmethod def _get_forward_graph(module, *inputs, **kwargs): @@ -338,6 +275,11 @@ class ORTModule(torch.nn.Module): # Deepcopy inputs, since input values may change after model run. sample_inputs_copy = copy.deepcopy(inputs) + # Ignore optional *inputs explicitly specified as None + sig = signature(module.forward) + all_input_names = sig.parameters.keys() + input_names = [name for idx, name in enumerate(all_input_names) if inputs[idx] is not None] + # TODO: Support contrib OPs support? user model has no hint # from onnxruntime.training import register_custom_ops_pytorch_exporter # register_custom_ops_pytorch_exporter.register_custom_op() @@ -346,173 +288,68 @@ class ORTModule(torch.nn.Module): torch.onnx.export(module, tuple(sample_inputs_copy), f, + input_names=input_names, opset_version=ONNX_OPSET_VERSION, do_constant_folding=False, training=torch.onnx.TrainingMode.TRAINING) + return onnx.load_model_from_string(f.getvalue()) - def _get_initializer_from_graph(self, graph): - '''Returns a descriptor list of initializers for `graph` - - The list descriptor has the following format: - [{ 'name': name, 'shape':[int1,...,intN], 'dtype': ]}] - - For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461 - ''' - - # TODO: There is a tradeoff between memory footprint and total model export time - # Ideally we want to export the model using torch.onnx.export(.., export_params=False, keep_initializers_as_inputs=True) - # to obtain an ONNX model with minimal size and initializers as input. - # However, this results in (guessing) assuming only initializer's name end with '.weight' and '.bias'. - # Otherwise, it is not possible to separate input from initializer after the model is exported - # Options are: - # 1) If memory footprint is more important, we can export ONNX twice, varying keep_initializers_as_inputs flag - # ONNX model is small (400 bytes vs 1.6MB for MNIST), but export takes twice the time - # 2) If total export time is more important, we can export ONNX once, using export_params=True - # ONNX model is bigger, but export takes half the time - - # As performance is not the main goal in this first deliverable, using approach 2) for simplicity - initializers = [] - for initializer in graph.graph.initializer: - name = initializer.name - shape = initializer.dims - dtype = _utils.dtype_onnx_to_torch(initializer.data_type) - initializers.append({'name': name, 'shape': shape, 'dtype': dtype}) - return initializers - - def _get_input_from_graph(self, graph, initializers_only=False, append_initializers=[]): - '''Returns a descriptor list of input tensors for an ONNX `graph` - - When `initializers_only=True`, only input initializers are returned. Otherwise, both - user input and initializers are considered. - This is being used to get backward initializer list TODO: #ImproveGraphSplitting - - When `append_initializers` is not empty, this list is appended to the end of the result list - This is being used to get backward input list TODO: #ImproveGraphSplitting - - The list descriptor has the following format: - [{ 'name': name, 'shape':[int1,...,intN], 'dtype': ]}] - - For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461 - ''' - - inputs = [] - for elem in graph.graph.input: - for initializer in self._onnx_forward_initializers_desc: - if elem.name == initializer['name']: - if initializers_only: - name = elem.name - shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim] - dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type) - inputs.append({'name': name, 'shape': shape, 'dtype': dtype}) - break - else: - if not initializers_only: - name = elem.name - shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim] - dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type) - inputs.append({'name': name, 'shape': shape, 'dtype': dtype}) - if append_initializers: - inputs.extend(append_initializers) - return inputs - - def _get_gradient_input_from_graph(self, backward_graph, forward_input, forward_initializer, forward_intermediate): - '''Returns a descriptor list of gradient output for `backward_graph` - - Gradient output tensors are found through an elimination process, that cross reference - inputs from the backward graph to the forward input, initializer and intermediate tensors. - - The list descriptor has the following format: - [{ 'name': name, 'shape':[int1,...,intN], 'dtype': ]}] - - For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461 - - TODO: #ImproveGraphSplitting - ''' - grads = [] - found = False - for elem in backward_graph.graph.input: - for item in forward_input: - if elem.name == item['name']: - # skip output - break - else: - for item in forward_initializer: - if elem.name == item['name']: - # skip output - break - else: - for item in forward_intermediate: - if elem.name == item['name']: - # skip output - break - else: - name = elem.name - shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim] - dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type) - grads.append({'name': name, 'shape': shape, 'dtype': dtype}) - return grads - - def _get_output_from_graph(self, graph): - '''Returns a descriptor list of output tensors for an ONNX `graph` - - The list descriptor has the following format: - [{ 'name': name, 'shape':[int1,...,intN], 'dtype': ]}] - - For ONNX types, refer to https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L461 - ''' - outputs = [] - for elem in graph.graph.output: - for initializer in self._onnx_forward_initializers_desc: - if elem.name == initializer['name']: - # skip initializers - break - else: - name = elem.name - shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim] - dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type) - outputs.append({'name': name, 'shape': shape, 'dtype': dtype}) - return outputs - - def _get_intermediate_from_forward_graph(self, forward_graph): - '''Returns a descriptor list with all intermediate tensors for `forward_graph` - - Intermediate tensors are found through an elimination process, that cross reference - outputs from the forward graph to the original model (exported to ONNX) - - The list descriptor has the following format: - [{ 'name': name, 'shape':[int1,...,intN], 'dtype': ]}] - - TODO: #ImproveGraphSplitting - ''' - intermediates = [] - for elem in forward_graph.graph.output: - for output in self._onnx_training_outputs_desc: - if elem.name == output['name']: - # skip output - break - else: - name = elem.name - shape = [dim.dim_value for dim in elem.type.tensor_type.shape.dim] - dtype = _utils.dtype_onnx_to_torch(elem.type.tensor_type.elem_type) - intermediates.append({'name': name, 'shape': shape, 'dtype': dtype}) - return intermediates @staticmethod - def _build_gradient_graph(forward_graph, config): - '''Adds gradient nodes on top of an existing ONNX graph (with training flag) - - TODO: #SplittingGraphAtFrontend - ''' - if not config.weight_names_to_train: - weight_names_to_train = set() + def _build_fw_bw_grad_graphs(forward_graph, config): + '''Adds gradient nodes on top of an existing ONNX graph (with training flag)''' + if not config.initializer_names_to_train: + initializer_names_to_train = [] for initializer in forward_graph.graph.initializer: - weight_names_to_train.add(initializer.name) - config.weight_names_to_train = weight_names_to_train - output_names = set() - for output in forward_graph.graph.output: - output_names.add(output.name) - config.output_names = output_names - models = [onnx.load_model_from_string(model_as_string) - for model_as_string in C.ModuleGradientGraphBuilder().build_and_split(forward_graph.SerializeToString(), config)] - return models[0], models[1], models[2] + initializer_names_to_train.append(initializer.name) + config.initializer_names_to_train = initializer_names_to_train + + # TODO: Add support to input with grad required + config.input_names_require_grad = [] + # input_names_require_grad = [] + # input_names_require_grad.append('input.1') + # config.input_names_require_grad = input_names_require_grad + + module_gradient_graph_builder = C.ModuleGradientGraphBuilder() + module_gradient_graph_builder.build_and_split(forward_graph.SerializeToString(), config) + forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model()) + backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model()) + gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model()) + split_graphs_info = module_gradient_graph_builder.get_split_graphs_info() + + return gradient_model, forward_model, backward_model, split_graphs_info + + + @staticmethod + def _get_io_info_from_onnx_graph(model, graphs_info): + type_map = {} + for name in graphs_info.user_input_names: + type_map[name] = None + for name in graphs_info.initializer_names_to_train: + type_map[name] = None + for name in graphs_info.user_output_names: + type_map[name] = None + for name in graphs_info.backward_user_input_names: + type_map[name] = None + for name in graphs_info.backward_intializer_names_as_input: + type_map[name] = None + for name in graphs_info.intermediate_tensor_names: + type_map[name] = None + for name in graphs_info.user_output_grad_names: + type_map[name] = None + for name in graphs_info.backward_output_grad_names: + type_map[name] = None + + for input in model.graph.input: + if input.name in type_map and type_map[input.name] is None: + type_map[input.name] = input.type + + for output in model.graph.output: + if output.name in type_map and type_map[output.name] is None: + type_map[output.name] = output.type + output_grad_name = output.name + '_grad' + if output_grad_name in type_map and type_map[output_grad_name] is None: + type_map[output_grad_name] = output.type + + return type_map \ No newline at end of file diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py index fd6fea783c..a1e37f502e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py @@ -143,7 +143,7 @@ def main(): # Set log level numeric_level = getattr(logging, args.log_level.upper(), None) if not isinstance(numeric_level, int): - raise ValueError('Invalid log level: %s' % loglevel) + raise ValueError('Invalid log level: %s' % args.log_level) logging.basicConfig(level=numeric_level) else: print('Training MNIST on vanilla PyTorch....') diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index 3a8871731d..2b8a6336d6 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -1,13 +1,11 @@ - -import pdb - +import logging import argparse import torch import wget import os import pandas as pd import zipfile -from transformers import BertTokenizer +from transformers import BertTokenizer, AutoConfig from keras.preprocessing.sequence import pad_sequences from sklearn.model_selection import train_test_split from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler @@ -50,13 +48,10 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): if step == args.train_steps: break - # Progress update every 40 batches. - if step % args.log_interval == 0 and not step == 0: - # Calculate elapsed time in minutes. - elapsed = format_time(time.time() - t0) - - # Report progress. - print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed)) + # TODO: Dynamic axis is not supported yet + if batch[0].shape[0] != args.batch_size: + logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}') + continue # Unpack this training batch from our dataloader. # @@ -78,12 +73,23 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): model.zero_grad() # Perform a forward pass (evaluate the model on this training batch). - # This will return the loss (rather than the model output) because we - # have provided the `labels`. + # This will return the loss (rather than the model output) because we have provided the `labels`. # The documentation for this `model` function is here: - # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification + # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification + # TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule - outputs = model(b_input_ids, b_input_mask, None, None, None, None, b_labels) + # outputs = model(b_input_ids, + # token_type_ids = None, + # attention_mask = b_input_mask, + # labels = b_labels) + outputs = model(b_input_ids, + b_input_mask, + None, + None, + None, + None, + b_labels) + if args.view_graphs: import torchviz pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters()))) @@ -91,15 +97,21 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): # The call to `model` always returns a tuple, so we need to pull the # loss value out of the tuple. - # pdb.set_trace() loss = outputs[0] + # Progress update every 40 batches. + if step % args.log_interval == 0 and not step == 0: + # Calculate elapsed time in minutes. + elapsed = format_time(time.time() - t0) + + # Report progress. + print(f'Batch {step} of {len(train_dataloader)}. Elapsed: {elapsed}. Loss: {loss.item()}') + # Accumulate the training loss over all of the batches so that we can # calculate the average loss at the end. `loss` is a Tensor containing a # single value; the `.item()` function just returns the Python value # from the tensor. total_loss += loss.item() - # total_loss += loss # Perform a backward pass to calculate the gradients. loss.backward() @@ -122,7 +134,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): print("\n Average training loss: {0:.2f}".format(avg_train_loss)) print(" Training epoch took: {:}".format(format_time(time.time() - t0))) -def test(model, validation_dataloader, device): +def test(model, validation_dataloader, device, args): # ======================================== # Validation # ======================================== @@ -143,12 +155,16 @@ def test(model, validation_dataloader, device): # Evaluate data for one epoch for batch in validation_dataloader: + # TODO: Dynamic axis is not supported yet + if batch[0].shape[0] != args.test_batch_size: + logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}') + continue + # Add batch to GPU batch = tuple(t.to(device) for t in batch) # Unpack the inputs from our dataloader b_input_ids, b_input_mask, b_labels = batch - # Telling the model not to compute or store gradients, saving memory and # speeding up validation with torch.no_grad(): @@ -160,18 +176,22 @@ def test(model, validation_dataloader, device): # differentiates sentence 1 and 2 in 2-sentence tasks. # The documentation for this `model` function is here: # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification + # TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule + # TODO: original sample had the last argument equal to None, but b_labels is because model was + # exported using 3 inputs for training, so validation must follow. + # Another approach would be checkpoint the trained model, re-export the model for validation with the checkpoint outputs = model(b_input_ids, b_input_mask, None, None, None, None, - None) + b_labels) # Get the "logits" output by the model. The "logits" are the output # values prior to applying an activation function like the softmax. - logits = outputs[0] + logits = outputs[1] # Move logits and labels to CPU logits = logits.detach().cpu().numpy() @@ -190,7 +210,7 @@ def test(model, validation_dataloader, device): print(" Accuracy: {0:.2f}".format(eval_accuracy/nb_eval_steps)) print(" Validation took: {:}".format(format_time(time.time() - t0))) -def load_dataset(): +def load_dataset(args): # 2. Loading CoLA Dataset print('Downloading dataset...') @@ -276,18 +296,15 @@ def load_dataset(): train_masks = torch.tensor(train_masks) validation_masks = torch.tensor(validation_masks) - # The DataLoader needs to know our batch size for training, so we specify it - batch_size = 32 - # Create the DataLoader for our training set. train_data = TensorDataset(train_inputs, train_masks, train_labels) train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) + train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) # Create the DataLoader for our validation set. validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels) validation_sampler = SequentialSampler(validation_data) - validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size) + validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=args.test_batch_size) return train_dataloader, validation_dataloader @@ -310,6 +327,10 @@ def main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--pytorch-only', action='store_true', default=False, help='disables ONNX Runtime training') + parser.add_argument('--batch-size', type=int, default=32, metavar='N', + help='input batch size for training (default: 32)') + parser.add_argument('--test-batch-size', type=int, default=32, metavar='N', + help='input batch size for testing (default: 32)') parser.add_argument('--view-graphs', action='store_true', default=False, help='views forward and backward graphs') parser.add_argument('--no-cuda', action='store_true', default=False, @@ -322,6 +343,11 @@ def main(): help='how many batches to wait before logging training status (default: 40)') parser.add_argument('--train-steps', type=int, default=-1, metavar='N', help='number of steps to train. Set -1 to run through whole dataset (default: -1)') + parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', + help='Log level (default: WARNING)') + parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H', + help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)') + args = parser.parse_args() # Device (CPU vs CUDA) @@ -333,17 +359,28 @@ def main(): print('No GPU available, using the CPU instead.') device = torch.device("cpu") + # Set log level + numeric_level = getattr(logging, args.log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError('Invalid log level: %s' % args.log_level) + logging.basicConfig(level=numeric_level) + # 2. Dataloader - train_dataloader, validation_dataloader = load_dataset() + train_dataloader, validation_dataloader = load_dataset(args) # 3. Modeling # Load BertForSequenceClassification, the pretrained BERT model with a single # linear classification layer on top. + config = AutoConfig.from_pretrained( + "bert-base-uncased", + num_labels=2, + num_hidden_layers=args.num_hidden_layers, + output_attentions = False, # Whether the model returns attentions weights. + output_hidden_states = False, # Whether the model returns all hidden-states. + ) model = BertForSequenceClassification.from_pretrained( "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. - num_labels = 2, # The number of output labels--2 for binary classification. - output_attentions = False, # Whether the model returns attentions weights. - output_hidden_states = False, # Whether the model returns all hidden-states. + config=config, ) if not args.pytorch_only: @@ -382,7 +419,7 @@ def main(): # 4. Train loop (fine-tune) for epoch_i in range(0, args.epochs): train(model, optimizer, scheduler, train_dataloader, epoch_i, device, args) - test(model, validation_dataloader, device) + test(model, validation_dataloader, device, args) if __name__ == '__main__': main() diff --git a/run_ortmodule_mvp_bert_finetuning.sh b/run_ortmodule_mvp_bert_finetuning.sh index b60de0bf60..52c78c875e 100755 --- a/run_ortmodule_mvp_bert_finetuning.sh +++ b/run_ortmodule_mvp_bert_finetuning.sh @@ -15,4 +15,4 @@ echo "Copying PyTorch frontend source-code to build folder" cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/ echo "Running Flexible API (ORTModule)" -python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py --no-cuda --epochs 4 --log-interval 20 --log-level=DEBUG diff --git a/run_ortmodule_mvp_mnist.sh b/run_ortmodule_mvp_mnist.sh index 8a14061d41..22e37ad2b9 100755 --- a/run_ortmodule_mvp_mnist.sh +++ b/run_ortmodule_mvp_mnist.sh @@ -15,4 +15,4 @@ echo "Copying PyTorch frontend source-code to build folder" cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/ echo "Running Flexible API (ORTModule)" -python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py +python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py --epochs 10 --log-interval 100 --log-level=DEBUG