Move CopyTensor out of IExecutionProvider interface. (#1268)

* add ortdevice class

* add data transfer manager for copying tensors.

* update

* add data trasnfer for gpu

* fix constexpr build break.

* update

* remove unnecessary header files.

* remove unnecessary header files.

* add dependency

* add dependency

* add dependency

* add dependency

* fix linux build break.

* update

* fix build break

* fix build break

* fix build break

* update

* update

* update c api.

* update to not use OrtCreateAllocatorInfo

* change to all eps .

* fix linux build break

* remove useless codes.

* update

* move datatransfermanager in session state

* update

* fix cuda build break.

* fix comments

* fix windows GPU build.

* fix comments

* fix build break

* fix comments

* fix test failure

* update

* fix comments

* fix onnx runtime server.

* update

* fix test failure.

* fix comments

* fix comment
This commit is contained in:
Ke Zhang 2019-07-11 14:49:20 -07:00 committed by GitHub
parent e580b76305
commit 3bf0e364e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
58 changed files with 459 additions and 253 deletions

View file

@ -31,6 +31,10 @@ add_dependencies(onnxruntime onnxruntime_generate_def ${onnxruntime_EXTERNAL_DEP
target_include_directories(onnxruntime PRIVATE ${ONNXRUNTIME_ROOT})
onnxruntime_add_include_to_target(onnxruntime gsl)
if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()
if(UNIX)
if (APPLE)
set(BEGIN_WHOLE_ARCHIVE -Xlinker -all_load)

View file

@ -15,6 +15,9 @@ onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxrun
target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS})
add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES})
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()
if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
add_definitions(-DENABLE_LANGUAGE_INTEROP_OPS)

View file

@ -15,21 +15,80 @@
#include "core/framework/fence.h"
#include "core/session/onnxruntime_c_api.h"
// Struct to represent a physical device.
struct OrtDevice {
using DeviceType = int8_t;
using MemoryType = int8_t;
using DeviceId = int16_t;
// Pre-defined device types.
static const DeviceType CPU = 0;
static const DeviceType GPU = 1; //CUDA
static const DeviceType FPGA = 2;
struct MemType {
// Pre-defined memory types.
static const MemoryType DEFAULT = 0;
static const MemoryType CUDA_PINNED = 1;
};
constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
: device_type(device_type_),
memory_type(memory_type_),
device_id(device_id_) {}
constexpr OrtDevice() : OrtDevice(CPU, MemType::DEFAULT, 0) {}
DeviceType Type() const {
return device_type;
}
MemoryType MemType() const {
return memory_type;
}
DeviceId Id() const {
return device_id;
}
std::string ToString() const {
std::ostringstream ostr;
ostr << "Device: ["
<< " type:" << static_cast<int>(device_type)
<< " memory_type:" << static_cast<int>(memory_type)
<< " device_id:" << device_id
<< "]";
return ostr.str();
}
private:
// Device type.
DeviceType device_type;
// Memory type.
MemoryType memory_type;
// Device index.
DeviceId device_id;
};
struct OrtAllocatorInfo {
// use string for name, so we could have customized allocator in execution provider.
const char* name;
int id;
OrtMemType mem_type;
OrtAllocatorType type;
OrtDevice device;
constexpr OrtAllocatorInfo(const char* name_, OrtAllocatorType type_, int id_ = 0, OrtMemType mem_type_ = OrtMemTypeDefault)
constexpr OrtAllocatorInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0, OrtMemType mem_type_ = OrtMemTypeDefault)
#if (defined(__GNUC__) || defined(__clang__))
__attribute__((nonnull))
#endif
: name(name_),
id(id_),
mem_type(mem_type_),
type(type_) {
type(type_),
device(device_) {
}
// To make OrtAllocatorInfo become a valid key in std map
@ -67,6 +126,8 @@ std::ostream& operator<<(std::ostream& out, const OrtAllocatorInfo& info);
namespace onnxruntime {
constexpr const char* CPU = "Cpu";
constexpr const char* CUDA = "Cuda";
constexpr const char* CUDA_PINNED = "CudaPinned";
// forward declaration
class SessionState;

View file

@ -84,20 +84,6 @@ class IExecutionProvider {
*/
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const;
/**
* Copy tensor between execution providers. It's always a deep copy
* Either src.location is CPU, or dst.location is CPU. They can't be both on CPU.
*/
virtual common::Status CopyTensor(const Tensor& src, Tensor& dst) const = 0;
/**
* Copy tensor between execution providers on specified exec queue
* It's always a deep copy
* Either src.location is CPU, or dst.location is CPU. They can't be both on CPU.
*/
virtual common::Status CopyTensor(const Tensor& src, Tensor& dst,
int exec_queue_id) const;
/**
Returns an opaque handle whose exact type varies based on the provider
and is interpreted accordingly by the corresponding kernel implementation.

View file

@ -24,9 +24,12 @@ class KernelRegistry {
// for its clients unless the factory is managing the lifecycle of the pointer
// itself.
// TODO(Task:132) Make usage of unique_ptr/shared_ptr as out param consistent
Status TryCreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider,
Status TryCreateKernel(const onnxruntime::Node& node,
const IExecutionProvider& execution_provider,
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& mlvalue_name_idx_map, const FuncManager& funcs_mgr,
const OrtValueNameIdxMap& mlvalue_name_idx_map,
const FuncManager& funcs_mgr,
const DataTransferManager& data_transfer_mgr,
std::unique_ptr<OpKernel>& op_kernel) const;
// Check if an execution provider can create kernel for a node and return

View file

@ -15,16 +15,20 @@ namespace onnxruntime {
class OrtValueNameIdxMap;
class FuncManager;
class DataTransferManager;
// A very light-weight class, which works as an aggregated
// view of all data needed for constructing a Kernel instance.
// NOTE: it does not own/hold any objects.
class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
public:
explicit OpKernelInfo(const onnxruntime::Node& node, const KernelDef& kernel_def,
explicit OpKernelInfo(const onnxruntime::Node& node,
const KernelDef& kernel_def,
const IExecutionProvider& execution_provider,
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& mlvalue_name_idx_map, const FuncManager& funcs_mgr);
const OrtValueNameIdxMap& mlvalue_name_idx_map,
const FuncManager& funcs_mgr,
const DataTransferManager& data_transfer_mgr);
OpKernelInfo(const OpKernelInfo& other);
@ -36,6 +40,8 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
const IExecutionProvider* GetExecutionProvider() const noexcept;
const DataTransferManager& GetDataTransferManager() const noexcept;
const onnxruntime::Node& node() const noexcept;
bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const;
@ -56,6 +62,7 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
const std::unordered_map<int, OrtValue>& constant_initialized_tensors_;
const OrtValueNameIdxMap& ort_value_name_idx_map_;
const FuncManager& funcs_mgr_;
const DataTransferManager& data_transfer_mgr_;
ProtoHelperNodeContext proto_helper_context_;
};

View file

@ -18,7 +18,7 @@ void* CPUAllocator::Alloc(size_t size) {
#elif defined(__AVX__)
size_t alignment = 32;
#else
size_t alignment = 32; //Indeed, the default one(8 or 16) should be enough
size_t alignment = 32; //Indeed, the default one(8 or 16) should be enough
#endif
#if _MSC_VER
p = _aligned_malloc(size, alignment);
@ -52,7 +52,15 @@ std::ostream& operator<<(std::ostream& out, const OrtAllocatorInfo& info) {
ORT_API_STATUS_IMPL(OrtCreateAllocatorInfo, _In_ const char* name1, OrtAllocatorType type, int id1,
OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out) {
*out = new OrtAllocatorInfo(name1, type, id1, mem_type1);
if (strcmp(name1, onnxruntime::CPU) == 0) {
*out = new OrtAllocatorInfo(name1, type, OrtDevice(), id1, mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA) == 0) {
*out = new OrtAllocatorInfo(name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1, mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) {
*out = new OrtAllocatorInfo(name1, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)), id1, mem_type1);
} else {
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported.");
}
return nullptr;
}

View file

@ -37,7 +37,7 @@ class DummyArena : public IArenaAllocator {
public:
explicit DummyArena(std::unique_ptr<IDeviceAllocator> resource_allocator)
: allocator_(std::move(resource_allocator)),
info_(allocator_->Info().name, OrtAllocatorType::OrtArenaAllocator, allocator_->Info().id) {
info_(allocator_->Info().name, OrtAllocatorType::OrtArenaAllocator, allocator_->Info().device, allocator_->Info().id) {
}
~DummyArena() override = default;

View file

@ -9,7 +9,7 @@ BFCArena::BFCArena(std::unique_ptr<IDeviceAllocator> resource_allocator,
: device_allocator_(std::move(resource_allocator)),
free_chunks_list_(kInvalidChunkHandle),
next_allocation_id_(1),
info_(device_allocator_->Info().name, OrtAllocatorType::OrtArenaAllocator, device_allocator_->Info().id, device_allocator_->Info().mem_type) {
info_(device_allocator_->Info().name, OrtAllocatorType::OrtArenaAllocator, device_allocator_->Info().device, device_allocator_->Info().id, device_allocator_->Info().mem_type) {
curr_region_allocation_bytes_ = RoundedBytes(std::min(total_memory, size_t{1048576}));
// Allocate the requested amount of memory.

View file

@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/data_transfer.h"
namespace onnxruntime {
common::Status IDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
return CopyTensor(src, dst, 0);
}
bool CPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
return src_device.Type() == OrtDevice::CPU && dst_device.Type() == OrtDevice::CPU;
}
common::Status CPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int /*exec_queue_id*/) const {
const void* src_data = src.DataRaw();
void* dst_data = dst.MutableDataRaw();
if (src_data == dst_data) {
// no need copying as both pointers are referring to same piece of memory.
return Status::OK();
}
// Copying only happens between two same size tensors.
ORT_ENFORCE(src.Size() == dst.Size());
memcpy(dst_data, src_data, src.Size());
return Status::OK();
}
}; // namespace onnxruntime

View file

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/status.h"
#include "core/framework/tensor.h"
namespace onnxruntime {
// Data transfer interface.
class IDataTransfer {
public:
virtual ~IDataTransfer() = default;
virtual bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const = 0;
virtual common::Status CopyTensor(const Tensor& src, Tensor& dst) const;
virtual common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const = 0;
};
class CPUDataTransfer : public IDataTransfer {
public:
CPUDataTransfer() = default;
bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;
common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const override;
};
} // namespace onnxruntime

View file

@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/data_transfer_manager.h"
namespace onnxruntime {
using namespace common;
Status DataTransferManager::RegisterDataTransfer(std::unique_ptr<IDataTransfer> data_transfer) {
if (nullptr == data_transfer) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "data_transfer registered is nullptr.");
}
datatransfers_.push_back(std::move(data_transfer));
return Status::OK();
}
const IDataTransfer* DataTransferManager::GetDataTransfer(const OrtDevice& src_device, const OrtDevice& dst_device) const {
for (auto& data_transfer : datatransfers_) {
if (!data_transfer->CanCopy(src_device, dst_device)) {
continue;
}
return data_transfer.get();
}
return nullptr;
}
Status DataTransferManager::CopyTensor(const Tensor& src, Tensor& dst) const {
return CopyTensor(src, dst, 0);
}
Status DataTransferManager::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const {
if (src.Shape().Size() != dst.Shape().Size()) {
return Status(ONNXRUNTIME, FAIL, "Tensor size mismatch");
}
for (auto& data_transfer : datatransfers_) {
if (!data_transfer->CanCopy(src.Location().device, dst.Location().device)) {
continue;
}
return data_transfer->CopyTensor(src, dst, exec_queue_id);
}
return ORT_MAKE_STATUS(ONNXRUNTIME,
FAIL,
"There's no data transfer registered for copying tensors from ",
src.Location().device.ToString(),
" to ",
dst.Location().device.ToString());
}
} // namespace onnxruntime

