diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 697bec53cb..2909114aad 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -171,6 +171,7 @@ bool IsMatMulInputTypeSupported(const Node& node) { // if no matching key is present, any data type is allowed static const std::map> k_supported_data_types{ {kCudaExecutionProvider, {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}}, + {kRocmExecutionProvider, {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}}, {kCpuExecutionProvider, {"tensor(float)"}}, }; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 22b9e13545..c31267b1ba 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -3756,6 +3756,8 @@ static void RunGatherNDE2EGraph(std::vector& run_results, const PathSt execution_provider = DefaultCpuExecutionProvider(); else if (provider_type == onnxruntime::kCudaExecutionProvider) execution_provider = DefaultCudaExecutionProvider(); + else if (provider_type == onnxruntime::kRocmExecutionProvider) + execution_provider = DefaultRocmExecutionProvider(); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); Status st; @@ -3851,6 +3853,8 @@ TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_E2E) { onnxruntime::kCpuExecutionProvider, #ifdef USE_CUDA onnxruntime::kCudaExecutionProvider, +#elif USE_ROCM + onnxruntime::kRocmExecutionProvider, #endif }; diff --git a/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc b/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc index e9b6d0a454..2017598a7f 100644 --- a/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc +++ b/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc @@ -27,7 +27,7 @@ TEST(RuleBasedGraphTransformerTest, TestCompatibleProviders) { Graph& graph = model->MainGraph(); // Create rule based transformer with a dummy rewrite rule and register it with Cuda as compatible provider - std::unordered_set compatible_provider{onnxruntime::kCudaExecutionProvider}; + std::unordered_set compatible_provider{onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider}; auto dummy_rule = std::make_unique("DummyRule"); const auto* dummy_rule_ptr = dummy_rule.get(); diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 26292d4d68..9eb7b764c4 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -120,7 +120,7 @@ std::vector> GeneratePreTrainingTransformers( config.number_recompute_layers, compatible_eps)); } if (config.propagate_cast_ops_config.level >= 0) { - std::unordered_set cuda_execution_provider = {onnxruntime::kCudaExecutionProvider}; + std::unordered_set cuda_execution_provider = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider}; transformers.emplace_back(std::make_unique(config.propagate_cast_ops_config.strategy, static_cast(config.propagate_cast_ops_config.level), config.propagate_cast_ops_config.allow,