Register BiasGelu and BiasDropout for CUDA only. (#5060)

Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
Vincent Wang 2020-09-09 11:46:55 +08:00 committed by GitHub
parent f41614a875
commit 07bf8b968e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -83,7 +83,11 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(compatible_eps));
transformers.emplace_back(onnxruntime::make_unique<FastGeluFusion>(compatible_eps));
#ifdef USE_CUDA
// We are supposed to use execution provider as indicator, but here we don't have access to the registered EP at this point
// as the session is not initialized yet. So using macro for now.
transformers.emplace_back(onnxruntime::make_unique<BiasGeluFusion>(compatible_eps));
#endif
if (config.enable_gelu_approximation) {
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(compatible_eps));
@ -154,6 +158,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
switch (level) {
case TransformerLevel::Level1: {
std::unordered_set<std::string> l1_execution_providers = {};
std::unordered_set<std::string> cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider};
// TODO hack - constant folding currently doesn't work after mixed precision transformation so it's disabled for now
// ORT uses CPU kernels to evaluate constant values but some of them don't support fp16
@ -161,7 +166,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(onnxruntime::make_unique<MatMulAddFusion>(l1_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<FreeDimensionOverrideTransformer>(free_dimension_overrides));
transformers.emplace_back(onnxruntime::make_unique<MatmulTransposeFusion>(l1_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<BiasDropoutFusion>(l1_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<BiasDropoutFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<MatMulScaleFusion>(l1_execution_providers, weights_to_train));
rule_transformer = optimizer_utils::GenerateRuleBasedGraphTransformer(level, transformers_and_rules_to_enable, l1_execution_providers);