View file

@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/status.h"
#include "core/framework/data_transfer.h"
#include "core/framework/tensor.h"
namespace onnxruntime {
// Data transfer manager, which has all functions registered to copy tensors with different location.
// It's not thread-safe.
class DataTransferManager {
public:
DataTransferManager() = default;
//static DataTransferManager& Instance();
common::Status RegisterDataTransfer(std::unique_ptr<IDataTransfer> data_transfer);
const IDataTransfer* GetDataTransfer(const OrtDevice& src_device, const OrtDevice& dst_device) const;
common::Status CopyTensor(const Tensor& src, Tensor& dst) const;
common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager);
// It's assumed that data transfers in this array have no overlap in terms of copying functionality.
std::vector<std::unique_ptr<IDataTransfer>> datatransfers_;
};
} // namespace onnxruntime

View file

@ -43,14 +43,6 @@ IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
return result;
}
common::Status IExecutionProvider::CopyTensor(const Tensor& src,
Tensor& dst,
int exec_queue_id) const {
// execution provider may override this to support different exec queues
ORT_ENFORCE(exec_queue_id == 0);
return CopyTensor(src, dst);
}
common::Status IExecutionProvider::Sync() const { return Status::OK(); };
common::Status IExecutionProvider::OnRunStart() { return Status::OK(); }

View file

@ -240,9 +240,12 @@ Status KernelRegistry::Register(KernelCreateInfo&& create_info) {
return Status::OK();
}
Status KernelRegistry::TryCreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider,
Status KernelRegistry::TryCreateKernel(const onnxruntime::Node& node,
const IExecutionProvider& execution_provider,
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& ort_value_name_idx_map, const FuncManager& funcs_mgr,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const FuncManager& funcs_mgr,
const DataTransferManager& data_transfer_mgr,
/*out*/ std::unique_ptr<OpKernel>& op_kernel) const {
const KernelCreateInfo* kernel_create_info = TryFindKernel(node, execution_provider.Type());
@ -250,8 +253,13 @@ Status KernelRegistry::TryCreateKernel(const onnxruntime::Node& node, const IExe
return Status(ONNXRUNTIME, FAIL, "Failed to find kernel for " + node.OpType());
}
OpKernelInfo kernel_info(node, *kernel_create_info->kernel_def, execution_provider, constant_initialized_tensors,
ort_value_name_idx_map, funcs_mgr);
OpKernelInfo kernel_info(node,
*kernel_create_info->kernel_def,
execution_provider,
constant_initialized_tensors,
ort_value_name_idx_map,
funcs_mgr,
data_transfer_mgr);
op_kernel.reset(kernel_create_info->kernel_create_func(kernel_info));
return Status::OK();
}

View file

