Partition initial optimizer state for Zero-1 (#6093)

* Initial changes

* Working changes

* Working changes

* Cleanup

* fix windows CI

* Review comments

* review comments
This commit is contained in:
ashbhandare 2020-12-16 15:27:42 -05:00 committed by GitHub
parent 8fd085801a
commit 82690486c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 335 additions and 45 deletions

View file

@ -37,13 +37,12 @@ Status AdamOptimizerBuilder::Build(
// In distributed training, some weights may not be updated by all ranks.
if (opt_configs[i].enabled) {
// The type proto initializer for Update Count
const std::string uc_prefix = "Update_Count";
const std::string update_count_string = uc_prefix + "_" + weight_name; // per weight optimizer requires a per weight update count
const std::string update_count_string = ADAM_UC_PREFIX + "_" + weight_name; // per weight optimizer requires a per weight update count
TensorProto uc_tensor_proto;
// Update 'Update_Count' initializer with init value
const auto& initial_states = opt_configs[i].initial_states;
const auto uc_state_it = initial_states.find(uc_prefix);
const auto uc_state_it = initial_states.find(ADAM_UC_PREFIX);
if (uc_state_it != initial_states.end()) {
const auto& init_tensor = uc_state_it->second.Get<Tensor>();
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, ONNX_NAMESPACE::TensorProto_DataType_INT64, {1}));
@ -82,8 +81,7 @@ Status AdamOptimizerBuilder::Build(
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
// Add first- and second-order momentums to input list.
const std::vector<std::string> moments_prefixes({"Moment_1", "Moment_2"});
for (const auto& moments_prefix : moments_prefixes) {
for (const auto& moments_prefix : MOMENTS_PREFIXES) {
const std::string gradient_moment_name = moments_prefix + "_" + weight_name;
TensorProto moment_tensor_proto;

View file

@ -63,24 +63,23 @@ Status LambOptimizerBuilder::Build(
// Update count, which should be 1 at the first training iteration.
// At the end of each Lamb call, the update count may be increased by one.
const std::string step_tensor_name = "Step"; // per weight optimizer requires a per weight update count
// Add step as an initializer.
TensorProto step_tensor_proto;
const auto& shared_optim_state = config.shared_optimizer_states;
const auto step_state_it = shared_optim_state.find(step_tensor_name);
const auto step_state_it = shared_optim_state.find(LAMB_STEP_TENSOR_NAME);
if (step_state_it != shared_optim_state.end()) {
const auto& init_tensor = step_state_it->second.Get<Tensor>();
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, ONNX_NAMESPACE::TensorProto_DataType_INT64, {1}));
step_tensor_proto = utils::TensorToTensorProto(init_tensor, step_tensor_name);
step_tensor_proto = utils::TensorToTensorProto(init_tensor, LAMB_STEP_TENSOR_NAME);
} else {
step_tensor_proto = CreateTensorProto<int64_t>(step_tensor_name, 1);
step_tensor_proto = CreateTensorProto<int64_t>(LAMB_STEP_TENSOR_NAME, 1);
}
new_external_initializers.emplace_back(step_tensor_proto);
input_argdefs.emplace_back(ArgDef(step_tensor_name));
input_argdefs.emplace_back(ArgDef(LAMB_STEP_TENSOR_NAME));
// Add the first output, which is the updated step.
TypeProto* step_type_proto = graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_INT64);
output_argdefs.emplace_back(ArgDef(step_tensor_name + "_Out", step_type_proto));
output_argdefs.emplace_back(ArgDef(LAMB_STEP_TENSOR_NAME + "_Out", step_type_proto));
// Lamb optimizer's attributes.
std::vector<float> alpha;
@ -205,8 +204,7 @@ Status LambOptimizerBuilder::Build(
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
// m1 & m2 & m1_new & m2_new
const std::vector<std::string> moments_prefixes({"Moment_1", "Moment_2"});
for (const auto& moment_prefix : moments_prefixes) {
for (const auto& moment_prefix : MOMENTS_PREFIXES) {
const std::string gradient_moment_name = moment_prefix + "_" + weight_name;
// Construct type of momentum tensor.

View file

@ -19,10 +19,11 @@ void OptimizerBuilderRegistry::RegisterBuilders() {
Status IsMatchingTypeAndShape(
const onnxruntime::Tensor& tensor,
const int32_t element_type,
const std::vector<int64_t>& expected_shape) {
const std::vector<int64_t>& expected_shape_dims) {
ORT_RETURN_IF_NOT(tensor.GetElementType() == element_type);
const std::vector<int64_t>& tensor_shape = tensor.Shape().GetDims();
ORT_RETURN_IF_NOT(tensor_shape == expected_shape);
const TensorShape& tensor_shape = tensor.Shape();
TensorShape expected_shape(expected_shape_dims);
ORT_RETURN_IF_NOT(tensor_shape == expected_shape, "Mismatch: expected:[", tensor_shape, "], actual:[", expected_shape, "]");
return Status::OK();
}

View file

@ -14,6 +14,10 @@
namespace onnxruntime {
namespace training {
const std::vector<std::string> MOMENTS_PREFIXES({"Moment_1", "Moment_2"});
const std::string LAMB_STEP_TENSOR_NAME = "Step";
const std::string ADAM_UC_PREFIX = "Update_Count";
template <class T>
ONNX_NAMESPACE::TensorProto CreateTensorProto(
const std::string& name,

View file

@ -104,7 +104,7 @@ static std::vector<ArgDef> AddPartitionsForParameter(
GraphAugmenter::GraphDefs& graph_defs,
const std::string& initializer_name,
const std::vector<TensorShape>& shapes,
std::unordered_map<std::string, std::string>& updated_weight_names_map){
std::unordered_map<std::string, std::string>& updated_weight_names_map) {
ORT_ENFORCE(shapes.size() == 3, "Invalid shapes vector passed for partitioning.");
int64_t partition_offset = shapes[0].GetDims()[0];
int64_t partition_size = shapes[1].GetDims()[0];
@ -135,7 +135,7 @@ static std::vector<ArgDef> AddPartitionsForParameter(
graph.RemoveInitializedTensor(initializer_name);
graph.AddInitializedTensor(initializer_partition);
//add the modified weight name to get state
//add the modified weight name to get state
updated_weight_names_map[initializer_name] = partition_name;
auto partition_argdef = ArgDef(partition_name, graph_defs.CreateTypeProto({partition_size}, dtype));
@ -185,6 +185,45 @@ static std::vector<ArgDef> AddViewForParameter(
return view_outputs;
}
void PartitionOptimizerState(
const int64_t partition_offset,
const int64_t partition_size,
NameMLValMap& initial_states) {
for (const auto& moments_prefix : MOMENTS_PREFIXES) {
const auto initial_state_it = initial_states.find(moments_prefix);
if (initial_state_it != initial_states.end()) {
auto* init_tensor = initial_state_it->second.GetMutable<Tensor>();
OrtValue partitioned;
TensorShape shape({partition_size});
auto element_type = init_tensor->DataType();
const OrtMemoryInfo& info = init_tensor->Location();
std::unique_ptr<Tensor> p_tensor;
if (utils::IsPrimitiveDataType<float>(element_type)) {
float* data_buffer = init_tensor->MutableData<float>();
p_tensor = onnxruntime::make_unique<Tensor>(element_type,
shape,
data_buffer + partition_offset,
info);
} else if (utils::IsPrimitiveDataType<MLFloat16>(element_type)) {
MLFloat16* data_buffer = init_tensor->MutableData<MLFloat16>();
p_tensor = onnxruntime::make_unique<Tensor>(element_type,
shape,
data_buffer + partition_offset,
info);
} else {
ORT_THROW("Unsupported type: ", element_type, "for initial optimizer moments.");
}
partitioned.Init(p_tensor.release(),
DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
initial_states[moments_prefix] = std::move(partitioned);
}
}
}
static Status AddParameterPartition(
Graph& graph,
GraphAugmenter::GraphDefs& graph_defs,
@ -202,8 +241,8 @@ static Status AddParameterPartition(
if (opt_config.mixed_precision_weight_arg != nullptr) {
const NodeArg* weight_arg = graph.GetNodeArg(weight_argdef.name);
ORT_ENFORCE(weight_arg != nullptr, "Could not find nodearg in graph: " + weight_argdef.name);
ORT_ENFORCE(!graph_utils::IsGraphInput(graph, weight_arg), "Cannot partition weight that is a part of graph inputs for "+weight_argdef.name);
ORT_ENFORCE(!graph_utils::IsGraphInput(graph, weight_arg), "Cannot partition weight that is a part of graph inputs for " + weight_argdef.name);
//Partition the FP32 weight
weight_views = AddPartitionsForParameter(graph, graph_defs, weight_argdef.name, view_shapes, updated_weight_names_map);
ORT_ENFORCE(weight_views.size() == enabled.size());
@ -223,12 +262,22 @@ static Status AddParameterPartition(
weight_argdefs.insert(weight_argdefs.end(), weight_views.begin(), weight_views.end());
gradient_argdefs.insert(gradient_argdefs.end(), gradient_views.begin(), gradient_views.end());
const auto& initial_states = opt_config.initial_states;
// Update Optimizer node configs.
ORT_ENFORCE(weight_views.size() == gradient_views.size());
for (size_t i = 0; i < weight_views.size(); i++) {
OptimizerNodeConfig new_config = opt_config;
new_config.enabled = enabled[i];
// Partition initial optimizer state
if (enabled[i] && !initial_states.empty()) {
ORT_ENFORCE(view_shapes.size() == 3, "Invalid view_shapes vector passed for partitioning.");
int64_t partition_offset = view_shapes[0].GetDims()[0];
int64_t partition_size = view_shapes[1].GetDims()[0];
new_config.initial_states = opt_config.initial_states;
PartitionOptimizerState(partition_offset, partition_size, new_config.initial_states);
}
if (opt_config.mixed_precision_weight_arg != nullptr) {
new_config.mixed_precision_weight_arg = &graph.GetOrCreateNodeArg(mixed_precision_weight_views[i].name, mixed_precision_weight_views[i].type_proto);
}
@ -306,14 +355,14 @@ static Status ModifyParametersForOptimizerPartitioning(
std::vector<TensorShape> view_shapes = {{size_for_previous_rank}, {size_for_current_rank}, {0}};
std::vector<bool> enabled = {false, true};
AddParameterPartition(graph, graph_defs, weight_argdef, gradient_argdef, opt_config, view_shapes, enabled,
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
} else if (offset >= rank_start && offset + tensor_count > rank_end) {
int64_t size_for_current_rank = rank_end - offset;
int64_t size_for_next_rank = offset + tensor_count - rank_end;
std::vector<TensorShape> view_shapes = {{0}, {size_for_current_rank}, {size_for_next_rank}};
std::vector<bool> enabled = {true, false};
AddParameterPartition(graph, graph_defs, weight_argdef, gradient_argdef, opt_config, view_shapes, enabled,
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
} else { // offset < rank_start && offset + tensor_count > rank_end
int64_t size_for_previous_rank = rank_start - offset;
int64_t size_for_current_rank = rank_end - rank_start;
@ -321,7 +370,7 @@ static Status ModifyParametersForOptimizerPartitioning(
std::vector<TensorShape> view_shapes = {{size_for_previous_rank}, {size_for_current_rank}, {size_for_next_rank}};
std::vector<bool> enabled = {false, true, false};
AddParameterPartition(graph, graph_defs, weight_argdef, gradient_argdef, opt_config, view_shapes, enabled,
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
new_opt_configs, new_weight_argdefs, new_gradient_argdefs, updated_weight_names_map);
}
} else {
// Parameter is handled by a different rank.

View file

@ -28,5 +28,19 @@ class ZeROOptimizerGraphBuilder : public OptimizerGraphBuilder {
OptimizerOutputKeyMap<std::string>& optimizer_graph_outputs) override;
};
/**
* Partitions the initial states according to the offset and
* size provided when the optimizer state for a weight is to be
* partitioned in Zero stage 1.
*
* @param partition_offset The offset for start of partition
* @param partition_size The size(number of elements) of the partition
* @param[out] initial_states The optimizer initial states modified in-place.
*/
void PartitionOptimizerState(
const int64_t partition_offset,
const int64_t partition_size,
NameMLValMap& initial_states);
} // namespace training
} // namespace onnxruntime

View file

@ -23,6 +23,7 @@
#include "test/util/include/asserts.h"
#include "test/test_environment.h"
#include "orttraining/test/session/training_session_test_utils.h"
#include "orttraining/core/graph/optimizer_builder.h"
using onnxruntime::test::CountOpsInGraph;
using onnxruntime::test::CreateMLValue;
@ -168,17 +169,17 @@ static void TestOptimizerGraphBuilderWithInitialStates(OptimizerGraphConfig conf
NameMLValMap per_weight_states;
OrtValue ml_value;
for (const auto key : MOMENT_PREFIX) {
for (const auto key : MOMENTS_PREFIXES) {
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, values, &ml_value);
per_weight_states.insert(std::make_pair(key, std::move(ml_value)));
}
if (optimizer_op_name == k_adam_optimizer_op_name) {
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, uc_value, &ml_value);
per_weight_states.insert(std::make_pair(UC_PREFIX, std::move(ml_value)));
per_weight_states.insert(std::make_pair(ADAM_UC_PREFIX, std::move(ml_value)));
} else if (optimizer_op_name == k_lamb_optimizer_op_name) {
// add "Step" for lamb
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims, uc_value, &ml_value);
shared_states.insert(std::make_pair(STEP_TENSOR_NAME, std::move(ml_value)));
shared_states.insert(std::make_pair(LAMB_STEP_TENSOR_NAME, std::move(ml_value)));
config.shared_optimizer_states = std::move(shared_states);
}
opt_config_it.second.initial_states = std::move(per_weight_states);
@ -217,6 +218,37 @@ TEST_F(OptimizerGraphBuilderTest, LoadOptimState_FullPrecision_Lamb) {
TestOptimizerGraphBuilderWithInitialStates(config, graph_, k_lamb_optimizer_op_name);
}
TEST_F(OptimizerGraphBuilderTest, ZeroSplitInitialOptimizerState) {
NameMLValMap initial_states;
std::vector<int64_t> param_dims = {784, 128};
int64_t num_ele = std::accumulate(param_dims.begin(), param_dims.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
MLValue mlValue;
std::vector<float> init_value(num_ele);
std::iota(init_value.begin(), init_value.end(), static_cast<float>(0));
for (const auto& param_prefix : MOMENTS_PREFIXES) {
TrainingUtil::CreateCpuMLValue<float>(param_dims, init_value, &mlValue);
initial_states.insert(std::make_pair(param_prefix, std::move(mlValue)));
}
int64_t partition_offset = 10;
int64_t partition_size = 500;
PartitionOptimizerState(partition_offset, partition_size, initial_states);
std::vector<float> expected_vec(init_value.begin() + partition_offset, init_value.begin() + partition_offset + partition_size);
std::vector<int64_t> expected_shape = {partition_size};
for (const auto& state : initial_states) {
const auto& init_tensor = state.second.Get<Tensor>();
const auto& shape = init_tensor.Shape().GetDims();
ASSERT_EQ(shape, expected_shape);
const std::vector<float> found(init_tensor.Data<float>(),
init_tensor.Data<float>() + partition_size);
ASSERT_EQ(expected_vec, found);
}
}
static void TestDefaultOptimizerGraphBuilder(OptimizerGraphConfig config, Graph& graph) {
std::unordered_map<std::string, std::string> updated_weight_names_map;
OptimizerGraphBuilder optimizer_graph_builder(

View file

@ -4,6 +4,8 @@ import sys
import subprocess
import copy
import numpy as np
import torch
import onnx
from onnxruntime.training import optim
@ -83,6 +85,7 @@ def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power
def generate_dummy_optim_state(model, optimizer):
np.random.seed(0)
if not (isinstance(optimizer, optim.AdamConfig) or isinstance(optimizer, optim.LambConfig)):
return dict()
@ -92,26 +95,31 @@ def generate_dummy_optim_state(model, optimizer):
shared_state_key = "shared_optimizer_state"
optim_state = dict()
name_to_initializer_map = {n.name:n for n in model.graph.initializer}
initializers_names = name_to_initializer_map.keys()
for weight in initializers_names:
weight_shape_map = dict()
if isinstance(model, torch.nn.Module):
weight_shape_map = {name: param.size() for name, param in model.named_parameters()}
elif isinstance(model, onnx.ModelProto):
weight_shape_map = {n.name: n.dims for n in model.graph.initializer}
else:
raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'")
for weight_name, weight_shape in weight_shape_map.items():
per_weight_state = dict()
weight_shape = name_to_initializer_map[weight].dims
for moment in moment_keys:
per_weight_state[moment] = np.full(weight_shape, 2.5, dtype=np.float32)
per_weight_state[moment] = np.random.uniform(-2, 2, weight_shape).astype(np.float32)
if isinstance(optimizer, optim.AdamConfig):
per_weight_state[uc_key] = np.full([1], 5, dtype=np.int64)
optim_state[weight] = copy.deepcopy(per_weight_state)
optim_state[weight_name] = copy.deepcopy(per_weight_state)
if isinstance(optimizer, optim.LambConfig):
step_val = np.full([1], 5, dtype=np.int64)
optim_state[shared_state_key] = {step_key : step_val}
optim_state[shared_state_key] = {step_key: step_val}
return optim_state
def get_optim_state_from_state_dict(state_dict, optimizer):
if not (isinstance(optimizer, optim.AdamConfig) or isinstance(optimizer, optim.LambConfig)):
return dict()
moment_keys = ["Moment_1", "Moment_2"]
uc_key = "Update_Count"
step_key = "Step"
@ -119,6 +127,9 @@ def get_optim_state_from_state_dict(state_dict, optimizer):
optim_state = dict()
for param_name, v in state_dict.items():
if '_view_' in param_name:
param_name = param_name.split('_view_')[0]
for moment in moment_keys:
if param_name.startswith(moment):
fp32_name = param_name.split(moment + '_')[-1]

View file

@ -10,6 +10,8 @@ from onnxruntime.training import amp, checkpoint, optim, orttrainer
from orttraining_test_orttrainer_frontend import _load_pytorch_transformer_model
from onnxruntime.capi._pybind_state import set_cuda_device_id, get_mpi_context_world_rank, get_mpi_context_world_size
from _test_commons import generate_dummy_optim_state
global_fp16_fp32_atol = 1e-3
def _train(trainer, train_data, batcher_fn, total_batch_steps = 5, seed = 1):
@ -103,7 +105,7 @@ def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir,
def split_state_dict(state_dict):
"""Given a flat state dictionary, split it into optimizer, fp32_param, fp16_param hierarchical dictionary and return"""
optimizer_keys = ['Moment_1_', 'Moment_2_', 'Update_Count_', 'Step_']
optimizer_keys = ['Moment_1_', 'Moment_2_', 'Update_Count_', 'Step']
split_sd = {'optimizer': {}, 'fp32_param': {}, 'fp16_param': {}}
for k, v in state_dict.items():
mode = 'fp32_param'
@ -175,3 +177,25 @@ def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir,
# save current model parameters as a checkpoint
if checkpoint_dir:
_save(trainer, checkpoint_dir, state_dict_key_name)
def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True):
learning_rate = 0.1
seed = 1
torch.manual_seed(seed)
set_seed(seed)
optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))
# load dummy state
dummy_init_state = generate_dummy_optim_state(model, optim_config)
checkpoint._experimental_load_optimizer_state(trainer, dummy_init_state)
# run an eval step to innitialize the graph
data, targets = batcher_fn(train_data, 0)
trainer.eval_step(data, targets)
return dummy_init_state, checkpoint.experimental_state_dict(trainer)

View file

@ -0,0 +1,153 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
################################################################################
# Refer to orttraining_test_checkpoint.py for an overview about Checkpoint tests
################################################################################
import os
import pickle
from numpy.testing import assert_allclose
import argparse
import glob
import torch
import torch.distributed as dist
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import onnxruntime
from onnxruntime.training import checkpoint, optim
from _test_helpers import distributed_setup, load_model_optim_state_and_eval, split_state_dict, aggregate_states, global_fp16_fp32_atol
from _test_commons import get_optim_state_from_state_dict
def verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False):
expected_optim_state, trainer_state = load_model_optim_state_and_eval(device, opts, use_lamb)
trainer_state = split_state_dict(trainer_state)
# round about way of checking optimizer states. Save state dicts into temporary folder, read them and aggregate them.
with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f:
pickle.dump(trainer_state, f)
dist.barrier()
if world_rank == 0:
num_states = len(glob.glob1(checkpoint_dir, "distributed_state*"))
optimizer_states = dict()
for rank in range(num_states):
rank_state_dict = None
with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(rank)+'.pkl'), 'rb') as f:
rank_state_dict = pickle.load(f)
# collect optimizer states for later comparison since they are sharded
aggregate_states(optimizer_states, rank_state_dict['optimizer'])
# compare optimizer states
optimizer_config = optim.LambConfig() if use_lamb else optim.AdamConfig()
actual_optim_state = get_optim_state_from_state_dict(optimizer_states, optimizer_config)
assert actual_optim_state.keys() == expected_optim_state.keys()
for param_name, a_state in actual_optim_state.items():
for k, v in a_state.items():
assert_allclose(v.reshape(expected_optim_state[param_name][k].shape),
expected_optim_state[param_name][k],
err_msg=f"Optimizer state mismatch for param {param_name}, key {k}")
dist.barrier()
os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'))
@distributed_setup
def test_optim_load_to_distributed_zero_full_precision_adam(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/adam/'):
opts = {
'device' : {'id' : device},
'distributed' :
{
'world_rank' : world_rank,
'world_size' : world_size,
'allreduce_post_accumulation' : True,
'deepspeed_zero_optimization':
{
'stage': 1
}
},
'debug' : {'deterministic_compute': True}
}
verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False)
@distributed_setup
def test_optim_load_to_distributed_zero_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/adam/'):
opts = {
'device' : {'id' : device},
'mixed_precision':
{
'enabled': True
},
'distributed' :
{
'world_rank' : world_rank,
'world_size' : world_size,
'allreduce_post_accumulation' : True,
'deepspeed_zero_optimization':
{
'stage': 1
}
},
'debug' : {'deterministic_compute': True}
}
verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False)
@distributed_setup
def test_optim_load_to_distributed_zero_full_precision_lamb(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/lamb/'):
opts = {
'device' : {'id' : device},
'distributed' :
{
'world_rank' : world_rank,
'world_size' : world_size,
'allreduce_post_accumulation' : True,
'deepspeed_zero_optimization':
{
'stage': 1
}
},
'debug' : {'deterministic_compute': True}
}
verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=True)
@distributed_setup
def test_optim_load_to_distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb/'):
opts = {
'device' : {'id' : device},
'mixed_precision':
{
'enabled': True
},
'distributed' :
{
'world_rank' : world_rank,
'world_size' : world_size,
'allreduce_post_accumulation' : True,
'deepspeed_zero_optimization':
{
'stage': 1
}
},
'debug' : {'deterministic_compute': True}
}
verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=True)
function_map = {
# load to zero configs
'test_optim_load_to_distributed_zero_full_precision_adam': test_optim_load_to_distributed_zero_full_precision_adam,
'test_optim_load_to_distributed_zero_mixed_precision_adam': test_optim_load_to_distributed_zero_mixed_precision_adam,
'test_optim_load_to_distributed_zero_mixed_precision_lamb': test_optim_load_to_distributed_zero_mixed_precision_lamb,
'test_optim_load_to_distributed_zero_full_precision_lamb': test_optim_load_to_distributed_zero_full_precision_lamb
}
parser = argparse.ArgumentParser(description='Test loading of initial optimizer state for Zero-1')
parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to test loaded states', required=True)
parser.add_argument('--checkpoint_dir', help='path to the saved states directory', required=True)
args = parser.parse_args()
function_map[args.scenario](checkpoint_dir=args.checkpoint_dir)

View file

@ -56,6 +56,7 @@ ngpus = torch.cuda.device_count()
save_checkpoint_file = os.path.join('checkpoint', 'orttraining_test_save_checkpoint.py')
load_checkpoint_file = os.path.join('checkpoint', 'orttraining_test_load_checkpoint.py')
aggregate_checkpoint_file = os.path.join('checkpoint', 'orttraining_test_checkpoint_aggregation.py')
optim_state_file = os.path.join('checkpoint', 'orttraining_test_load_optimizer_state.py')
single_node_full_precision_path = os.path.join(checkpoint_dir, 'single_node', 'full_precision')
single_node_mixed_precision_path = os.path.join(checkpoint_dir, 'single_node', 'mixed_precision')
@ -125,4 +126,10 @@ _single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_m
_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path)
_single_run(aggregate_checkpoint_file, 'test_aggregation_from_distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path)
# optimizer state loading into model-parallel tests
_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_full_precision_adam', distributed_zero_full_precision_adam_path)
_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_mixed_precision_adam', distributed_zero_mixed_precision_adam_path)
_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path)
_distributed_run(optim_state_file, 'test_optim_load_to_distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path)
shutil.rmtree(checkpoint_dir)

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "orttraining/test/session/training_session_test_utils.h"
#include "orttraining/core/graph/optimizer_builder.h"
using namespace onnxruntime::logging;
using namespace onnxruntime::training;
@ -38,21 +39,21 @@ void GenerateOptimizerInitialState(const std::string& optimizer_op_name, const T
std::vector<int64_t> param_dims = WEIGHT_TO_SHAPE_MAP.at(weight_name);
int64_t num_ele = std::accumulate(param_dims.begin(), param_dims.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
for (auto& param_prefix : MOMENT_PREFIX) {
for (auto& param_prefix : MOMENTS_PREFIXES) {
std::vector<T> param_value(num_ele, init_moment_value);
TrainingUtil::CreateCpuMLValue<T>(param_dims, param_value, &mlValue);
optim_state.insert(std::make_pair(param_prefix, std::move(mlValue)));
}
if (optimizer_op_name == k_adam_optimizer_op_name) {
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, uc_value, &mlValue);
optim_state.insert(std::make_pair(UC_PREFIX, std::move(mlValue)));
optim_state.insert(std::make_pair(ADAM_UC_PREFIX, std::move(mlValue)));
}
result.insert(std::make_pair(weight_name, std::move(optim_state)));
}
if (optimizer_op_name == k_lamb_optimizer_op_name) {
// add "Step" for lamb
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, uc_value, &mlValue);
shared_states.insert(std::make_pair(STEP_TENSOR_NAME, std::move(mlValue)));
shared_states.insert(std::make_pair(LAMB_STEP_TENSOR_NAME, std::move(mlValue)));
result.insert(std::make_pair(onnxruntime::training::SHARED_OPTIMIZER_STATES_KEY, std::move(shared_states)));
}
optimizer_state = std::move(result);
@ -73,7 +74,9 @@ void SeparateStateTensors(const NameMLValMap& training_state, NameMLValMap& mode
model_state = std::move(result);
for (auto& weight_name : WEIGHT_NAMES) {
NameMLValMap optim_state;
for (auto& param_prefix : MOMENT_UC_PREFIX) {
std::vector<std::string> per_weight_states_prefixes(MOMENTS_PREFIXES);
per_weight_states_prefixes.push_back(ADAM_UC_PREFIX);
for (auto& param_prefix : per_weight_states_prefixes) {
std::string param_name = param_prefix + "_" + weight_name;
const auto& param_state_it = training_state.find(param_name);
if (param_state_it != training_state.end()) {
@ -83,9 +86,9 @@ void SeparateStateTensors(const NameMLValMap& training_state, NameMLValMap& mode
optimizer_state.insert(std::make_pair(weight_name, optim_state));
}
NameMLValMap shared_optim_state;
const auto& param_state_it = training_state.find(STEP_TENSOR_NAME);
const auto& param_state_it = training_state.find(LAMB_STEP_TENSOR_NAME);
if (param_state_it != training_state.end()) {
shared_optim_state.insert(std::make_pair(STEP_TENSOR_NAME, param_state_it->second));
shared_optim_state.insert(std::make_pair(LAMB_STEP_TENSOR_NAME, param_state_it->second));
optimizer_state.insert(std::make_pair(onnxruntime::training::SHARED_OPTIMIZER_STATES_KEY, shared_optim_state));
}
}

View file

@ -36,10 +36,6 @@ const std::unordered_map<std::string, std::vector<int64_t>> WEIGHT_TO_SHAPE_MAP
{"B2", {32}},
{"W3", {32, 10}},
{"B1", {128}}};
const std::vector<std::string> MOMENT_PREFIX = {"Moment_1", "Moment_2"};
const std::vector<std::string> MOMENT_UC_PREFIX = {"Moment_1", "Moment_2", "Update_Count"};
constexpr char STEP_TENSOR_NAME[] = "Step";
constexpr char UC_PREFIX[] = "Update_Count";
void GenerateOptimizerConfig(const std::string optimizer_name,
const bool use_mixed_precision_moments,