mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
enable MatMulScale and cast propagation for ROCm EP. (#7657)
This commit is contained in:
parent
5d9885f706
commit
9241f62e4c
4 changed files with 7 additions and 2 deletions
|
|
@ -171,6 +171,7 @@ bool IsMatMulInputTypeSupported(const Node& node) {
|
|||
// if no matching key is present, any data type is allowed
|
||||
static const std::map<std::string, std::vector<std::string>> k_supported_data_types{
|
||||
{kCudaExecutionProvider, {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}},
|
||||
{kRocmExecutionProvider, {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}},
|
||||
{kCpuExecutionProvider, {"tensor(float)"}},
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -3756,6 +3756,8 @@ static void RunGatherNDE2EGraph(std::vector<OrtValue>& run_results, const PathSt
|
|||
execution_provider = DefaultCpuExecutionProvider();
|
||||
else if (provider_type == onnxruntime::kCudaExecutionProvider)
|
||||
execution_provider = DefaultCudaExecutionProvider();
|
||||
else if (provider_type == onnxruntime::kRocmExecutionProvider)
|
||||
execution_provider = DefaultRocmExecutionProvider();
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
|
||||
Status st;
|
||||
|
|
@ -3851,6 +3853,8 @@ TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_E2E) {
|
|||
onnxruntime::kCpuExecutionProvider,
|
||||
#ifdef USE_CUDA
|
||||
onnxruntime::kCudaExecutionProvider,
|
||||
#elif USE_ROCM
|
||||
onnxruntime::kRocmExecutionProvider,
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ TEST(RuleBasedGraphTransformerTest, TestCompatibleProviders) {
|
|||
Graph& graph = model->MainGraph();
|
||||
|
||||
// Create rule based transformer with a dummy rewrite rule and register it with Cuda as compatible provider
|
||||
std::unordered_set<std::string> compatible_provider{onnxruntime::kCudaExecutionProvider};
|
||||
std::unordered_set<std::string> compatible_provider{onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider};
|
||||
auto dummy_rule = std::make_unique<DummyRewriteRule>("DummyRule");
|
||||
const auto* dummy_rule_ptr = dummy_rule.get();
|
||||
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
config.number_recompute_layers, compatible_eps));
|
||||
}
|
||||
if (config.propagate_cast_ops_config.level >= 0) {
|
||||
std::unordered_set<std::string> cuda_execution_provider = {onnxruntime::kCudaExecutionProvider};
|
||||
std::unordered_set<std::string> cuda_execution_provider = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider};
|
||||
transformers.emplace_back(std::make_unique<PropagateCastOps>(config.propagate_cast_ops_config.strategy,
|
||||
static_cast<size_t>(config.propagate_cast_ops_config.level),
|
||||
config.propagate_cast_ops_config.allow,
|
||||
|
|
|
|||
Loading…
Reference in a new issue