mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[MicroGraph] Address ROCM warning and build failure (#13234)
### Description Address build failures after Public API refactoring ### Motivation and Context Make pipelines health.
This commit is contained in:
parent
6662ece4a1
commit
bb1c133245
1 changed files with 4 additions and 4 deletions
|
|
@ -98,7 +98,6 @@ std::shared_ptr<KernelRegistry> MIGraphXExecutionProvider::GetKernelRegistry() c
|
|||
|
||||
MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info)
|
||||
: IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, true}, device_id_(info.device_id) {
|
||||
|
||||
InitProviderOrtApi();
|
||||
// Set GPU device to be used
|
||||
HIP_CALL_THROW(hipSetDevice(device_id_));
|
||||
|
|
@ -1043,7 +1042,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) {
|
||||
Ort::KernelContext ctx(context);
|
||||
MIGraphXFuncState* mgx_state = reinterpret_cast<MIGraphXFuncState*>(state);
|
||||
|
||||
|
||||
std::unordered_map<std::string, std::size_t>& map_input_name_index = mgx_state->input_name_indexes;
|
||||
migraphx::target t = mgx_state->t;
|
||||
migraphx::program& prog = mgx_state->prog;
|
||||
|
|
@ -1062,7 +1061,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
auto& index = it.second;
|
||||
auto input_tensor = ctx.GetInput(index);
|
||||
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
|
||||
const auto tensor_shape = tensor_info.GetTensorShape();
|
||||
const auto tensor_shape = tensor_info.GetShape();
|
||||
std::vector<std::size_t> ort_lens(tensor_shape.begin(), tensor_shape.end());
|
||||
cmp_options.set_input_parameter_shape(name, ort_lens);
|
||||
input_shape_match = false;
|
||||
|
|
@ -1130,7 +1129,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
if (mgx_type != mgx_s.type()) {
|
||||
LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch";
|
||||
}
|
||||
m.add(name, migraphx::argument(param_shapes[name], input_tensor.GetTensorMutableRawData()));
|
||||
m.add(name, migraphx::argument(param_shapes[name],
|
||||
const_cast<void*>(input_tensor.GetTensorRawData())));
|
||||
}
|
||||
// It is a output argument
|
||||
else {
|
||||
|
|
|
|||
Loading…
Reference in a new issue