From 0c3e88944d789a6779f3722eb0cf67710b3dab3f Mon Sep 17 00:00:00 2001 From: Anh Nguyen <94985387+anhnguyen7198@users.noreply.github.com> Date: Tue, 15 Feb 2022 10:40:44 -0800 Subject: [PATCH] Fix create ort value hardcoded memory info to CPU (#10510) * Fix create ort value hardcoded memory info to CPU * Remove unneeded check * Remove unneeded header * Remove unneeded header * Update ort_ops.cpp * Update ort_ops.cpp * Update ort_ops.cpp * Update ort_ops.cpp Co-authored-by: root --- orttraining/orttraining/eager/ort_aten.cpp | 6 +++++- orttraining/orttraining/eager/ort_ops.cpp | 3 ++- orttraining/orttraining/eager/ort_util.cpp | 20 +++++++++++++------- orttraining/orttraining/eager/ort_util.h | 15 ++++++++++++--- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index b9d6630437..98e69596f1 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -129,11 +129,15 @@ OrtValue create_ort_value( return impl->tensor(); } + OrtMemoryInfo *mem_info; + Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info)); + OrtValue ort_tensor; CreateMLValue( tensor.data_ptr(), ort_scalar_type_from_aten(tensor.scalar_type()), tensor.sizes().vec(), + *mem_info, &ort_tensor); return ort_tensor; } @@ -544,4 +548,4 @@ at::Tensor& add__Tensor( //#pragma endregion } // namespace eager -} // namespace torch_ort \ No newline at end of file +} // namespace torch_ort diff --git a/orttraining/orttraining/eager/ort_ops.cpp b/orttraining/orttraining/eager/ort_ops.cpp index 80fcf4432a..07a8de9d7f 100644 --- a/orttraining/orttraining/eager/ort_ops.cpp +++ b/orttraining/orttraining/eager/ort_ops.cpp @@ -26,7 +26,8 @@ void createInplaceOutputValue(OrtValue& input, V shape, OrtValue* p_mlv onnxruntime::ReshapeHelper helper(input.Get().Shape(), target_shape); onnxruntime::TensorShape new_shape(target_shape); CreateMLValue(input_ort_tensor->MutableDataRaw(), - input_ort_tensor->DataType(), new_shape, p_mlvalue); + input_ort_tensor->DataType(), new_shape, + input_ort_tensor->Location(), p_mlvalue); } template void createInplaceOutputValue(OrtValue& input, c10::ArrayRef shape, OrtValue* p_mlvalue); diff --git a/orttraining/orttraining/eager/ort_util.cpp b/orttraining/orttraining/eager/ort_util.cpp index d9ca7ee268..8f34c7434b 100644 --- a/orttraining/orttraining/eager/ort_util.cpp +++ b/orttraining/orttraining/eager/ort_util.cpp @@ -25,22 +25,28 @@ void CreateMLValue(onnxruntime::AllocatorPtr alloc, onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); } -void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, onnxruntime::TensorShape& shape, OrtValue* p_mlvalue){ - OrtMemoryInfo *cpu_info; - Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &cpu_info)); +void CreateMLValue(void* data_ptr, + onnxruntime::MLDataType element_type, + onnxruntime::TensorShape& shape, + const OrtMemoryInfo& memory_info, + OrtValue* p_mlvalue) { std::unique_ptr p_tensor = std::make_unique(element_type, shape, data_ptr, - *cpu_info); + memory_info); p_mlvalue->Init(p_tensor.release(), onnxruntime::DataTypeImpl::GetType(), onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); } -void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, const std::vector& dims, OrtValue* p_mlvalue) { +void CreateMLValue(void* data_ptr, + onnxruntime::MLDataType element_type, + const std::vector& dims, + const OrtMemoryInfo& memory_info, + OrtValue* p_mlvalue) { onnxruntime::TensorShape shape(dims); - CreateMLValue(data_ptr, element_type, shape, p_mlvalue); + CreateMLValue(data_ptr, element_type, shape, memory_info, p_mlvalue); } std::vector GetStrides(gsl::span shape) { @@ -52,4 +58,4 @@ std::vector GetStrides(gsl::span shape) { } } // namespace eager -} // namespace torch_ort \ No newline at end of file +} // namespace torch_ort diff --git a/orttraining/orttraining/eager/ort_util.h b/orttraining/orttraining/eager/ort_util.h index f34d86d2b5..dceed27758 100644 --- a/orttraining/orttraining/eager/ort_util.h +++ b/orttraining/orttraining/eager/ort_util.h @@ -15,8 +15,17 @@ void CreateMLValue(onnxruntime::AllocatorPtr alloc, const std::vector& dims, OrtValue* p_mlvalue); -void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, const std::vector& dims, OrtValue* p_mlvalue); -void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, onnxruntime::TensorShape& shape, OrtValue* p_mlvalue); +void CreateMLValue(void* data_ptr, + onnxruntime::MLDataType element_type, + const std::vector& dims, + const OrtMemoryInfo& memory_info, + OrtValue* p_mlvalue); + +void CreateMLValue(void* data_ptr, + onnxruntime::MLDataType element_type, + onnxruntime::TensorShape& shape, + const OrtMemoryInfo& memory_info, + OrtValue* p_mlvalue); template inline void CopyVectorToTensor(onnxruntime::ORTInvoker& invoker, @@ -57,4 +66,4 @@ inline void CopyVectorToTensor(onnxruntime::ORTInvoker& /*invoker*/, std::vector GetStrides(gsl::span shape); } // namespace eager -} // namespace torch_ort \ No newline at end of file +} // namespace torch_ort