@ -23,7 +23,7 @@ Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node,
{
for (auto& registry : custom_kernel_registries_) {
status = registry->TryCreateKernel(node, execution_provider, session_state.GetConstantInitializedTensors(),
session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), op_kernel);
session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), session_state.GetDataTransferMgr(), op_kernel);
if (status.IsOK()) {
return status;
}
@ -35,7 +35,7 @@ Status KernelRegistryManager::CreateKernel(const onnxruntime::Node& node,
if (iter != provider_type_to_registry_.end()) p = iter->second.get();
if (p != nullptr) {
status = p->TryCreateKernel(node, execution_provider, session_state.GetConstantInitializedTensors(),
session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), op_kernel);
session_state.GetOrtValueNameIdxMap(), session_state.GetFuncMgr(), session_state.GetDataTransferMgr(), op_kernel);
if (status.IsOK()) {
return status;
}

View file

@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/data_transfer_manager.h"
#include "memcpy.h"
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
@ -13,7 +14,7 @@ Memcpy::Memcpy(const OpKernelInfo& info)
Status Memcpy::Compute(OpKernelContext* ctx) const {
const auto* X = ctx->Input<Tensor>(0);
Tensor* Y = ctx->Output(0, X->Shape());
Status retval = provider_->CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
Status retval = Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
return retval;
}

View file

@ -8,10 +8,13 @@
namespace onnxruntime {
OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node, const KernelDef& kernel_def,
OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
const KernelDef& kernel_def,
const IExecutionProvider& execution_provider,
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& ort_value_name_idx_map, const FuncManager& funcs_mgr)
const OrtValueNameIdxMap& ort_value_name_idx_map,
const FuncManager& funcs_mgr,
const DataTransferManager& data_transfer_mgr)
: OpNodeProtoHelper(&proto_helper_context_),
node_(node),
kernel_def_(kernel_def),
@ -19,11 +22,12 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node, const KernelDef& kerne
constant_initialized_tensors_(constant_initialized_tensors),
ort_value_name_idx_map_(ort_value_name_idx_map),
funcs_mgr_(funcs_mgr),
data_transfer_mgr_(data_transfer_mgr),
proto_helper_context_(node) {}
OpKernelInfo::OpKernelInfo(const OpKernelInfo& other)
: OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.constant_initialized_tensors_,
other.ort_value_name_idx_map_, other.funcs_mgr_) {}
other.ort_value_name_idx_map_, other.funcs_mgr_, other.data_transfer_mgr_) {}
const OrtAllocatorInfo& OpKernelInfo::GetAllocatorInfo(int device_id, OrtMemType mem_type) const {
AllocatorPtr alloc = GetAllocator(device_id, mem_type);
@ -43,6 +47,10 @@ const IExecutionProvider* OpKernelInfo::GetExecutionProvider() const noexcept {
return execution_provider_;
}
const DataTransferManager& OpKernelInfo::GetDataTransferManager() const noexcept {
return data_transfer_mgr_;
}
const onnxruntime::Node& OpKernelInfo::node() const noexcept {
return node_;
}

View file

@ -14,6 +14,7 @@
#include "core/common/logging/logging.h"
#include "core/common/profiler.h"
#include "core/framework/allocation_planner.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_providers.h"
#include "core/framework/feeds_fetches_manager.h"
#include "core/framework/kernel_registry_manager.h"
@ -182,8 +183,10 @@ class SessionState {
const FuncManager& GetFuncMgr() const { return fused_funcs_mgr_; }
FuncManager& GetMutableFuncMgr() { return fused_funcs_mgr_; }
std::vector<BufferUniquePtr>& GetMutableWeightsBuffers() { return weights_buffers_; }
const DataTransferManager& GetDataTransferMgr() const { return *data_transfer_mgr_; }
void SetDataTransferMgr(const DataTransferManager* data_transfer_mgr) { data_transfer_mgr_ = data_transfer_mgr; }
std::vector<BufferUniquePtr>& GetMutableWeightsBuffers() { return weights_buffers_; }
void CalculateNodeIndexInfo();
const NodeIndexInfo& GetNodeIndexInfo() const;
@ -232,6 +235,7 @@ class SessionState {
bool export_fused_dll_ = false;
FuncManager fused_funcs_mgr_;
const DataTransferManager* data_transfer_mgr_;
std::unique_ptr<NodeIndexInfo> node_index_info_;
std::multimap<int, std::unique_ptr<FeedsFetchesManager>> cached_feeds_fetches_managers_;

View file

@ -12,6 +12,7 @@
#include "core/common/logging/logging.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/data_transfer_manager.h"
#include "core/graph/graph_utils.h"
#include "core/framework/graph_partitioner.h"
#include "core/framework/ml_value.h"
@ -36,7 +37,8 @@ static common::Status SaveInitializedTensors(const Env& env, const std::basic_st
const onnxruntime::Graph& graph, const ExecutionProviders& exec_providers,
const OrtValueNameIdxMap& ort_value_name_idx_map,
ITensorAllocator* planner, const T& save_tensor_func,
const logging::Logger& logger);
const logging::Logger& logger,
const DataTransferManager& data_transfer_mgr);
static common::Status SaveKernels(const ExecutionProviders& execution_providers,
SessionState& session_state,
@ -111,8 +113,7 @@ common::Status SessionStateInitializer::InitializeAndSave(
[this](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status {
return session_state_.AddInitializedTensor(idx, value, &d, constant);
},
logger_));
logger_, session_state_.GetDataTransferMgr()));
// remove weights from the graph now to save memory but in many cases it won't save memory, if the tensor was
// preallocated with the some other tensors in a single 'allocate' call, which is very common.
// TODO: make it better
@ -180,7 +181,8 @@ common::Status SaveMLValueNameIndexMapping(const GraphViewer& graph_viewer, OrtV
static common::Status DeserializeTensorProto(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& proto_path,
const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m,
const ExecutionProviders& exec_providers, OrtValue& ort_value,
OrtCallback& deleter) {
OrtCallback& deleter,
const DataTransferManager& data_transfer_mgr) {
const OrtAllocatorInfo& alloc_info = m.GetAllocInfo();
if (strcmp(alloc_info.name, CPU) == 0 || alloc_info.mem_type == OrtMemTypeCPUOutput) {
// deserialize directly to CPU tensor
@ -219,7 +221,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
p_tensor = std::make_unique<Tensor>(p_deserialize_tensor.DataType(), p_deserialize_tensor.Shape(), m.GetBuffer(),
m.GetAllocInfo());
// TODO: does this function work for string tensor?
Status copy_status = provider->CopyTensor(p_deserialize_tensor, *p_tensor);
Status copy_status = data_transfer_mgr.CopyTensor(p_deserialize_tensor, *p_tensor);
if (d.f) d.f(d.param);
if (!copy_status.IsOK()) {
if (copy_status.ErrorMessage().empty()) {
@ -239,7 +241,8 @@ template <typename T>
common::Status SaveInitializedTensors(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
const Graph& graph, const ExecutionProviders& exec_providers,
const OrtValueNameIdxMap& ort_value_name_idx_map, ITensorAllocator* planner,
const T& save_tensor_func, const logging::Logger& logger) {
const T& save_tensor_func, const logging::Logger& logger,
const DataTransferManager& data_transfer_mgr) {
LOGS(logger, INFO) << "Saving initialized tensors.";
ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > 0, "OrtValue indexes should have been populated.");
@ -272,7 +275,7 @@ common::Status SaveInitializedTensors(const Env& env, const std::basic_string<PA
ORT_ENFORCE(m->GetBuffer() != nullptr || m->GetLen() == 0);
#endif
OrtValue ort_value;
Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, *m, exec_providers, ort_value, deleter);
Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, *m, exec_providers, ort_value, deleter, data_transfer_mgr);
if (!st.IsOK()) {
std::ostringstream oss;
oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage();

View file

@ -6,7 +6,7 @@
#include <iomanip>
#include "core/graph/graph_viewer.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_frame.h"
#include "core/framework/execution_providers.h"
#include "core/framework/feeds_fetches_manager.h"
@ -58,7 +58,9 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info)
return required_provider_type;
}
static Status CopyMLValue(const FeedsFetchesManager::MLValueCopyInfo& copy_info, const OrtValue& source_mlvalue,
static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
const FeedsFetchesManager::MLValueCopyInfo& copy_info,
const OrtValue& source_mlvalue,
OrtValue& target_mlvalue) {
if (copy_info.copy_provider == nullptr) {
target_mlvalue = source_mlvalue;
@ -72,7 +74,7 @@ static Status CopyMLValue(const FeedsFetchesManager::MLValueCopyInfo& copy_info,
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();
ORT_RETURN_IF_ERROR(copy_info.copy_provider->CopyTensor(source_tensor, *p_output_tensor));
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
}
return Status::OK();
@ -150,7 +152,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
copy_info.allocation_provider = required_provider;
copy_info.copy_provider = p_copy_provider;
ORT_RETURN_IF_ERROR(CopyMLValue(copy_info, orig_mlvalue, new_mlvalue));
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue));
needed_copy = true;
@ -205,14 +207,15 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
// copies inputs across devices only if required using cached copy_info
static common::Status CachedCopyInputsAcrossDevices(
const std::vector<OrtValue>& orig_feeds, std::vector<OrtValue>& new_feeds,
const std::vector<FeedsFetchesManager::MLValueCopyInfo>& copy_info) {
const std::vector<FeedsFetchesManager::MLValueCopyInfo>& copy_info,
const DataTransferManager& data_transfer_mgr) {
size_t num_feeds = orig_feeds.size();
ORT_ENFORCE(copy_info.size() == num_feeds);
new_feeds.resize(num_feeds);
for (size_t idx = 0; idx < num_feeds; ++idx) {
ORT_RETURN_IF_ERROR(CopyMLValue(copy_info[idx], orig_feeds[idx], new_feeds[idx]));
ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], orig_feeds[idx], new_feeds[idx]));
}
return Status::OK();
@ -379,7 +382,7 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
const int device_id = 0; // TODO: As per comment in the copy input code, make this configurable.
FeedsFetchesManager::MLValueCopyInfo copy_info{device_id, p_output_provider, p_copy_provider};
ORT_RETURN_IF_ERROR(CopyMLValue(copy_info, fetched_mlvalue, output_mlvalue));
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, fetched_mlvalue, output_mlvalue));
if (copiers) {
(*copiers)[idx] = copy_info;
@ -391,7 +394,8 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
static common::Status CachedCopyOutputsAcrossDevices(
const std::vector<OrtValue>& fetches, std::vector<OrtValue>& user_fetches,
const std::vector<FeedsFetchesManager::MLValueCopyInfo>& copy_info) {
const std::vector<FeedsFetchesManager::MLValueCopyInfo>& copy_info,
const DataTransferManager& data_transfer_mgr) {
auto num_outputs = fetches.size();
// internal logic error if these are mismatched
@ -399,7 +403,7 @@ static common::Status CachedCopyOutputsAcrossDevices(
// used the cached copy logic if available
for (size_t idx = 0; idx < num_outputs; ++idx) {
ORT_RETURN_IF_ERROR(CopyMLValue(copy_info[idx], fetches[idx], user_fetches[idx]));
ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], fetches[idx], user_fetches[idx]));
}
return Status::OK();
@ -456,7 +460,8 @@ common::Status ExecuteGraphWithCachedInfo(
// Copy inputs
if (device_copy_checks.input_copy_needed == DeviceCopyCheck::Copy) {
ORT_RETURN_IF_ERROR(CachedCopyInputsAcrossDevices(feeds, device_feeds,
feeds_fetches_manager.GetFeedsDeviceCopiers()));
feeds_fetches_manager.GetFeedsDeviceCopiers(),
session_state.GetDataTransferMgr()));
p_feeds = &device_feeds;
}
@ -480,7 +485,8 @@ common::Status ExecuteGraphWithCachedInfo(
if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
ORT_RETURN_IF_ERROR(CachedCopyOutputsAcrossDevices(*p_fetches, fetches,
feeds_fetches_manager.GetFetchesDeviceCopiers()));
feeds_fetches_manager.GetFetchesDeviceCopiers(),
session_state.GetDataTransferMgr()));
}
}

