enable eager mode with ortmodule (#8961)

* initial change for eager/ortmodule integration

* pdate to latest pytorch api

* add test model;fix torch version issue

* fix comments in pr

* fix python test break

* fix api change

* fix comments in PR

* pass device into the fw function
This commit is contained in:
Tang, Cheng 2021-09-10 15:09:23 -07:00 committed by GitHub
parent 29d6573f3d
commit 8eb6546e8e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 158 additions and 29 deletions

View file

@ -11,13 +11,15 @@ from onnxruntime.capi import _pybind_state as C
def get_ort_device_type(device):
device = device.lower()
if device == 'cuda':
device_type = device if type(device) is str else device.type.lower()
if device_type == 'cuda':
return C.OrtDevice.cuda()
elif device == 'cpu':
elif device_type == 'cpu':
return C.OrtDevice.cpu()
elif device_type == 'ort':
return C.get_ort_device(device.index).device_type()
else:
raise Exception('Unsupported device type: ' + device)
raise Exception('Unsupported device type: ' + device_type)
def check_and_normalize_provider_args(providers, provider_options, available_provider_names):

View file

@ -60,6 +60,13 @@ onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const st
return onnxruntime::Status::OK();
}
OrtDevice ORTBackendsManager::GetOrtDeviceInfo(size_t torch_device_index){
auto lookup = backends_.find(torch_device_index);
ORT_ENFORCE(lookup != backends_.end());
auto allocator = lookup->second->GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault);
return allocator->Info().device;
}
onnxruntime::ORTInvoker& ORTBackendsManager::GetInvoker(const at::Device device) {
ORT_LOG_FN(device);

View file

@ -24,6 +24,8 @@ public:
onnxruntime::ORTInvoker& GetInvoker(const at::Device device);
OrtDevice GetOrtDeviceInfo(size_t torch_device_index);
private:
std::map<at::DeviceIndex, std::unique_ptr<onnxruntime::ORTInvoker>> backends_;
const onnxruntime::logging::Logger& logger_;

View file

@ -59,6 +59,9 @@ void addObjectMethodsForEager(py::module& m){
if (!status.IsOK())
throw std::runtime_error(status.ErrorMessage());
});
m.def("get_ort_device", [](size_t torch_device_index){
return torch_ort::eager::GetORTBackendsManager().GetOrtDeviceInfo(torch_device_index);
});
auto customop_module = m.def_submodule("custom_ops");
torch_ort::eager::GenerateCustomOpsBindings(customop_module);

View file

