From bb1c1332450bf9ff4d66a5e2c7d8ccab83fd89ad Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 7 Oct 2022 14:30:19 -0700 Subject: [PATCH] [MicroGraph] Address ROCM warning and build failure (#13234) ### Description Address build failures after Public API refactoring ### Motivation and Context Make pipelines health. --- .../providers/migraphx/migraphx_execution_provider.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 91a326ca6a..f5d7b8f2b4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -98,7 +98,6 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, true}, device_id_(info.device_id) { - InitProviderOrtApi(); // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(device_id_)); @@ -1043,7 +1042,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); - + std::unordered_map& map_input_name_index = mgx_state->input_name_indexes; migraphx::target t = mgx_state->t; migraphx::program& prog = mgx_state->prog; @@ -1062,7 +1061,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto& index = it.second; auto input_tensor = ctx.GetInput(index); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetTensorShape(); + const auto tensor_shape = tensor_info.GetShape(); std::vector ort_lens(tensor_shape.begin(), tensor_shape.end()); cmp_options.set_input_parameter_shape(name, ort_lens); input_shape_match = false; @@ -1130,7 +1129,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (mgx_type != mgx_s.type()) { LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - m.add(name, migraphx::argument(param_shapes[name], input_tensor.GetTensorMutableRawData())); + m.add(name, migraphx::argument(param_shapes[name], + const_cast(input_tensor.GetTensorRawData()))); } // It is a output argument else {