mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
add io binding
This commit is contained in:
parent
ff79e8743f
commit
39ac95b2fc
5 changed files with 324 additions and 88 deletions
|
|
@ -14,8 +14,7 @@ namespace training {
|
|||
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
void GetInputAndOutputNames(const Node& node,
|
||||
std::unordered_set<std::string>& input_names,
|
||||
void GetInputAndOutputNames(const Node& node, std::unordered_set<std::string>& input_names,
|
||||
std::unordered_set<std::string>& output_names) {
|
||||
std::for_each(node.InputDefs().begin(), node.InputDefs().end(),
|
||||
[&input_names](const NodeArg* node_arg) { input_names.insert(node_arg->Name()); });
|
||||
|
|
@ -63,7 +62,8 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
split_graphs_info_.user_output_names.emplace_back(node_arg->Name());
|
||||
}
|
||||
|
||||
split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end());
|
||||
split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(),
|
||||
config.initializer_names_to_train.end());
|
||||
|
||||
// Register and apply transformers for pre-training.
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
|
||||
|
|
@ -76,8 +76,9 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
config.input_names_require_grad.begin(), config.input_names_require_grad.end(),
|
||||
std::inserter(x_node_arg_names, x_node_arg_names.begin()));
|
||||
auto add_transformers = [&](TransformerLevel level) {
|
||||
std::unordered_map<std::string, std::string> updated_weight_names{};
|
||||
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
|
||||
level, x_node_arg_names, graph_transformer_config, *cpu_execution_provider, {});
|
||||
level, x_node_arg_names, graph_transformer_config, *cpu_execution_provider, updated_weight_names, {});
|
||||
for (auto& entry : transformers_to_register) {
|
||||
graph_transformation_mgr.Register(std::move(entry), level);
|
||||
}
|
||||
|
|
@ -101,13 +102,11 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
GradientGraphConfiguration gradient_graph_config{};
|
||||
gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs;
|
||||
std::unordered_set<std::string> y_node_arg_names(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end());
|
||||
GradientGraphBuilder grad_graph_builder(&model_->MainGraph(),
|
||||
y_node_arg_names,
|
||||
x_node_arg_names,
|
||||
"", // not support loss name for now.
|
||||
gradient_graph_config,
|
||||
*logger_);
|
||||
std::unordered_set<std::string> y_node_arg_names(split_graphs_info_.user_output_names.begin(),
|
||||
split_graphs_info_.user_output_names.end());
|
||||
GradientGraphBuilder grad_graph_builder(&model_->MainGraph(), y_node_arg_names, x_node_arg_names,
|
||||
"", // not support loss name for now.
|
||||
gradient_graph_config, *logger_);
|
||||
ORT_RETURN_IF_ERROR(grad_graph_builder.Build());
|
||||
|
||||
// Fix inputs/outputs related to gradients.
|
||||
|
|
@ -152,6 +151,7 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
std::string initializer_gradient_name = initializer_name + "_grad";
|
||||
if (output_names.find(initializer_gradient_name) != output_names.end()) {
|
||||
split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name);
|
||||
output_args.emplace_back(gradient_graph.GetNodeArg(initializer_gradient_name));
|
||||
}
|
||||
}
|
||||
|
|
@ -188,17 +188,11 @@ std::string SerializeModel(const std::shared_ptr<onnxruntime::Model>& model, con
|
|||
return model_str;
|
||||
}
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetGradientModel() const {
|
||||
return SerializeModel(model_, "gradient");
|
||||
}
|
||||
std::string ModuleGradientGraphBuilder::GetGradientModel() const { return SerializeModel(model_, "gradient"); }
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetForwardModel() const {
|
||||
return SerializeModel(forward_model_, "forward");
|
||||
}
|
||||
std::string ModuleGradientGraphBuilder::GetForwardModel() const { return SerializeModel(forward_model_, "forward"); }
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetBackwardModel() const {
|
||||
return SerializeModel(backward_model_, "backward");
|
||||
}
|
||||
std::string ModuleGradientGraphBuilder::GetBackwardModel() const { return SerializeModel(backward_model_, "backward"); }
|
||||
|
||||
Status ModuleGradientGraphBuilder::Split() {
|
||||
// Get forward model, also collect some information for backward model generation.
|
||||
|
|
@ -253,8 +247,8 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
// Add intermediate args to forward graph outputs.
|
||||
for (const auto& intermediate_arg_name : intermediate_arg_names) {
|
||||
// Ignore the user outputs.
|
||||
if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(), intermediate_arg_name)
|
||||
== split_graphs_info_.user_output_names.end()) {
|
||||
if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(),
|
||||
intermediate_arg_name) == split_graphs_info_.user_output_names.end()) {
|
||||
split_graphs_info_.intermediate_tensor_names.emplace_back(intermediate_arg_name);
|
||||
forward_output_args.emplace_back(forward_graph.GetNodeArg(intermediate_arg_name));
|
||||
}
|
||||
|
|
@ -264,7 +258,8 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
|
||||
// Resolve the forward graph, keep the trainable initializers for now.
|
||||
Graph::ResolveOptions options;
|
||||
std::unordered_set<std::string> initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(), split_graphs_info_.initializer_names_to_train.end());
|
||||
std::unordered_set<std::string> initializer_names_to_train_set(split_graphs_info_.initializer_names_to_train.begin(),
|
||||
split_graphs_info_.initializer_names_to_train.end());
|
||||
options.initializer_names_to_preserve = &initializer_names_to_train_set;
|
||||
forward_graph.Resolve(options);
|
||||
|
||||
|
|
@ -292,15 +287,9 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
}
|
||||
}
|
||||
|
||||
// Grad of user outputs to backward graph inputs.
|
||||
for (const auto& output_grad_name : split_graphs_info_.backward_output_grad_names) {
|
||||
backward_input_args.emplace_back(backward_graph.GetNodeArg(output_grad_name));
|
||||
}
|
||||
|
||||
// Add initializer args to backward graph inputs if any node uses them.
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
// Some initializers will be inputs for backward graph.
|
||||
split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_name + "_grad");
|
||||
if (backward_input_names.find(initializer_name) != backward_input_names.end()) {
|
||||
split_graphs_info_.backward_intializer_names_as_input.emplace_back(initializer_name);
|
||||
backward_input_args.emplace_back(backward_graph.GetNodeArg(initializer_name));
|
||||
|
|
@ -315,6 +304,11 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
backward_input_args.emplace_back(intermediate_node_arg);
|
||||
}
|
||||
|
||||
// Grad of user outputs to backward graph inputs.
|
||||
for (const auto& output_grad_name : split_graphs_info_.backward_output_grad_names) {
|
||||
backward_input_args.emplace_back(backward_graph.GetNodeArg(output_grad_name));
|
||||
}
|
||||
|
||||
backward_graph.SetInputs(backward_input_args);
|
||||
|
||||
// Exclude user outputs from the backward graph.
|
||||
|
|
|
|||
|
|
@ -44,15 +44,12 @@ struct SplitGraphsInfo {
|
|||
|
||||
class ModuleGradientGraphBuilder {
|
||||
public:
|
||||
Status BuildAndSplit(std::istream& model_istream,
|
||||
const ModuleGradientGraphBuilderConfiguration& config);
|
||||
Status BuildAndSplit(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config);
|
||||
|
||||
std::string GetGradientModel() const;
|
||||
std::string GetForwardModel() const;
|
||||
std::string GetBackwardModel() const;
|
||||
SplitGraphsInfo GetSplitGraphsInfo() const {
|
||||
return split_graphs_info_;
|
||||
}
|
||||
SplitGraphsInfo GetSplitGraphsInfo() const { return split_graphs_info_; }
|
||||
|
||||
private:
|
||||
Status Split();
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import onnxruntime
|
|||
import os
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from inspect import signature
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
|
|
@ -15,12 +16,44 @@ from . import _utils
|
|||
ONNX_OPSET_VERSION = 12
|
||||
|
||||
|
||||
def get_device_index(device):
|
||||
if type(device) == str:
|
||||
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
|
||||
device = torch.device(device)
|
||||
return 0 if device.index is None else device.index
|
||||
|
||||
|
||||
def prepare_io_binding(io_binding, inputs, model, output_buffers, device):
|
||||
idx = 0
|
||||
for value_info in model.graph.input:
|
||||
io_binding.bind_input(value_info.name, inputs[idx].device.type, get_device_index(inputs[idx].device),
|
||||
_utils.dtype_torch_to_numpy(inputs[idx].dtype), list(inputs[idx].size()),
|
||||
inputs[idx].data_ptr())
|
||||
idx += 1
|
||||
|
||||
for value_info in model.graph.output:
|
||||
name = value_info.name
|
||||
output_tensor = output_buffers[name]
|
||||
io_binding.bind_output(name, output_tensor.device.type, get_device_index(device),
|
||||
_utils.dtype_torch_to_numpy(output_tensor.dtype), list(output_tensor.size()),
|
||||
output_tensor.data_ptr())
|
||||
|
||||
|
||||
def value_info_to_buffer_tensor(value_info, device):
|
||||
shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
|
||||
dtype = _utils.dtype_onnx_to_torch(value_info.type.tensor_type.elem_type)
|
||||
return torch.zeros(shape, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class ORTModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module, device="cpu", use_iobinding=False):
|
||||
assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module"
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
self._device = device
|
||||
self._use_iobinding = use_iobinding
|
||||
|
||||
# User module is wrapped to use its initializers and save computed gradients
|
||||
self._original_module = module
|
||||
self._onnx_training = None
|
||||
|
|
@ -52,9 +85,10 @@ class ORTModule(torch.nn.Module):
|
|||
if not self._onnx_forward:
|
||||
self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs)
|
||||
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config)
|
||||
# TODO: PyTorch exporter bug: changes the initializer order
|
||||
self._onnx_graphs_info.initializer_grad_names_to_train = [ p[0]+'_grad' for p in self._original_module.named_parameters()]
|
||||
# Use the order in original module
|
||||
initializer_names = [p[0] for p in self._original_module.named_parameters()]
|
||||
self._onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config, initializer_names)
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
|
||||
|
|
@ -65,9 +99,19 @@ class ORTModule(torch.nn.Module):
|
|||
# TODO: Consider moving this to the backend. We don't want to append '_grad' to get correct tensor names
|
||||
self._onnx_graphs_types = ORTModule._get_io_info_from_onnx_graph(self._onnx_forward, self._onnx_graphs_info)
|
||||
|
||||
# TODO: hard-coding to CPU only
|
||||
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=['CPUExecutionProvider'])
|
||||
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=['CPUExecutionProvider'])
|
||||
execution_providers = ['CPUExecutionProvider'] if self._device == 'cpu' else ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
self._forward_session = onnxruntime.InferenceSession(self._onnx_forward.SerializeToString(), providers=execution_providers)
|
||||
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString(), providers=execution_providers)
|
||||
|
||||
if self._use_iobinding:
|
||||
self._forward_io_binding = self._forward_session.io_binding()
|
||||
self._forward_output_buffers = {}
|
||||
for output in self._onnx_forward.graph.output:
|
||||
self._forward_output_buffers[output.name] = value_info_to_buffer_tensor(output, self._device)
|
||||
self._backward_io_binding = self._backward_session.io_binding()
|
||||
self._backward_output_buffers = {}
|
||||
for output in self._onnx_backward.graph.output:
|
||||
self._backward_output_buffers[output.name] = value_info_to_buffer_tensor(output, self._device)
|
||||
|
||||
# Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
# gradient implementation for self.forward_graph.
|
||||
|
|
@ -85,30 +129,44 @@ class ORTModule(torch.nn.Module):
|
|||
* Intermediate tensors
|
||||
'''
|
||||
|
||||
# Convert input to dict of torch tensors
|
||||
data_dict = self._convert_forward_input_list_to_dict(*inputs)
|
||||
if not self._use_iobinding:
|
||||
# Convert input to dict of torch tensors
|
||||
data_dict = self._convert_forward_input_list_to_dict(*inputs)
|
||||
|
||||
# Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement)
|
||||
data_dict_numpy = self._convert_dict_torch_to_numpy(data_dict)
|
||||
# Convert dict of torch tensors to dict of numpy arrays (ORT BE requirement)
|
||||
data_dict_numpy = self._convert_dict_torch_to_numpy(data_dict)
|
||||
|
||||
# Feed forward
|
||||
outputs, intermediate = self._run_forward_graph(data_dict_numpy)
|
||||
outputs = tuple(torch.from_numpy(item) for item in outputs)
|
||||
# Feed forward
|
||||
outputs, intermediate = self._run_forward_graph(data_dict_numpy)
|
||||
outputs = tuple(torch.from_numpy(item) for item in outputs)
|
||||
|
||||
# Save input, initializers and intermediate tensors to be used during backward
|
||||
user_input = self._onnx_graphs_info.user_input_names
|
||||
backward_user_input = self._onnx_graphs_info.backward_user_input_names
|
||||
ctx_input = tuple(data_dict[name] for name in user_input if name in backward_user_input)
|
||||
forward_initializer = self._onnx_graphs_info.initializer_names_to_train
|
||||
backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input
|
||||
ctx_initializer = tuple(data_dict[name] for name in forward_initializer if name in backward_intializer)
|
||||
intermediate = tuple(torch.from_numpy(item) for item in intermediate)
|
||||
ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate])
|
||||
# Save input, initializers and intermediate tensors to be used during backward
|
||||
user_input = self._onnx_graphs_info.user_input_names
|
||||
backward_user_input = self._onnx_graphs_info.backward_user_input_names
|
||||
ctx_input = tuple(data_dict[name] for name in user_input if name in backward_user_input)
|
||||
forward_initializer = self._onnx_graphs_info.initializer_names_to_train
|
||||
backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input
|
||||
ctx_initializer = tuple(data_dict[name] for name in forward_initializer if name in backward_intializer)
|
||||
intermediate = tuple(torch.from_numpy(item) for item in intermediate)
|
||||
ctx.save_for_backward(*[*ctx_input, *ctx_initializer, *intermediate])
|
||||
|
||||
# TODO: Support original module output (currently dict is not supported)
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
# TODO: Support original module output (currently dict is not supported)
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
# Use IO binding.
|
||||
prepare_io_binding(self._forward_io_binding, inputs, self._onnx_forward, self._forward_output_buffers, self._device)
|
||||
self._forward_session.run_with_iobinding(self._forward_io_binding)
|
||||
|
||||
forward_input_dict = self._convert_forward_input_list_to_dict(*inputs)
|
||||
ctx_inputs = tuple(forward_input_dict[name] for name in self._onnx_graphs_info.backward_user_input_names)
|
||||
ctx_initializers = tuple(forward_input_dict[name] for name in self._onnx_graphs_info.backward_intializer_names_as_input)
|
||||
ctx_intermediates = tuple(self._forward_output_buffers[name] for name in self._onnx_graphs_info.intermediate_tensor_names)
|
||||
ctx.save_for_backward(*[*ctx_inputs, *ctx_initializers, *ctx_intermediates])
|
||||
|
||||
outputs = tuple(self._forward_output_buffers[name] for name in self._onnx_graphs_info.user_output_names)
|
||||
return outputs[0] if len(outputs) == 1 else outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
|
|
@ -120,12 +178,23 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
TODO: Input gradient is hard-coded to torch.tensor([1.])
|
||||
'''
|
||||
saved_tensors = ctx.saved_tensors
|
||||
grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output])
|
||||
if not self._use_iobinding:
|
||||
saved_tensors = ctx.saved_tensors
|
||||
grad_weights = self._run_backward_graph(*[*saved_tensors, *grad_output])
|
||||
|
||||
result = [torch.tensor([1])]* len(self._onnx_graphs_info.user_input_names)
|
||||
result += [torch.from_numpy(grad) for grad in grad_weights]
|
||||
return tuple(result)
|
||||
result = [torch.tensor([1])]* len(self._onnx_graphs_info.user_input_names)
|
||||
result += [torch.from_numpy(grad) for grad in grad_weights]
|
||||
return tuple(result)
|
||||
|
||||
# Use IO binding.
|
||||
grad_output_dict = dict(zip(self._onnx_graphs_info.user_output_grad_names, grad_output))
|
||||
backward_grad_output = tuple(grad_output_dict[name] for name in self._onnx_graphs_info.backward_output_grad_names)
|
||||
prepare_io_binding(self._backward_io_binding, [*ctx.saved_tensors, *backward_grad_output], self._onnx_backward, self._backward_output_buffers, self._device)
|
||||
self._backward_session.run_with_iobinding(self._backward_io_binding)
|
||||
|
||||
results = [torch.tensor([1])] * len(self._onnx_graphs_info.user_input_names)
|
||||
results += [self._backward_output_buffers[name] for name in self._onnx_graphs_info.initializer_grad_names_to_train]
|
||||
return tuple(results)
|
||||
|
||||
proc_inputs = [data for data in inputs if data is not None]
|
||||
return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*proc_inputs, **kwargs))
|
||||
|
|
@ -297,13 +366,16 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def _build_fw_bw_grad_graphs(forward_graph, config):
|
||||
def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[]):
|
||||
'''Adds gradient nodes on top of an existing ONNX graph (with training flag)'''
|
||||
if not config.initializer_names_to_train:
|
||||
initializer_names_to_train = []
|
||||
for initializer in forward_graph.graph.initializer:
|
||||
initializer_names_to_train.append(initializer.name)
|
||||
config.initializer_names_to_train = initializer_names_to_train
|
||||
if not initializer_names:
|
||||
initializer_names_to_train = []
|
||||
for initializer in forward_graph.graph.initializer:
|
||||
initializer_names_to_train.append(initializer.name)
|
||||
config.initializer_names_to_train = initializer_names_to_train
|
||||
else:
|
||||
config.initializer_names_to_train = initializer_names
|
||||
|
||||
# TODO: Add support to input with grad required
|
||||
config.input_names_require_grad = []
|
||||
|
|
@ -323,27 +395,21 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _get_io_info_from_onnx_graph(model, graphs_info):
|
||||
type_map = {}
|
||||
for name in graphs_info.user_input_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.initializer_names_to_train:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.user_output_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.backward_user_input_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.backward_intializer_names_as_input:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.intermediate_tensor_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.user_output_grad_names:
|
||||
type_map[name] = None
|
||||
for name in graphs_info.backward_output_grad_names:
|
||||
type_map[name] = None
|
||||
type_map = {key: None for key in [
|
||||
*graphs_info.user_input_names,
|
||||
*graphs_info.initializer_names_to_train,
|
||||
*graphs_info.initializer_grad_names_to_train,
|
||||
*graphs_info.user_output_names,
|
||||
*graphs_info.intermediate_tensor_names,
|
||||
*graphs_info.user_output_grad_names
|
||||
]}
|
||||
|
||||
for input in model.graph.input:
|
||||
if input.name in type_map and type_map[input.name] is None:
|
||||
type_map[input.name] = input.type
|
||||
input_grad_name = input.name + '_grad'
|
||||
if input_grad_name in type_map and type_map[input_grad_name] is None:
|
||||
type_map[input_grad_name] = input.type
|
||||
|
||||
for output in model.graph.output:
|
||||
if output.name in type_map and type_map[output.name] is None:
|
||||
|
|
@ -352,4 +418,4 @@ class ORTModule(torch.nn.Module):
|
|||
if output_grad_name in type_map and type_map[output_grad_name] is None:
|
||||
type_map[output_grad_name] = output.type
|
||||
|
||||
return type_map
|
||||
return type_map
|
||||
|
|
|
|||
|
|
@ -0,0 +1,161 @@
|
|||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training import ORTModule
|
||||
|
||||
|
||||
class NeuralNet(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(NeuralNet, self).__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1):
|
||||
out = self.fc1(input1)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
|
||||
def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
|
||||
model.train()
|
||||
for iteration, (data, target) in enumerate(train_loader):
|
||||
if iteration == args.train_steps:
|
||||
break
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
if args.pytorch_only:
|
||||
probability = model(data)
|
||||
else:
|
||||
probability = model(data)
|
||||
|
||||
if args.view_graphs:
|
||||
import torchviz
|
||||
pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters())))
|
||||
pytorch_backward_graph.view()
|
||||
|
||||
loss = loss_fn(probability, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Stats
|
||||
if iteration % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, iteration * len(data), len(train_loader.dataset),
|
||||
100. * iteration / len(train_loader), loss))
|
||||
|
||||
|
||||
def test(args, model, device, loss_fn, test_loader):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
data = data.reshape(data.shape[0], -1)
|
||||
output = model(data)
|
||||
|
||||
# Stats
|
||||
test_loss += loss_fn(output, target, False).item()
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
test_loss /= len(test_loader.dataset)
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
def my_loss(x, target, is_train=True):
|
||||
if is_train:
|
||||
return torch.nn.CrossEntropyLoss()(x, target)
|
||||
else:
|
||||
return torch.nn.CrossEntropyLoss(reduction='sum')(x, target)
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--train-steps', type=int, default=-1, metavar='N',
|
||||
help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
|
||||
help='input batch size for training (default: 20)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=20, metavar='N',
|
||||
help='input batch size for testing (default: 20)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--use_iobinding', action='store_true', default=False,
|
||||
help='use IO binding')
|
||||
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
||||
help='random seed (default: 42)')
|
||||
parser.add_argument('--pytorch-only', action='store_true', default=False,
|
||||
help='disables ONNX Runtime training')
|
||||
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
|
||||
help='how many batches to wait before logging training status (default: 100)')
|
||||
parser.add_argument('--view-graphs', action='store_true', default=False,
|
||||
help='views forward and backward graphs')
|
||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING',
|
||||
help='Log level (default: WARNING)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Common setup
|
||||
torch.manual_seed(args.seed)
|
||||
onnxruntime.set_seed(args.seed)
|
||||
|
||||
# TODO: CUDA support is broken due to copying from PyTorch into ORT
|
||||
if not args.no_cuda and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
# device = 'cpu'
|
||||
|
||||
## Data loader
|
||||
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
|
||||
transform=transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True)
|
||||
if args.test_batch_size > 0:
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('./data', train=False, transform=transforms.Compose([
|
||||
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.test_batch_size, shuffle=True)
|
||||
|
||||
# Model architecture
|
||||
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
|
||||
if not args.pytorch_only:
|
||||
print('Training MNIST on ORTModule....')
|
||||
model = ORTModule(model, device, args.use_iobinding)
|
||||
|
||||
# TODO: change it to False to stop saving ONNX models
|
||||
model._save_onnx = True
|
||||
model._save_onnx_prefix = 'MNIST'
|
||||
|
||||
# Set log level
|
||||
numeric_level = getattr(logging, args.log_level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise ValueError('Invalid log level: %s' % args.log_level)
|
||||
logging.basicConfig(level=numeric_level)
|
||||
else:
|
||||
print('Training MNIST on vanilla PyTorch....')
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
|
||||
|
||||
# Train loop
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, optimizer, my_loss, train_loader, epoch)
|
||||
if args.test_batch_size > 0:
|
||||
test(args, model, device, my_loss, test_loader)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
18
run_ortmodule_mvp_mnist_iobinding.sh
Normal file
18
run_ortmodule_mvp_mnist_iobinding.sh
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
|
||||
cur_dir=$(basename `pwd`)
|
||||
|
||||
if [[ ${cur_dir} != "RelWithDebInfo" ]]
|
||||
then
|
||||
echo "Going to build folder (aka build/Linux/RelWithDebInfo)"
|
||||
cd build/Linux/RelWithDebInfo
|
||||
fi
|
||||
|
||||
echo "Exporting PYTHONPATH to use build dir as onnxruntime package"
|
||||
export PYTHONPATH=$(pwd)
|
||||
|
||||
echo "Copying PyTorch frontend source-code to build folder"
|
||||
cp -Rf ../../../orttraining/orttraining/python/training/* ../../../build/Linux/RelWithDebInfo/onnxruntime/training/
|
||||
|
||||
echo "Running Flexible API (ORTModule)"
|
||||
python ../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_iobinding.py --epochs 10 --log-interval 100 --log-level=DEBUG --use_iobinding
|
||||
Loading…
Reference in a new issue