mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Change DropouGrad.input[1].input_type and del logits_grad from backward graph
This commit is contained in:
parent
f1dc6e4007
commit
ea5871ac15
2 changed files with 18 additions and 12 deletions
|
|
@ -575,7 +575,7 @@ class ORTModule(torch.nn.Module):
|
|||
for output in node.output:
|
||||
backward_graph_outputs.add(output)
|
||||
else:
|
||||
# nodes belogs to forward graph
|
||||
# nodes belongs to forward graph
|
||||
for input in node.input:
|
||||
if input in initializers:
|
||||
forward_graph_initializer_names.add(input)
|
||||
|
|
@ -589,7 +589,7 @@ class ORTModule(torch.nn.Module):
|
|||
if initializer_name in weight_names_to_train:
|
||||
add_input_from_initializer(forward_model, initializers[initializer_name])
|
||||
|
||||
# outputs from forward graph that are also inputs of backwoard graph need to be added as graph output.
|
||||
# outputs from forward graph that are also inputs of backward graph need to be added as graph output.
|
||||
for output in forward_graph_outputs:
|
||||
if output in backward_graph_inputs:
|
||||
add_output(forward_model, output)
|
||||
|
|
@ -607,10 +607,15 @@ class ORTModule(torch.nn.Module):
|
|||
for initializer in backward_model.graph.initializer:
|
||||
initializers[initializer.name] = initializer
|
||||
|
||||
# Remove nodes from forward graph
|
||||
nodes_to_remove_from_backward_graph = []
|
||||
dropoutgrad_input1_set = set()
|
||||
for node in backward_model.graph.node:
|
||||
if node.doc_string != 'Backward pass':
|
||||
nodes_to_remove_from_backward_graph.append(node)
|
||||
# TODO: thiagofc: BERT: Remove this once graph splitter can handle unspecified optional input (without type)
|
||||
if node.op_type == 'DropoutGrad':
|
||||
dropoutgrad_input1_set.add(node.input[1])
|
||||
|
||||
backward_graph_initializer_names = set()
|
||||
for input in backward_graph_inputs:
|
||||
|
|
@ -618,8 +623,7 @@ class ORTModule(torch.nn.Module):
|
|||
# inputs of backward graph that are also outputs from forward graph need to be added to backward graph input
|
||||
# TODO: thiagofc: BERT: Remove this once graph splitter can handle unspecified optional input (without type)
|
||||
input_type = tensor_elem_types[input] if input in tensor_elem_types else 1
|
||||
if input in {'1835', '1813', '1781','1760', '1683','1651','1630','1553','1521','1500','1423','1391','1370','1293','1261','1240','1163','1131','1110','1033','1001','980','871',
|
||||
'267','330','351','383','460','481','513','590','611','643','720','741','773','850','903'}:
|
||||
if input in dropoutgrad_input1_set:
|
||||
input_type = 9
|
||||
add_input(backward_model, input, input_type)
|
||||
elif input in forward_graph_initializer_names:
|
||||
|
|
@ -630,7 +634,7 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# gradient of forward graph output will be the input of backward graph
|
||||
for output in backward_model.graph.output:
|
||||
if output.name + '_grad' in backward_graph_inputs:
|
||||
if output.name + '_grad' in backward_graph_inputs and output.name != '1835': # TODO: 1835_grad is grad of logits and must be computed by ONNX's loss grad subgraph
|
||||
add_input(backward_model, output.name + '_grad', output.type.tensor_type.elem_type)
|
||||
|
||||
backward_model.graph.ClearField('initializer')
|
||||
|
|
|
|||
|
|
@ -82,11 +82,8 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
# have provided the `labels`.
|
||||
# The documentation for this `model` function is here:
|
||||
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
|
||||
outputs = model(b_input_ids,
|
||||
token_type_ids=None,
|
||||
attention_mask=b_input_mask,
|
||||
labels=b_labels)
|
||||
|
||||
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
|
||||
outputs = model(b_input_ids, b_input_mask, None, None, None, None, b_labels)
|
||||
if args.view_graphs:
|
||||
import torchviz
|
||||
pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters())))
|
||||
|
|
@ -163,9 +160,14 @@ def test(model, validation_dataloader, device):
|
|||
# differentiates sentence 1 and 2 in 2-sentence tasks.
|
||||
# The documentation for this `model` function is here:
|
||||
# https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
|
||||
# TODO: explicitly setting (optional) inputs to workaround *input, **kwargs limitation on ORTModule
|
||||
outputs = model(b_input_ids,
|
||||
token_type_ids=None,
|
||||
attention_mask=b_input_mask)
|
||||
b_input_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None)
|
||||
|
||||
# Get the "logits" output by the model. The "logits" are the output
|
||||
# values prior to applying an activation function like the softmax.
|
||||
|
|
|
|||
Loading…
Reference in a new issue