mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
44 lines
1.9 KiB
Python
44 lines
1.9 KiB
Python
|
|
import torch
|
||
|
|
|
||
|
|
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
|
||
|
|
ModelDescription as Legacy_ModelDescription
|
||
|
|
|
||
|
|
|
||
|
|
def my_loss(x, target):
|
||
|
|
x = x.view(-1, 28785)
|
||
|
|
return torch.nn.CrossEntropyLoss()(x, target)
|
||
|
|
|
||
|
|
|
||
|
|
def transformer_model_description(bptt=35, batch_size=20, ntokens=28785):
|
||
|
|
model_desc = {'inputs': [('input1', [bptt, batch_size]),
|
||
|
|
('label', [bptt * batch_size])],
|
||
|
|
'outputs': [('loss', [], True),
|
||
|
|
('predictions', [bptt, batch_size, ntokens])]}
|
||
|
|
return model_desc
|
||
|
|
|
||
|
|
|
||
|
|
def transformer_model_description_dynamic_axes(ntokens=28785):
|
||
|
|
model_desc = {'inputs': [('input1', ['bptt', 'batch_size']),
|
||
|
|
('label', ['bptt_x_batch_size'])],
|
||
|
|
'outputs': [('loss', [], True),
|
||
|
|
('predictions', ['bptt', 'batch_size', ntokens])]}
|
||
|
|
return model_desc
|
||
|
|
|
||
|
|
|
||
|
|
def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785):
|
||
|
|
input_desc = Legacy_IODescription('input1', [bptt, batch_size])
|
||
|
|
label_desc = Legacy_IODescription('label', [bptt * batch_size])
|
||
|
|
loss_desc = Legacy_IODescription('loss', [])
|
||
|
|
predictions_desc = Legacy_IODescription('predictions', [bptt, batch_size, ntokens])
|
||
|
|
return Legacy_ModelDescription([input_desc, label_desc],[loss_desc, predictions_desc]),\
|
||
|
|
Legacy_IODescription('__learning_rate', [1])
|
||
|
|
|
||
|
|
|
||
|
|
def legacy_transformer_model_description_dynamic_axes(ntokens=28785):
|
||
|
|
input_desc = Legacy_IODescription('input1', ['bptt', 'batch_size'])
|
||
|
|
label_desc = Legacy_IODescription('label', ['bptt_x_batch_size'])
|
||
|
|
loss_desc = Legacy_IODescription('loss', [])
|
||
|
|
predictions_desc = Legacy_IODescription('predictions', ['bptt', 'batch_size', ntokens])
|
||
|
|
return Legacy_ModelDescription([input_desc, label_desc],[loss_desc, predictions_desc]),\
|
||
|
|
Legacy_IODescription('__learning_rate', [1])
|