From c557a558160bd6f83c82f05a1bdb667cceb60b85 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 31 Oct 2022 21:29:06 -0700 Subject: [PATCH] Fix on-device training ExportModelForInferencing api (#13510) --- cmake/CMakeLists.txt | 7 ++++++- .../orttraining/python/training/onnxblock/model.py | 6 ------ .../test/training_api/core/training_api_tests.cc | 10 +++++----- .../include/onnxruntime_training_cxx_inline.h | 5 +++-- .../training_api/onnxruntime_training_c_api.cc | 6 ++++++ 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ed44a2c95b..a404d6214a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -691,7 +691,12 @@ else() message("AVX instruction set is not supported.") endif() - if (NOT (COMPILER_SUPPORT_MF16C AND COMPILER_SUPPORT_FMA AND COMPILER_SUPPORT_AVX)) + if (CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_ENABLE_TRAINING_ON_DEVICE) + message("F16C, FMA and AVX flags are not supported on Android for on-device training.") + endif() + + if (NOT (COMPILER_SUPPORT_MF16C AND COMPILER_SUPPORT_FMA AND COMPILER_SUPPORT_AVX) OR + (CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_ENABLE_TRAINING_ON_DEVICE)) message("One or more AVX/F16C instruction flags are not supported. ") set(onnxruntime_ENABLE_CPU_FP16_OPS FALSE) endif() diff --git a/orttraining/orttraining/python/training/onnxblock/model.py b/orttraining/orttraining/python/training/onnxblock/model.py index 41bd15b118..ae03cb63a0 100644 --- a/orttraining/orttraining/python/training/onnxblock/model.py +++ b/orttraining/orttraining/python/training/onnxblock/model.py @@ -41,9 +41,6 @@ class Model(building_blocks.Block): # Build the graph outputs graph_utils.build_graph_outputs(accessor.global_accessor.model, output) - # validate and check the model - onnx.checker.check_model(accessor.global_accessor.model, True) - return output @@ -129,7 +126,4 @@ class TrainingModel(building_blocks.Block): # add gradient accumulation nodes graph_utils.build_gradient_accumulation_graph(accessor.global_accessor.model, all_args_requiring_gradient_names) - # validate and check the model - onnx.checker.check_model(accessor.global_accessor.model, True) - return output diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index 1c8a9decd6..877b62cbfe 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -111,17 +111,17 @@ void TestModuleExport(const std::vector>& pr ASSERT_EQ(outputs.size(), 1U); } -void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) { - ASSERT_NEAR(expected, output, atol); - ASSERT_NEAR(expected, output, rtol * std::abs(expected)); -} - #if defined(USE_CUDA) || defined(USE_ROCM) const int64_t total_step_count = 100; const float initial_lr = 1e-3f; const int64_t resume_step = total_step_count / 2; +void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) { + ASSERT_NEAR(expected, output, atol); + ASSERT_NEAR(expected, output, rtol * std::abs(expected)); +} + void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count, int64_t warmup_step_count) { /// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances. diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 0ebff687f4..8eaa1dbc83 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -96,8 +96,9 @@ inline void CheckpointState::SaveCheckpoint(const TrainingSession& session, inline void TrainingSession::ExportModelForInferencing(const std::basic_string& inference_model_path, const std::vector& graph_output_names) { - std::vector output_names(graph_output_names.size(), nullptr); - for (auto& output_name : graph_output_names) { + std::vector output_names; + output_names.reserve(graph_output_names.size()); + for (const auto& output_name : graph_output_names) { output_names.push_back(output_name.c_str()); } ThrowOnError(GetTrainingApi().ExportModelForInferencing( diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 7d37335a4e..c6eff3a8f6 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -344,6 +344,12 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::ExportModelForInferencing, _Inout_ OrtTrain _In_reads_(graph_outputs_len) const char* const* graph_output_names) { API_IMPL_BEGIN + if (graph_outputs_len == 0U) { + return OrtApis::CreateStatus( + ORT_INVALID_ARGUMENT, + "Empty array of graph output names is not valid. Please provide valid graph output names"); + } + auto session = reinterpret_cast(sess); onnxruntime::InlinedVector output_names(graph_outputs_len);