diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index eeb581d16a..f1998290b1 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -575,7 +575,7 @@ class ORTModule(torch.nn.Module): for output in node.output: backward_graph_outputs.add(output) else: - # nodes belogs to forward graph + # nodes belongs to forward graph for input in node.input: if input in initializers: forward_graph_initializer_names.add(input) @@ -589,7 +589,7 @@ class ORTModule(torch.nn.Module): if initializer_name in weight_names_to_train: add_input_from_initializer(forward_model, initializers[initializer_name]) - # outputs from forward graph that are also inputs of backwoard graph need to be added as graph output. + # outputs from forward graph that are also inputs of backward graph need to be added as graph output. for output in forward_graph_outputs: if output in backward_graph_inputs: add_output(forward_model, output) @@ -607,10 +607,15 @@ class ORTModule(torch.nn.Module): for initializer in backward_model.graph.initializer: initializers[initializer.name] = initializer + # Remove nodes from forward graph nodes_to_remove_from_backward_graph = [] + dropoutgrad_input1_set = set() for node in backward_model.graph.node: if node.doc_string != 'Backward pass': nodes_to_remove_from_backward_graph.append(node) + # TODO: thiagofc: BERT: Remove this once graph splitter can handle unspecified optional input (without type) + if node.op_type == 'DropoutGrad': + dropoutgrad_input1_set.add(node.input[1]) backward_graph_initializer_names = set() for input in backward_graph_inputs: @@ -618,8 +623,7 @@ class ORTModule(torch.nn.Module): # inputs of backward graph that are also outputs from forward graph need to be added to backward graph input # TODO: thiagofc: BERT: Remove this once graph splitter can handle unspecified optional input (without type) input_type = tensor_elem_types[input] if input in tensor_elem_types else 1 - if input in {'1835', '1813', '1781','1760', '1683','1651','1630','1553','1521','1500','1423','1391','1370','1293','1261','1240','1163','1131','1110','1033','1001','980','871', - '267','330','351','383','460','481','513','590','611','643','720','741','773','850','903'}: + if input in dropoutgrad_input1_set: input_type = 9 add_input(backward_model, input, input_type) elif input in forward_graph_initializer_names: @@ -630,7 +634,7 @@ class ORTModule(torch.nn.Module): # gradient of forward graph output will be the input of backward graph for output in backward_model.graph.output: - if output.name + '_grad' in backward_graph_inputs: + if output.name + '_grad' in backward_graph_inputs and output.name != '1835': # TODO: 1835_grad is grad of logits and must be computed by ONNX's loss grad subgraph add_input(backward_model, output.name + '_grad', output.type.tensor_type.elem_type) backward_model.graph.ClearField('initializer') diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index 63200a12f5..3a8871731d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -82,11 +82,8 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): # have provided the `labels`. # The documentation for this `model` function is here: # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification - outputs = model(b_input_ids, - token_type_ids=None, - attention_mask=b_input_mask, - labels=b_labels) - + # TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule + outputs = model(b_input_ids, b_input_mask, None, None, None, None, b_labels) if args.view_graphs: import torchviz pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters()))) @@ -163,9 +160,14 @@ def test(model, validation_dataloader, device): # differentiates sentence 1 and 2 in 2-sentence tasks. # The documentation for this `model` function is here: # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification + # TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule outputs = model(b_input_ids, - token_type_ids=None, - attention_mask=b_input_mask) + b_input_mask, + None, + None, + None, + None, + None) # Get the "logits" output by the model. The "logits" are the output # values prior to applying an activation function like the softmax.