View file

@ -3,6 +3,7 @@
#include "core/common/status.h"
#include "core/common/logging/logging.h"
#include "core/common/logging/macros.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/data_types.h"
#include "core/framework/mldata_type_utils.h"
@ -22,6 +23,8 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
allocator_ptr_ = cpu_execution_provider_->GetAllocator(device_id_, mem_type_);
ORT_ENFORCE(allocator_ptr_ != nullptr, "Failed to get allocator for optimizer");
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<CPUDataTransfer>());
// Create MLValues related maps
auto initialize_maps = [this, &initialized_tensor_set](const NodeArg& arg, size_t /*index*/) -> Status {
int idx = ort_value_name_idx_map_.Add(arg.Name());
@ -63,7 +66,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
std::unique_ptr<OpKernel> op_kernel;
std::shared_ptr<KernelRegistry> kernel_registry = cpu_execution_provider_->GetKernelRegistry();
auto status = kernel_registry->TryCreateKernel(*node, *cpu_execution_provider_, initializers_,
ort_value_name_idx_map_, FuncManager(), op_kernel);
ort_value_name_idx_map_, FuncManager(), data_transfer_mgr_, op_kernel);
kernels_[node->Index()] = std::move(op_kernel);
}
}

View file

@ -7,12 +7,14 @@
#include "core/graph/graph.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_frame.h"
#include "core/framework/ort_value_name_idx_map.h"
#include "core/framework/ml_value.h"
#include "core/common/callback.h"
namespace onnxruntime {
class DataTransferManager;
class OptimizerExecutionFrame final : public IExecutionFrame {
public:
@ -54,7 +56,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
const int device_id_{0};
const OrtMemType mem_type_{OrtMemTypeDefault};
AllocatorPtr allocator_ptr_;
DataTransferManager data_transfer_mgr_;
// MLValues for optimizer
OrtValueNameIdxMap ort_value_name_idx_map_;
std::unordered_map<int, const NodeArg*> ort_value_idx_nodearg_map_;

View file

@ -46,10 +46,6 @@ class CPUExecutionProvider : public IExecutionProvider {
#endif
}
Status CopyTensor(const Tensor&, Tensor&) const override {
return Status(common::ONNXRUNTIME, common::FAIL, "Shouldn't reach here. CPUExecutionProvider doesn't support CopyTensor");
}
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;

View file

@ -35,5 +35,6 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessio
}
ORT_API_STATUS_IMPL(OrtCreateCpuAllocatorInfo, enum OrtAllocatorType type, enum OrtMemType mem_type, _Out_ OrtAllocatorInfo** out) {
return OrtCreateAllocatorInfo(onnxruntime::CPU, type, 0, mem_type, out);
*out = new OrtAllocatorInfo(onnxruntime::CPU, type, OrtDevice(), 0, mem_type);
return nullptr;
}

View file

@ -1,17 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "cuda_common.h"
#include "cuda_allocator.h"
#include "cuda_common.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/session_state.h"
#include "cuda_fence.h"
#include "gpu_data_transfer.h"
namespace onnxruntime {
static const CUDAExecutionProvider* GetCUDAExecutionProvider(const SessionState* session_state) {
return dynamic_cast<const CUDAExecutionProvider*>(
session_state->GetExecutionProviders().Get(onnxruntime::kCudaExecutionProvider));
static const GPUDataTransfer* GetGPUDataTransfer(const SessionState* session_state) {
OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0);
OrtDevice cpu_device;
return dynamic_cast<const GPUDataTransfer*>(session_state->GetDataTransferMgr().GetDataTransfer(gpu_device, cpu_device));
}
void CUDAAllocator::CheckDevice() const {
@ -43,7 +45,7 @@ const OrtAllocatorInfo& CUDAAllocator::Info() const {
}
FencePtr CUDAAllocator::CreateFence(const SessionState* session_state) {
return std::make_shared<CUDAFence>(GetCUDAExecutionProvider(session_state));
return std::make_shared<CUDAFence>(GetGPUDataTransfer(session_state));
}
void* CUDAPinnedAllocator::Alloc(size_t size) {
@ -59,12 +61,12 @@ void CUDAPinnedAllocator::Free(void* p) {
}
const OrtAllocatorInfo& CUDAPinnedAllocator::Info() const {
static constexpr OrtAllocatorInfo cuda_allocator_info(CUDA_PINNED, OrtDeviceAllocator, 0, OrtMemTypeCPUOutput);
static constexpr OrtAllocatorInfo cuda_allocator_info(CUDA_PINNED, OrtDeviceAllocator, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0), 0, OrtMemTypeCPUOutput);
return cuda_allocator_info;
}
FencePtr CUDAPinnedAllocator::CreateFence(const SessionState* session_state) {
return std::make_shared<CUDAFence>(GetCUDAExecutionProvider(session_state));
return std::make_shared<CUDAFence>(GetGPUDataTransfer(session_state));
}
} // namespace onnxruntime

View file

@ -6,12 +6,10 @@
#include "core/framework/allocator.h"
namespace onnxruntime {
constexpr const char* CUDA = "Cuda";
constexpr const char* CUDA_PINNED = "CudaPinned";
class CUDAAllocator : public IDeviceAllocator {
public:
CUDAAllocator(int device_id) : info_(CUDA, OrtAllocatorType::OrtDeviceAllocator, device_id, OrtMemTypeDefault) {}
CUDAAllocator(int device_id) : info_(CUDA, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeDefault) {}
virtual void* Alloc(size_t size) override;
virtual void Free(void* p) override;
virtual const OrtAllocatorInfo& Info() const override;

View file

@ -4,6 +4,7 @@
#pragma once
#include "cuda_pch.h"
#include "core/common/status.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/op_kernel.h"
#include "core/graph/graph_viewer.h"
#include "shared_inc/cuda_call.h"
@ -137,7 +138,7 @@ class CudaKernel : public OpKernel {
}
inline Status CopyTensor(const Tensor& src, Tensor& dst) const {
return provider_->CopyTensor(src, dst);
return Info().GetDataTransferManager().CopyTensor(src, dst);
}
inline int GetDeviceId() const { return provider_->GetDeviceId(); }

View file

@ -199,52 +199,6 @@ Status CUDAExecutionProvider::OnRunEnd() {
return Status::OK();
}
Status CUDAExecutionProvider::CopyTensor(const Tensor& src, Tensor& dst) const {
return CopyTensor(src, dst, kCudaStreamDefault);
}
Status CUDAExecutionProvider::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const {
if (src.Shape().Size() != dst.Shape().Size()) {
return Status(ONNXRUNTIME, FAIL, "Tensor size mismatch");
}
if (strcmp(src.Location().name, CUDA) != 0 && strcmp(src.Location().name, CUDA_PINNED) != 0 &&
strcmp(dst.Location().name, CUDA) != 0 && strcmp(dst.Location().name, CUDA_PINNED) != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: src_location is: ", src.Location().name, " and dst_location is: ", dst.Location().name);
}
size_t bytes = src.Size();
const void* src_data = src.DataRaw();
void* dst_data = dst.MutableDataRaw();
if (strcmp(dst.Location().name, CUDA) == 0) {
if (strcmp(src.Location().name, CUDA_PINNED) == 0) {
// copy from pinned memory to GPU, this is non-blocking
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, streams_[exec_queue_id]));
} else if (strcmp(src.Location().name, CUDA) == 0) {
// copying between GPU, this is non-blocking
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault]));
} else {
// copy from other CPU memory to GPU, this is blocking
CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice));
}
} else if (strcmp(src.Location().name, CUDA) == 0) {
if (strcmp(dst.Location().name, CUDA_PINNED) == 0) {
// copying from GPU to pinned memory, this is non-blocking
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, streams_[exec_queue_id]));
} else {
// copying from GPU to CPU memory, this is blocking
CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost));
}
} else {
// copying between cpu memory
memcpy(dst_data, src_data, bytes);
}
return Status::OK();
}
namespace cuda {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyToHost);

