mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Update PyTorch TransformerModel sample (#5275)
This commit is contained in:
parent
71b52ad5de
commit
ebeeff22dd
3 changed files with 191 additions and 2 deletions
|
|
@ -10,12 +10,24 @@ This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.T
|
|||
|
||||
## Running PyTorch version
|
||||
|
||||
```python
|
||||
```bash
|
||||
python pt_model.py
|
||||
```
|
||||
|
||||
## Running ONNX Runtime version
|
||||
|
||||
```python
|
||||
```bash
|
||||
python ort_model.py
|
||||
```
|
||||
|
||||
## Optional arguments
|
||||
|
||||
| Argument | Description | Default |
|
||||
| :---------------- | :-----------------------------------------------------: | --------: |
|
||||
| --batch-size | input batch size for training | 20 |
|
||||
| --test-batch-size | input batch size for testing | 20 |
|
||||
| --epochs | number of epochs to train | 2 |
|
||||
| --lr | learning rate | 0.001 |
|
||||
| --no-cuda | disables CUDA training | False |
|
||||
| --seed | random seed | 1 |
|
||||
| --log-interval | how many batches to wait before logging training status | 200 |
|
||||
|
|
|
|||
85
samples/python/pytorch_transformer/ort_train.py
Normal file
85
samples/python/pytorch_transformer/ort_train.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
import argparse
|
||||
import math
|
||||
import torch
|
||||
import onnxruntime
|
||||
|
||||
from utils import prepare_data, get_batch
|
||||
from ort_utils import my_loss, transformer_model_description_dynamic_axes
|
||||
from pt_model import TransformerModel
|
||||
|
||||
|
||||
def train(trainer, data_source, device, epoch, args, bptt=35):
|
||||
total_loss = 0.
|
||||
for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)):
|
||||
data, targets = get_batch(data_source, i)
|
||||
|
||||
loss, pred = trainer.train_step(data, targets)
|
||||
total_loss += loss.item()
|
||||
if batch % args.log_interval == 0 and batch > 0:
|
||||
cur_loss = total_loss / args.log_interval
|
||||
print('epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}'.format(epoch,
|
||||
batch,
|
||||
len(data_source) // bptt,
|
||||
cur_loss))
|
||||
total_loss = 0
|
||||
|
||||
|
||||
def evaluate(trainer, data_source, bptt=35):
|
||||
total_loss = 0.
|
||||
with torch.no_grad():
|
||||
for i in range(0, data_source.size(0) - 1, bptt):
|
||||
data, targets = get_batch(data_source, i)
|
||||
loss, pred = trainer.eval_step(data, targets)
|
||||
total_loss += len(data) * loss.item()
|
||||
return total_loss / (len(data_source) - 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch TransformerModel example')
|
||||
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
|
||||
help='input batch size for training (default: 20)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=20, metavar='N',
|
||||
help='input batch size for testing (default: 20)')
|
||||
parser.add_argument('--epochs', type=int, default=2, metavar='N',
|
||||
help='number of epochs to train (default: 2)')
|
||||
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
|
||||
help='learning rate (default: 0.001)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
|
||||
help='how many batches to wait before logging training status (default: 200)')
|
||||
|
||||
# Basic setup
|
||||
args = parser.parse_args()
|
||||
if not args.no_cuda and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
torch.manual_seed(args.seed)
|
||||
onnxruntime.set_seed(args.seed)
|
||||
|
||||
# Model
|
||||
optim_config = onnxruntime.training.optim.SGDConfig(lr=args.lr)
|
||||
model_desc = transformer_model_description_dynamic_axes()
|
||||
model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
|
||||
|
||||
# Preparing data
|
||||
train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size)
|
||||
trainer = onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss)
|
||||
|
||||
# Train
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(trainer, train_data, device, epoch, args)
|
||||
val_loss = evaluate(trainer, val_data)
|
||||
print('-' * 89)
|
||||
print('| end of epoch {:3d} | valid loss {:5.2f} | '.format(epoch, val_loss))
|
||||
print('-' * 89)
|
||||
|
||||
# Evaluate
|
||||
test_loss = evaluate(trainer, test_data)
|
||||
print('=' * 89)
|
||||
print('| End of training | test loss {:5.2f}'.format(test_loss))
|
||||
print('=' * 89)
|
||||
92
samples/python/pytorch_transformer/pt_train.py
Normal file
92
samples/python/pytorch_transformer/pt_train.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import argparse
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils import prepare_data, get_batch
|
||||
from pt_model import TransformerModel
|
||||
|
||||
|
||||
def train(model, data_source, device, epoch, args, bptt=35):
|
||||
total_loss = 0.
|
||||
model.train()
|
||||
for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)):
|
||||
data, targets = get_batch(data_source, i)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output.view(-1, 28785), targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
if batch % args.log_interval == 0 and batch > 0:
|
||||
cur_loss = total_loss / args.log_interval
|
||||
print('epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}'.format(epoch,
|
||||
batch,
|
||||
len(data_source) // bptt,
|
||||
cur_loss))
|
||||
total_loss = 0
|
||||
|
||||
|
||||
def evaluate(model, data_source, criterion, bptt=35):
|
||||
total_loss = 0.
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, data_source.size(0) - 1, bptt):
|
||||
data, targets = get_batch(data_source, i)
|
||||
output = model(data)
|
||||
output_flat = output.view(-1, 28785)
|
||||
total_loss += len(data) * criterion(output_flat, targets).item()
|
||||
return total_loss / (len(data_source) - 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch TransformerModel example')
|
||||
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
|
||||
help='input batch size for training (default: 20)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=20, metavar='N',
|
||||
help='input batch size for testing (default: 20)')
|
||||
parser.add_argument('--epochs', type=int, default=2, metavar='N',
|
||||
help='number of epochs to train (default: 2)')
|
||||
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
|
||||
help='learning rate (default: 0.001)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
|
||||
help='how many batches to wait before logging training status (default: 200)')
|
||||
|
||||
# Basic setup
|
||||
args = parser.parse_args()
|
||||
if not args.no_cuda and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Model
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
lr = 0.001
|
||||
model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
|
||||
# Preparing data
|
||||
train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size)
|
||||
|
||||
# Train
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(model, train_data, device, epoch, args)
|
||||
val_loss = evaluate(model, val_data, criterion)
|
||||
print('-' * 89)
|
||||
print('| end of epoch {:3d} | valid loss {:5.2f} | '.format(epoch, val_loss))
|
||||
print('-' * 89)
|
||||
|
||||
# Evaluate
|
||||
test_loss = evaluate(model, test_data, criterion)
|
||||
print('=' * 89)
|
||||
print('| End of training | test loss {:5.2f}'.format(test_loss))
|
||||
print('=' * 89)
|
||||
Loading…
Reference in a new issue