mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Add initial dynamic axes support
This commit is contained in:
parent
004632ff8d
commit
7729bb3c8d
5 changed files with 107 additions and 100 deletions
|
|
@ -43,13 +43,16 @@ void FilterInitializers(Graph& graph, const std::unordered_set<std::string>& inp
|
|||
}
|
||||
}
|
||||
|
||||
Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
|
||||
Status ModuleGradientGraphBuilder::Initialize(std::istream& model_istream,
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
// We need to apply the pre-training transformers before the gradient graph builder so we can build
|
||||
// an optimized gradient graph. The constant folding transformer depends on concrete shapes, without
|
||||
// constant folding with concrete shapes, shapes of some intermediate tensors will fail to infer.
|
||||
// This means we need to "apply transformers -> build gradient graph -> split" each time we have different
|
||||
// concrete input shapes. So this init func is just to save the original graph and config.
|
||||
ONNX_NAMESPACE::ModelProto model_proto;
|
||||
ORT_RETURN_IF_ERROR(Model::Load(model_istream, &model_proto));
|
||||
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_));
|
||||
ORT_RETURN_IF_ERROR(model_->MainGraph().Resolve());
|
||||
|
||||
// Handle original model inputs, outputs and trainable initializers.
|
||||
const std::vector<const NodeArg*>& graph_inputs = model_->MainGraph().GetInputsIncludingInitializers();
|
||||
|
|
@ -65,6 +68,35 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(),
|
||||
config.initializer_names_to_train.end());
|
||||
|
||||
config_ = config;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ModuleGradientGraphBuilder::BuildAndSplit(const std::vector<std::vector<int64_t>>& input_shapes) {
|
||||
// Make a copy of the original model.
|
||||
auto model_proto = model_->ToProto();
|
||||
std::shared_ptr<onnxruntime::Model> model_copied;
|
||||
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_copied, nullptr, *logger_));
|
||||
Graph& graph = model_copied->MainGraph();
|
||||
|
||||
// Replace the input shapes.
|
||||
std::vector<const NodeArg*> input_args;
|
||||
size_t input_index = 0;
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
NodeArg* input_node_arg = graph.GetNodeArg(input_name);
|
||||
ONNX_NAMESPACE::TensorShapeProto new_shape;
|
||||
for (size_t i = 0; i < input_shapes[input_index].size(); i++) {
|
||||
new_shape.add_dim()->set_dim_value(input_shapes[input_index][i]);
|
||||
}
|
||||
|
||||
input_node_arg->SetShape(new_shape);
|
||||
input_args.emplace_back(input_node_arg);
|
||||
input_index++;
|
||||
}
|
||||
|
||||
graph.SetInputs(input_args);
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
|
||||
// Register and apply transformers for pre-training.
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
|
||||
GraphTransformerManager graph_transformation_mgr{2};
|
||||
|
|
@ -72,8 +104,8 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
|
||||
std::unordered_set<std::string> x_node_arg_names;
|
||||
std::set_union(config.initializer_names_to_train.begin(), config.initializer_names_to_train.end(),
|
||||
config.input_names_require_grad.begin(), config.input_names_require_grad.end(),
|
||||
std::set_union(config_.initializer_names_to_train.begin(), config_.initializer_names_to_train.end(),
|
||||
config_.input_names_require_grad.begin(), config_.input_names_require_grad.end(),
|
||||
std::inserter(x_node_arg_names, x_node_arg_names.begin()));
|
||||
auto add_transformers = [&](TransformerLevel level) {
|
||||
std::unordered_map<std::string, std::string> updated_weight_names{};
|
||||
|
|
@ -91,41 +123,39 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
}
|
||||
}
|
||||
|
||||
Graph& graph = model_->MainGraph();
|
||||
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *logger_));
|
||||
}
|
||||
|
||||
// TODO: mixed precision transformer.
|
||||
|
||||
// Build gradient graph.
|
||||
GradientGraphConfiguration gradient_graph_config{};
|
||||
gradient_graph_config.use_invertible_layernorm_grad = config.use_invertible_layernorm_grad;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = config.set_gradients_as_graph_outputs;
|
||||
gradient_graph_config.use_invertible_layernorm_grad = config_.use_invertible_layernorm_grad;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = config_.set_gradients_as_graph_outputs;
|
||||
std::unordered_set<std::string> y_node_arg_names(split_graphs_info_.user_output_names.begin(),
|
||||
split_graphs_info_.user_output_names.end());
|
||||
GradientGraphBuilder grad_graph_builder(&model_->MainGraph(), y_node_arg_names, x_node_arg_names,
|
||||
"", // not support loss name for now.
|
||||
GradientGraphBuilder grad_graph_builder(&graph, y_node_arg_names, x_node_arg_names,
|
||||
"",
|
||||
gradient_graph_config, *logger_);
|
||||
ORT_RETURN_IF_ERROR(grad_graph_builder.Build());
|
||||
|
||||
// Fix inputs/outputs related to gradients.
|
||||
Graph& gradient_graph = model_->MainGraph();
|
||||
GraphViewer gradient_graph_viewer(gradient_graph);
|
||||
const auto& node_topology_list = gradient_graph_viewer.GetNodesInTopologicalOrder();
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::unordered_set<std::string> input_names;
|
||||
std::unordered_set<std::string> output_names;
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto& node = *gradient_graph.GetNode(node_index);
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
GetInputAndOutputNames(node, input_names, output_names);
|
||||
}
|
||||
|
||||
std::vector<const NodeArg*> input_args;
|
||||
input_args.clear();
|
||||
for (auto& input_name : split_graphs_info_.user_input_names) {
|
||||
input_args.emplace_back(gradient_graph.GetNodeArg(input_name));
|
||||
input_args.emplace_back(graph.GetNodeArg(input_name));
|
||||
}
|
||||
|
||||
// Add the entry points of gradients (normally loss_gard) to the graph inputs. Using the order of graph outputs.
|
||||
split_graphs_info_.user_output_grad_names.clear();
|
||||
split_graphs_info_.backward_output_grad_names.clear();
|
||||
for (const auto& output_name : split_graphs_info_.user_output_names) {
|
||||
std::string output_gradient_name = output_name + "_grad";
|
||||
if (input_names.find(output_gradient_name) != input_names.end()) {
|
||||
|
|
@ -133,48 +163,48 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(std::istream& model_istream,
|
|||
// Only add to graph input when it's not an output of a node.
|
||||
if (output_names.find(output_gradient_name) == output_names.end()) {
|
||||
split_graphs_info_.backward_output_grad_names.emplace_back(output_gradient_name);
|
||||
NodeArg* output_gradient_node_arg = gradient_graph.GetNodeArg(output_gradient_name);
|
||||
output_gradient_node_arg->UpdateTypeAndShape(*gradient_graph.GetNodeArg(output_name), true, true, *logger_);
|
||||
NodeArg* output_gradient_node_arg = graph.GetNodeArg(output_gradient_name);
|
||||
output_gradient_node_arg->UpdateTypeAndShape(*graph.GetNodeArg(output_name), true, true, *logger_);
|
||||
input_args.emplace_back(output_gradient_node_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
gradient_graph.SetInputs(input_args);
|
||||
graph.SetInputs(input_args);
|
||||
|
||||
std::vector<const NodeArg*> output_args;
|
||||
for (auto& output_name : split_graphs_info_.user_output_names) {
|
||||
output_args.emplace_back(gradient_graph.GetNodeArg(output_name));
|
||||
output_args.emplace_back(graph.GetNodeArg(output_name));
|
||||
}
|
||||
|
||||
// Add initializer gradients to graph outputs.
|
||||
split_graphs_info_.initializer_grad_names_to_train.clear();
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
std::string initializer_gradient_name = initializer_name + "_grad";
|
||||
if (output_names.find(initializer_gradient_name) != output_names.end()) {
|
||||
split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name);
|
||||
output_args.emplace_back(gradient_graph.GetNodeArg(initializer_gradient_name));
|
||||
output_args.emplace_back(graph.GetNodeArg(initializer_gradient_name));
|
||||
}
|
||||
}
|
||||
|
||||
// Add input gradients to graph outputs if it's required.
|
||||
for (const auto& input_name : config.input_names_require_grad) {
|
||||
for (const auto& input_name : config_.input_names_require_grad) {
|
||||
std::string input_gradient_name = input_name + "_grad";
|
||||
if (output_names.find(input_gradient_name) != output_names.end()) {
|
||||
output_args.emplace_back(gradient_graph.GetNodeArg(input_gradient_name));
|
||||
output_args.emplace_back(graph.GetNodeArg(input_gradient_name));
|
||||
}
|
||||
}
|
||||
|
||||
gradient_graph.SetOutputs(output_args);
|
||||
graph.SetOutputs(output_args);
|
||||
graph.Resolve();
|
||||
|
||||
gradient_graph.Resolve();
|
||||
|
||||
// Run the transformers again mainly for backward part.
|
||||
// Run the transformers again mainly for backward part, e.g., constant fold from those Shape nodes in backward graph.
|
||||
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(gradient_graph, static_cast<TransformerLevel>(i), *logger_));
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *logger_));
|
||||
}
|
||||
|
||||
// Create two copies of gradient model for forward and backward models respectively.
|
||||
auto gradient_model_proto = model_->ToProto();
|
||||
auto gradient_model_proto = model_copied->ToProto();
|
||||
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, forward_model_, nullptr, *logger_));
|
||||
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_));
|
||||
|
||||
|
|
@ -193,8 +223,6 @@ std::string SerializeModel(const std::shared_ptr<onnxruntime::Model>& model, con
|
|||
return model_str;
|
||||
}
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetGradientModel() const { return SerializeModel(model_, "gradient"); }
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetForwardModel() const { return SerializeModel(forward_model_, "forward"); }
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetBackwardModel() const { return SerializeModel(backward_model_, "backward"); }
|
||||
|
|
@ -251,6 +279,7 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
}
|
||||
|
||||
// Add intermediate args to forward graph outputs.
|
||||
split_graphs_info_.intermediate_tensor_names.clear();
|
||||
for (const auto& intermediate_arg_name : intermediate_arg_names) {
|
||||
// Ignore the user outputs.
|
||||
if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(),
|
||||
|
|
@ -261,7 +290,6 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
}
|
||||
|
||||
forward_graph.SetOutputs(forward_output_args);
|
||||
|
||||
forward_graph.Resolve();
|
||||
|
||||
// Get backward graph.
|
||||
|
|
@ -279,6 +307,7 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
RemoveNodes(backward_graph, backward_nodes_to_remove);
|
||||
|
||||
// User inputs to backward graph inputs.
|
||||
split_graphs_info_.backward_user_input_names.clear();
|
||||
std::vector<const NodeArg*> backward_input_args;
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
// Only takes those in the backward inputs.
|
||||
|
|
@ -289,6 +318,7 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
}
|
||||
|
||||
// Add initializer args to backward graph inputs if any node uses them.
|
||||
split_graphs_info_.backward_intializer_names_as_input.clear();
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
// Some initializers will be inputs for backward graph.
|
||||
if (backward_input_names.find(initializer_name) != backward_input_names.end()) {
|
||||
|
|
@ -322,11 +352,8 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
}
|
||||
|
||||
backward_graph.SetOutputs(backward_output_args);
|
||||
|
||||
FilterInitializers(backward_graph, backward_input_names);
|
||||
|
||||
backward_graph.Resolve();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ struct SplitGraphsInfo {
|
|||
|
||||
class ModuleGradientGraphBuilder {
|
||||
public:
|
||||
Status BuildAndSplit(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config);
|
||||
Status Initialize(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config);
|
||||
Status BuildAndSplit(const std::vector<std::vector<int64_t>>& input_shapes);
|
||||
|
||||
std::string GetGradientModel() const;
|
||||
std::string GetForwardModel() const;
|
||||
std::string GetBackwardModel() const;
|
||||
SplitGraphsInfo GetSplitGraphsInfo() const { return split_graphs_info_; }
|
||||
|
|
@ -59,7 +59,8 @@ class ModuleGradientGraphBuilder {
|
|||
std::shared_ptr<onnxruntime::Model> backward_model_;
|
||||
SplitGraphsInfo split_graphs_info_;
|
||||
|
||||
const logging::Logger* logger_;
|
||||
ModuleGradientGraphBuilderConfiguration config_;
|
||||
const logging::Logger* logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -380,14 +380,15 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
.def(py::init([]() {
|
||||
return onnxruntime::make_unique<ModuleGradientGraphBuilder>();
|
||||
}))
|
||||
.def("build_and_split", [](ModuleGradientGraphBuilder* module_gradient_graph_builder,
|
||||
const py::bytes& serialized_model,
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
.def("initialize", [](ModuleGradientGraphBuilder* module_gradient_graph_builder,
|
||||
const py::bytes& serialized_model,
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
std::istringstream buffer(serialized_model);
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(buffer, config));
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->Initialize(buffer, config));
|
||||
})
|
||||
.def("get_gradient_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetGradientModel());
|
||||
.def("build_and_split", [](ModuleGradientGraphBuilder* module_gradient_graph_builder,
|
||||
const std::vector<std::vector<int64_t>>& input_shapes) {
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(input_shapes));
|
||||
})
|
||||
.def("get_forward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetForwardModel());
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def _onnx_value_info_to_buffer_tensor(value_info, device):
|
|||
|
||||
class ORTModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
def __init__(self, module, dynamic_axes=None):
|
||||
assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module"
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
|
|
@ -66,8 +66,12 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# User module is wrapped to use its initializers and save computed gradients
|
||||
self._original_module = module
|
||||
self._dynamic_axes = dynamic_axes
|
||||
self._onnx_training = None
|
||||
|
||||
self._curr_inputs_size = None
|
||||
self._module_gradient_graph_builder = None
|
||||
|
||||
# Forward pass
|
||||
self._onnx_forward = None
|
||||
self._forward_session = None
|
||||
|
|
@ -154,19 +158,28 @@ class ORTModule(torch.nn.Module):
|
|||
if not self._onnx_forward or self._require_export:
|
||||
self._require_export = False
|
||||
|
||||
self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs)
|
||||
self._onnx_training = ORTModule._get_forward_graph(self._original_module, self._dynamic_axes, *inputs, **kwargs)
|
||||
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
|
||||
# TODO: PyTorch exporter bug: changes the initializer order
|
||||
initializer_names = [p[0] for p in self._original_module.named_parameters()]
|
||||
onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \
|
||||
ORTModule._build_fw_bw_grad_graphs(self._onnx_training, grad_builder_config,
|
||||
initializer_names,
|
||||
self._save_onnx)
|
||||
grad_builder_config.initializer_names_to_train = initializer_names
|
||||
grad_builder_config.input_names_require_grad = []
|
||||
self._module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
|
||||
self._module_gradient_graph_builder.initialize(self._onnx_training.SerializeToString(), grad_builder_config)
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
|
||||
onnx.save(onnx_gradient, self._save_onnx_prefix + '_with_grad.onnx')
|
||||
|
||||
inputs_size = [list(input.size()) for input in inputs if input is not None]
|
||||
if self._curr_inputs_size is None or self._curr_inputs_size != inputs_size:
|
||||
self._curr_inputs_size = inputs_size
|
||||
self._module_gradient_graph_builder.build_and_split(self._curr_inputs_size)
|
||||
self._onnx_forward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_forward_model())
|
||||
self._onnx_backward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_backward_model())
|
||||
self._onnx_graphs_info = self._module_gradient_graph_builder.get_split_graphs_info()
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
|
||||
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
|
||||
|
||||
|
|
@ -174,6 +187,7 @@ class ORTModule(torch.nn.Module):
|
|||
self._backward_session = onnxruntime.InferenceSession(self._onnx_backward.SerializeToString())
|
||||
|
||||
# IO binding
|
||||
# TODO: we should try to reuse the output buffers as some of the output tensors are same sizes, expecially the backward graph outputs.
|
||||
self._forward_io_binding = self._forward_session.io_binding()
|
||||
self._forward_output_buffers = {}
|
||||
for output in self._onnx_forward.graph.output:
|
||||
|
|
@ -335,7 +349,7 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def _get_forward_graph(module, *inputs, **kwargs):
|
||||
def _get_forward_graph(module, dynamic_axes, *inputs, **kwargs):
|
||||
'''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input
|
||||
|
||||
TODO: How to support dynamic axes? Dimensions are determined by samples
|
||||
|
|
@ -363,36 +377,7 @@ class ORTModule(torch.nn.Module):
|
|||
input_names=input_names,
|
||||
opset_version=ONNX_OPSET_VERSION,
|
||||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING)
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
dynamic_axes=dynamic_axes)
|
||||
|
||||
return onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _build_fw_bw_grad_graphs(forward_graph, config, initializer_names=[], include_gradient_model=False):
|
||||
'''Adds gradient nodes on top of an existing ONNX graph (with training flag)'''
|
||||
if not config.initializer_names_to_train:
|
||||
if not initializer_names:
|
||||
initializer_names_to_train = []
|
||||
for initializer in forward_graph.graph.initializer:
|
||||
initializer_names_to_train.append(initializer.name)
|
||||
config.initializer_names_to_train = initializer_names_to_train
|
||||
else:
|
||||
config.initializer_names_to_train = initializer_names
|
||||
|
||||
# TODO: Add support to input with grad required
|
||||
config.input_names_require_grad = []
|
||||
# input_names_require_grad = []
|
||||
# input_names_require_grad.append('input.1')
|
||||
# config.input_names_require_grad = input_names_require_grad
|
||||
|
||||
module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
|
||||
module_gradient_graph_builder.build_and_split(forward_graph.SerializeToString(), config)
|
||||
forward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_forward_model())
|
||||
backward_model = onnx.load_model_from_string(module_gradient_graph_builder.get_backward_model())
|
||||
gradient_model = None
|
||||
if include_gradient_model:
|
||||
gradient_model = onnx.load_model_from_string(module_gradient_graph_builder.get_gradient_model())
|
||||
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
|
||||
|
||||
return gradient_model, forward_model, backward_model, split_graphs_info
|
||||
|
|
|
|||
|
|
@ -49,11 +49,6 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
if step == args.train_steps:
|
||||
break
|
||||
|
||||
# TODO: Dynamic axis is not supported yet
|
||||
if batch[0].shape[0] != args.batch_size:
|
||||
logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}')
|
||||
continue
|
||||
|
||||
# Unpack this training batch from our dataloader.
|
||||
#
|
||||
# As we unpack the batch, we'll also copy each tensor to the GPU using the
|
||||
|
|
@ -159,12 +154,6 @@ def test(model, validation_dataloader, device, args):
|
|||
|
||||
# Evaluate data for one epoch
|
||||
for batch in validation_dataloader:
|
||||
|
||||
# TODO: Dynamic axis is not supported yet
|
||||
if batch[0].shape[0] != args.test_batch_size:
|
||||
logging.warning(f'Dynamic axis is not supported yet {len(batch)}/{args.batch_size}')
|
||||
continue
|
||||
|
||||
# Add batch to GPU
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
|
||||
|
|
@ -336,8 +325,8 @@ def main():
|
|||
help='disables ONNX Runtime training')
|
||||
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for testing (default: 32)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for testing (default: 64)')
|
||||
parser.add_argument('--view-graphs', action='store_true', default=False,
|
||||
help='views forward and backward graphs')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
|
|
@ -391,7 +380,11 @@ def main():
|
|||
)
|
||||
|
||||
if not args.pytorch_only:
|
||||
model = ORTModule(model)
|
||||
dynamic_axes = {'input_ids': {0: 'batch_size', 1: 'seq_len'},
|
||||
'attention_mask': {0: 'batch_size', 1: 'seq_len'},
|
||||
'labels': {0: 'batch_size'},
|
||||
'210': {0: 'batch'}}
|
||||
model = ORTModule(model, dynamic_axes)
|
||||
|
||||
# TODO: change it to False to stop saving ONNX models
|
||||
model._save_onnx = True
|
||||
|
|
|
|||
Loading…
Reference in a new issue