View file

@ -7,6 +7,7 @@
#include "core/graph/constants.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/execution_provider.h"
#include "core/providers/cuda/gpu_data_transfer.h"
#include "shared_inc/cuda_utils.h"
#include <deque>
@ -17,13 +18,6 @@ struct CUDAExecutionProviderInfo {
int device_id{0};
};
enum CUDAStreamType : int {
kCudaStreamDefault = 0,
kCudaStreamCopyIn,
kCudaStreamCopyOut,
kTotalCudaStreams,
};
// Logical device representation.
class CUDAExecutionProvider : public IExecutionProvider {
public:
@ -38,10 +32,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
Status OnRunEnd() override;
Status CopyTensor(const Tensor& src, Tensor& dst) const override;
Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const override;
cublasHandle_t PerThreadCublasHandle() {
return GetPerThreadContext().CublasHandle();
}

View file

@ -3,10 +3,11 @@
#include "cuda_common.h"
#include "cuda_fence.h"
#include "gpu_data_transfer.h"
namespace onnxruntime {
CUDAFence::CUDAFence(const CUDAExecutionProvider* provider) : provider_(provider) {
CUDAFence::CUDAFence(const GPUDataTransfer* data_transfer) : data_transfer_(data_transfer) {
// NOTE: cudaEventBlockingSync may leads to longer wait time because of thread yield/switching in kernel
// if lower CPU usage is more important than latency, we should use this flag to avoid spin-loop in WaitOnCPU
int event_flags = /*cudaEventBlockingSync |*/ cudaEventDisableTiming;
@ -22,7 +23,7 @@ CUDAFence::~CUDAFence() {
void CUDAFence::BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int async_queue_id) {
if (provider_type == onnxruntime::kCudaExecutionProvider) {
// sync in GPU, the call is non-blocking on CPU
CUDA_CALL_THROW(cudaStreamWaitEvent(provider_->GetStream(async_queue_id), write_event_, 0));
CUDA_CALL_THROW(cudaStreamWaitEvent(data_transfer_->GetStream(async_queue_id), write_event_, 0));
} else {
// sync on CPU for all other providers, this is blocking
CUDA_CALL_THROW(cudaEventSynchronize(write_event_));
@ -32,7 +33,7 @@ void CUDAFence::BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int
void CUDAFence::BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) {
if (provider_type == onnxruntime::kCudaExecutionProvider) {
// sync in GPU, the call is non-blocking on CPU
cudaStream_t stream = provider_->GetStream(queue_id);
cudaStream_t stream = data_transfer_->GetStream(queue_id);
CUDA_CALL_THROW(cudaStreamWaitEvent(stream, read_event_, 0));
CUDA_CALL_THROW(cudaStreamWaitEvent(stream, write_event_, 0));
} else {
@ -49,13 +50,13 @@ bool CUDAFence::CanRelease() {
void CUDAFence::AfterUsedAsInput(int queue_id) {
// update read fence
cudaStream_t stream = provider_->GetStream(queue_id);
cudaStream_t stream = data_transfer_->GetStream(queue_id);
CUDA_CALL_THROW(cudaEventRecord(read_event_, stream));
}
void CUDAFence::AfterUsedAsOutput(int queue_id) {
// update write fence
cudaStream_t stream = provider_->GetStream(queue_id);
cudaStream_t stream = data_transfer_->GetStream(queue_id);
CUDA_CALL_THROW(cudaEventRecord(write_event_, stream));
}

View file

@ -3,12 +3,14 @@
#pragma once
#include "core/framework/tensor.h"
#include "cuda_execution_provider.h"
#include "core/graph/basic_types.h"
namespace onnxruntime {
class GPUDataTransfer;
class CUDAFence : public IFence {
public:
CUDAFence(const CUDAExecutionProvider* provider);
CUDAFence(const GPUDataTransfer* data_transfer);
virtual ~CUDAFence();
virtual void BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int queue_id) override;
virtual void BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) override;
@ -19,7 +21,7 @@ class CUDAFence : public IFence {
private:
cudaEvent_t read_event_;
cudaEvent_t write_event_;
const CUDAExecutionProvider* provider_;
const GPUDataTransfer* data_transfer_;
};
} // namespace onnxruntime

View file

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/gpu_data_transfer.h"
#include "cuda_common.h"
namespace onnxruntime {
GPUDataTransfer::GPUDataTransfer() {
// create streams, default is nullptr
streams_[kCudaStreamDefault] = nullptr;
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyIn], cudaStreamNonBlocking));
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyOut], cudaStreamNonBlocking));
}
bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED
|| dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED;
}
common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const {
size_t bytes = src.Size();
const void* src_data = src.DataRaw();
void* dst_data = dst.MutableDataRaw();
auto& src_device = src.Location().device;
auto& dst_device = dst.Location().device;
if (dst_device.Type() == OrtDevice::GPU) {
if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) {
// copy from pinned memory to GPU, this is non-blocking
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, streams_[exec_queue_id]));
} else if (src_device.Type() == OrtDevice::GPU) {
// copying between GPU, this is non-blocking
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault]));
} else {
// copy from other CPU memory to GPU, this is blocking
CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice));
}
} else if (src_device.Type() == OrtDevice::GPU) {
if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED) {
// copying from GPU to pinned memory, this is non-blocking
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, streams_[exec_queue_id]));
} else {
// copying from GPU to CPU memory, this is blocking
CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost));
}
} else {
// copying between cpu memory
memcpy(dst_data, src_data, bytes);
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "cuda_pch.h"
#include "core/framework/data_transfer.h"
namespace onnxruntime {
enum CUDAStreamType : int {
kCudaStreamDefault = 0,
kCudaStreamCopyIn,
kCudaStreamCopyOut,
kTotalCudaStreams,
};
class GPUDataTransfer : public IDataTransfer {
public:
GPUDataTransfer();
bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;
common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const override;
cudaStream_t GetStream(int queue_id) const {
ORT_ENFORCE(queue_id >= 0 && queue_id < kTotalCudaStreams);
return streams_[queue_id];
}
private:
cudaStream_t streams_[kTotalCudaStreams];
};
} // namespace onnxruntime

View file

@ -41,10 +41,10 @@ ONNX_OPERATOR_KERNEL_EX(
MKLDNNExecutionProvider::MKLDNNExecutionProvider(const MKLDNNExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kMklDnnExecutionProvider} {
DeviceAllocatorRegistrationInfo default_allocator_info({OrtMemTypeDefault,
[](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(MKLDNN, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault)); }, std::numeric_limits<size_t>::max()});
[](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(MKLDNN, OrtAllocatorType::OrtDeviceAllocator)); }, std::numeric_limits<size_t>::max()});
DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUOutput,
[](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(MKLDNN_CPU, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeCPUOutput)); }, std::numeric_limits<size_t>::max()});
[](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(MKLDNN_CPU, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); }, std::numeric_limits<size_t>::max()});
if (info.create_arena) {
InsertAllocator(CreateAllocator(default_allocator_info));
@ -62,23 +62,6 @@ MKLDNNExecutionProvider::MKLDNNExecutionProvider(const MKLDNNExecutionProviderIn
MKLDNNExecutionProvider::~MKLDNNExecutionProvider() {
}
Status MKLDNNExecutionProvider::CopyTensor(const Tensor& src, Tensor& dst) const {
// Support CPU <-> MKLDNN for now
if (!(strcmp(src.Location().name, MKLDNN) == 0 && strcmp(dst.Location().name, CPU) == 0) &&
!(strcmp(src.Location().name, CPU) == 0 && strcmp(dst.Location().name, MKLDNN) == 0) &&
!(strcmp(src.Location().name, MKLDNN) == 0 && strcmp(dst.Location().name, MKLDNN_CPU) == 0)) {
ORT_NOT_IMPLEMENTED(src.Location().name, " copy to ", dst.Location().name, " is not implemented");
}
// Todo: Copy for now. May optimize later to avoid copy.
size_t bytes = src.DataType()->Size() * src.Shape().Size();
const void* src_data = src.DataRaw();
void* dst_data = dst.MutableDataRaw();
memcpy(dst_data, src_data, bytes);
return Status::OK();
}
namespace mkl_dnn {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 7, Gemm);

View file

@ -35,8 +35,6 @@ class MKLDNNExecutionProvider : public IExecutionProvider {
explicit MKLDNNExecutionProvider(const MKLDNNExecutionProviderInfo& info);
virtual ~MKLDNNExecutionProvider();
Status CopyTensor(const Tensor& src, Tensor& dst) const override;
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
std::shared_ptr<mkldnn::memory> GetWeightsMemoryBuffer(const std::string& weight_key) {

View file

@ -34,13 +34,13 @@ constexpr const char* NGRAPH = "nGraph";
NGRAPHExecutionProvider::NGRAPHExecutionProvider(const NGRAPHExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kNGraphExecutionProvider} {
DeviceAllocatorRegistrationInfo default_allocator_info({OrtMemTypeDefault,
[](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(NGRAPH, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault)); },
[](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(NGRAPH, OrtAllocatorType::OrtDeviceAllocator)); },
std::numeric_limits<size_t>::max()});
InsertAllocator(CreateAllocator(default_allocator_info));
DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUOutput,
[](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(NGRAPH, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeCPUOutput)); },
[](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(NGRAPH, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); },
std::numeric_limits<size_t>::max()});
InsertAllocator(CreateAllocator(cpu_allocator_info));
@ -76,24 +76,6 @@ bool TensorCopyPossible(const std::string& src_location, const std::string& dst_
});
}
Status NGRAPHExecutionProvider::CopyTensor(const Tensor& src, Tensor& dst) const {
const size_t src_bytes = src.DataType()->Size() * src.Shape().Size();
const size_t dst_bytes = dst.DataType()->Size() * dst.Shape().Size();
if (src_bytes != dst_bytes) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"nGraph: Source and Destination data sizes are not equal - cannot copy tensors");
}
if (!TensorCopyPossible(src.Location().name, dst.Location().name)) {
ORT_NOT_IMPLEMENTED("Copying tensors between '", src.Location().name, "' and '", dst.Location().name,
"' is not implemented in NGRAPHExecutionProvider");
}
MEMCPY_S(dst.MutableDataRaw(), src.DataRaw(), dst_bytes, src_bytes);
return Status::OK();
}
// Returns true only if op is in a mode that is not currently supported
static bool IsUnsupportedOpMode(const Node* node, const onnxruntime::GraphViewer& graph_viewer) {
const auto& optype = node->OpType();

View file

@ -24,8 +24,6 @@ class NGRAPHExecutionProvider : public IExecutionProvider {
explicit NGRAPHExecutionProvider(const NGRAPHExecutionProviderInfo& info);
~NGRAPHExecutionProvider() = default;
Status CopyTensor(const Tensor& src, Tensor& dst) const override;
std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) const override;

View file

@ -29,7 +29,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(OpenVINOExecutionProviderIn
: IExecutionProvider{onnxruntime::kOpenVINOExecutionProvider} {
ORT_UNUSED_PARAMETER(info);
DeviceAllocatorRegistrationInfo device_info({OrtMemTypeDefault, [](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(OPENVINO, OrtDeviceAllocator, 0, OrtMemTypeDefault)); }, std::numeric_limits<size_t>::max()});
DeviceAllocatorRegistrationInfo device_info({OrtMemTypeDefault, [](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>(OPENVINO, OrtDeviceAllocator)); }, std::numeric_limits<size_t>::max()});
InsertAllocator(CreateAllocator(device_info));
}

