diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 4e76174d82..624b30ffda 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -41,7 +41,7 @@ def generate_artifacts( requires_grad: Optional[List[str]] = None, frozen_params: Optional[List[str]] = None, loss: Optional[Union[LossType, onnxblock.Block]] = None, - optimizer: Optional[OptimType] = None, + optimizer: Optional[Union[OptimType, onnxblock.Block]] = None, artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None, prefix: str = "", ort_format: bool = False, @@ -64,8 +64,8 @@ def generate_artifacts( model: The base model to be used for gradient graph generation. requires_grad: List of names of model parameters that require gradient computation frozen_params: List of names of model parameters that should be frozen. - loss: The loss function enum to be used for training. If None, no loss node is added to the graph. - optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated. + loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph. + optimizer: The optimizer enum or onnxblock to be used for training. If None, no optimizer model is generated. artifact_directory: The directory to save the generated artifacts. If None, the current working directory is used. prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used. @@ -219,14 +219,6 @@ def generate_artifacts( logging.info("No optimizer enum provided. Skipping optimizer model generation.") return - if not isinstance(optimizer, OptimType): - raise RuntimeError( - f"Unknown optimizer provided {type(optimizer)}. Expected optimizer to be of type " - "onnxruntime.training.artifacts.OptimType." - ) - - logging.info("Optimizer enum provided: %s", optimizer.name) - opset_version = None for domain in model.opset_import: if domain.domain == "" or domain.domain == "ai.onnx": @@ -235,8 +227,19 @@ def generate_artifacts( optim_model = None optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW, OptimType.SGD: onnxblock.optim.SGD} + optim_block = None + if isinstance(optimizer, OptimType): + logging.info("Optimizer enum provided: %s", optimizer.name) + optim_block = optim_blocks[optimizer]() + elif isinstance(optimizer, onnxblock.Block): + logging.info("Optimizer block provided: %s", optimizer.__class__.__name__) + optim_block = optimizer + else: + raise TypeError( + f"Unknown optimizer provided {type(optimizer)}. Expected optimizer to be either one of" + "onnxruntime.training.artifacts.OptimType or onnxruntime.training.onnxblock.Block." + ) - optim_block = optim_blocks[optimizer]() with onnxblock.empty_base(opset_version=opset_version): _ = optim_block(model_params) optim_model = optim_block.to_model_proto() diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 11df3fa347..ac49c1c283 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1072,3 +1072,30 @@ def test_save_nominal_checkpoint(): os.stat(os.path.join(temp_dir, "checkpoint")).st_size > os.stat(os.path.join(temp_dir, "nominal_checkpoint")).st_size ) + + +def test_custom_optimizer_block(): + device = "cpu" + batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 + _, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size) + weight_decay = 123 + optimizer = onnxblock.optim.AdamW(weight_decay=weight_decay) + + with tempfile.TemporaryDirectory() as temp_dir: + artifacts.generate_artifacts( + base_model, + requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"], + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=optimizer, + artifact_directory=temp_dir, + ) + + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) + assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) + + optimizer_model = onnx.load(os.path.join(temp_dir, "optimizer_model.onnx")) + for node in optimizer_model.graph.node: + if node.op_type == "AdamW": + for attr in node.attribute: + if attr.name == "weight_decay": + assert attr.f == weight_decay diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index 84c35e6100..4647f89072 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -61,32 +61,19 @@ Status GraphInputsAreExpected(gsl::span actual_graph_inputs, } // namespace std::unique_ptr OptimizerAlorithmFactory::CreateInstance( - std::shared_ptr model, int32_t& group_count) { + const GraphViewer& graph_viewer, int32_t& group_count) { std::map, int32_t> opt_type_to_freq_map; -#if !defined(ORT_MINIMAL_BUILD) - if (model != nullptr) { - Graph& graph = model->MainGraph(); - for (auto& node : graph.Nodes()) { - if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) { - auto domain_type_pair = std::make_pair(node.Domain(), node.OpType()); - if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) { - opt_type_to_freq_map[domain_type_pair] = 0; - } - opt_type_to_freq_map[domain_type_pair] += 1; + for (const auto& node : graph_viewer.Nodes()) { + if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) { + auto domain_type_pair = std::make_pair(node.Domain(), node.OpType()); + if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) { + opt_type_to_freq_map[domain_type_pair] = 0; } + + opt_type_to_freq_map[domain_type_pair] += 1; } - } else { -#else - ORT_UNUSED_PARAMETER(model); -#endif - // TODO(baijumeswani): Figure out the best way to extract the optimizer type - // from the model (either onnx model or ort format model) or from the checkpoint. - // For now, assume that the optimizer type is AdamWOptimizer when using ort format models. - opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; -#if !defined(ORT_MINIMAL_BUILD) } -#endif ORT_ENFORCE(opt_type_to_freq_map.size() == 1U, "Only support one type of optimizer algorithm, but got: " + std::to_string(opt_type_to_freq_map.size())); @@ -105,42 +92,6 @@ std::unique_ptr OptimizerAlorithmFactory::CreateInstance } } -std::unique_ptr OptimizerAlorithmFactory::CreateInstance( - const PathString& optim_path, int32_t& group_count) { - std::shared_ptr model = nullptr; -#if !defined(ORT_MINIMAL_BUILD) - if (!fbs::utils::IsOrtFormatModel(optim_path)) { - ORT_ENFORCE(Model::Load(optim_path, model, nullptr, - logging::LoggingManager::DefaultLogger()) - .IsOK()); - } -#else - ORT_UNUSED_PARAMETER(optim_path); -#endif - return CreateInstance(model, group_count); -} - -std::unique_ptr OptimizerAlorithmFactory::CreateInstance( - const uint8_t* optim_model_data, size_t optim_model_data_len, int32_t& group_count) { - std::shared_ptr model = nullptr; -#if !defined(ORT_MINIMAL_BUILD) - if (!fbs::utils::IsOrtFormatModelBytes(optim_model_data, static_cast(optim_model_data_len))) { - ONNX_NAMESPACE::ModelProto model_proto; - ORT_ENFORCE(model_proto.ParseFromArray(optim_model_data, static_cast(optim_model_data_len)) == true, - "Failed to load model because protobuf parsing failed."); - - ORT_ENFORCE(Model::Load(std::move(model_proto), model, nullptr, - logging::LoggingManager::DefaultLogger(), ModelOptions(true, true)) - .IsOK()); - } -#else - ORT_UNUSED_PARAMETER(optim_model_data); - ORT_UNUSED_PARAMETER(optim_model_data_len); -#endif - - return CreateInstance(model, group_count); -} - Status Optimizer::GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states) { auto group_optimizer_state_it = optimizer_checkpoint_states.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -280,17 +231,15 @@ void Optimizer::Initialize(const ModelIdentifiers& model_identifiers, auto optimizer_model = std::get>(model_identifiers.optim_model); // The above call to IsOptimizerModelAvailable() ensures that optimizer_model is not nullopt ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.value())); - optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(ToWideString(optimizer_model.value()), group_count_); } else { auto optimizer_model = std::get>(model_identifiers.optim_model); ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.data(), static_cast(optimizer_model.size()))); - optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optimizer_model.data(), - optimizer_model.size(), - group_count_); } ORT_THROW_IF_ERROR(optim_sess_->Initialize()); + optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optim_sess_->GetSessionState().GetGraphViewer(), + group_count_); // Make sure that the checkpoint state can copy tensors state_->optimizer_checkpoint_state.optimizer_session_data_transfer_mgr = &optim_sess_->GetDataTransferManager(); diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h index 031b114265..5b908acf7c 100644 --- a/orttraining/orttraining/training_api/optimizer.h +++ b/orttraining/orttraining/training_api/optimizer.h @@ -64,11 +64,8 @@ struct SGDOptimizerV2Algorithm : public OptimizerAlgorithmBase { }; struct OptimizerAlorithmFactory { - static std::unique_ptr CreateInstance(const PathString& optim_path, + static std::unique_ptr CreateInstance(const GraphViewer& graph_viewer, int32_t& group_count); - static std::unique_ptr CreateInstance(const uint8_t* optim_model_data, - size_t optim_model_data_len, int32_t& group_count); - static std::unique_ptr CreateInstance(std::shared_ptr model, int32_t& group_count); }; struct CheckpointState;