diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 0412928772..4addfc8571 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -11,13 +11,15 @@ from onnxruntime.capi import _pybind_state as C def get_ort_device_type(device): - device = device.lower() - if device == 'cuda': + device_type = device if type(device) is str else device.type.lower() + if device_type == 'cuda': return C.OrtDevice.cuda() - elif device == 'cpu': + elif device_type == 'cpu': return C.OrtDevice.cpu() + elif device_type == 'ort': + return C.get_ort_device(device.index).device_type() else: - raise Exception('Unsupported device type: ' + device) + raise Exception('Unsupported device type: ' + device_type) def check_and_normalize_provider_args(providers, provider_options, available_provider_names): diff --git a/orttraining/orttraining/eager/ort_backends.cpp b/orttraining/orttraining/eager/ort_backends.cpp index fc32f1d06a..bd03e535d3 100644 --- a/orttraining/orttraining/eager/ort_backends.cpp +++ b/orttraining/orttraining/eager/ort_backends.cpp @@ -60,6 +60,13 @@ onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const st return onnxruntime::Status::OK(); } +OrtDevice ORTBackendsManager::GetOrtDeviceInfo(size_t torch_device_index){ + auto lookup = backends_.find(torch_device_index); + ORT_ENFORCE(lookup != backends_.end()); + auto allocator = lookup->second->GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault); + return allocator->Info().device; +} + onnxruntime::ORTInvoker& ORTBackendsManager::GetInvoker(const at::Device device) { ORT_LOG_FN(device); diff --git a/orttraining/orttraining/eager/ort_backends.h b/orttraining/orttraining/eager/ort_backends.h index a5090b1373..4dd64dfaf6 100644 --- a/orttraining/orttraining/eager/ort_backends.h +++ b/orttraining/orttraining/eager/ort_backends.h @@ -24,6 +24,8 @@ public: onnxruntime::ORTInvoker& GetInvoker(const at::Device device); + OrtDevice GetOrtDeviceInfo(size_t torch_device_index); + private: std::map> backends_; const onnxruntime::logging::Logger& logger_; diff --git a/orttraining/orttraining/eager/ort_eager.cpp b/orttraining/orttraining/eager/ort_eager.cpp index 556e8751a8..ad9d6d0780 100644 --- a/orttraining/orttraining/eager/ort_eager.cpp +++ b/orttraining/orttraining/eager/ort_eager.cpp @@ -59,6 +59,9 @@ void addObjectMethodsForEager(py::module& m){ if (!status.IsOK()) throw std::runtime_error(status.ErrorMessage()); }); + m.def("get_ort_device", [](size_t torch_device_index){ + return torch_ort::eager::GetORTBackendsManager().GetOrtDeviceInfo(torch_device_index); + }); auto customop_module = m.def_submodule("custom_ops"); torch_ort::eager::GenerateCustomOpsBindings(customop_module); diff --git a/orttraining/orttraining/eager/test_model/mnist_fc_training.py b/orttraining/orttraining/eager/test_model/mnist_fc_training.py new file mode 100644 index 0000000000..32bb94328f --- /dev/null +++ b/orttraining/orttraining/eager/test_model/mnist_fc_training.py @@ -0,0 +1,110 @@ +## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py +## with modification to do training using onnxruntime as backend on cuda device. +## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo. + +## Model testing is not complete. + +from __future__ import print_function +import argparse +import torch +from onnxruntime.training import ORTModule +from onnxruntime.capi import _pybind_state as torch_ort_eager +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import numpy as np +import os + +class NeuralNet(nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(NeuralNet, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, num_classes) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + +def my_loss(x, target): + return F.nll_loss(F.log_softmax(x, dim=1), target) + +def train_with_eager(args, model, optimizer, device, train_loader, epoch): + for batch_idx, (data, target) in enumerate(train_loader): + data_cpu = data.reshape(data.shape[0], -1) + data = data_cpu.to(device) + + x = model(data) + loss = my_loss(x.cpu(), target) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + + # Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss. + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data_cpu), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss)) + +def main(): +#Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + + + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + kwargs = {'num_workers': 0, 'pin_memory': True} + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('./data', train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.batch_size, shuffle=True, **kwargs) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('./data', train=False, transform=transforms.Compose([ + transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), + batch_size=args.test_batch_size, shuffle=True, **kwargs) + + # set device + torch_ort_eager.set_device(0, 'CPUExecutionProvider', {}) + + device = torch.device('ort', index=0) + input_size = 784 + hidden_size = 500 + num_classes = 10 + model = NeuralNet(input_size, hidden_size, num_classes) + model.to(device) + model = ORTModule(model) + optimizer = optim.SGD(model.parameters(), lr=0.01) + + print('\nStart Training.') + + for epoch in range(1, args.epochs + 1): + train_with_eager(args, model, optimizer, device, train_loader, epoch) + + +if __name__ == '__main__': + main() diff --git a/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py index 22fce2fe18..4358379010 100644 --- a/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py @@ -48,20 +48,20 @@ class GradientAccumulationManager(object): """ return self._enabled - def extract_outputs_and_maybe_update_cache(self, forward_outputs): + def extract_outputs_and_maybe_update_cache(self, forward_outputs, device): """Extract the user outputs from the forward outputs as torch tensor and update cache, if needed Args: forward_outputs (OrtValueVector): List of outputs returned by forward function """ if not self.enabled: - return tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs) + return tuple(_utils._ortvalue_to_torch_tensor(forward_output, device) for forward_output in forward_outputs) if self._update_cache: for i in range(self._cache_start, len(forward_outputs)): self.cache.insert( self._cached_node_arg_names[i-self._cache_start], forward_outputs[i]) self._update_cache = False - return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i]) for i in range(self._cache_start)) + return tuple(_utils._ortvalue_to_torch_tensor(forward_outputs[i], device) for i in range(self._cache_start)) def maybe_update_cache_before_run(self): """Update cache when model parameters are modified and optimization is enabled. diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index a81f83422c..8a754891f4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -43,7 +43,7 @@ class InferenceManager(GraphExecutionManager): ort_output = execution_session.run_forward(io_binding, run_options) forward_outputs, run_id = ort_output.ortvalues, ort_output.run_id user_outputs = tuple(_utils._ortvalue_to_torch_tensor( - forward_output._ortvalue) for forward_output in forward_outputs) + forward_output._ortvalue, device) for forward_output in forward_outputs) state = None output_info = [(output.shape, output.device, output.dtype) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 6b1696981d..f2d2726812 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -14,8 +14,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_typ import torch import warnings -from torch.utils.dlpack import from_dlpack, to_dlpack - class TrainingManager(GraphExecutionManager): """Concrete instance of GraphExecutionManager that is able to manage the training model @@ -28,7 +26,7 @@ class TrainingManager(GraphExecutionManager): self._export_mode = torch.onnx.TrainingMode.TRAINING @staticmethod - def execution_session_run_forward(execution_session, onnx_model, gradient_accumulation_manager, *inputs): + def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs): """Runs the forward graph on execution_session with given model inputs and device""" # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, @@ -39,12 +37,13 @@ class TrainingManager(GraphExecutionManager): forward_inputs = C.OrtValueVector() forward_inputs.reserve(len(inputs)) for input in inputs: - forward_inputs.push_back(to_dlpack(input), input.dtype == torch.bool) + forward_inputs.push_back(_utils._torch_tensor_to_dlpack(input), input.dtype == torch.bool) forward_outputs = C.OrtValueVector() # Run and return module outputs. execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache) - user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(forward_outputs) + + user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(forward_outputs, device) output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] run_info = _RunStateInfo(state, output_info) @@ -145,6 +144,7 @@ class TrainingManager(GraphExecutionManager): user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(self._execution_agent, self._onnx_models.optimized_model, + self._device, self._gradient_accumulation_manager, *inputs) @@ -204,7 +204,7 @@ class TrainingManager(GraphExecutionManager): grad_output = torch.tensor(0., device=device, dtype=dtype) elif not grad_output.is_contiguous(): grad_output = grad_output.contiguous() - backward_inputs.push_back(to_dlpack(grad_output), grad_output.dtype == torch.bool) + backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output), grad_output.dtype == torch.bool) backward_inputs.shrink_to_fit() # Run and get results @@ -223,7 +223,7 @@ class TrainingManager(GraphExecutionManager): if input_name in require_grad_names_set: results.append(_utils._torch_tensor_from_dl_pack( backward_outputs.dlpack_at(require_grad_names_index), - backward_outputs[require_grad_names_index])) + backward_outputs[require_grad_names_index], self._device)) require_grad_names_index += 1 else: # input_name is not found in the self._input_info.require_grad_names list @@ -237,7 +237,7 @@ class TrainingManager(GraphExecutionManager): if initializer_name in self._graph_initializer_names_to_train: results.append(_utils._torch_tensor_from_dl_pack( backward_outputs.dlpack_at(initializer_index), - backward_outputs[initializer_index])) + backward_outputs[initializer_index], self._device)) initializer_index += 1 else: results.append(None) @@ -285,7 +285,7 @@ class TrainingManager(GraphExecutionManager): session_options, providers, provider_options = self._get_session_config() fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input] fw_outputs_device_info = [ - C.OrtDevice(get_ort_device_type(self._device.type), + C.OrtDevice(get_ort_device_type(self._device), C.OrtDevice.default_memory(), _utils.get_device_index(self._device) )] * (len(self._graph_info.user_output_names) + @@ -293,7 +293,7 @@ class TrainingManager(GraphExecutionManager): bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output] bw_outputs_device_info = [ - C.OrtDevice(get_ort_device_type(self._device.type), + C.OrtDevice(get_ort_device_type(self._device), C.OrtDevice.default_memory(), _utils.get_device_index(self._device) )] * len(bw_fetches_names) diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index c1a146b2ff..ec3af45b0a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -16,24 +16,29 @@ from typing import List import types import warnings - -def _ortvalue_to_torch_tensor(ortvalue): - # PyTorch's to_dlpack() uses same config for both torch.bool and torch.uint8, - # and convert the config to torch.uint8 tensor duing from_dlpack(). - # So we need to convert the torch tensor to torch.bool type if OrtValue is bool tensor. - torch_tensor = from_dlpack(ortvalue.to_dlpack()) - return torch_tensor.to(torch.bool) if ortvalue.data_type() == 'tensor(bool)' else torch_tensor - - def _ortvalue_from_torch_tensor(torch_tensor): return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), torch_tensor.dtype == torch.bool) -def _torch_tensor_from_dl_pack(dlpack, ortvalue): - torch_tensor = from_dlpack(dlpack) +def _torch_tensor_from_dl_pack(dlpack, ortvalue, device): + torch_tensor = from_dlpack(dlpack) if device.type != 'ort' else C.ort_from_dlpack(dlpack) return torch_tensor.to(torch.bool) if ortvalue.data_type() == 'tensor(bool)' else torch_tensor +def _ortvalue_to_torch_tensor(ortvalue, device): + # PyTorch's to_dlpack() uses same config for both torch.bool and torch.uint8, + # and convert the config to torch.uint8 tensor duing from_dlpack(). + # So we need to convert the torch tensor to torch.bool type if OrtValue is bool tensor. + dlpack_tensor = ortvalue.to_dlpack() + return _torch_tensor_from_dl_pack(dlpack_tensor, ortvalue, device) + +def _torch_tensor_to_dlpack(tensor): + if tensor.device.type == 'ort': + return C.ort_to_dlpack(tensor) + else: + return to_dlpack(tensor) + + def _check_same_device(device, argument_str, *args): '''Check that all tensor arguments in *args reside on the same device as the input device'''