diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index fc16345af3..22ca244b85 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -83,7 +83,11 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); +#ifdef USE_CUDA + // We are supposed to use execution provider as indicator, but here we don't have access to the registered EP at this point + // as the session is not initialized yet. So using macro for now. transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); +#endif if (config.enable_gelu_approximation) { transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); @@ -154,6 +158,7 @@ std::vector> GenerateTransformers( switch (level) { case TransformerLevel::Level1: { std::unordered_set l1_execution_providers = {}; + std::unordered_set cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider}; // TODO hack - constant folding currently doesn't work after mixed precision transformation so it's disabled for now // ORT uses CPU kernels to evaluate constant values but some of them don't support fp16 @@ -161,7 +166,7 @@ std::vector> GenerateTransformers( transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); transformers.emplace_back(onnxruntime::make_unique(free_dimension_overrides)); transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers, weights_to_train)); rule_transformer = optimizer_utils::GenerateRuleBasedGraphTransformer(level, transformers_and_rules_to_enable, l1_execution_providers);