From 07bf8b968e31483d86d9bd5d88cd3b4bd903bb16 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 9 Sep 2020 11:46:55 +0800 Subject: [PATCH] Register BiasGelu and BiasDropout for CUDA only. (#5060) Co-authored-by: Vincent Wang --- .../orttraining/core/optimizer/graph_transformer_utils.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index fc16345af3..22ca244b85 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -83,7 +83,11 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); +#ifdef USE_CUDA + // We are supposed to use execution provider as indicator, but here we don't have access to the registered EP at this point + // as the session is not initialized yet. So using macro for now. transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); +#endif if (config.enable_gelu_approximation) { transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); @@ -154,6 +158,7 @@ std::vector> GenerateTransformers( switch (level) { case TransformerLevel::Level1: { std::unordered_set l1_execution_providers = {}; + std::unordered_set cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider}; // TODO hack - constant folding currently doesn't work after mixed precision transformation so it's disabled for now // ORT uses CPU kernels to evaluate constant values but some of them don't support fp16 @@ -161,7 +166,7 @@ std::vector> GenerateTransformers( transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); transformers.emplace_back(onnxruntime::make_unique(free_dimension_overrides)); transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers, weights_to_train)); rule_transformer = optimizer_utils::GenerateRuleBasedGraphTransformer(level, transformers_and_rules_to_enable, l1_execution_providers);