mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Fix on-device training ExportModelForInferencing api (#13510)
This commit is contained in:
parent
17f0ffd1c8
commit
c557a55816
5 changed files with 20 additions and 14 deletions
|
|
@ -691,7 +691,12 @@ else()
|
|||
message("AVX instruction set is not supported.")
|
||||
endif()
|
||||
|
||||
if (NOT (COMPILER_SUPPORT_MF16C AND COMPILER_SUPPORT_FMA AND COMPILER_SUPPORT_AVX))
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_ENABLE_TRAINING_ON_DEVICE)
|
||||
message("F16C, FMA and AVX flags are not supported on Android for on-device training.")
|
||||
endif()
|
||||
|
||||
if (NOT (COMPILER_SUPPORT_MF16C AND COMPILER_SUPPORT_FMA AND COMPILER_SUPPORT_AVX) OR
|
||||
(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_ENABLE_TRAINING_ON_DEVICE))
|
||||
message("One or more AVX/F16C instruction flags are not supported. ")
|
||||
set(onnxruntime_ENABLE_CPU_FP16_OPS FALSE)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -41,9 +41,6 @@ class Model(building_blocks.Block):
|
|||
# Build the graph outputs
|
||||
graph_utils.build_graph_outputs(accessor.global_accessor.model, output)
|
||||
|
||||
# validate and check the model
|
||||
onnx.checker.check_model(accessor.global_accessor.model, True)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
|
@ -129,7 +126,4 @@ class TrainingModel(building_blocks.Block):
|
|||
# add gradient accumulation nodes
|
||||
graph_utils.build_gradient_accumulation_graph(accessor.global_accessor.model, all_args_requiring_gradient_names)
|
||||
|
||||
# validate and check the model
|
||||
onnx.checker.check_model(accessor.global_accessor.model, True)
|
||||
|
||||
return output
|
||||
|
|
|
|||
|
|
@ -111,17 +111,17 @@ void TestModuleExport(const std::vector<std::shared_ptr<IExecutionProvider>>& pr
|
|||
ASSERT_EQ(outputs.size(), 1U);
|
||||
}
|
||||
|
||||
void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) {
|
||||
ASSERT_NEAR(expected, output, atol);
|
||||
ASSERT_NEAR(expected, output, rtol * std::abs(expected));
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
const int64_t total_step_count = 100;
|
||||
const float initial_lr = 1e-3f;
|
||||
const int64_t resume_step = total_step_count / 2;
|
||||
|
||||
void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) {
|
||||
ASSERT_NEAR(expected, output, atol);
|
||||
ASSERT_NEAR(expected, output, rtol * std::abs(expected));
|
||||
}
|
||||
|
||||
void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count,
|
||||
int64_t warmup_step_count) {
|
||||
/// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances.
|
||||
|
|
|
|||
|
|
@ -96,8 +96,9 @@ inline void CheckpointState::SaveCheckpoint(const TrainingSession& session,
|
|||
|
||||
inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
|
||||
const std::vector<std::string>& graph_output_names) {
|
||||
std::vector<const char*> output_names(graph_output_names.size(), nullptr);
|
||||
for (auto& output_name : graph_output_names) {
|
||||
std::vector<const char*> output_names;
|
||||
output_names.reserve(graph_output_names.size());
|
||||
for (const auto& output_name : graph_output_names) {
|
||||
output_names.push_back(output_name.c_str());
|
||||
}
|
||||
ThrowOnError(GetTrainingApi().ExportModelForInferencing(
|
||||
|
|
|
|||
|
|
@ -344,6 +344,12 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::ExportModelForInferencing, _Inout_ OrtTrain
|
|||
_In_reads_(graph_outputs_len) const char* const* graph_output_names) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
if (graph_outputs_len == 0U) {
|
||||
return OrtApis::CreateStatus(
|
||||
ORT_INVALID_ARGUMENT,
|
||||
"Empty array of graph output names is not valid. Please provide valid graph output names");
|
||||
}
|
||||
|
||||
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
|
||||
|
||||
onnxruntime::InlinedVector<std::string> output_names(graph_outputs_len);
|
||||
|
|
|
|||
Loading…
Reference in a new issue