View file

@ -47,16 +47,6 @@ class OpenVINOExecutionProvider : public IExecutionProvider {
return std::make_shared<KernelRegistry>();
}
common::Status CopyTensor(const Tensor& src, Tensor& dst) const override {
// TODO: Copy for now. May optimize later to avoid copy.
size_t bytes = src.DataType()->Size() * src.Shape().Size();
const void* src_data = src.DataRaw();
void* dst_data = dst.MutableDataRaw();
memcpy(dst_data, src_data, bytes);
return Status::OK();
}
const void* GetExecutionHandle() const noexcept override {
return nullptr;
}

View file

@ -12,7 +12,7 @@ class TensorrtPinnedAllocator : public CPUAllocator {
public:
virtual const OrtAllocatorInfo& Info() const override {
static OrtAllocatorInfo tensorrt_cpu_allocator_info(TRT,
OrtAllocatorType::OrtDeviceAllocator, 0,
OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0,
OrtMemType::OrtMemTypeCPU);
return tensorrt_cpu_allocator_info;
}
@ -25,8 +25,7 @@ class TensorrtAllocator : public CPUAllocator {
public:
virtual const OrtAllocatorInfo& Info() const override {
static OrtAllocatorInfo tensorrt_default_allocator_info(TRT,
OrtAllocatorType::OrtDeviceAllocator, 0,
OrtMemType::OrtMemTypeDefault);
OrtAllocatorType::OrtDeviceAllocator);
return tensorrt_default_allocator_info;
}
};

