mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Fix create ort value hardcoded memory info to CPU (#10510)
* Fix create ort value hardcoded memory info to CPU * Remove unneeded check * Remove unneeded header * Remove unneeded header * Update ort_ops.cpp * Update ort_ops.cpp * Update ort_ops.cpp * Update ort_ops.cpp Co-authored-by: root <root@QTM-ANHNGUYEN-1.northamerica.corp.microsoft.com>
This commit is contained in:
parent
1cdc23aba4
commit
0c3e88944d
4 changed files with 32 additions and 12 deletions
|
|
@ -129,11 +129,15 @@ OrtValue create_ort_value(
|
|||
return impl->tensor();
|
||||
}
|
||||
|
||||
OrtMemoryInfo *mem_info;
|
||||
Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info));
|
||||
|
||||
OrtValue ort_tensor;
|
||||
CreateMLValue(
|
||||
tensor.data_ptr(),
|
||||
ort_scalar_type_from_aten(tensor.scalar_type()),
|
||||
tensor.sizes().vec(),
|
||||
*mem_info,
|
||||
&ort_tensor);
|
||||
return ort_tensor;
|
||||
}
|
||||
|
|
@ -544,4 +548,4 @@ at::Tensor& add__Tensor(
|
|||
//#pragma endregion
|
||||
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
} // namespace torch_ort
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ void createInplaceOutputValue(OrtValue& input, V<int64_t> shape, OrtValue* p_mlv
|
|||
onnxruntime::ReshapeHelper helper(input.Get<onnxruntime::Tensor>().Shape(), target_shape);
|
||||
onnxruntime::TensorShape new_shape(target_shape);
|
||||
CreateMLValue(input_ort_tensor->MutableDataRaw(),
|
||||
input_ort_tensor->DataType(), new_shape, p_mlvalue);
|
||||
input_ort_tensor->DataType(), new_shape,
|
||||
input_ort_tensor->Location(), p_mlvalue);
|
||||
}
|
||||
|
||||
template void createInplaceOutputValue<c10::ArrayRef>(OrtValue& input, c10::ArrayRef<int64_t> shape, OrtValue* p_mlvalue);
|
||||
|
|
|
|||
|
|
@ -25,22 +25,28 @@ void CreateMLValue(onnxruntime::AllocatorPtr alloc,
|
|||
onnxruntime::DataTypeImpl::GetType<onnxruntime::Tensor>()->GetDeleteFunc());
|
||||
}
|
||||
|
||||
void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, onnxruntime::TensorShape& shape, OrtValue* p_mlvalue){
|
||||
OrtMemoryInfo *cpu_info;
|
||||
Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &cpu_info));
|
||||
void CreateMLValue(void* data_ptr,
|
||||
onnxruntime::MLDataType element_type,
|
||||
onnxruntime::TensorShape& shape,
|
||||
const OrtMemoryInfo& memory_info,
|
||||
OrtValue* p_mlvalue) {
|
||||
std::unique_ptr<onnxruntime::Tensor> p_tensor = std::make_unique<onnxruntime::Tensor>(element_type,
|
||||
shape,
|
||||
data_ptr,
|
||||
*cpu_info);
|
||||
memory_info);
|
||||
|
||||
p_mlvalue->Init(p_tensor.release(),
|
||||
onnxruntime::DataTypeImpl::GetType<onnxruntime::Tensor>(),
|
||||
onnxruntime::DataTypeImpl::GetType<onnxruntime::Tensor>()->GetDeleteFunc());
|
||||
}
|
||||
|
||||
void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, const std::vector<int64_t>& dims, OrtValue* p_mlvalue) {
|
||||
void CreateMLValue(void* data_ptr,
|
||||
onnxruntime::MLDataType element_type,
|
||||
const std::vector<int64_t>& dims,
|
||||
const OrtMemoryInfo& memory_info,
|
||||
OrtValue* p_mlvalue) {
|
||||
onnxruntime::TensorShape shape(dims);
|
||||
CreateMLValue(data_ptr, element_type, shape, p_mlvalue);
|
||||
CreateMLValue(data_ptr, element_type, shape, memory_info, p_mlvalue);
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetStrides(gsl::span<const int64_t> shape) {
|
||||
|
|
@ -52,4 +58,4 @@ std::vector<int64_t> GetStrides(gsl::span<const int64_t> shape) {
|
|||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
} // namespace torch_ort
|
||||
|
|
|
|||
|
|
@ -15,8 +15,17 @@ void CreateMLValue(onnxruntime::AllocatorPtr alloc,
|
|||
const std::vector<int64_t>& dims,
|
||||
OrtValue* p_mlvalue);
|
||||
|
||||
void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, const std::vector<int64_t>& dims, OrtValue* p_mlvalue);
|
||||
void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, onnxruntime::TensorShape& shape, OrtValue* p_mlvalue);
|
||||
void CreateMLValue(void* data_ptr,
|
||||
onnxruntime::MLDataType element_type,
|
||||
const std::vector<int64_t>& dims,
|
||||
const OrtMemoryInfo& memory_info,
|
||||
OrtValue* p_mlvalue);
|
||||
|
||||
void CreateMLValue(void* data_ptr,
|
||||
onnxruntime::MLDataType element_type,
|
||||
onnxruntime::TensorShape& shape,
|
||||
const OrtMemoryInfo& memory_info,
|
||||
OrtValue* p_mlvalue);
|
||||
|
||||
template <typename T>
|
||||
inline void CopyVectorToTensor(onnxruntime::ORTInvoker& invoker,
|
||||
|
|
@ -57,4 +66,4 @@ inline void CopyVectorToTensor<bool>(onnxruntime::ORTInvoker& /*invoker*/,
|
|||
std::vector<int64_t> GetStrides(gsl::span<const int64_t> shape);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
} // namespace torch_ort
|
||||
|
|
|
|||
Loading…
Reference in a new issue