mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add support for SGD optimizer in minimal build (#19901)
This commit is contained in:
parent
1fb6cbddee
commit
226f60f2f1
4 changed files with 53 additions and 77 deletions
|
|
@ -41,7 +41,7 @@ def generate_artifacts(
|
|||
requires_grad: Optional[List[str]] = None,
|
||||
frozen_params: Optional[List[str]] = None,
|
||||
loss: Optional[Union[LossType, onnxblock.Block]] = None,
|
||||
optimizer: Optional[OptimType] = None,
|
||||
optimizer: Optional[Union[OptimType, onnxblock.Block]] = None,
|
||||
artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None,
|
||||
prefix: str = "",
|
||||
ort_format: bool = False,
|
||||
|
|
@ -64,8 +64,8 @@ def generate_artifacts(
|
|||
model: The base model to be used for gradient graph generation.
|
||||
requires_grad: List of names of model parameters that require gradient computation
|
||||
frozen_params: List of names of model parameters that should be frozen.
|
||||
loss: The loss function enum to be used for training. If None, no loss node is added to the graph.
|
||||
optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated.
|
||||
loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph.
|
||||
optimizer: The optimizer enum or onnxblock to be used for training. If None, no optimizer model is generated.
|
||||
artifact_directory: The directory to save the generated artifacts.
|
||||
If None, the current working directory is used.
|
||||
prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used.
|
||||
|
|
@ -219,14 +219,6 @@ def generate_artifacts(
|
|||
logging.info("No optimizer enum provided. Skipping optimizer model generation.")
|
||||
return
|
||||
|
||||
if not isinstance(optimizer, OptimType):
|
||||
raise RuntimeError(
|
||||
f"Unknown optimizer provided {type(optimizer)}. Expected optimizer to be of type "
|
||||
"onnxruntime.training.artifacts.OptimType."
|
||||
)
|
||||
|
||||
logging.info("Optimizer enum provided: %s", optimizer.name)
|
||||
|
||||
opset_version = None
|
||||
for domain in model.opset_import:
|
||||
if domain.domain == "" or domain.domain == "ai.onnx":
|
||||
|
|
@ -235,8 +227,19 @@ def generate_artifacts(
|
|||
|
||||
optim_model = None
|
||||
optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW, OptimType.SGD: onnxblock.optim.SGD}
|
||||
optim_block = None
|
||||
if isinstance(optimizer, OptimType):
|
||||
logging.info("Optimizer enum provided: %s", optimizer.name)
|
||||
optim_block = optim_blocks[optimizer]()
|
||||
elif isinstance(optimizer, onnxblock.Block):
|
||||
logging.info("Optimizer block provided: %s", optimizer.__class__.__name__)
|
||||
optim_block = optimizer
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown optimizer provided {type(optimizer)}. Expected optimizer to be either one of"
|
||||
"onnxruntime.training.artifacts.OptimType or onnxruntime.training.onnxblock.Block."
|
||||
)
|
||||
|
||||
optim_block = optim_blocks[optimizer]()
|
||||
with onnxblock.empty_base(opset_version=opset_version):
|
||||
_ = optim_block(model_params)
|
||||
optim_model = optim_block.to_model_proto()
|
||||
|
|
|
|||
|
|
@ -1072,3 +1072,30 @@ def test_save_nominal_checkpoint():
|
|||
os.stat(os.path.join(temp_dir, "checkpoint")).st_size
|
||||
> os.stat(os.path.join(temp_dir, "nominal_checkpoint")).st_size
|
||||
)
|
||||
|
||||
|
||||
def test_custom_optimizer_block():
|
||||
device = "cpu"
|
||||
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
|
||||
_, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size)
|
||||
weight_decay = 123
|
||||
optimizer = onnxblock.optim.AdamW(weight_decay=weight_decay)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
artifacts.generate_artifacts(
|
||||
base_model,
|
||||
requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"],
|
||||
loss=artifacts.LossType.CrossEntropyLoss,
|
||||
optimizer=optimizer,
|
||||
artifact_directory=temp_dir,
|
||||
)
|
||||
|
||||
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))
|
||||
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
|
||||
|
||||
optimizer_model = onnx.load(os.path.join(temp_dir, "optimizer_model.onnx"))
|
||||
for node in optimizer_model.graph.node:
|
||||
if node.op_type == "AdamW":
|
||||
for attr in node.attribute:
|
||||
if attr.name == "weight_decay":
|
||||
assert attr.f == weight_decay
|
||||
|
|
|
|||
|
|
@ -61,32 +61,19 @@ Status GraphInputsAreExpected(gsl::span<const std::string> actual_graph_inputs,
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
|
||||
std::shared_ptr<Model> model, int32_t& group_count) {
|
||||
const GraphViewer& graph_viewer, int32_t& group_count) {
|
||||
std::map<std::pair<std::string, std::string>, int32_t> opt_type_to_freq_map;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
if (model != nullptr) {
|
||||
Graph& graph = model->MainGraph();
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
|
||||
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
|
||||
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
|
||||
opt_type_to_freq_map[domain_type_pair] = 0;
|
||||
}
|
||||
|
||||
opt_type_to_freq_map[domain_type_pair] += 1;
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
|
||||
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
|
||||
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
|
||||
opt_type_to_freq_map[domain_type_pair] = 0;
|
||||
}
|
||||
|
||||
opt_type_to_freq_map[domain_type_pair] += 1;
|
||||
}
|
||||
} else {
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(model);
|
||||
#endif
|
||||
// TODO(baijumeswani): Figure out the best way to extract the optimizer type
|
||||
// from the model (either onnx model or ort format model) or from the checkpoint.
|
||||
// For now, assume that the optimizer type is AdamWOptimizer when using ort format models.
|
||||
opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
}
|
||||
#endif
|
||||
|
||||
ORT_ENFORCE(opt_type_to_freq_map.size() == 1U, "Only support one type of optimizer algorithm, but got: " +
|
||||
std::to_string(opt_type_to_freq_map.size()));
|
||||
|
|
@ -105,42 +92,6 @@ std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance
|
|||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
|
||||
const PathString& optim_path, int32_t& group_count) {
|
||||
std::shared_ptr<Model> model = nullptr;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
if (!fbs::utils::IsOrtFormatModel(optim_path)) {
|
||||
ORT_ENFORCE(Model::Load(optim_path, model, nullptr,
|
||||
logging::LoggingManager::DefaultLogger())
|
||||
.IsOK());
|
||||
}
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(optim_path);
|
||||
#endif
|
||||
return CreateInstance(model, group_count);
|
||||
}
|
||||
|
||||
std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
|
||||
const uint8_t* optim_model_data, size_t optim_model_data_len, int32_t& group_count) {
|
||||
std::shared_ptr<Model> model = nullptr;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
if (!fbs::utils::IsOrtFormatModelBytes(optim_model_data, static_cast<int>(optim_model_data_len))) {
|
||||
ONNX_NAMESPACE::ModelProto model_proto;
|
||||
ORT_ENFORCE(model_proto.ParseFromArray(optim_model_data, static_cast<int>(optim_model_data_len)) == true,
|
||||
"Failed to load model because protobuf parsing failed.");
|
||||
|
||||
ORT_ENFORCE(Model::Load(std::move(model_proto), model, nullptr,
|
||||
logging::LoggingManager::DefaultLogger(), ModelOptions(true, true))
|
||||
.IsOK());
|
||||
}
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(optim_model_data);
|
||||
ORT_UNUSED_PARAMETER(optim_model_data_len);
|
||||
#endif
|
||||
|
||||
return CreateInstance(model, group_count);
|
||||
}
|
||||
|
||||
Status Optimizer::GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states) {
|
||||
auto group_optimizer_state_it =
|
||||
optimizer_checkpoint_states.group_named_optimizer_states.find(GROUP_ZERO_NAME);
|
||||
|
|
@ -280,17 +231,15 @@ void Optimizer::Initialize(const ModelIdentifiers& model_identifiers,
|
|||
auto optimizer_model = std::get<std::optional<std::string>>(model_identifiers.optim_model);
|
||||
// The above call to IsOptimizerModelAvailable() ensures that optimizer_model is not nullopt
|
||||
ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.value()));
|
||||
optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(ToWideString(optimizer_model.value()), group_count_);
|
||||
} else {
|
||||
auto optimizer_model = std::get<gsl::span<const uint8_t>>(model_identifiers.optim_model);
|
||||
ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.data(),
|
||||
static_cast<int>(optimizer_model.size())));
|
||||
optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optimizer_model.data(),
|
||||
optimizer_model.size(),
|
||||
group_count_);
|
||||
}
|
||||
|
||||
ORT_THROW_IF_ERROR(optim_sess_->Initialize());
|
||||
optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optim_sess_->GetSessionState().GetGraphViewer(),
|
||||
group_count_);
|
||||
|
||||
// Make sure that the checkpoint state can copy tensors
|
||||
state_->optimizer_checkpoint_state.optimizer_session_data_transfer_mgr = &optim_sess_->GetDataTransferManager();
|
||||
|
|
|
|||
|
|
@ -64,11 +64,8 @@ struct SGDOptimizerV2Algorithm : public OptimizerAlgorithmBase {
|
|||
};
|
||||
|
||||
struct OptimizerAlorithmFactory {
|
||||
static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(const PathString& optim_path,
|
||||
static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(const GraphViewer& graph_viewer,
|
||||
int32_t& group_count);
|
||||
static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(const uint8_t* optim_model_data,
|
||||
size_t optim_model_data_len, int32_t& group_count);
|
||||
static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(std::shared_ptr<Model> model, int32_t& group_count);
|
||||
};
|
||||
|
||||
struct CheckpointState;
|
||||
|
|
|
|||
Loading…
Reference in a new issue