mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Partition initial optimizer state for Zero-1 (#6093)
* Initial changes * Working changes * Working changes * Cleanup * fix windows CI * Review comments * review comments
This commit is contained in:
parent
8fd085801a
commit
82690486c1
13 changed files with 335 additions and 45 deletions
|
|
@ -37,13 +37,12 @@ Status AdamOptimizerBuilder::Build(
|
|||
// In distributed training, some weights may not be updated by all ranks.
|
||||
if (opt_configs[i].enabled) {
|
||||
// The type proto initializer for Update Count
|
||||
const std::string uc_prefix = "Update_Count";
|
||||
const std::string update_count_string = uc_prefix + "_" + weight_name; // per weight optimizer requires a per weight update count
|
||||
const std::string update_count_string = ADAM_UC_PREFIX + "_" + weight_name; // per weight optimizer requires a per weight update count
|
||||
TensorProto uc_tensor_proto;
|
||||
|
||||
// Update 'Update_Count' initializer with init value
|
||||
const auto& initial_states = opt_configs[i].initial_states;
|
||||
const auto uc_state_it = initial_states.find(uc_prefix);
|
||||
const auto uc_state_it = initial_states.find(ADAM_UC_PREFIX);
|
||||
if (uc_state_it != initial_states.end()) {
|
||||
const auto& init_tensor = uc_state_it->second.Get<Tensor>();
|
||||
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, ONNX_NAMESPACE::TensorProto_DataType_INT64, {1}));
|
||||
|
|
@ -82,8 +81,7 @@ Status AdamOptimizerBuilder::Build(
|
|||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
|
||||
|
||||
// Add first- and second-order momentums to input list.
|
||||
const std::vector<std::string> moments_prefixes({"Moment_1", "Moment_2"});
|
||||
for (const auto& moments_prefix : moments_prefixes) {
|
||||
for (const auto& moments_prefix : MOMENTS_PREFIXES) {
|
||||
const std::string gradient_moment_name = moments_prefix + "_" + weight_name;
|
||||
|
||||
TensorProto moment_tensor_proto;
|
||||
|
|
|
|||
|
|
@ -63,24 +63,23 @@ Status LambOptimizerBuilder::Build(
|
|||
|
||||
// Update count, which should be 1 at the first training iteration.
|
||||
// At the end of each Lamb call, the update count may be increased by one.
|
||||
const std::string step_tensor_name = "Step"; // per weight optimizer requires a per weight update count
|
||||
// Add step as an initializer.
|
||||
TensorProto step_tensor_proto;
|
||||
const auto& shared_optim_state = config.shared_optimizer_states;
|
||||
const auto step_state_it = shared_optim_state.find(step_tensor_name);
|
||||
const auto step_state_it = shared_optim_state.find(LAMB_STEP_TENSOR_NAME);
|
||||
if (step_state_it != shared_optim_state.end()) {
|
||||
const auto& init_tensor = step_state_it->second.Get<Tensor>();
|
||||
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, ONNX_NAMESPACE::TensorProto_DataType_INT64, {1}));
|
||||
step_tensor_proto = utils::TensorToTensorProto(init_tensor, step_tensor_name);
|
||||
step_tensor_proto = utils::TensorToTensorProto(init_tensor, LAMB_STEP_TENSOR_NAME);
|
||||
} else {
|
||||
step_tensor_proto = CreateTensorProto<int64_t>(step_tensor_name, 1);
|
||||
step_tensor_proto = CreateTensorProto<int64_t>(LAMB_STEP_TENSOR_NAME, 1);
|
||||
}
|
||||
new_external_initializers.emplace_back(step_tensor_proto);
|
||||
input_argdefs.emplace_back(ArgDef(step_tensor_name));
|
||||
input_argdefs.emplace_back(ArgDef(LAMB_STEP_TENSOR_NAME));
|
||||
|
||||
// Add the first output, which is the updated step.
|
||||
TypeProto* step_type_proto = graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
output_argdefs.emplace_back(ArgDef(step_tensor_name + "_Out", step_type_proto));
|
||||
output_argdefs.emplace_back(ArgDef(LAMB_STEP_TENSOR_NAME + "_Out", step_type_proto));
|
||||
|
||||
// Lamb optimizer's attributes.
|
||||
std::vector<float> alpha;
|
||||
|
|
@ -205,8 +204,7 @@ Status LambOptimizerBuilder::Build(
|
|||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
|
||||
|
||||
// m1 & m2 & m1_new & m2_new
|
||||
const std::vector<std::string> moments_prefixes({"Moment_1", "Moment_2"});
|
||||
for (const auto& moment_prefix : moments_prefixes) {
|
||||
for (const auto& moment_prefix : MOMENTS_PREFIXES) {
|
||||
const std::string gradient_moment_name = moment_prefix + "_" + weight_name;
|
||||
|
||||
// Construct type of momentum tensor.
|
||||
|
|
|
|||
|
|
@ -19,10 +19,11 @@ void OptimizerBuilderRegistry::RegisterBuilders() {
|
|||
Status IsMatchingTypeAndShape(
|
||||
const onnxruntime::Tensor& tensor,
|
||||
const int32_t element_type,
|
||||
const std::vector<int64_t>& expected_shape) {
|
||||
const std::vector<int64_t>& expected_shape_dims) {
|
||||
ORT_RETURN_IF_NOT(tensor.GetElementType() == element_type);
|
||||
const std::vector<int64_t>& tensor_shape = tensor.Shape().GetDims();
|
||||
ORT_RETURN_IF_NOT(tensor_shape == expected_shape);
|
||||
const TensorShape& tensor_shape = tensor.Shape();
|
||||
TensorShape expected_shape(expected_shape_dims);
|
||||
ORT_RETURN_IF_NOT(tensor_shape == expected_shape, "Mismatch: expected:[", tensor_shape, "], actual:[", expected_shape, "]");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,10 @@
|
|||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
const std::vector<std::string> MOMENTS_PREFIXES({"Moment_1", "Moment_2"});
|
||||
const std::string LAMB_STEP_TENSOR_NAME = "Step";
|
||||
const std::string ADAM_UC_PREFIX = "Update_Count";
|
||||
|
||||
template <class T>
|
||||
ONNX_NAMESPACE::TensorProto CreateTensorProto(
|
||||
const std::string& name,
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ static std::vector<ArgDef> AddPartitionsForParameter(
|
|||
GraphAugmenter::GraphDefs& graph_defs,
|
||||
const std::string& initializer_name,
|
||||
const std::vector<TensorShape>& shapes,
|
||||
std::unordered_map<std::string, std::string>& updated_weight_names_map){
|
||||
std::unordered_map<std::string, std::string>& updated_weight_names_map) {
|
||||
ORT_ENFORCE(shapes.size() == 3, "Invalid shapes vector passed for partitioning.");
|
||||
int64_t partition_offset = shapes[0].GetDims()[0];
|
||||
int64_t partition_size = shapes[1].GetDims()[0];
|
||||
|
|
@ -135,7 +135,7 @@ static std::vector<ArgDef> AddPartitionsForParameter(
|
|||
graph.RemoveInitializedTensor(initializer_name);
|
||||
graph.AddInitializedTensor(initializer_partition);
|
||||
|
||||
//add the modified weight name to get state
|
||||
//add the modified weight name to get state
|
||||
updated_weight_names_map[initializer_name] = partition_name;
|
||||
|
||||
auto partition_argdef = ArgDef(partition_name, graph_defs.CreateTypeProto({partition_size}, dtype));
|
||||
|
|
@ -185,6 +185,45 @@ static std::vector<ArgDef> AddViewForParameter(
|
|||
return view_outputs;
|
||||
}
|
||||
|
||||
void PartitionOptimizerState(
|
||||
const int64_t partition_offset,
|
||||
const int64_t partition_size,
|
||||
NameMLValMap& initial_states) {
|
||||
for (const auto& moments_prefix : MOMENTS_PREFIXES) {
|
||||
const auto initial_state_it = initial_states.find(moments_prefix);
|
||||
if (initial_state_it != initial_states.end()) {
|
||||
auto* init_tensor = initial_state_it->second.GetMutable<Tensor>();
|
||||
|
||||
OrtValue partitioned;
|
||||
TensorShape shape({partition_size});
|
||||
auto element_type = init_tensor->DataType();
|
||||
const OrtMemoryInfo& info = init_tensor->Location();
|
||||
std::unique_ptr<Tensor> p_tensor;
|
||||
|
||||
if (utils::IsPrimitiveDataType<float>(element_type)) {
|
||||
float* data_buffer = init_tensor->MutableData<float>();
|
||||
p_tensor = onnxruntime::make_unique<Tensor>(element_type,
|
||||
shape,
|
||||
data_buffer + partition_offset,
|
||||
info);
|
||||
} else if (utils::IsPrimitiveDataType<MLFloat16>(element_type)) {
|
||||
MLFloat16* data_buffer = init_tensor->MutableData<MLFloat16>();
|
||||
p_tensor = onnxruntime::make_unique<Tensor>(element_type,
|
||||
shape,
|
||||
data_buffer + partition_offset,
|
||||
info);
|
||||
|
||||
} else {
|
||||
ORT_THROW("Unsupported type: ", element_type, "for initial optimizer moments.");
|
||||
}
|
||||
partitioned.Init(p_tensor.release(),
|
||||
DataTypeImpl::GetType<Tensor>(),
|
||||
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
|
||||
initial_states[moments_prefix] = std::move(partitioned);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static Status AddParameterPartition(
|
||||
Graph& graph,
|
||||
GraphAugmenter::GraphDefs& graph_defs,
|
||||
|
|
@ -202,8 +241,8 @@ static Status AddParameterPartition(
|
|||
if (opt_config.mixed_precision_weight_arg != nullptr) {
|
||||
const NodeArg* weight_arg = graph.GetNodeArg(weight_argdef.name);
|
||||
ORT_ENFORCE(weight_arg != nullptr, "Could not find nodearg in graph: " + weight_argdef.name);
|
||||
ORT_ENFORCE(!graph_utils::IsGraphInput(graph, weight_arg), "Cannot partition weight that is a part of graph inputs for "+weight_argdef.name);
|
||||
|
||||
ORT_ENFORCE(!graph_utils::IsGraphInput(graph, weight_arg), "Cannot partition weight that is a part of graph inputs for " + weight_argdef.name);
|
||||
|
||||
//Partition the FP32 weight
|
||||
weight_views = AddPartitionsForParameter(graph, graph_defs, weight_argdef.name, view_shapes, updated_weight_names_map);
|
||||
ORT_ENFORCE(weight_views.size() == enabled.size());
|
||||
|
|
@ -223,12 +262,22 @@ static Status AddParameterPartition(
|
|||
weight_argdefs.insert(weight_argdefs.end(), weight_views.begin(), weight_views.end());
|
||||
gradient_argdefs.insert(gradient_argdefs.end(), gradient_views.begin(), gradient_views.end());
|
||||
|
||||
const auto& initial_states = opt_config.initial_states;
|
||||
// Update Optimizer node configs.
|
||||
ORT_ENFORCE(weight_views.size() == gradient_views.size());
|
||||
for (size_t i = 0; i < weight_views.size(); i++) {
|
||||
OptimizerNodeConfig new_config = opt_config;
|
||||
new_config.enabled = enabled[i];
|
||||
|
||||
// Partition initial optimizer state
|
||||
if (enabled[i] && !initial_states.empty()) {
|
||||
ORT_ENFORCE(view_shapes.size() == 3, "Invalid view_shapes vector passed for partitioning.");
|
||||
int64_t partition_offset = view_shapes[0].GetDims()[0];
|
||||
int64_t partition_size = view_shapes[1].GetDims()[0];
|
||||
new_config.initial_states = opt_config.initial_states;
|
||||
PartitionOptimizerState(partition_offset, partition_size, new_config.initial_states);
|
||||
}
|
||||
|
||||
if (opt_config.mixed_precision_weight_arg != nullptr) {
|
||||
new_config.mixed_precision_weight_arg = &graph.GetOrCreateNodeArg(mixed_precision_weight_views[i].name, mixed_precision_weight_views[i].type_proto);
|
||||
}
|
||||
|
|
@ -306,14 +355,14 @@ static Status ModifyParametersForOptimizerPartitioning(
|
|||
std::vector<TensorShape> view_shapes = {{size_for_previous_rank}, {size_for_current_rank}, {0}};
|
||||
std::vector<bool> enabled = {false, true};
|
||||
AddParameterPartition(graph, graph_defs, weight_argdef, gradient_argdef, opt_config, view_shapes, enabled,
|
||||
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
|
||||
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
|
||||
} else if (offset >= rank_start && offset + tensor_count > rank_end) {
|
||||
int64_t size_for_current_rank = rank_end - offset;
|
||||
int64_t size_for_next_rank = offset + tensor_count - rank_end;
|
||||
std::vector<TensorShape> view_shapes = {{0}, {size_for_current_rank}, {size_for_next_rank}};
|
||||
std::vector<bool> enabled = {true, false};
|
||||
AddParameterPartition(graph, graph_defs, weight_argdef, gradient_argdef, opt_config, view_shapes, enabled,
|
||||
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
|
||||
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
|
||||
} else { // offset < rank_start && offset + tensor_count > rank_end
|
||||
int64_t size_for_previous_rank = rank_start - offset;
|
||||
int64_t size_for_current_rank = rank_end - rank_start;
|
||||
|
|
@ -321,7 +370,7 @@ static Status ModifyParametersForOptimizerPartitioning(
|
|||
std::vector<TensorShape> view_shapes = {{size_for_previous_rank}, {size_for_current_rank}, {size_for_next_rank}};
|
||||
std::vector<bool> enabled = {false, true, false};
|
||||
AddParameterPartition(graph, graph_defs, weight_argdef, gradient_argdef, opt_config, view_shapes, enabled,
|
||||
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
|
||||
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
|
||||
}
|
||||
} else {
|
||||
// Parameter is handled by a different rank.
|
||||
|
|
|
|||
|
|
@ -28,5 +28,19 @@ class ZeROOptimizerGraphBuilder : public OptimizerGraphBuilder {
|
|||
OptimizerOutputKeyMap<std::string>& optimizer_graph_outputs) override;
|
||||
};
|
||||
|
||||
/**
|
||||
* Partitions the initial states according to the offset and
|
||||
* size provided when the optimizer state for a weight is to be
|
||||
* partitioned in Zero stage 1.
|
||||
*
|
||||
* @param partition_offset The offset for start of partition
|
||||
* @param partition_size The size(number of elements) of the partition
|
||||
* @param[out] initial_states The optimizer initial states modified in-place.
|
||||
*/
|
||||
void PartitionOptimizerState(
|
||||
const int64_t partition_offset,
|
||||
const int64_t partition_size,
|
||||
NameMLValMap& initial_states);
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include "test/util/include/asserts.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "orttraining/test/session/training_session_test_utils.h"
|
||||
#include "orttraining/core/graph/optimizer_builder.h"
|
||||
|
||||
using onnxruntime::test::CountOpsInGraph;
|
||||
using onnxruntime::test::CreateMLValue;
|
||||
|
|
@ -168,17 +169,17 @@ static void TestOptimizerGraphBuilderWithInitialStates(OptimizerGraphConfig conf
|
|||
NameMLValMap per_weight_states;
|
||||
OrtValue ml_value;
|
||||
|
||||
for (const auto key : MOMENT_PREFIX) {
|
||||
for (const auto key : MOMENTS_PREFIXES) {
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, values, &ml_value);
|
||||
per_weight_states.insert(std::make_pair(key, std::move(ml_value)));
|
||||
}
|
||||
if (optimizer_op_name == k_adam_optimizer_op_name) {
|
||||
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, uc_value, &ml_value);
|
||||
per_weight_states.insert(std::make_pair(UC_PREFIX, std::move(ml_value)));
|
||||
per_weight_states.insert(std::make_pair(ADAM_UC_PREFIX, std::move(ml_value)));
|
||||
} else if (optimizer_op_name == k_lamb_optimizer_op_name) {
|
||||
// add "Step" for lamb
|
||||
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, uc_value, &ml_value);
|
||||
shared_states.insert(std::make_pair(STEP_TENSOR_NAME, std::move(ml_value)));
|
||||
shared_states.insert(std::make_pair(LAMB_STEP_TENSOR_NAME, std::move(ml_value)));
|
||||
config.shared_optimizer_states = std::move(shared_states);
|
||||
}
|
||||
opt_config_it.second.initial_states = std::move(per_weight_states);
|
||||
|
|
@ -217,6 +218,37 @@ TEST_F(OptimizerGraphBuilderTest, LoadOptimState_FullPrecision_Lamb) {
|
|||
TestOptimizerGraphBuilderWithInitialStates(config, graph_, k_lamb_optimizer_op_name);
|
||||
}
|
||||
|
||||
TEST_F(OptimizerGraphBuilderTest, ZeroSplitInitialOptimizerState) {
|
||||
NameMLValMap initial_states;
|
||||
std::vector<int64_t> param_dims = {784, 128};
|
||||
int64_t num_ele = std::accumulate(param_dims.begin(), param_dims.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
|
||||
MLValue mlValue;
|
||||
std::vector<float> init_value(num_ele);
|
||||
std::iota(init_value.begin(), init_value.end(), static_cast<float>(0));
|
||||
|
||||
for (const auto& param_prefix : MOMENTS_PREFIXES) {
|
||||
TrainingUtil::CreateCpuMLValue<float>(param_dims, init_value, &mlValue);
|
||||
initial_states.insert(std::make_pair(param_prefix, std::move(mlValue)));
|
||||
}
|
||||
|
||||
int64_t partition_offset = 10;
|
||||
int64_t partition_size = 500;
|
||||
PartitionOptimizerState(partition_offset, partition_size, initial_states);
|
||||
|
||||
std::vector<float> expected_vec(init_value.begin() + partition_offset, init_value.begin() + partition_offset + partition_size);
|
||||
std::vector<int64_t> expected_shape = {partition_size};
|
||||
|
||||
for (const auto& state : initial_states) {
|
||||
const auto& init_tensor = state.second.Get<Tensor>();
|
||||
const auto& shape = init_tensor.Shape().GetDims();
|
||||
ASSERT_EQ(shape, expected_shape);
|
||||
const std::vector<float> found(init_tensor.Data<float>(),
|
||||
init_tensor.Data<float>() + partition_size);
|
||||
ASSERT_EQ(expected_vec, found);
|
||||
}
|
||||
}
|
||||
|
||||
static void TestDefaultOptimizerGraphBuilder(OptimizerGraphConfig config, Graph& graph) {
|
||||
std::unordered_map<std::string, std::string> updated_weight_names_map;
|
||||
OptimizerGraphBuilder optimizer_graph_builder(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import sys
|
|||
import subprocess
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
import onnx
|
||||
|
||||
from onnxruntime.training import optim
|
||||
|
||||
|
|
@ -83,6 +85,7 @@ def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power
|
|||
|
||||
|
||||
def generate_dummy_optim_state(model, optimizer):
|
||||
np.random.seed(0)
|
||||
if not (isinstance(optimizer, optim.AdamConfig) or isinstance(optimizer, optim.LambConfig)):
|
||||
return dict()
|
||||
|
||||
|
|
@ -92,26 +95,31 @@ def generate_dummy_optim_state(model, optimizer):
|
|||
shared_state_key = "shared_optimizer_state"
|
||||
|
||||
optim_state = dict()
|
||||
name_to_initializer_map = {n.name:n for n in model.graph.initializer}
|
||||
initializers_names = name_to_initializer_map.keys()
|
||||
for weight in initializers_names:
|
||||
weight_shape_map = dict()
|
||||
if isinstance(model, torch.nn.Module):
|
||||
weight_shape_map = {name: param.size() for name, param in model.named_parameters()}
|
||||
elif isinstance(model, onnx.ModelProto):
|
||||
weight_shape_map = {n.name: n.dims for n in model.graph.initializer}
|
||||
else:
|
||||
raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'")
|
||||
|
||||
for weight_name, weight_shape in weight_shape_map.items():
|
||||
per_weight_state = dict()
|
||||
weight_shape = name_to_initializer_map[weight].dims
|
||||
for moment in moment_keys:
|
||||
per_weight_state[moment] = np.full(weight_shape, 2.5, dtype=np.float32)
|
||||
per_weight_state[moment] = np.random.uniform(-2, 2, weight_shape).astype(np.float32)
|
||||
if isinstance(optimizer, optim.AdamConfig):
|
||||
per_weight_state[uc_key] = np.full([1], 5, dtype=np.int64)
|
||||
optim_state[weight] = copy.deepcopy(per_weight_state)
|
||||
optim_state[weight_name] = copy.deepcopy(per_weight_state)
|
||||
if isinstance(optimizer, optim.LambConfig):
|
||||
step_val = np.full([1], 5, dtype=np.int64)
|
||||
optim_state[shared_state_key] = {step_key : step_val}
|
||||
optim_state[shared_state_key] = {step_key: step_val}
|
||||
return optim_state
|
||||
|
||||
|
||||
def get_optim_state_from_state_dict(state_dict, optimizer):
|
||||
if not (isinstance(optimizer, optim.AdamConfig) or isinstance(optimizer, optim.LambConfig)):
|
||||
return dict()
|
||||
|
||||
|
||||
moment_keys = ["Moment_1", "Moment_2"]
|
||||
uc_key = "Update_Count"
|
||||
step_key = "Step"
|
||||
|
|
@ -119,6 +127,9 @@ def get_optim_state_from_state_dict(state_dict, optimizer):
|
|||
|
||||
optim_state = dict()
|
||||
for param_name, v in state_dict.items():
|
||||
if '_view_' in param_name:
|
||||
param_name = param_name.split('_view_')[0]
|
||||
|
||||
for moment in moment_keys:
|
||||
if param_name.startswith(moment):
|
||||
fp32_name = param_name.split(moment + '_')[-1]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from onnxruntime.training import amp, checkpoint, optim, orttrainer
|
|||
from orttraining_test_orttrainer_frontend import _load_pytorch_transformer_model
|
||||
from onnxruntime.capi._pybind_state import set_cuda_device_id, get_mpi_context_world_rank, get_mpi_context_world_size
|
||||
|
||||
from _test_commons import generate_dummy_optim_state
|
||||
|
||||
global_fp16_fp32_atol = 1e-3
|
||||
|
||||
def _train(trainer, train_data, batcher_fn, total_batch_steps = 5, seed = 1):
|
||||
|
|
@ -103,7 +105,7 @@ def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir,
|
|||
def split_state_dict(state_dict):
|
||||
"""Given a flat state dictionary, split it into optimizer, fp32_param, fp16_param hierarchical dictionary and return"""
|
||||
|
||||
optimizer_keys = ['Moment_1_', 'Moment_2_', 'Update_Count_', 'Step_']
|
||||
optimizer_keys = ['Moment_1_', 'Moment_2_', 'Update_Count_', 'Step']
|
||||
split_sd = {'optimizer': {}, 'fp32_param': {}, 'fp16_param': {}}
|
||||
for k, v in state_dict.items():
|
||||
mode = 'fp32_param'
|
||||
|
|
@ -175,3 +177,25 @@ def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir,
|
|||
# save current model parameters as a checkpoint
|
||||
if checkpoint_dir:
|
||||
_save(trainer, checkpoint_dir, state_dict_key_name)
|
||||
|
||||
|
||||
def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True):
|
||||
learning_rate = 0.1
|
||||
seed = 1
|
||||
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
|
||||
optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
|
||||
model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))
|
||||
|
||||
# load dummy state
|
||||
dummy_init_state = generate_dummy_optim_state(model, optim_config)
|
||||
checkpoint._experimental_load_optimizer_state(trainer, dummy_init_state)
|
||||
|
||||
# run an eval step to innitialize the graph
|
||||
data, targets = batcher_fn(train_data, 0)
|
||||
trainer.eval_step(data, targets)
|
||||
|
||||
return dummy_init_state, checkpoint.experimental_state_dict(trainer)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,153 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
################################################################################
|
||||
# Refer to orttraining_test_checkpoint.py for an overview about Checkpoint tests
|
||||
################################################################################
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from numpy.testing import assert_allclose
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training import checkpoint, optim
|
||||
from _test_helpers import distributed_setup, load_model_optim_state_and_eval, split_state_dict, aggregate_states, global_fp16_fp32_atol
|
||||
from _test_commons import get_optim_state_from_state_dict
|
||||
|
||||
def verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False):
|
||||
expected_optim_state, trainer_state = load_model_optim_state_and_eval(device, opts, use_lamb)
|
||||
trainer_state = split_state_dict(trainer_state)
|
||||
# round about way of checking optimizer states. Save state dicts into temporary folder, read them and aggregate them.
|
||||
with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f:
|
||||
pickle.dump(trainer_state, f)
|
||||
dist.barrier()
|
||||
|
||||
if world_rank == 0:
|
||||
num_states = len(glob.glob1(checkpoint_dir, "distributed_state*"))
|
||||
optimizer_states = dict()
|
||||
for rank in range(num_states):
|
||||
rank_state_dict = None
|
||||
with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(rank)+'.pkl'), 'rb') as f:
|
||||
rank_state_dict = pickle.load(f)
|
||||
|
||||
# collect optimizer states for later comparison since they are sharded
|
||||
aggregate_states(optimizer_states, rank_state_dict['optimizer'])
|
||||
|
||||
# compare optimizer states
|
||||
optimizer_config = optim.LambConfig() if use_lamb else optim.AdamConfig()
|
||||
actual_optim_state = get_optim_state_from_state_dict(optimizer_states, optimizer_config)
|
||||
assert actual_optim_state.keys() == expected_optim_state.keys()
|
||||
for param_name, a_state in actual_optim_state.items():
|
||||
for k, v in a_state.items():
|
||||
assert_allclose(v.reshape(expected_optim_state[param_name][k].shape),
|
||||
expected_optim_state[param_name][k],
|
||||
err_msg=f"Optimizer state mismatch for param {param_name}, key {k}")
|
||||
|
||||
dist.barrier()
|
||||
os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'))
|
||||
|
||||
|
||||
@distributed_setup
|
||||
def test_optim_load_to_distributed_zero_full_precision_adam(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/adam/'):
|
||||
opts = {
|
||||
'device' : {'id' : device},
|
||||
'distributed' :
|
||||
{
|
||||
'world_rank' : world_rank,
|
||||
'world_size' : world_size,
|
||||
'allreduce_post_accumulation' : True,
|
||||
'deepspeed_zero_optimization':
|
||||
{
|
||||
'stage': 1
|
||||
}
|
||||
},
|
||||
'debug' : {'deterministic_compute': True}
|
||||
}
|
||||
verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False)
|
||||
|
||||
|
||||
@distributed_setup
|
||||
def test_optim_load_to_distributed_zero_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/adam/'):
|
||||
opts = {
|
||||
'device' : {'id' : device},
|
||||
'mixed_precision':
|
||||
{
|
||||
'enabled': True
|
||||
},
|
||||
'distributed' :
|
||||
{
|
||||
'world_rank' : world_rank,
|
||||
'world_size' : world_size,
|
||||
'allreduce_post_accumulation' : True,
|
||||
'deepspeed_zero_optimization':
|
||||
{
|
||||
'stage': 1
|
||||
}
|
||||
},
|
||||
'debug' : {'deterministic_compute': True}
|
||||
}
|
||||
verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False)
|
||||
|
||||
|
||||
@distributed_setup
|
||||
def test_optim_load_to_distributed_zero_full_precision_lamb(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/lamb/'):
|
||||
opts = {
|
||||
'device' : {'id' : device},
|
||||
'distributed' :
|
||||
{
|
||||
'world_rank' : world_rank,
|
||||
'world_size' : world_size,
|
||||
'allreduce_post_accumulation' : True,
|
||||
'deepspeed_zero_optimization':
|
||||
{
|
||||
'stage': 1
|
||||
}
|
||||
},
|
||||
'debug' : {'deterministic_compute': True}
|
||||
}
|
||||
verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=True)
|
||||
|
||||
@distributed_setup
|
||||
def test_optim_load_to_distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb/'):
|
||||
opts = {
|
||||
'device' : {'id' : device},
|
||||
'mixed_precision':
|
||||
{
|
||||
'enabled': True
|
||||
},
|
||||
'distributed' :
|
||||
{
|
||||
'world_rank' : world_rank,
|
||||
'world_size' : world_size,
|
||||
'allreduce_post_accumulation' : True,
|
||||
'deepspeed_zero_optimization':
|
||||
{
|
||||
'stage': 1
|
||||
}
|
||||
},
|
||||
'debug' : {'deterministic_compute': True}
|
||||
}
|
||||
verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=True)
|
||||
|
||||
|
||||
function_map = {
|
||||
# load to zero configs
|
||||
'test_optim_load_to_distributed_zero_full_precision_adam': test_optim_load_to_distributed_zero_full_precision_adam,
|
||||
'test_optim_load_to_distributed_zero_mixed_precision_adam': test_optim_load_to_distributed_zero_mixed_precision_adam,
|
||||
'test_optim_load_to_distributed_zero_mixed_precision_lamb': test_optim_load_to_distributed_zero_mixed_precision_lamb,
|
||||
'test_optim_load_to_distributed_zero_full_precision_lamb': test_optim_load_to_distributed_zero_full_precision_lamb
|
||||
}
|
||||
parser = argparse.ArgumentParser(description='Test loading of initial optimizer state for Zero-1')
|
||||
parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to test loaded states', required=True)
|
||||
parser.add_argument('--checkpoint_dir', help='path to the saved states directory', required=True)
|
||||
args = parser.parse_args()
|
||||
function_map[args.scenario](checkpoint_dir=args.checkpoint_dir)
|
||||
|
|
@ -56,6 +56,7 @@ ngpus = torch.cuda.device_count()
|
|||
save_checkpoint_file = os.path.join('checkpoint', 'orttraining_test_save_checkpoint.py')
|
||||
load_checkpoint_file = os.path.join('checkpoint', 'orttraining_test_load_checkpoint.py')
|
||||
aggregate_checkpoint_file = os.path.join('checkpoint', 'orttraining_test_checkpoint_aggregation.py')
|
||||
optim_state_file = os.path.join('checkpoint', 'orttraining_test_load_optimizer_state.py')
|
||||
|
||||
single_node_full_precision_path = os.path.join(checkpoint_dir, 'single_node', 'full_precision')
|
||||
single_node_mixed_precision_path = os.path.join(checkpoint_dir, 'single_node', 'mixed_precision')
|
||||
|
|
@ -125,4 +126,10 @@ _single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_m
|
|||
_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path)
|
||||
_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path)
|
||||
|
||||
# optimizer state loading into model-parallel tests
|
||||
_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_full_precision_adam', distributed_zero_full_precision_adam_path)
|
||||
_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_mixed_precision_adam', distributed_zero_mixed_precision_adam_path)
|
||||
_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path)
|
||||
_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path)
|
||||
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/test/session/training_session_test_utils.h"
|
||||
#include "orttraining/core/graph/optimizer_builder.h"
|
||||
|
||||
using namespace onnxruntime::logging;
|
||||
using namespace onnxruntime::training;
|
||||
|
|
@ -38,21 +39,21 @@ void GenerateOptimizerInitialState(const std::string& optimizer_op_name, const T
|
|||
std::vector<int64_t> param_dims = WEIGHT_TO_SHAPE_MAP.at(weight_name);
|
||||
int64_t num_ele = std::accumulate(param_dims.begin(), param_dims.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
|
||||
for (auto& param_prefix : MOMENT_PREFIX) {
|
||||
for (auto& param_prefix : MOMENTS_PREFIXES) {
|
||||
std::vector<T> param_value(num_ele, init_moment_value);
|
||||
TrainingUtil::CreateCpuMLValue<T>(param_dims, param_value, &mlValue);
|
||||
optim_state.insert(std::make_pair(param_prefix, std::move(mlValue)));
|
||||
}
|
||||
if (optimizer_op_name == k_adam_optimizer_op_name) {
|
||||
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, uc_value, &mlValue);
|
||||
optim_state.insert(std::make_pair(UC_PREFIX, std::move(mlValue)));
|
||||
optim_state.insert(std::make_pair(ADAM_UC_PREFIX, std::move(mlValue)));
|
||||
}
|
||||
result.insert(std::make_pair(weight_name, std::move(optim_state)));
|
||||
}
|
||||
if (optimizer_op_name == k_lamb_optimizer_op_name) {
|
||||
// add "Step" for lamb
|
||||
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, uc_value, &mlValue);
|
||||
shared_states.insert(std::make_pair(STEP_TENSOR_NAME, std::move(mlValue)));
|
||||
shared_states.insert(std::make_pair(LAMB_STEP_TENSOR_NAME, std::move(mlValue)));
|
||||
result.insert(std::make_pair(onnxruntime::training::SHARED_OPTIMIZER_STATES_KEY, std::move(shared_states)));
|
||||
}
|
||||
optimizer_state = std::move(result);
|
||||
|
|
@ -73,7 +74,9 @@ void SeparateStateTensors(const NameMLValMap& training_state, NameMLValMap& mode
|
|||
model_state = std::move(result);
|
||||
for (auto& weight_name : WEIGHT_NAMES) {
|
||||
NameMLValMap optim_state;
|
||||
for (auto& param_prefix : MOMENT_UC_PREFIX) {
|
||||
std::vector<std::string> per_weight_states_prefixes(MOMENTS_PREFIXES);
|
||||
per_weight_states_prefixes.push_back(ADAM_UC_PREFIX);
|
||||
for (auto& param_prefix : per_weight_states_prefixes) {
|
||||
std::string param_name = param_prefix + "_" + weight_name;
|
||||
const auto& param_state_it = training_state.find(param_name);
|
||||
if (param_state_it != training_state.end()) {
|
||||
|
|
@ -83,9 +86,9 @@ void SeparateStateTensors(const NameMLValMap& training_state, NameMLValMap& mode
|
|||
optimizer_state.insert(std::make_pair(weight_name, optim_state));
|
||||
}
|
||||
NameMLValMap shared_optim_state;
|
||||
const auto& param_state_it = training_state.find(STEP_TENSOR_NAME);
|
||||
const auto& param_state_it = training_state.find(LAMB_STEP_TENSOR_NAME);
|
||||
if (param_state_it != training_state.end()) {
|
||||
shared_optim_state.insert(std::make_pair(STEP_TENSOR_NAME, param_state_it->second));
|
||||
shared_optim_state.insert(std::make_pair(LAMB_STEP_TENSOR_NAME, param_state_it->second));
|
||||
optimizer_state.insert(std::make_pair(onnxruntime::training::SHARED_OPTIMIZER_STATES_KEY, shared_optim_state));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,10 +36,6 @@ const std::unordered_map<std::string, std::vector<int64_t>> WEIGHT_TO_SHAPE_MAP
|
|||
{"B2", {32}},
|
||||
{"W3", {32, 10}},
|
||||
{"B1", {128}}};
|
||||
const std::vector<std::string> MOMENT_PREFIX = {"Moment_1", "Moment_2"};
|
||||
const std::vector<std::string> MOMENT_UC_PREFIX = {"Moment_1", "Moment_2", "Update_Count"};
|
||||
constexpr char STEP_TENSOR_NAME[] = "Step";
|
||||
constexpr char UC_PREFIX[] = "Update_Count";
|
||||
|
||||
void GenerateOptimizerConfig(const std::string optimizer_name,
|
||||
const bool use_mixed_precision_moments,
|
||||
|
|
|
|||
Loading…
Reference in a new issue