View file

@ -280,12 +280,6 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
return result;
}
common::Status TensorrtExecutionProvider::CopyTensor(const Tensor& src, Tensor& dst) const {
ORT_UNUSED_PARAMETER(src);
ORT_UNUSED_PARAMETER(dst);
return Status::OK();
}
common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto* fused_node : fused_nodes) {

View file

@ -62,8 +62,6 @@ class TensorrtExecutionProvider : public IExecutionProvider {
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
Status CopyTensor(const Tensor& src, Tensor& dst) const override;
void SetMaxBatchSize(const int batch_size) {
max_batch_size_ = batch_size;
}

View file

@ -44,7 +44,7 @@ class IOBinding {
* copy it to the desired location. This copy may or may not be async. It depends on the exec provider.
* If the input ort_value is not at the desired location, it should be preallocated
* If the input ort_value isn't preallocated, it should have memtype of OrtMemTypeDefault
* For copying it leverages IExecutionProvider::CopyTensor().
* For copying it leverages DataTransferManager::CopyTensor().
*/
common::Status BindInput(const std::string& name, const OrtValue& ml_value);

View file

@ -17,7 +17,7 @@ struct OrtDefaultAllocator : OrtAllocatorImpl {
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<OrtDefaultAllocator*>(this_)->Alloc(size); };
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<OrtDefaultAllocator*>(this_)->Free(p); };
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const OrtDefaultAllocator*>(this_)->Info(); };
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
ORT_THROW_ON_ERROR(OrtCreateCpuAllocatorInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpuAllocatorInfo));
}
~OrtDefaultAllocator() override { OrtReleaseAllocatorInfo(cpuAllocatorInfo); }

View file

@ -44,6 +44,9 @@
#include "core/optimizer/insert_cast_transformer.h"
#include "core/optimizer/transformer_memcpy.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#ifdef USE_CUDA
#include "core/providers/cuda/gpu_data_transfer.h"
#endif
#include "core/session/IOBinding.h"
#include "core/session/custom_ops.h"
#include "core/util/protobuf_parsing_utils.h"
@ -100,6 +103,13 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, loggin
InitLogger(logging_manager);
// Register data transfer methods.
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<CPUDataTransfer>());
#ifdef USE_CUDA
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<GPUDataTransfer>());
#endif
session_state_.SetDataTransferMgr(&data_transfer_mgr_);
// The threadpool is currently evolving. We will always create a per session threadpool.
// Beyond this, we will create a global thread pool to share across sessions.
{
@ -396,7 +406,8 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio
subgraph_session_state->SetLogger(*session_logger_);
// Pass threadpool to subgraph
subgraph_session_state->SetThreadPool(session_state.GetThreadPool());
// Pass data transfer manager to subgraph.
subgraph_session_state->SetDataTransferMgr(&session_state.GetDataTransferMgr());
// Pass fused function manager to subgraph
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());

View file

