Fix on-device training ExportModelForInferencing api (#13510)

This commit is contained in:
Baiju Meswani 2022-10-31 21:29:06 -07:00 committed by GitHub
parent 17f0ffd1c8
commit c557a55816
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 14 deletions

View file

@ -691,7 +691,12 @@ else()
message("AVX instruction set is not supported.")
endif()
if (NOT (COMPILER_SUPPORT_MF16C AND COMPILER_SUPPORT_FMA AND COMPILER_SUPPORT_AVX))
if (CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_ENABLE_TRAINING_ON_DEVICE)
message("F16C, FMA and AVX flags are not supported on Android for on-device training.")
endif()
if (NOT (COMPILER_SUPPORT_MF16C AND COMPILER_SUPPORT_FMA AND COMPILER_SUPPORT_AVX) OR
(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_ENABLE_TRAINING_ON_DEVICE))
message("One or more AVX/F16C instruction flags are not supported. ")
set(onnxruntime_ENABLE_CPU_FP16_OPS FALSE)
endif()

View file

@ -41,9 +41,6 @@ class Model(building_blocks.Block):
# Build the graph outputs
graph_utils.build_graph_outputs(accessor.global_accessor.model, output)
# validate and check the model
onnx.checker.check_model(accessor.global_accessor.model, True)
return output
@ -129,7 +126,4 @@ class TrainingModel(building_blocks.Block):
# add gradient accumulation nodes
graph_utils.build_gradient_accumulation_graph(accessor.global_accessor.model, all_args_requiring_gradient_names)
# validate and check the model
onnx.checker.check_model(accessor.global_accessor.model, True)
return output

View file

@ -111,17 +111,17 @@ void TestModuleExport(const std::vector<std::shared_ptr<IExecutionProvider>>& pr
ASSERT_EQ(outputs.size(), 1U);
}
void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) {
ASSERT_NEAR(expected, output, atol);
ASSERT_NEAR(expected, output, rtol * std::abs(expected));
}
#if defined(USE_CUDA) || defined(USE_ROCM)
const int64_t total_step_count = 100;
const float initial_lr = 1e-3f;
const int64_t resume_step = total_step_count / 2;
void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) {
ASSERT_NEAR(expected, output, atol);
ASSERT_NEAR(expected, output, rtol * std::abs(expected));
}
void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count,
int64_t warmup_step_count) {
/// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances.

View file

@ -96,8 +96,9 @@ inline void CheckpointState::SaveCheckpoint(const TrainingSession& session,
inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
const std::vector<std::string>& graph_output_names) {
std::vector<const char*> output_names(graph_output_names.size(), nullptr);
for (auto& output_name : graph_output_names) {
std::vector<const char*> output_names;
output_names.reserve(graph_output_names.size());
for (const auto& output_name : graph_output_names) {
output_names.push_back(output_name.c_str());
}
ThrowOnError(GetTrainingApi().ExportModelForInferencing(

View file

@ -344,6 +344,12 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::ExportModelForInferencing, _Inout_ OrtTrain
_In_reads_(graph_outputs_len) const char* const* graph_output_names) {
API_IMPL_BEGIN
if (graph_outputs_len == 0U) {
return OrtApis::CreateStatus(
ORT_INVALID_ARGUMENT,
"Empty array of graph output names is not valid. Please provide valid graph output names");
}
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
onnxruntime::InlinedVector<std::string> output_names(graph_outputs_len);