Change DropouGrad.input[1].input_type and del logits_grad from backward graph

This commit is contained in:
Thiago Crepaldi 2020-11-04 16:38:43 -08:00
parent f1dc6e4007
commit ea5871ac15
2 changed files with 18 additions and 12 deletions

View file

@ -575,7 +575,7 @@ class ORTModule(torch.nn.Module):
for output in node.output:
backward_graph_outputs.add(output)
else:
# nodes belogs to forward graph
# nodes belongs to forward graph
for input in node.input:
if input in initializers:
forward_graph_initializer_names.add(input)
@ -589,7 +589,7 @@ class ORTModule(torch.nn.Module):
if initializer_name in weight_names_to_train:
add_input_from_initializer(forward_model, initializers[initializer_name])
# outputs from forward graph that are also inputs of backwoard graph need to be added as graph output.
# outputs from forward graph that are also inputs of backward graph need to be added as graph output.
for output in forward_graph_outputs:
if output in backward_graph_inputs:
add_output(forward_model, output)
@ -607,10 +607,15 @@ class ORTModule(torch.nn.Module):
for initializer in backward_model.graph.initializer:
initializers[initializer.name] = initializer
# Remove nodes from forward graph
nodes_to_remove_from_backward_graph = []
dropoutgrad_input1_set = set()
for node in backward_model.graph.node:
if node.doc_string != 'Backward pass':
nodes_to_remove_from_backward_graph.append(node)
# TODO: thiagofc: BERT: Remove this once graph splitter can handle unspecified optional input (without type)
if node.op_type == 'DropoutGrad':
dropoutgrad_input1_set.add(node.input[1])
backward_graph_initializer_names = set()
for input in backward_graph_inputs:
@ -618,8 +623,7 @@ class ORTModule(torch.nn.Module):
# inputs of backward graph that are also outputs from forward graph need to be added to backward graph input
# TODO: thiagofc: BERT: Remove this once graph splitter can handle unspecified optional input (without type)
input_type = tensor_elem_types[input] if input in tensor_elem_types else 1
if input in {'1835', '1813', '1781','1760', '1683','1651','1630','1553','1521','1500','1423','1391','1370','1293','1261','1240','1163','1131','1110','1033','1001','980','871',
'267','330','351','383','460','481','513','590','611','643','720','741','773','850','903'}:
if input in dropoutgrad_input1_set:
input_type = 9
add_input(backward_model, input, input_type)
elif input in forward_graph_initializer_names:
@ -630,7 +634,7 @@ class ORTModule(torch.nn.Module):
# gradient of forward graph output will be the input of backward graph
for output in backward_model.graph.output:
if output.name + '_grad' in backward_graph_inputs:
if output.name + '_grad' in backward_graph_inputs and output.name != '1835': # TODO: 1835_grad is grad of logits and must be computed by ONNX's loss grad subgraph
add_input(backward_model, output.name + '_grad', output.type.tensor_type.elem_type)
backward_model.graph.ClearField('initializer')

View file

@ -82,11 +82,8 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
# have provided the `labels`.
# The documentation for this `model` function is here:
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
outputs = model(b_input_ids,
token_type_ids=None,
attention_mask=b_input_mask,
labels=b_labels)
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
outputs = model(b_input_ids, b_input_mask, None, None, None, None, b_labels)
if args.view_graphs:
import torchviz
pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters())))
@ -163,9 +160,14 @@ def test(model, validation_dataloader, device):
# differentiates sentence 1 and 2 in 2-sentence tasks.
# The documentation for this `model` function is here:
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
outputs = model(b_input_ids,
token_type_ids=None,
attention_mask=b_input_mask)
b_input_mask,
None,
None,
None,
None,
None)
# Get the "logits" output by the model. The "logits" are the output
# values prior to applying an activation function like the softmax.