enable MatMulScale and cast propagation for ROCm EP. (#7657)

This commit is contained in:
Weixing Zhang 2021-05-12 13:43:24 -07:00 committed by GitHub
parent 5d9885f706
commit 9241f62e4c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 7 additions and 2 deletions

View file

@ -171,6 +171,7 @@ bool IsMatMulInputTypeSupported(const Node& node) {
// if no matching key is present, any data type is allowed
static const std::map<std::string, std::vector<std::string>> k_supported_data_types{
{kCudaExecutionProvider, {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}},
{kRocmExecutionProvider, {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}},
{kCpuExecutionProvider, {"tensor(float)"}},
};

View file

@ -3756,6 +3756,8 @@ static void RunGatherNDE2EGraph(std::vector<OrtValue>& run_results, const PathSt
execution_provider = DefaultCpuExecutionProvider();
else if (provider_type == onnxruntime::kCudaExecutionProvider)
execution_provider = DefaultCudaExecutionProvider();
else if (provider_type == onnxruntime::kRocmExecutionProvider)
execution_provider = DefaultRocmExecutionProvider();
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
Status st;
@ -3851,6 +3853,8 @@ TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_E2E) {
onnxruntime::kCpuExecutionProvider,
#ifdef USE_CUDA
onnxruntime::kCudaExecutionProvider,
#elif USE_ROCM
onnxruntime::kRocmExecutionProvider,
#endif
};

View file

@ -27,7 +27,7 @@ TEST(RuleBasedGraphTransformerTest, TestCompatibleProviders) {
Graph& graph = model->MainGraph();
// Create rule based transformer with a dummy rewrite rule and register it with Cuda as compatible provider
std::unordered_set<std::string> compatible_provider{onnxruntime::kCudaExecutionProvider};
std::unordered_set<std::string> compatible_provider{onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider};
auto dummy_rule = std::make_unique<DummyRewriteRule>("DummyRule");
const auto* dummy_rule_ptr = dummy_rule.get();

View file

@ -120,7 +120,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
config.number_recompute_layers, compatible_eps));
}
if (config.propagate_cast_ops_config.level >= 0) {
std::unordered_set<std::string> cuda_execution_provider = {onnxruntime::kCudaExecutionProvider};
std::unordered_set<std::string> cuda_execution_provider = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider};
transformers.emplace_back(std::make_unique<PropagateCastOps>(config.propagate_cast_ops_config.strategy,
static_cast<size_t>(config.propagate_cast_ops_config.level),
config.propagate_cast_ops_config.allow,