From ea7bbd667d14332a9c8f1c4f6e832a1663296773 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Thu, 19 Jan 2023 10:26:53 -0800 Subject: [PATCH] fix headers for training apis (#14350) ### Description Minor refactor PR for fixing header placement for training apis --- cmake/onnxruntime_session.cmake | 1 + onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- .../orttraining/python/orttraining_pybind_state.cc | 4 ++-- .../test/training_api/core/checkpoint_test.cc | 10 +++++----- .../test/training_api/core/training_api_tests.cc | 12 ++++++------ orttraining/orttraining/training_api/checkpoint.cc | 4 ++-- .../training_api/{include => }/checkpoint.h | 6 +++--- .../orttraining/training_api/checkpoint_property.cc | 2 +- .../training_api/{include => }/checkpoint_property.h | 0 orttraining/orttraining/training_api/lr_scheduler.cc | 2 +- .../training_api/{include => }/lr_scheduler.h | 2 +- orttraining/orttraining/training_api/module.cc | 4 ++-- .../orttraining/training_api/{include => }/module.h | 0 .../training_api/onnxruntime_training_c_api.cc | 6 +++--- orttraining/orttraining/training_api/optimizer.cc | 4 ++-- .../training_api/{include => }/optimizer.h | 2 +- .../training_api/{include => }/ort_training_apis.h | 0 .../orttraining/training_api/training_session.cc | 2 +- .../training_api/{include => }/training_session.h | 0 orttraining/orttraining/training_api/utils.cc | 2 +- .../orttraining/training_api/{include => }/utils.h | 0 21 files changed, 33 insertions(+), 32 deletions(-) rename orttraining/orttraining/training_api/{include => }/checkpoint.h (95%) rename orttraining/orttraining/training_api/{include => }/checkpoint_property.h (100%) rename orttraining/orttraining/training_api/{include => }/lr_scheduler.h (97%) rename orttraining/orttraining/training_api/{include => }/module.h (100%) rename orttraining/orttraining/training_api/{include => }/optimizer.h (98%) rename orttraining/orttraining/training_api/{include => }/ort_training_apis.h (100%) rename orttraining/orttraining/training_api/{include => }/training_session.h (100%) rename orttraining/orttraining/training_api/{include => }/utils.h (100%) diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 821c545ba4..5120517acf 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -10,6 +10,7 @@ file(GLOB onnxruntime_session_srcs CONFIGURE_DEPENDS if (onnxruntime_ENABLE_TRAINING_APIS) file(GLOB_RECURSE training_api_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/training_api/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_api/*.h" "${ORTTRAINING_SOURCE_DIR}/core/framework/checkpoint_common.cc" "${ORTTRAINING_SOURCE_DIR}/core/framework/checkpoint_common.h" ) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a1a5bc7893..265823b09a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -49,7 +49,7 @@ ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); #ifdef ENABLE_TRAINING_APIS #include "orttraining/training_api/include/onnxruntime_training_c_api.h" -#include "orttraining/training_api/include/ort_training_apis.h" +#include "orttraining/training_api/ort_training_apis.h" #endif #ifdef USE_CANN diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 05216d2bdc..14a407dc0c 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -35,8 +35,8 @@ #endif #ifdef ENABLE_TRAINING_APIS -#include "orttraining/training_api/include/checkpoint.h" -#include "orttraining/training_api/include/lr_scheduler.h" +#include "orttraining/training_api/checkpoint.h" +#include "orttraining/training_api/lr_scheduler.h" #endif diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index 1c1ea31929..4416f75724 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -19,11 +19,11 @@ #include "core/platform/path_lib.h" #include "orttraining/core/framework/checkpoint_common.h" -#include "orttraining/training_api/include/module.h" -#include "orttraining/training_api/include/optimizer.h" -#include "orttraining/training_api/include/checkpoint_property.h" -#include "orttraining/training_api/include/checkpoint.h" -#include "orttraining/training_api/include/lr_scheduler.h" +#include "orttraining/training_api/module.h" +#include "orttraining/training_api/optimizer.h" +#include "orttraining/training_api/checkpoint_property.h" +#include "orttraining/training_api/checkpoint.h" +#include "orttraining/training_api/lr_scheduler.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index a6ac041594..6dee94801a 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -10,12 +10,12 @@ #include "test/framework/test_utils.h" #include "test/util/include/asserts.h" #include "core/framework/tensorprotoutils.h" -#include "orttraining/training_api/include/utils.h" -#include "orttraining/training_api/include/module.h" -#include "orttraining/training_api/include/optimizer.h" -#include "orttraining/training_api/include/checkpoint_property.h" -#include "orttraining/training_api/include/checkpoint.h" -#include "orttraining/training_api/include/lr_scheduler.h" +#include "orttraining/training_api/utils.h" +#include "orttraining/training_api/module.h" +#include "orttraining/training_api/optimizer.h" +#include "orttraining/training_api/checkpoint_property.h" +#include "orttraining/training_api/checkpoint.h" +#include "orttraining/training_api/lr_scheduler.h" #include "orttraining/test/training_api/core/data_utils.h" #include "test/util/include/temp_dir.h" #include "default_providers.h" diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index 7f4494cde9..8b65f17e4d 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -14,8 +14,8 @@ #include "orttraining/core/framework/checkpoint_common.h" #include "orttraining/core/framework/protobuf_message_sequence.h" -#include "orttraining/training_api/include/checkpoint.h" -#include "orttraining/training_api/include/utils.h" +#include "orttraining/training_api/checkpoint.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { diff --git a/orttraining/orttraining/training_api/include/checkpoint.h b/orttraining/orttraining/training_api/checkpoint.h similarity index 95% rename from orttraining/orttraining/training_api/include/checkpoint.h rename to orttraining/orttraining/training_api/checkpoint.h index ce6785702b..cbe9b379b9 100644 --- a/orttraining/orttraining/training_api/include/checkpoint.h +++ b/orttraining/orttraining/training_api/checkpoint.h @@ -7,9 +7,9 @@ #include "core/platform/env.h" #include "onnx/defs/tensor_proto_util.h" -#include "orttraining/training_api/include/module.h" -#include "orttraining/training_api/include/optimizer.h" -#include "orttraining/training_api/include/checkpoint_property.h" +#include "orttraining/training_api/module.h" +#include "orttraining/training_api/optimizer.h" +#include "orttraining/training_api/checkpoint_property.h" /** * There are two representation for checkpoint respectively in memory and files: diff --git a/orttraining/orttraining/training_api/checkpoint_property.cc b/orttraining/orttraining/training_api/checkpoint_property.cc index 6ce573ae3c..dce21003f9 100644 --- a/orttraining/orttraining/training_api/checkpoint_property.cc +++ b/orttraining/orttraining/training_api/checkpoint_property.cc @@ -6,7 +6,7 @@ #include "core/platform/path_lib.h" #include "core/platform/env.h" #include "core/framework/tensorprotoutils.h" -#include "orttraining/training_api/include/checkpoint_property.h" +#include "orttraining/training_api/checkpoint_property.h" namespace onnxruntime { namespace training { diff --git a/orttraining/orttraining/training_api/include/checkpoint_property.h b/orttraining/orttraining/training_api/checkpoint_property.h similarity index 100% rename from orttraining/orttraining/training_api/include/checkpoint_property.h rename to orttraining/orttraining/training_api/checkpoint_property.h diff --git a/orttraining/orttraining/training_api/lr_scheduler.cc b/orttraining/orttraining/training_api/lr_scheduler.cc index 8ec8538971..35398a2788 100644 --- a/orttraining/orttraining/training_api/lr_scheduler.cc +++ b/orttraining/orttraining/training_api/lr_scheduler.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "orttraining/training_api/include/lr_scheduler.h" +#include "orttraining/training_api/lr_scheduler.h" namespace onnxruntime { namespace training { diff --git a/orttraining/orttraining/training_api/include/lr_scheduler.h b/orttraining/orttraining/training_api/lr_scheduler.h similarity index 97% rename from orttraining/orttraining/training_api/include/lr_scheduler.h rename to orttraining/orttraining/training_api/lr_scheduler.h index 5b0b3df56a..738cda30a4 100644 --- a/orttraining/orttraining/training_api/include/lr_scheduler.h +++ b/orttraining/orttraining/training_api/lr_scheduler.h @@ -3,7 +3,7 @@ #pragma once -#include "orttraining/training_api/include/optimizer.h" +#include "orttraining/training_api/optimizer.h" namespace onnxruntime { namespace training { diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index b23a8e9bd7..54da8c236b 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -8,8 +8,8 @@ #include "core/session/environment.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "orttraining/training_api/include/module.h" -#include "orttraining/training_api/include/utils.h" +#include "orttraining/training_api/module.h" +#include "orttraining/training_api/utils.h" using namespace onnxruntime; diff --git a/orttraining/orttraining/training_api/include/module.h b/orttraining/orttraining/training_api/module.h similarity index 100% rename from orttraining/orttraining/training_api/include/module.h rename to orttraining/orttraining/training_api/module.h diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 7563c8878f..f8552de920 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -8,9 +8,9 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" -#include "orttraining/training_api/include/checkpoint.h" -#include "orttraining/training_api/include/ort_training_apis.h" -#include "orttraining/training_api/include/training_session.h" +#include "orttraining/training_api/checkpoint.h" +#include "orttraining/training_api/ort_training_apis.h" +#include "orttraining/training_api/training_session.h" namespace { diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index 6c301874bd..efe6bbd2c9 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -6,8 +6,8 @@ #include "core/session/inference_session.h" #include "core/session/environment.h" -#include "orttraining/training_api/include/utils.h" -#include "orttraining/training_api/include/optimizer.h" +#include "orttraining/training_api/utils.h" +#include "orttraining/training_api/optimizer.h" namespace onnxruntime { namespace training { diff --git a/orttraining/orttraining/training_api/include/optimizer.h b/orttraining/orttraining/training_api/optimizer.h similarity index 98% rename from orttraining/orttraining/training_api/include/optimizer.h rename to orttraining/orttraining/training_api/optimizer.h index d81e5d69ab..ac5d35215e 100644 --- a/orttraining/orttraining/training_api/include/optimizer.h +++ b/orttraining/orttraining/training_api/optimizer.h @@ -6,7 +6,7 @@ #include "core/session/inference_session.h" #include "core/session/environment.h" -#include "orttraining/training_api/include/module.h" +#include "orttraining/training_api/module.h" namespace onnxruntime { namespace training { diff --git a/orttraining/orttraining/training_api/include/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h similarity index 100% rename from orttraining/orttraining/training_api/include/ort_training_apis.h rename to orttraining/orttraining/training_api/ort_training_apis.h diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index e69a2e20b2..f05fea1c5f 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "orttraining/training_api/include/training_session.h" +#include "orttraining/training_api/training_session.h" namespace onnxruntime { namespace training { diff --git a/orttraining/orttraining/training_api/include/training_session.h b/orttraining/orttraining/training_api/training_session.h similarity index 100% rename from orttraining/orttraining/training_api/include/training_session.h rename to orttraining/orttraining/training_api/training_session.h diff --git a/orttraining/orttraining/training_api/utils.cc b/orttraining/orttraining/training_api/utils.cc index a4c95a53b4..e719c48bea 100644 --- a/orttraining/orttraining/training_api/utils.cc +++ b/orttraining/orttraining/training_api/utils.cc @@ -8,7 +8,7 @@ #include "core/framework/allocator.h" #include "core/framework/tensorprotoutils.h" -#include "orttraining/training_api/include/utils.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { diff --git a/orttraining/orttraining/training_api/include/utils.h b/orttraining/orttraining/training_api/utils.h similarity index 100% rename from orttraining/orttraining/training_api/include/utils.h rename to orttraining/orttraining/training_api/utils.h