@ -420,6 +420,8 @@ class InferenceSession {
// Threadpool for this session
std::unique_ptr<onnxruntime::concurrency::ThreadPool> thread_pool_;
// Data transfer manager.
DataTransferManager data_transfer_mgr_;
// Number of concurrently running executors
std::atomic<int> current_num_runs_;

View file

@ -525,7 +525,7 @@ ORT_API_STATUS_IMPL(OrtTensorProtoToOrtValue, _In_ const void* input, int input_
_Out_ OrtValue** out, _Out_ OrtCallback** deleter) {
API_IMPL_BEGIN
OrtAllocatorInfo* cpuAllocatorInfo;
auto st = OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo);
auto st = OrtCreateCpuAllocatorInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpuAllocatorInfo);
if (st != nullptr) return st;
::ONNX_NAMESPACE::TensorProto proto;
if (!proto.ParseFromArray(input, input_len)) {

View file

@ -62,8 +62,8 @@ protobufutil::Status Executor::SetNameMLValueMap(std::vector<std::string>& input
auto ort_status = OrtCreateCpuAllocatorInfo(OrtArenaAllocator, OrtMemTypeDefault, &allocator_info);
if (ort_status != nullptr || allocator_info == nullptr) {
logger->error("OrtCreateAllocatorInfo failed");
return protobufutil::Status(protobufutil::error::Code::RESOURCE_EXHAUSTED, "OrtCreateAllocatorInfo() failed");
logger->error("OrtCreateCpuAllocatorInfo failed");
return protobufutil::Status(protobufutil::error::Code::RESOURCE_EXHAUSTED, "OrtCreateCpuAllocatorInfo() failed");
}
// Prepare the Value object

View file

@ -196,7 +196,7 @@ class PlannerTest : public ::testing::Test {
void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def) {
auto info = std::make_unique<OpKernelInfo>(*p_node, kernel_def, *execution_providers_.Get(*p_node),
state_.GetInitializedTensors(), state_.GetOrtValueNameIdxMap(),
state_.GetFuncMgr());
state_.GetFuncMgr(), state_.GetDataTransferMgr());
auto dummy = std::make_unique<DummyOpKernel>(*info);
op_kernel_infos_.push_back(std::move(info));
state_.AddKernel(p_node->Index(), std::move(dummy));

View file

@ -49,7 +49,7 @@ class TestAllocator : public IAllocator {
}
virtual const OrtAllocatorInfo& Info() const override {
static OrtAllocatorInfo info("test", OrtDeviceAllocator, 0);
static OrtAllocatorInfo info("test", OrtDeviceAllocator);
return info;
}

View file

@ -18,23 +18,6 @@ class DummyExecutionProvider : public IExecutionProvider {
InsertAllocator(std::make_unique<DummyAllocator>());
}
Status CopyTensor(const Tensor& src, Tensor& dst) const override {
// we can 'copy' from anything we allocated to/from CPU
ORT_ENFORCE(strcmp(dst.Location().name, DummyAllocator::kDummyAllocator) == 0 ||
strcmp(dst.Location().name, CPU) == 0);
ORT_ENFORCE(strcmp(src.Location().name, DummyAllocator::kDummyAllocator) == 0 ||
strcmp(src.Location().name, CPU) == 0);
// no really copy needed.
const void* src_data = src.DataRaw();
void* dst_data = dst.MutableDataRaw();
// copying between cpu memory
memcpy(dst_data, src_data, src.Size());
return Status::OK();
}
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
};

View file

@ -14,6 +14,7 @@
#include "core/common/logging/logging.h"
#include "core/common/profiler.h"
#include "core/framework/compute_capability.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_provider.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/op_kernel.h"
@ -25,6 +26,9 @@
#include "core/platform/env.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/providers/cpu/math/element_wise_ops.h"
#ifdef USE_CUDA
#include "core/providers/cuda/gpu_data_transfer.h"
#endif
#include "core/session/IOBinding.h"
#include "dummy_provider.h"
#include "test_utils.h"
@ -112,12 +116,6 @@ class FuseExecutionProvider : public IExecutionProvider {
static std::shared_ptr<KernelRegistry> kernel_registry = GetFusedKernelRegistry();
return kernel_registry;
}
common::Status CopyTensor(const Tensor& src, Tensor& dst) const override {
ORT_UNUSED_PARAMETER(src);
ORT_UNUSED_PARAMETER(dst);
return Status::OK();
}
};
namespace test {
@ -284,7 +282,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
std::unique_ptr<Tensor> cpu_tensor = std::make_unique<Tensor>(element_type,
shape,
cpu_allocator);
st = TestCudaExecutionProvider()->CopyTensor(rtensor, *cpu_tensor.get());
st = GPUDataTransfer().CopyTensor(rtensor, *cpu_tensor.get(), 0);
ASSERT_TRUE(st.IsOK());
OrtValue ml_value;
ml_value.Init(cpu_tensor.release(),

View file

@ -21,13 +21,6 @@ namespace test {
class XPUExecutionProvider : public IExecutionProvider {
public:
XPUExecutionProvider() : IExecutionProvider{onnxruntime::kCpuExecutionProvider} {}
Status CopyTensor(const Tensor& src, Tensor& dst) const override {
ORT_UNUSED_PARAMETER(src);
ORT_UNUSED_PARAMETER(dst);
return Status::OK();
}
};
} // namespace test

View file

@ -57,7 +57,7 @@ TEST(SessionStateTest, AddGetKernelTest) {
CPUExecutionProvider execution_provider{CPUExecutionProviderInfo{"CPUExecutionProvider"}};
OpKernelInfo p_info(node, kernel_def, execution_provider, s.GetConstantInitializedTensors(),
s.GetOrtValueNameIdxMap(), s.GetFuncMgr());
s.GetOrtValueNameIdxMap(), s.GetFuncMgr(), s.GetDataTransferMgr());
unique_ptr<TestOpKernel> p_kernel;
p_kernel.reset(new TestOpKernel(p_info));
size_t orig_num_outputs = p_kernel->Node().OutputDefs().size();

View file

@ -92,7 +92,7 @@ OrtValue* CreateTensorWithDataAsOrtValue(OrtAllocatorInfo* info, std::vector<T>&
template <typename key_type, typename value_type>
OrtValue* PbMapToOrtValue(const google::protobuf::Map<key_type, value_type>& map) {
OrtAllocatorInfo* info;
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &info));
ORT_THROW_ON_ERROR(OrtCreateCpuAllocatorInfo(OrtDeviceAllocator, OrtMemTypeDefault, &info));
std::unique_ptr<OrtAllocatorInfo, decltype(&OrtReleaseAllocatorInfo)> rel_info(info, OrtReleaseAllocatorInfo);
const size_t ele_count = map.size();
std::vector<int64_t> dims(1, ele_count);
@ -122,7 +122,7 @@ OrtValue* PbMapToOrtValue(const google::protobuf::Map<key_type, value_type>& map
template <typename T>
void VectorProtoToOrtValue(const RepeatedPtrField<T>& input, ORT_VALUE_HOLDER& output) {
OrtAllocatorInfo* info;
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &info));
ORT_THROW_ON_ERROR(OrtCreateCpuAllocatorInfo(OrtDeviceAllocator, OrtMemTypeDefault, &info));
std::unique_ptr<OrtAllocatorInfo, decltype(&OrtReleaseAllocatorInfo)> rel_info(info, OrtReleaseAllocatorInfo);
OrtValueArray in(input.size());
size_t j = 0;

View file

@ -10,7 +10,7 @@ using namespace onnxruntime;
TEST_F(CApiTest, allocation_info) {
OrtAllocatorInfo *info1, *info2;
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtArenaAllocator, 0, OrtMemTypeDefault, &info1));
ORT_THROW_ON_ERROR(OrtCreateCpuAllocatorInfo(OrtArenaAllocator, OrtMemTypeDefault, &info1));
ORT_THROW_ON_ERROR(OrtCreateCpuAllocatorInfo(OrtArenaAllocator, OrtMemTypeDefault, &info2));
int result;
ORT_THROW_ON_ERROR(OrtCompareAllocatorInfo(info1, info2, &result));

View file

@ -8,7 +8,7 @@ MockedOrtAllocator::MockedOrtAllocator() {
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<MockedOrtAllocator*>(this_)->Alloc(size); };
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<MockedOrtAllocator*>(this_)->Free(p); };
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const MockedOrtAllocator*>(this_)->Info(); };
ORT_THROW_ON_ERROR(OrtCreateAllocatorInfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault, &cpuAllocatorInfo));
ORT_THROW_ON_ERROR(OrtCreateCpuAllocatorInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpuAllocatorInfo));
}
MockedOrtAllocator::~MockedOrtAllocator() {