mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Remove more header dependencies
This commit is contained in:
parent
9f1c2ed767
commit
4cf9e0b9a3
3 changed files with 9 additions and 16 deletions
|
|
@ -11,7 +11,6 @@
|
|||
#include "core/common/path_string.h"
|
||||
#include "core/platform/path_lib.h"
|
||||
#include "core/session/environment.h"
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#include "orttraining/models/runner/data_loader.h"
|
||||
#include "orttraining/models/runner/training_util.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
|
|
|
|||
|
|
@ -15,12 +15,6 @@
|
|||
|
||||
#include "orttraining/training_ops/cpu/controlflow/event_pool.h" // TODO: move with PipelineBatchPlanner
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#elif USE_ROCM
|
||||
#include "core/providers/rocm/rocm_execution_provider.h"
|
||||
#endif
|
||||
|
||||
using namespace onnxruntime::logging;
|
||||
using namespace onnxruntime::training;
|
||||
using namespace google::protobuf::util;
|
||||
|
|
@ -34,7 +28,7 @@ static void RunTrainingSessionLoadOptimTests(std::string optim_name, bool mixed_
|
|||
auto config = MakeBasicTrainingConfig();
|
||||
if (mixed_precision) {
|
||||
TrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mp{};
|
||||
mp.use_mixed_precision_initializers=true;
|
||||
mp.use_mixed_precision_initializers = true;
|
||||
config.mixed_precision_config = mp;
|
||||
}
|
||||
GenerateOptimizerConfig(optim_name, mixed_precision_moments, config);
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
#include "orttraining/training_ops/cpu/controlflow/event_pool.h" // TODO: move with PipelineBatchPlanner
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#include "core/providers/cuda/cuda_execution_provider_info.h"
|
||||
#elif USE_ROCM
|
||||
#include "core/providers/rocm/rocm_execution_provider.h"
|
||||
#endif
|
||||
|
|
@ -44,18 +44,18 @@ void GenerateOptimizerConfig(const std::string optimizer_name,
|
|||
training::TrainingSession::TrainingConfiguration& config);
|
||||
|
||||
template <class T>
|
||||
void GenerateOptimizerInitialState(const std::string& optimizer_op_name,
|
||||
const T init_moment_value,
|
||||
training::TrainingSession::OptimizerState& optimizer_state);
|
||||
void GenerateOptimizerInitialState(const std::string& optimizer_op_name,
|
||||
const T init_moment_value,
|
||||
training::TrainingSession::OptimizerState& optimizer_state);
|
||||
|
||||
void SeparateStateTensors(const NameMLValMap& training_state,
|
||||
NameMLValMap& model_state,
|
||||
void SeparateStateTensors(const NameMLValMap& training_state,
|
||||
NameMLValMap& model_state,
|
||||
training::TrainingSession::OptimizerState& optimizer_state);
|
||||
|
||||
void VerifyState(const DataTransferManager& data_transfer_mgr, const NameMLValMap& expected_state, const NameMLValMap& actual_state);
|
||||
|
||||
void VerifyOptimizerState(const DataTransferManager& data_transfer_manager,
|
||||
const training::TrainingSession::OptimizerState& expected_state,
|
||||
void VerifyOptimizerState(const DataTransferManager& data_transfer_manager,
|
||||
const training::TrainingSession::OptimizerState& expected_state,
|
||||
const training::TrainingSession::OptimizerState& actual_state);
|
||||
|
||||
std::unordered_set<std::string> GetModelOutputNames(const InferenceSession& session);
|
||||
|
|
|
|||
Loading…
Reference in a new issue