@ -0,0 +1,110 @@
## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py
## with modification to do training using onnxruntime as backend on cuda device.
## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo.
## Model testing is not complete.
from __future__ import print_function
import argparse
import torch
from onnxruntime.training import ORTModule
from onnxruntime.capi import _pybind_state as torch_ort_eager
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import os
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
def my_loss(x, target):
return F.nll_loss(F.log_softmax(x, dim=1), target)
def train_with_eager(args, model, optimizer, device, train_loader, epoch):
for batch_idx, (data, target) in enumerate(train_loader):
data_cpu = data.reshape(data.shape[0], -1)
data = data_cpu.to(device)
x = model(data)
loss = my_loss(x.cpu(), target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss.
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data_cpu), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss))
def main():
#Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
kwargs = {'num_workers': 0, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
# set device
torch_ort_eager.set_device(0, 'CPUExecutionProvider', {})
device = torch.device('ort', index=0)
input_size = 784
hidden_size = 500
num_classes = 10
model = NeuralNet(input_size, hidden_size, num_classes)
model.to(device)
model = ORTModule(model)
optimizer = optim.SGD(model.parameters(), lr=0.01)
print('\nStart Training.')
for epoch in range(1, args.epochs + 1):
train_with_eager(args, model, optimizer, device, train_loader, epoch)
if __name__ == '__main__':
main()

View file

@ -48,20 +48,20 @@ class GradientAccumulationManager(object):
"""
return self._enabled
def extract_outputs_and_maybe_update_cache(self, forward_outputs):
def extract_outputs_and_maybe_update_cache(self, forward_outputs, device):
"""Extract the user outputs from the forward outputs as torch tensor and update cache, if needed
Args:
forward_outputs (OrtValueVector): List of outputs returned by forward function
"""
if not self.enabled:
return tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs)
return tuple(_utils._ortvalue_to_torch_tensor(forward_output, device) for forward_output in forward_outputs)
if self._update_cache:
for i in range(self._cache_start, len(forward_outputs)):
self.cache.insert(
self._cached_node_arg_names[i-self._cache_start], forward_outputs[i])
self._update_cache = False
return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i]) for i in range(self._cache_start))
return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i], device) for i in range(self._cache_start))
def maybe_update_cache_before_run(self):
"""Update cache when model parameters are modified and optimization is enabled.

View file

@ -43,7 +43,7 @@ class InferenceManager(GraphExecutionManager):
ort_output = execution_session.run_forward(io_binding, run_options)
forward_outputs, run_id = ort_output.ortvalues, ort_output.run_id
user_outputs = tuple(_utils._ortvalue_to_torch_tensor(
forward_output._ortvalue) for forward_output in forward_outputs)
forward_output._ortvalue, device) for forward_output in forward_outputs)
state = None
output_info = [(output.shape, output.device, output.dtype)

View file

@ -14,8 +14,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_typ
import torch
import warnings
from torch.utils.dlpack import from_dlpack, to_dlpack
class TrainingManager(GraphExecutionManager):
"""Concrete instance of GraphExecutionManager that is able to manage the training model
@ -28,7 +26,7 @@ class TrainingManager(GraphExecutionManager):
self._export_mode = torch.onnx.TrainingMode.TRAINING
@staticmethod
def execution_session_run_forward(execution_session, onnx_model, gradient_accumulation_manager, *inputs):
def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs):
"""Runs the forward graph on execution_session with given model inputs and device"""
# TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
@ -39,12 +37,13 @@ class TrainingManager(GraphExecutionManager):
forward_inputs = C.OrtValueVector()
forward_inputs.reserve(len(inputs))
for input in inputs:
forward_inputs.push_back(to_dlpack(input), input.dtype == torch.bool)
forward_inputs.push_back(_utils._torch_tensor_to_dlpack(input), input.dtype == torch.bool)
forward_outputs = C.OrtValueVector()
# Run and return module outputs.
execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache)
user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(forward_outputs)
user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(forward_outputs, device)
output_info = [(output.shape, output.device, output.dtype) for output in user_outputs]
run_info = _RunStateInfo(state, output_info)
@ -145,6 +144,7 @@ class TrainingManager(GraphExecutionManager):
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(self._execution_agent,
self._onnx_models.optimized_model,
self._device,
self._gradient_accumulation_manager,
*inputs)
@ -204,7 +204,7 @@ class TrainingManager(GraphExecutionManager):
grad_output = torch.tensor(0., device=device, dtype=dtype)
elif not grad_output.is_contiguous():
grad_output = grad_output.contiguous()
backward_inputs.push_back(to_dlpack(grad_output), grad_output.dtype == torch.bool)
backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output), grad_output.dtype == torch.bool)
backward_inputs.shrink_to_fit()
# Run and get results
@ -223,7 +223,7 @@ class TrainingManager(GraphExecutionManager):
if input_name in require_grad_names_set:
results.append(_utils._torch_tensor_from_dl_pack(
backward_outputs.dlpack_at(require_grad_names_index),
backward_outputs[require_grad_names_index]))
backward_outputs[require_grad_names_index], self._device))
require_grad_names_index += 1
else:
# input_name is not found in the self._input_info.require_grad_names list
@ -237,7 +237,7 @@ class TrainingManager(GraphExecutionManager):
if initializer_name in self._graph_initializer_names_to_train:
results.append(_utils._torch_tensor_from_dl_pack(
backward_outputs.dlpack_at(initializer_index),
backward_outputs[initializer_index]))
backward_outputs[initializer_index], self._device))
initializer_index += 1
else:
results.append(None)
@ -285,7 +285,7 @@ class TrainingManager(GraphExecutionManager):
session_options, providers, provider_options = self._get_session_config()
fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input]
fw_outputs_device_info = [
C.OrtDevice(get_ort_device_type(self._device.type),
C.OrtDevice(get_ort_device_type(self._device),
C.OrtDevice.default_memory(),
_utils.get_device_index(self._device)
)] * (len(self._graph_info.user_output_names) +
@ -293,7 +293,7 @@ class TrainingManager(GraphExecutionManager):
bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output]
bw_outputs_device_info = [
C.OrtDevice(get_ort_device_type(self._device.type),
C.OrtDevice(get_ort_device_type(self._device),
C.OrtDevice.default_memory(),
_utils.get_device_index(self._device)
)] * len(bw_fetches_names)

View file

@ -16,24 +16,29 @@ from typing import List
import types
import warnings
def _ortvalue_to_torch_tensor(ortvalue):
# PyTorch's to_dlpack() uses same config for both torch.bool and torch.uint8,
# and convert the config to torch.uint8 tensor duing from_dlpack().
# So we need to convert the torch tensor to torch.bool type if OrtValue is bool tensor.
torch_tensor = from_dlpack(ortvalue.to_dlpack())
return torch_tensor.to(torch.bool) if ortvalue.data_type() == 'tensor(bool)' else torch_tensor
def _ortvalue_from_torch_tensor(torch_tensor):
return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), torch_tensor.dtype == torch.bool)
def _torch_tensor_from_dl_pack(dlpack, ortvalue):
torch_tensor = from_dlpack(dlpack)
def _torch_tensor_from_dl_pack(dlpack, ortvalue, device):
torch_tensor = from_dlpack(dlpack) if device.type != 'ort' else C.ort_from_dlpack(dlpack)
return torch_tensor.to(torch.bool) if ortvalue.data_type() == 'tensor(bool)' else torch_tensor
def _ortvalue_to_torch_tensor(ortvalue, device):
# PyTorch's to_dlpack() uses same config for both torch.bool and torch.uint8,
# and convert the config to torch.uint8 tensor duing from_dlpack().
# So we need to convert the torch tensor to torch.bool type if OrtValue is bool tensor.
dlpack_tensor = ortvalue.to_dlpack()
return _torch_tensor_from_dl_pack(dlpack_tensor, ortvalue, device)
def _torch_tensor_to_dlpack(tensor):
if tensor.device.type == 'ort':
return C.ort_to_dlpack(tensor)
else:
return to_dlpack(tensor)
def _check_same_device(device, argument_str, *args):
'''Check that all tensor arguments in *args reside on the same device as the input device'''