mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
enable eager mode with ortmodule (#8961)
* initial change for eager/ortmodule integration * pdate to latest pytorch api * add test model;fix torch version issue * fix comments in pr * fix python test break * fix api change * fix comments in PR * pass device into the fw function
This commit is contained in:
parent
29d6573f3d
commit
8eb6546e8e
9 changed files with 158 additions and 29 deletions
|
|
@ -11,13 +11,15 @@ from onnxruntime.capi import _pybind_state as C
|
|||
|
||||
|
||||
def get_ort_device_type(device):
|
||||
device = device.lower()
|
||||
if device == 'cuda':
|
||||
device_type = device if type(device) is str else device.type.lower()
|
||||
if device_type == 'cuda':
|
||||
return C.OrtDevice.cuda()
|
||||
elif device == 'cpu':
|
||||
elif device_type == 'cpu':
|
||||
return C.OrtDevice.cpu()
|
||||
elif device_type == 'ort':
|
||||
return C.get_ort_device(device.index).device_type()
|
||||
else:
|
||||
raise Exception('Unsupported device type: ' + device)
|
||||
raise Exception('Unsupported device type: ' + device_type)
|
||||
|
||||
|
||||
def check_and_normalize_provider_args(providers, provider_options, available_provider_names):
|
||||
|
|
|
|||
|
|
@ -60,6 +60,13 @@ onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const st
|
|||
return onnxruntime::Status::OK();
|
||||
}
|
||||
|
||||
OrtDevice ORTBackendsManager::GetOrtDeviceInfo(size_t torch_device_index){
|
||||
auto lookup = backends_.find(torch_device_index);
|
||||
ORT_ENFORCE(lookup != backends_.end());
|
||||
auto allocator = lookup->second->GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault);
|
||||
return allocator->Info().device;
|
||||
}
|
||||
|
||||
onnxruntime::ORTInvoker& ORTBackendsManager::GetInvoker(const at::Device device) {
|
||||
ORT_LOG_FN(device);
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ public:
|
|||
|
||||
onnxruntime::ORTInvoker& GetInvoker(const at::Device device);
|
||||
|
||||
OrtDevice GetOrtDeviceInfo(size_t torch_device_index);
|
||||
|
||||
private:
|
||||
std::map<at::DeviceIndex, std::unique_ptr<onnxruntime::ORTInvoker>> backends_;
|
||||
const onnxruntime::logging::Logger& logger_;
|
||||
|
|
|
|||
|
|
@ -59,6 +59,9 @@ void addObjectMethodsForEager(py::module& m){
|
|||
if (!status.IsOK())
|
||||
throw std::runtime_error(status.ErrorMessage());
|
||||
});
|
||||
m.def("get_ort_device", [](size_t torch_device_index){
|
||||
return torch_ort::eager::GetORTBackendsManager().GetOrtDeviceInfo(torch_device_index);
|
||||
});
|
||||
|
||||
auto customop_module = m.def_submodule("custom_ops");
|
||||
torch_ort::eager::GenerateCustomOpsBindings(customop_module);
|
||||
|
|
|
|||
110
orttraining/orttraining/eager/test_model/mnist_fc_training.py
Normal file
110
orttraining/orttraining/eager/test_model/mnist_fc_training.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
## with modification to do training using onnxruntime as backend on cuda device.
|
||||
## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo.
|
||||
|
||||
## Model testing is not complete.
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import torch
|
||||
from onnxruntime.training import ORTModule
|
||||
from onnxruntime.capi import _pybind_state as torch_ort_eager
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNet, self).__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
def my_loss(x, target):
|
||||
return F.nll_loss(F.log_softmax(x, dim=1), target)
|
||||
|
||||
def train_with_eager(args, model, optimizer, device, train_loader, epoch):
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data_cpu = data.reshape(data.shape[0], -1)
|
||||
data = data_cpu.to(device)
|
||||
|
||||
x = model(data)
|
||||
loss = my_loss(x.cpu(), target)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss.
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data_cpu), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss))
|
||||
|
||||
def main():
|
||||
#Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
kwargs = {'num_workers': 0, 'pin_memory': True}
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('./data', train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])),
|
||||
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('./data', train=False, transform=transforms.Compose([
|
||||
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.test_batch_size, shuffle=True, **kwargs)
|
||||
|
||||
# set device
|
||||
torch_ort_eager.set_device(0, 'CPUExecutionProvider', {})
|
||||
|
||||
device = torch.device('ort', index=0)
|
||||
input_size = 784
|
||||
hidden_size = 500
|
||||
num_classes = 10
|
||||
model = NeuralNet(input_size, hidden_size, num_classes)
|
||||
model.to(device)
|
||||
model = ORTModule(model)
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
print('\nStart Training.')
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train_with_eager(args, model, optimizer, device, train_loader, epoch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -48,20 +48,20 @@ class GradientAccumulationManager(object):
|
|||
"""
|
||||
return self._enabled
|
||||
|
||||
def extract_outputs_and_maybe_update_cache(self, forward_outputs):
|
||||
def extract_outputs_and_maybe_update_cache(self, forward_outputs, device):
|
||||
"""Extract the user outputs from the forward outputs as torch tensor and update cache, if needed
|
||||
|
||||
Args:
|
||||
forward_outputs (OrtValueVector): List of outputs returned by forward function
|
||||
"""
|
||||
if not self.enabled:
|
||||
return tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs)
|
||||
return tuple(_utils._ortvalue_to_torch_tensor(forward_output, device) for forward_output in forward_outputs)
|
||||
if self._update_cache:
|
||||
for i in range(self._cache_start, len(forward_outputs)):
|
||||
self.cache.insert(
|
||||
self._cached_node_arg_names[i-self._cache_start], forward_outputs[i])
|
||||
self._update_cache = False
|
||||
return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i]) for i in range(self._cache_start))
|
||||
return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i], device) for i in range(self._cache_start))
|
||||
|
||||
def maybe_update_cache_before_run(self):
|
||||
"""Update cache when model parameters are modified and optimization is enabled.
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class InferenceManager(GraphExecutionManager):
|
|||
ort_output = execution_session.run_forward(io_binding, run_options)
|
||||
forward_outputs, run_id = ort_output.ortvalues, ort_output.run_id
|
||||
user_outputs = tuple(_utils._ortvalue_to_torch_tensor(
|
||||
forward_output._ortvalue) for forward_output in forward_outputs)
|
||||
forward_output._ortvalue, device) for forward_output in forward_outputs)
|
||||
state = None
|
||||
|
||||
output_info = [(output.shape, output.device, output.dtype)
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_typ
|
|||
|
||||
import torch
|
||||
import warnings
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
|
||||
|
||||
class TrainingManager(GraphExecutionManager):
|
||||
"""Concrete instance of GraphExecutionManager that is able to manage the training model
|
||||
|
|
@ -28,7 +26,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
self._export_mode = torch.onnx.TrainingMode.TRAINING
|
||||
|
||||
@staticmethod
|
||||
def execution_session_run_forward(execution_session, onnx_model, gradient_accumulation_manager, *inputs):
|
||||
def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs):
|
||||
"""Runs the forward graph on execution_session with given model inputs and device"""
|
||||
|
||||
# TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
|
||||
|
|
@ -39,12 +37,13 @@ class TrainingManager(GraphExecutionManager):
|
|||
forward_inputs = C.OrtValueVector()
|
||||
forward_inputs.reserve(len(inputs))
|
||||
for input in inputs:
|
||||
forward_inputs.push_back(to_dlpack(input), input.dtype == torch.bool)
|
||||
forward_inputs.push_back(_utils._torch_tensor_to_dlpack(input), input.dtype == torch.bool)
|
||||
|
||||
forward_outputs = C.OrtValueVector()
|
||||
# Run and return module outputs.
|
||||
execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache)
|
||||
user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(forward_outputs)
|
||||
|
||||
user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(forward_outputs, device)
|
||||
|
||||
output_info = [(output.shape, output.device, output.dtype) for output in user_outputs]
|
||||
run_info = _RunStateInfo(state, output_info)
|
||||
|
|
@ -145,6 +144,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(self._execution_agent,
|
||||
self._onnx_models.optimized_model,
|
||||
self._device,
|
||||
self._gradient_accumulation_manager,
|
||||
*inputs)
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
grad_output = torch.tensor(0., device=device, dtype=dtype)
|
||||
elif not grad_output.is_contiguous():
|
||||
grad_output = grad_output.contiguous()
|
||||
backward_inputs.push_back(to_dlpack(grad_output), grad_output.dtype == torch.bool)
|
||||
backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output), grad_output.dtype == torch.bool)
|
||||
backward_inputs.shrink_to_fit()
|
||||
|
||||
# Run and get results
|
||||
|
|
@ -223,7 +223,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
if input_name in require_grad_names_set:
|
||||
results.append(_utils._torch_tensor_from_dl_pack(
|
||||
backward_outputs.dlpack_at(require_grad_names_index),
|
||||
backward_outputs[require_grad_names_index]))
|
||||
backward_outputs[require_grad_names_index], self._device))
|
||||
require_grad_names_index += 1
|
||||
else:
|
||||
# input_name is not found in the self._input_info.require_grad_names list
|
||||
|
|
@ -237,7 +237,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
if initializer_name in self._graph_initializer_names_to_train:
|
||||
results.append(_utils._torch_tensor_from_dl_pack(
|
||||
backward_outputs.dlpack_at(initializer_index),
|
||||
backward_outputs[initializer_index]))
|
||||
backward_outputs[initializer_index], self._device))
|
||||
initializer_index += 1
|
||||
else:
|
||||
results.append(None)
|
||||
|
|
@ -285,7 +285,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
session_options, providers, provider_options = self._get_session_config()
|
||||
fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input]
|
||||
fw_outputs_device_info = [
|
||||
C.OrtDevice(get_ort_device_type(self._device.type),
|
||||
C.OrtDevice(get_ort_device_type(self._device),
|
||||
C.OrtDevice.default_memory(),
|
||||
_utils.get_device_index(self._device)
|
||||
)] * (len(self._graph_info.user_output_names) +
|
||||
|
|
@ -293,7 +293,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output]
|
||||
bw_outputs_device_info = [
|
||||
C.OrtDevice(get_ort_device_type(self._device.type),
|
||||
C.OrtDevice(get_ort_device_type(self._device),
|
||||
C.OrtDevice.default_memory(),
|
||||
_utils.get_device_index(self._device)
|
||||
)] * len(bw_fetches_names)
|
||||
|
|
|
|||
|
|
@ -16,24 +16,29 @@ from typing import List
|
|||
import types
|
||||
import warnings
|
||||
|
||||
|
||||
def _ortvalue_to_torch_tensor(ortvalue):
|
||||
# PyTorch's to_dlpack() uses same config for both torch.bool and torch.uint8,
|
||||
# and convert the config to torch.uint8 tensor duing from_dlpack().
|
||||
# So we need to convert the torch tensor to torch.bool type if OrtValue is bool tensor.
|
||||
torch_tensor = from_dlpack(ortvalue.to_dlpack())
|
||||
return torch_tensor.to(torch.bool) if ortvalue.data_type() == 'tensor(bool)' else torch_tensor
|
||||
|
||||
|
||||
def _ortvalue_from_torch_tensor(torch_tensor):
|
||||
return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), torch_tensor.dtype == torch.bool)
|
||||
|
||||
|
||||
def _torch_tensor_from_dl_pack(dlpack, ortvalue):
|
||||
torch_tensor = from_dlpack(dlpack)
|
||||
def _torch_tensor_from_dl_pack(dlpack, ortvalue, device):
|
||||
torch_tensor = from_dlpack(dlpack) if device.type != 'ort' else C.ort_from_dlpack(dlpack)
|
||||
return torch_tensor.to(torch.bool) if ortvalue.data_type() == 'tensor(bool)' else torch_tensor
|
||||
|
||||
|
||||
def _ortvalue_to_torch_tensor(ortvalue, device):
|
||||
# PyTorch's to_dlpack() uses same config for both torch.bool and torch.uint8,
|
||||
# and convert the config to torch.uint8 tensor duing from_dlpack().
|
||||
# So we need to convert the torch tensor to torch.bool type if OrtValue is bool tensor.
|
||||
dlpack_tensor = ortvalue.to_dlpack()
|
||||
return _torch_tensor_from_dl_pack(dlpack_tensor, ortvalue, device)
|
||||
|
||||
def _torch_tensor_to_dlpack(tensor):
|
||||
if tensor.device.type == 'ort':
|
||||
return C.ort_to_dlpack(tensor)
|
||||
else:
|
||||
return to_dlpack(tensor)
|
||||
|
||||
|
||||
def _check_same_device(device, argument_str, *args):
|
||||
'''Check that all tensor arguments in *args reside on the same device as the input device'''
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue