mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
[CANN] Multi-stream execution support for CANN EP. (#14058)
### Description **Multi-stream** execution support for **CANN EP**. ### Motivation and Context **CANN EP** is currently **unavailable** due to the introduction of a new mechanism for multi-stream execution [#13495](https://github.com/microsoft/onnxruntime/pull/13495), the deletion of the Fence-based synchronization mechanism, and the failure to update the relevant logic of **CANN EP** synchronously. This PR is to fix it.
This commit is contained in:
parent
febc69e1b2
commit
ecb89ed752
49 changed files with 560 additions and 439 deletions
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
#include "core/framework/arena_extend_strategy.h"
|
||||
|
||||
|
|
@ -11,9 +13,13 @@ struct OrtCANNProviderOptions {
|
|||
int device_id; // CANN device id
|
||||
size_t npu_mem_limit; // BFC Arena memory limit for CANN
|
||||
onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena
|
||||
int do_copy_in_default_stream; // Flag indicating if copying needs to take place on the
|
||||
// same stream as the compute stream in the CANN EP
|
||||
int enable_cann_graph; // Flag indicating if prioritizing the use of
|
||||
// CANN's graph-running capabilities
|
||||
int dump_graphs; // Flag indicating if dumping graphs
|
||||
std::string precision_mode; // Operator Precision Mode
|
||||
std::string op_select_impl_mode; // Operator-level model compilation options:
|
||||
// Mode selection
|
||||
std::string optypelist_for_implmode; // Operator-level model compilation options:
|
||||
// Operator list
|
||||
OrtArenaCfg* default_memory_arena_cfg; // CANN memory arena configuration parameters
|
||||
};
|
||||
|
|
|
|||
|
|
@ -181,6 +181,19 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess
|
|||
return;
|
||||
}
|
||||
|
||||
#ifdef USE_CANN
|
||||
// For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool,
|
||||
// which is different from CUDA Runtime API, but similar to CUDA Driver API.
|
||||
auto& execution_providers = ctx.GetSessionState().GetExecutionProviders();
|
||||
for (auto& xp : execution_providers) {
|
||||
auto status = xp->OnRunStart();
|
||||
if (!status.IsOK()) {
|
||||
ctx.SetStatus(status);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// get logic stream
|
||||
auto& execution_plan = ctx.GetSessionState().GetExecutionPlan()->execution_plan;
|
||||
auto& logic_stream = execution_plan[stream_idx];
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@ Status Activations::Prepare(OpKernelContext* ctx, CannPreparation& prepare) cons
|
|||
|
||||
#define REGISTER_ACTIVATION_TYPED_COMPUTE(x, T) \
|
||||
template <> \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* ctx) const { \
|
||||
CannPreparation prepare; \
|
||||
ORT_RETURN_IF_ERROR(Prepare<T>(context, prepare)); \
|
||||
ORT_RETURN_IF_ERROR(Prepare<T>(ctx, prepare)); \
|
||||
CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \
|
||||
prepare.inputDesc_.size(), \
|
||||
prepare.inputDesc_.data(), \
|
||||
|
|
@ -46,7 +46,7 @@ Status Activations::Prepare(OpKernelContext* ctx, CannPreparation& prepare) cons
|
|||
ACL_ENGINE_SYS, \
|
||||
ACL_COMPILE_SYS, \
|
||||
NULL, \
|
||||
Stream())); \
|
||||
Stream(ctx))); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ template <typename T>
|
|||
class Relu final : public Activations {
|
||||
public:
|
||||
Relu(const OpKernelInfo& info) : Activations(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
} // namespace cann
|
||||
|
|
|
|||
|
|
@ -7,18 +7,10 @@
|
|||
#include "core/providers/cann/cann_call.h"
|
||||
#include "core/providers/cann/cann_allocator.h"
|
||||
#include "core/framework/allocatormgr.h"
|
||||
#include "core/providers/cann/cann_fence.h"
|
||||
#include "core/providers/cann/npu_data_transfer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
static const NPUDataTransfer* GetNPUDataTransfer(const SessionState* session_state) {
|
||||
OrtDevice npu_device(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, 0);
|
||||
OrtDevice cpu_device;
|
||||
return static_cast<const NPUDataTransfer*>(
|
||||
session_state->GetDataTransferMgr().GetDataTransfer(npu_device, cpu_device));
|
||||
}
|
||||
|
||||
void* CANNAllocator::Alloc(size_t size) {
|
||||
void* p = nullptr;
|
||||
aclrtMemMallocPolicy policy = ACL_MEM_MALLOC_HUGE_FIRST;
|
||||
|
|
@ -32,10 +24,6 @@ void CANNAllocator::Free(void* p) {
|
|||
aclrtFree(p);
|
||||
}
|
||||
|
||||
FencePtr CANNAllocator::CreateFence(const SessionState* session_state) {
|
||||
return std::make_shared<CANNFence>(GetNPUDataTransfer(session_state));
|
||||
}
|
||||
|
||||
void* CANNPinnedAllocator::Alloc(size_t size) {
|
||||
void* p = nullptr;
|
||||
if (size > 0) {
|
||||
|
|
@ -48,8 +36,4 @@ void CANNPinnedAllocator::Free(void* p) {
|
|||
CANN_CALL_THROW(aclrtFreeHost(p));
|
||||
}
|
||||
|
||||
FencePtr CANNPinnedAllocator::CreateFence(const SessionState* session_state) {
|
||||
return std::make_shared<CANNFence>(GetNPUDataTransfer(session_state));
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ class CANNAllocator : public IAllocator {
|
|||
device_id, OrtMemTypeDefault)) {}
|
||||
void* Alloc(size_t size) override;
|
||||
void Free(void* p) override;
|
||||
FencePtr CreateFence(const SessionState* session_state) override;
|
||||
};
|
||||
|
||||
class CANNPinnedAllocator : public IAllocator {
|
||||
|
|
@ -32,7 +31,6 @@ class CANNPinnedAllocator : public IAllocator {
|
|||
|
||||
void* Alloc(size_t size) override;
|
||||
void Free(void* p) override;
|
||||
FencePtr CreateFence(const SessionState* session_state) override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -101,6 +101,8 @@ template <>
|
|||
const char* CannErrString<ge::graphStatus>(ge::graphStatus e) {
|
||||
using namespace ge;
|
||||
|
||||
aclrtSynchronizeDevice();
|
||||
|
||||
switch (e) {
|
||||
CASE_ENUM_TO_STR(GRAPH_FAILED);
|
||||
CASE_ENUM_TO_STR(GRAPH_SUCCESS);
|
||||
|
|
|
|||
|
|
@ -16,14 +16,14 @@
|
|||
#include "core/providers/cann/cann_inc.h"
|
||||
#include "core/providers/cann/cann_call.h"
|
||||
#include "core/providers/cann/cann_allocator.h"
|
||||
#include "core/providers/cann/cann_fence.h"
|
||||
#include "core/providers/cann/cann_fwd.h"
|
||||
#include "core/providers/cann/cann_stream_handle.h"
|
||||
#include "core/providers/cann/npu_data_transfer.h"
|
||||
|
||||
using onnxruntime::cann::BuildONNXModel;
|
||||
using onnxruntime::cann::CannModelPreparation;
|
||||
using onnxruntime::cann::ParserONNXModel;
|
||||
using onnxruntime::cann::SupportONNXModel;
|
||||
using onnxruntime::cann::CannModelPreparation;
|
||||
using onnxruntime::common::Status;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -42,49 +42,63 @@ class Memcpy final : public OpKernel {
|
|||
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr.");
|
||||
Tensor* Y = ctx->Output(0, X->Shape());
|
||||
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor.");
|
||||
return Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
|
||||
} else if (X_type->IsSparseTensorType()) {
|
||||
const auto* X = ctx->Input<SparseTensor>(0);
|
||||
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr.");
|
||||
SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape());
|
||||
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor.");
|
||||
return X->Copy(Info().GetDataTransferManager(), Info().GetKernelDef().ExecQueueId(), *Y);
|
||||
} else if (X_type->IsTensorSequenceType()) {
|
||||
const TensorSeq* X = ctx->Input<TensorSeq>(0);
|
||||
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr.");
|
||||
TensorSeq* Y = ctx->Output<TensorSeq>(0);
|
||||
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence.");
|
||||
auto X_dtype = X->DataType();
|
||||
Y->SetType(X_dtype);
|
||||
AllocatorPtr alloc;
|
||||
|
||||
if (Node().OpType() == "MemcpyFromHost") {
|
||||
auto status = ctx->GetTempSpaceAllocator(&alloc);
|
||||
if (!status.IsOK()) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"Memcpy cann: unable to get an allocator.");
|
||||
}
|
||||
} else {
|
||||
auto status = ctx->GetTempSpaceCPUAllocator(&alloc);
|
||||
if (!status.IsOK()) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"Memcpy cann: unable to get the CPU allocator.");
|
||||
}
|
||||
}
|
||||
auto X_size = X->Size();
|
||||
for (size_t i = 0; i < X_size; ++i) {
|
||||
const Tensor& source_tensor = X->Get(i);
|
||||
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc);
|
||||
Status retval = Info().GetDataTransferManager().CopyTensor(source_tensor, *target_tensor,
|
||||
Info().GetKernelDef().ExecQueueId());
|
||||
if (!retval.IsOK()) {
|
||||
return retval;
|
||||
}
|
||||
Y->Add(std::move(*target_tensor));
|
||||
}
|
||||
auto* npu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device,
|
||||
Y->Location().device);
|
||||
ORT_RETURN_IF_ERROR(npu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream()));
|
||||
return Status::OK();
|
||||
} else {
|
||||
if (X_type->IsSparseTensorType()) {
|
||||
aclrtSynchronizeStream(static_cast<aclrtStream>(ctx->GetComputeStream()->GetHandle()));
|
||||
const auto* X = ctx->Input<SparseTensor>(0);
|
||||
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr.");
|
||||
SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape());
|
||||
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor.");
|
||||
return X->Copy(Info().GetDataTransferManager(), *Y);
|
||||
} else if (X_type->IsTensorSequenceType()) {
|
||||
const TensorSeq* X = ctx->Input<TensorSeq>(0);
|
||||
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr.");
|
||||
TensorSeq* Y = ctx->Output<TensorSeq>(0);
|
||||
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence.");
|
||||
auto X_dtype = X->DataType();
|
||||
Y->SetType(X_dtype);
|
||||
AllocatorPtr alloc;
|
||||
|
||||
// If we are copying contents to CANN, the allocator to use
|
||||
// to allocate the buffers of the new tensors in the sequence
|
||||
// can be temp space allocator associated with the CANN EP
|
||||
if (Node().OpType() == "MemcpyFromHost") {
|
||||
auto status = ctx->GetTempSpaceAllocator(&alloc);
|
||||
if (!status.IsOK()) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"Memcpy cann: unable to get an allocator.");
|
||||
}
|
||||
} else {
|
||||
// If we are copying contents to CPU (op type is "MemcpyToHost"),
|
||||
// the allocator to use to allocate the buffers of the new tensors
|
||||
// in the sequence will be the allocator from the CPU EP
|
||||
auto status = ctx->GetTempSpaceCPUAllocator(&alloc);
|
||||
if (!status.IsOK()) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"Memcpy cann: unable to get the CPU allocator.");
|
||||
}
|
||||
}
|
||||
auto X_size = X->Size();
|
||||
for (size_t i = 0; i < X_size; ++i) {
|
||||
const Tensor& source_tensor = X->Get(i);
|
||||
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(),
|
||||
source_tensor.Shape(),
|
||||
alloc);
|
||||
auto* npu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device,
|
||||
target_tensor->Location().device);
|
||||
ORT_RETURN_IF_ERROR(npu_data_transfer->CopyTensorAsync(source_tensor,
|
||||
*target_tensor,
|
||||
*ctx->GetComputeStream()));
|
||||
Y->Add(std::move(*target_tensor));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type.");
|
||||
}
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type.");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -97,7 +111,6 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kCannExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 0)
|
||||
.ExecQueueId(kCannStreamCopyIn)
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()),
|
||||
Memcpy);
|
||||
|
||||
|
|
@ -108,7 +121,6 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kCannExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 0)
|
||||
.ExecQueueId(kCannStreamCopyOut)
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()),
|
||||
Memcpy);
|
||||
|
||||
|
|
@ -1013,27 +1025,20 @@ CANNExecutionProvider::CANNExecutionProvider(const CANNExecutionProviderInfo& in
|
|||
InitProviderOrtApi();
|
||||
|
||||
CANN_CALL_THROW(aclrtSetDevice(info_.device_id));
|
||||
CANN_CALL_THROW(aclrtCreateStream(&stream_));
|
||||
|
||||
soc_name_ = aclrtGetSocName();
|
||||
ORT_ENFORCE(soc_name_ != nullptr, "aclrtGetSocName return nullptr");
|
||||
}
|
||||
|
||||
CANNExecutionProvider::~CANNExecutionProvider() {
|
||||
CANN_CALL_THROW(aclrtDestroyStream(stream_));
|
||||
for (auto modelID : modelIDs_) {
|
||||
CANN_CALL_THROW(aclmdlUnload(modelID.second));
|
||||
}
|
||||
}
|
||||
|
||||
// All threads share the same context and stream
|
||||
Status CANNExecutionProvider::OnRunStart() {
|
||||
CANN_CALL_THROW(aclrtSetDevice(info_.device_id));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CANNExecutionProvider::OnRunEnd(bool sync_stream) {
|
||||
if (sync_stream) {
|
||||
CANN_CALL_THROW(aclrtSynchronizeStream(stream_));
|
||||
}
|
||||
CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -1050,6 +1055,8 @@ void InitializeRegistry() {
|
|||
void DeleteRegistry() {
|
||||
s_kernel_registry.reset();
|
||||
|
||||
ge::aclgrphBuildFinalize();
|
||||
|
||||
CANN_CALL_THROW(aclFinalize());
|
||||
}
|
||||
|
||||
|
|
@ -1058,8 +1065,7 @@ std::shared_ptr<KernelRegistry> CANNExecutionProvider::GetKernelRegistry() const
|
|||
}
|
||||
|
||||
std::unique_ptr<onnxruntime::IDataTransfer> CANNExecutionProvider::GetDataTransfer() const {
|
||||
return std::make_unique<onnxruntime::NPUDataTransfer>(static_cast<aclrtStream>(GetComputeStream()),
|
||||
info_.do_copy_in_default_stream);
|
||||
return std::make_unique<onnxruntime::NPUDataTransfer>();
|
||||
}
|
||||
|
||||
std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
|
||||
|
|
@ -1072,9 +1078,9 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
|
|||
}
|
||||
|
||||
// Get parent graph output names
|
||||
std::vector<std::string> graph_output_names;
|
||||
std::unordered_set<std::string> graph_output_names;
|
||||
for (const auto* output_arg : graph_viewer.GetOutputs()) {
|
||||
graph_output_names.push_back(output_arg->Name());
|
||||
graph_output_names.insert(output_arg->Name());
|
||||
}
|
||||
|
||||
// Find inputs and outputs of the subgraph
|
||||
|
|
@ -1084,26 +1090,35 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
|
|||
int input_order = 0;
|
||||
int output_order = 0;
|
||||
|
||||
std::vector<std::string> initializers;
|
||||
for (const auto& index : graph_nodes_index) {
|
||||
sub_graph->Nodes().push_back(index);
|
||||
const auto& node = graph_viewer.GetNode(index);
|
||||
for (const auto& input : node->InputDefs()) {
|
||||
if (graph_viewer.IsConstantInitializer(input->Name(), true)) {
|
||||
initializers.push_back(input->Name());
|
||||
continue;
|
||||
}
|
||||
const auto& it = fused_outputs.find(input);
|
||||
if (it != fused_outputs.end()) {
|
||||
fused_outputs.erase(it);
|
||||
erased.insert(input);
|
||||
} else if (erased.find(input) == erased.end() && !graph_viewer.GetAllInitializedTensors().count(input->Name())) {
|
||||
} else if (erased.find(input) == erased.end()) {
|
||||
// Only when input is neither in output list nor erased list, add the input to input list
|
||||
fused_inputs[input] = input_order++;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& input : node->ImplicitInputDefs()) {
|
||||
if (graph_viewer.IsConstantInitializer(input->Name(), true)) {
|
||||
initializers.push_back(input->Name());
|
||||
continue;
|
||||
}
|
||||
const auto& it = fused_outputs.find(input);
|
||||
if (it != fused_outputs.end()) {
|
||||
fused_outputs.erase(it);
|
||||
erased.insert(input);
|
||||
} else if (erased.find(input) == erased.end() && !graph_viewer.GetAllInitializedTensors().count(input->Name())) {
|
||||
} else if (erased.find(input) == erased.end()) {
|
||||
// Only when input is neither in output list nor erased list, add the input to input list
|
||||
fused_inputs[input] = input_order++;
|
||||
}
|
||||
|
|
@ -1118,15 +1133,20 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
|
|||
if (node->GetOutputEdgesCount() > node->OutputDefs().size()) {
|
||||
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
|
||||
const auto& node_idx = it->GetNode().Index();
|
||||
const auto& output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()];
|
||||
const onnxruntime::NodeArg* output;
|
||||
if (it->GetDstArgIndex() < static_cast<int>(it->GetNode().InputDefs().size())) {
|
||||
output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()];
|
||||
} else {
|
||||
auto index = it->GetDstArgIndex() - static_cast<int>(it->GetNode().InputDefs().size());
|
||||
output = (it->GetNode()).ImplicitInputDefs()[index];
|
||||
}
|
||||
if (node_set.find(node_idx) != node_set.end()) {
|
||||
const auto& iter = fused_inputs.find(output);
|
||||
if (iter != fused_inputs.end()) {
|
||||
fused_inputs.erase(iter);
|
||||
erased.insert(output);
|
||||
} else if (erased.find(output) == erased.end()) {
|
||||
auto it = std::find(graph_output_names.begin(), graph_output_names.end(), output->Name());
|
||||
if (it != graph_output_names.end()) {
|
||||
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
|
||||
graph_outputs_to_add[output] = output_order;
|
||||
}
|
||||
fused_outputs[output] = output_order++;
|
||||
|
|
@ -1144,8 +1164,7 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
|
|||
} else {
|
||||
// Only when output is neither in input list nor erased list, add the output to output list
|
||||
if (erased.find(output) == erased.end()) {
|
||||
auto it = std::find(graph_output_names.begin(), graph_output_names.end(), output->Name());
|
||||
if (it != graph_output_names.end()) {
|
||||
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
|
||||
graph_outputs_to_add[output] = output_order;
|
||||
}
|
||||
fused_outputs[output] = output_order++;
|
||||
|
|
@ -1168,27 +1187,6 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
|
|||
outputs.insert(std::pair<int, const NodeArg*>(it->second, it->first));
|
||||
}
|
||||
|
||||
// It is possible that an output of an node is put bebind the output of an later
|
||||
// node in the graph output list. So we should sort the output name according
|
||||
// to the graph output names
|
||||
std::vector<std::string> output_names;
|
||||
std::unordered_set<std::string> graph_out_names;
|
||||
for (const auto& output : outputs) {
|
||||
if (output.second->Exists()) {
|
||||
auto name = output.second->Name();
|
||||
if (std::find(graph_output_names.begin(), graph_output_names.end(), name) == graph_output_names.end()) {
|
||||
output_names.push_back(name);
|
||||
} else {
|
||||
graph_out_names.insert(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& name : graph_output_names) {
|
||||
if (std::find(graph_out_names.begin(), graph_out_names.end(), name) != graph_out_names.end())
|
||||
output_names.push_back(name);
|
||||
}
|
||||
|
||||
// Generate unique kernel name for CANN subgraph
|
||||
HashValue model_hash = 0;
|
||||
int id = GenerateMetaDefId(graph_viewer, model_hash);
|
||||
|
|
@ -1202,8 +1200,14 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
|
|||
}
|
||||
}
|
||||
|
||||
for (const auto& output : output_names) {
|
||||
meta_def->outputs().push_back(output);
|
||||
for (const auto& initializer : initializers) {
|
||||
meta_def->constant_initializers().push_back(initializer);
|
||||
}
|
||||
|
||||
for (const auto& output : outputs) {
|
||||
if (output.second->Exists()) {
|
||||
meta_def->outputs().push_back(output.second->Name());
|
||||
}
|
||||
}
|
||||
|
||||
meta_def->domain() = kMSDomain;
|
||||
|
|
@ -1306,13 +1310,13 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
|
|||
|
||||
const std::string node_name = fused_node.Name();
|
||||
|
||||
std::unordered_map<size_t, std::string> names2index;
|
||||
std::unordered_map<size_t, std::string> index2name;
|
||||
const auto& input_defs = fused_node.InputDefs();
|
||||
names2index.reserve(input_defs.size());
|
||||
index2name.reserve(input_defs.size());
|
||||
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
|
||||
names2index[i] = input_defs[i]->Name();
|
||||
index2name[i] = input_defs[i]->Name();
|
||||
}
|
||||
names_[node_name] = names2index;
|
||||
names_[node_name] = index2name;
|
||||
|
||||
std::string string_model;
|
||||
auto model = cann::CreateModel(graph_body_viewer, *GetLogger());
|
||||
|
|
@ -1322,6 +1326,11 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
|
|||
model_proto->SerializeToString(string_model);
|
||||
models_[node_name] = string_model;
|
||||
|
||||
if (info_.dump_graphs) {
|
||||
std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
model_proto->SerializeToOstream(dump);
|
||||
}
|
||||
|
||||
NodeComputeInfo compute_info;
|
||||
compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
|
||||
std::unique_ptr<CannFuncState> p = std::make_unique<CannFuncState>();
|
||||
|
|
@ -1335,20 +1344,19 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
|
|||
delete static_cast<CannFuncState*>(state);
|
||||
};
|
||||
|
||||
compute_info.compute_func = [this](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) {
|
||||
compute_info.compute_func = [this](FunctionState state, const OrtApi*, OrtKernelContext* context) {
|
||||
Ort::KernelContext ctx(context);
|
||||
|
||||
CannFuncState* cann_state = reinterpret_cast<CannFuncState*>(state);
|
||||
std::string& string_model = models_[cann_state->node_name];
|
||||
std::unordered_map<size_t, std::string>& names2index = names_[cann_state->node_name];
|
||||
std::unordered_map<size_t, std::string>& index2name = names_[cann_state->node_name];
|
||||
|
||||
std::string input_shape = [&ctx, &names2index]() -> std::string {
|
||||
std::string input_shape = [&ctx, &index2name]() -> std::string {
|
||||
std::string res;
|
||||
for (size_t i = 0; i < ctx.GetInputCount(); i++) {
|
||||
auto&& shape = ctx.GetInput(i).GetTensorTypeAndShapeInfo().GetShape();
|
||||
auto name = names2index[i];
|
||||
|
||||
std::string s = name + ":";
|
||||
std::string s = index2name[i] + ":";
|
||||
for (auto& d : shape) {
|
||||
s += std::to_string(d) + ",";
|
||||
}
|
||||
|
|
@ -1366,8 +1374,13 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
|
|||
std::string filename = cann_state->node_name + "_" + std::to_string(hash);
|
||||
std::string filename_with_suffix = filename + ".om";
|
||||
|
||||
// TODO(FFFrog): Resource Management
|
||||
// It is very necessary to provide a new mechanism for memory reclamation to avoid inference failure caused by
|
||||
// device memory exhaustion
|
||||
uint32_t modelID;
|
||||
{
|
||||
if (modelIDs_.find(filename) != modelIDs_.end()) {
|
||||
modelID = modelIDs_[filename];
|
||||
} else {
|
||||
std::lock_guard<OrtMutex> lock(g_mutex);
|
||||
|
||||
if (cann::FileExist(filename_with_suffix)) {
|
||||
|
|
@ -1377,10 +1390,12 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
|
|||
ORT_RETURN_IF_ERROR(ParserONNXModel(string_model, graph));
|
||||
|
||||
ge::ModelBufferData model;
|
||||
ORT_RETURN_IF_ERROR(BuildONNXModel(graph, input_shape, soc_name_, filename, model));
|
||||
ORT_RETURN_IF_ERROR(BuildONNXModel(graph, input_shape, soc_name_, filename, info_, model));
|
||||
|
||||
CANN_RETURN_IF_ERROR(aclmdlLoadFromMem(model.data.get(), model.length, &modelID));
|
||||
}
|
||||
|
||||
modelIDs_.emplace(filename, modelID);
|
||||
}
|
||||
|
||||
CannModelPreparation prepare(modelID);
|
||||
|
|
@ -1407,7 +1422,8 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what());
|
||||
}
|
||||
|
||||
CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream_));
|
||||
aclrtStream stream = static_cast<aclrtStream>(ctx.GetGPUComputeStream());
|
||||
CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream));
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
|
@ -1436,7 +1452,12 @@ void CANNExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage
|
|||
true,
|
||||
{info_.default_memory_arena_cfg ? *info_.default_memory_arena_cfg
|
||||
: OrtArenaCfg(info_.npu_mem_limit,
|
||||
static_cast<int>(info_.arena_extend_strategy), -1, -1, -1)});
|
||||
static_cast<int>(info_.arena_extend_strategy),
|
||||
-1,
|
||||
-1,
|
||||
-1)},
|
||||
true,
|
||||
false);
|
||||
|
||||
cann_alloc = CreateAllocator(default_memory_info);
|
||||
allocator_manager.InsertAllocator(cann_alloc);
|
||||
|
|
@ -1484,4 +1505,8 @@ void CANNExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage
|
|||
}
|
||||
}
|
||||
|
||||
void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const {
|
||||
RegisterCannStreamHandles(stream_handle_registry, OrtDevice::NPU);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -36,36 +36,33 @@ class CANNExecutionProvider : public IExecutionProvider {
|
|||
|
||||
Status OnRunStart() override;
|
||||
|
||||
Status OnRunEnd(bool sync_stream) override;
|
||||
|
||||
void* GetComputeStream() const override { return static_cast<void*>(stream_); }
|
||||
|
||||
template <typename T>
|
||||
IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes) const {
|
||||
IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes, Stream* stream, WaitNotificationFn wait_fn) const {
|
||||
if (count_or_bytes == 0)
|
||||
return nullptr;
|
||||
|
||||
return IAllocator::MakeUniquePtr<T>(GetAllocator(OrtMemTypeDefault), count_or_bytes);
|
||||
return IAllocator::MakeUniquePtr<T>(GetAllocator(OrtMemTypeDefault), count_or_bytes, false, stream, wait_fn);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
IAllocatorUniquePtr<T> GetScratchBufferOnCANNPinned(size_t count_or_bytes) const {
|
||||
if (count_or_bytes == 0)
|
||||
return nullptr;
|
||||
return IAllocator::MakeUniquePtr<T>(GetAllocator(OrtMemTypeCPU),
|
||||
count_or_bytes);
|
||||
|
||||
return IAllocator::MakeUniquePtr<T>(GetAllocator(OrtMemTypeCPU), count_or_bytes);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Fill(Tensor* y, void* addr) const {
|
||||
return cann::Fill<T>(y, addr, stream_);
|
||||
Status Fill(Tensor* y, void* addr, aclrtStream stream) const {
|
||||
return cann::Fill<T>(y, addr, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Broadcast(const Tensor* x, Tensor* y, void* addr) const {
|
||||
return cann::Broadcast<T>(x, y, addr, stream_);
|
||||
Status Broadcast(const Tensor* x, Tensor* y, void* addr, aclrtStream stream) const {
|
||||
return cann::Broadcast<T>(x, y, addr, stream);
|
||||
}
|
||||
|
||||
int GetDeviceId() const override { return info_.device_id; }
|
||||
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
|
||||
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;
|
||||
|
||||
|
|
@ -86,11 +83,13 @@ class CANNExecutionProvider : public IExecutionProvider {
|
|||
|
||||
void RegisterAllocator(AllocatorManager& allocator_manager) override;
|
||||
|
||||
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const override;
|
||||
|
||||
private:
|
||||
CANNExecutionProviderInfo info_;
|
||||
aclrtStream stream_ = nullptr;
|
||||
const char* soc_name_ = nullptr;
|
||||
|
||||
std::unordered_map<std::string, uint32_t> modelIDs_;
|
||||
std::unordered_map<std::string, std::string> models_;
|
||||
std::unordered_map<std::string, std::unordered_map<std::size_t, std::string>> names_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -19,8 +19,11 @@ namespace provider_option_names {
|
|||
constexpr const char* kDeviceId = "device_id";
|
||||
constexpr const char* kMemLimit = "npu_mem_limit";
|
||||
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
|
||||
constexpr const char* kDoCopyInDefaultStream = "do_copy_in_default_stream";
|
||||
constexpr const char* kEnableCannGraph = "enable_cann_graph";
|
||||
constexpr const char* kDumpGraphs = "dump_graphs";
|
||||
constexpr const char* kPrecisionMode = "precision_mode";
|
||||
constexpr const char* kOpSelectImplMode = "op_select_impl_mode";
|
||||
constexpr const char* kOpTypeListForImplMode = "optypelist_for_implmode";
|
||||
} // namespace provider_option_names
|
||||
} // namespace cann
|
||||
|
||||
|
|
@ -53,8 +56,11 @@ CANNExecutionProviderInfo CANNExecutionProviderInfo::FromProviderOptions(const P
|
|||
.AddAssignmentToEnumReference(
|
||||
cann::provider_option_names::kArenaExtendStrategy,
|
||||
arena_extend_strategy_mapping, info.arena_extend_strategy)
|
||||
.AddAssignmentToReference(cann::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream)
|
||||
.AddAssignmentToReference(cann::provider_option_names::kEnableCannGraph, info.enable_cann_graph)
|
||||
.AddAssignmentToReference(cann::provider_option_names::kDumpGraphs, info.dump_graphs)
|
||||
.AddAssignmentToReference(cann::provider_option_names::kPrecisionMode, info.precision_mode)
|
||||
.AddAssignmentToReference(cann::provider_option_names::kOpSelectImplMode, info.op_select_impl_mode)
|
||||
.AddAssignmentToReference(cann::provider_option_names::kOpTypeListForImplMode, info.optypelist_for_implmode)
|
||||
.Parse(options));
|
||||
return info;
|
||||
}
|
||||
|
|
@ -65,9 +71,11 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const CANNExecution
|
|||
{cann::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.npu_mem_limit)},
|
||||
{cann::provider_option_names::kArenaExtendStrategy,
|
||||
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
|
||||
{cann::provider_option_names::kDoCopyInDefaultStream,
|
||||
MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
|
||||
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}};
|
||||
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)},
|
||||
{cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)},
|
||||
{cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)},
|
||||
{cann::provider_option_names::kOpSelectImplMode, MakeStringWithClassicLocale(info.op_select_impl_mode)},
|
||||
{cann::provider_option_names::kOpTypeListForImplMode, MakeStringWithClassicLocale(info.optypelist_for_implmode)}};
|
||||
return options;
|
||||
}
|
||||
|
||||
|
|
@ -77,9 +85,11 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const OrtCANNProvid
|
|||
{cann::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.npu_mem_limit)},
|
||||
{cann::provider_option_names::kArenaExtendStrategy,
|
||||
EnumToName(arena_extend_strategy_mapping, ArenaExtendStrategy(info.arena_extend_strategy))},
|
||||
{cann::provider_option_names::kDoCopyInDefaultStream,
|
||||
MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
|
||||
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}};
|
||||
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)},
|
||||
{cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)},
|
||||
{cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)},
|
||||
{cann::provider_option_names::kOpSelectImplMode, MakeStringWithClassicLocale(info.op_select_impl_mode)},
|
||||
{cann::provider_option_names::kOpTypeListForImplMode, MakeStringWithClassicLocale(info.optypelist_for_implmode)}};
|
||||
return options;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
#include "core/framework/arena_extend_strategy.h"
|
||||
#include "core/framework/ortdevice.h"
|
||||
|
|
@ -16,8 +17,11 @@ struct CANNExecutionProviderInfo {
|
|||
OrtDevice::DeviceId device_id{0};
|
||||
size_t npu_mem_limit{std::numeric_limits<size_t>::max()};
|
||||
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
|
||||
bool do_copy_in_default_stream{true};
|
||||
bool enable_cann_graph{true};
|
||||
bool dump_graphs{false};
|
||||
std::string precision_mode;
|
||||
std::string op_select_impl_mode;
|
||||
std::string optypelist_for_implmode;
|
||||
OrtArenaCfg* default_memory_arena_cfg{nullptr};
|
||||
|
||||
static CANNExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
|
||||
|
|
|
|||
|
|
@ -1,62 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Huawei. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/providers/cann/cann_call.h"
|
||||
#include "core/providers/cann/npu_data_transfer.h"
|
||||
#include "core/providers/cann/cann_fence.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
CANNFence::CANNFence(const NPUDataTransfer* data_transfer) : data_transfer_(data_transfer) {
|
||||
CANN_CALL_THROW(aclrtCreateEvent(&read_event_));
|
||||
CANN_CALL_THROW(aclrtCreateEvent(&write_event_));
|
||||
}
|
||||
|
||||
CANNFence::~CANNFence() {
|
||||
CANN_CALL_THROW(aclrtDestroyEvent(read_event_));
|
||||
CANN_CALL_THROW(aclrtDestroyEvent(write_event_));
|
||||
}
|
||||
|
||||
void CANNFence::BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int async_queue_id) {
|
||||
if (provider_type == onnxruntime::kCannExecutionProvider) {
|
||||
CANN_CALL_THROW(aclrtStreamWaitEvent(data_transfer_->GetStream(async_queue_id), write_event_));
|
||||
} else {
|
||||
CANN_CALL_THROW(aclrtSynchronizeEvent(write_event_));
|
||||
}
|
||||
}
|
||||
|
||||
void CANNFence::BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) {
|
||||
if (provider_type == onnxruntime::kCannExecutionProvider) {
|
||||
aclrtStream stream = data_transfer_->GetStream(queue_id);
|
||||
CANN_CALL_THROW(aclrtStreamWaitEvent(stream, read_event_));
|
||||
CANN_CALL_THROW(aclrtStreamWaitEvent(stream, write_event_));
|
||||
} else {
|
||||
CANN_CALL_THROW(aclrtSynchronizeEvent(read_event_));
|
||||
CANN_CALL_THROW(aclrtSynchronizeEvent(write_event_));
|
||||
}
|
||||
}
|
||||
|
||||
bool CANNFence::CanRelease() {
|
||||
aclrtEventRecordedStatus read_status;
|
||||
aclrtEventRecordedStatus write_status;
|
||||
|
||||
return aclrtQueryEventStatus(read_event_, &read_status) == ACL_SUCCESS &&
|
||||
aclrtQueryEventStatus(write_event_, &write_status) == ACL_SUCCESS &&
|
||||
read_status == ACL_EVENT_RECORDED_STATUS_COMPLETE &&
|
||||
write_status == ACL_EVENT_RECORDED_STATUS_COMPLETE;
|
||||
}
|
||||
|
||||
void CANNFence::AfterUsedAsInput(int queue_id) {
|
||||
aclrtStream stream = data_transfer_->GetStream(queue_id);
|
||||
CANN_CALL_THROW(aclrtRecordEvent(read_event_, stream));
|
||||
}
|
||||
|
||||
void CANNFence::AfterUsedAsOutput(int queue_id) {
|
||||
aclrtStream stream = data_transfer_->GetStream(queue_id);
|
||||
CANN_CALL_THROW(aclrtRecordEvent(write_event_, stream));
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Huawei. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/fence.h"
|
||||
#include "core/providers/cann/cann_inc.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class NPUDataTransfer;
|
||||
|
||||
class CANNFence : public IFence {
|
||||
public:
|
||||
explicit CANNFence(const NPUDataTransfer* data_transfer);
|
||||
virtual ~CANNFence();
|
||||
void BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int queue_id) override;
|
||||
void BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) override;
|
||||
void AfterUsedAsInput(int queue_id) override;
|
||||
void AfterUsedAsOutput(int queue_id) override;
|
||||
bool CanRelease() override;
|
||||
|
||||
private:
|
||||
aclrtEvent read_event_;
|
||||
aclrtEvent write_event_;
|
||||
const NPUDataTransfer* data_transfer_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -13,6 +13,8 @@ namespace cann {
|
|||
static int lower_bound = 8; // Supported domain version lower bounds
|
||||
static int upper_bound = 15; // Supported domain version upper bounds
|
||||
|
||||
std::once_flag flag;
|
||||
|
||||
/**
|
||||
* This function will been changed with the evolution of ONNX and CANN
|
||||
* and will be replaced by the corresponding API provided by CANN in the future, probably.
|
||||
|
|
@ -32,7 +34,7 @@ std::vector<NodeIndex> SupportONNXModel(const GraphViewer& graph_viewer) {
|
|||
"GatherElements", "GatherND", "Gemm", "GlobalAveragePool",
|
||||
"GlobalLpPool", "GlobalMaxPool", "Greater", "GreaterOrEqual",
|
||||
"Hardmax", "HardSigmoid", "HardSwish", "Identity",
|
||||
"If", "InstanceNormalization", "LeakyRelu", "Less",
|
||||
"InstanceNormalization", "LeakyRelu", "Less",
|
||||
"LessOrEqual", "Log", "LogSoftmax", "LpNormalization",
|
||||
"LpPool", "LRN", "LSTM", "MatMul",
|
||||
"Max", "MaxPool", "MaxRoiPool", "MaxUnpool",
|
||||
|
|
@ -94,20 +96,27 @@ Status ParserONNXModel(std::string string_model, ge::Graph& graph) {
|
|||
}
|
||||
|
||||
Status BuildONNXModel(ge::Graph& graph, std::string input_shape, const char* soc_name, std::string file_name,
|
||||
ge::ModelBufferData& model) {
|
||||
CANNExecutionProviderInfo& info, ge::ModelBufferData& model) {
|
||||
std::call_once(flag, [&soc_name, &info]() {
|
||||
std::map<ge::AscendString, ge::AscendString> options;
|
||||
options.emplace(ge::ir_option::SOC_VERSION, soc_name);
|
||||
|
||||
if (!info.precision_mode.empty())
|
||||
options.emplace(ge::ir_option::PRECISION_MODE, info.precision_mode.c_str());
|
||||
if (!info.op_select_impl_mode.empty())
|
||||
options.emplace(ge::ir_option::OP_SELECT_IMPL_MODE, info.op_select_impl_mode.c_str());
|
||||
if (!info.optypelist_for_implmode.empty())
|
||||
options.emplace(ge::ir_option::OPTYPELIST_FOR_IMPLMODE, info.optypelist_for_implmode.c_str());
|
||||
|
||||
CANN_CALL_THROW(ge::aclgrphBuildInitialize(options));
|
||||
});
|
||||
|
||||
std::map<ge::AscendString, ge::AscendString> options;
|
||||
|
||||
options.emplace(ge::ir_option::SOC_VERSION, soc_name);
|
||||
CANN_GRAPH_RETURN_IF_ERROR(ge::aclgrphBuildInitialize(options));
|
||||
|
||||
options.clear();
|
||||
options.emplace(ge::ir_option::INPUT_SHAPE, input_shape.c_str());
|
||||
CANN_GRAPH_RETURN_IF_ERROR(ge::aclgrphBuildModel(graph, options, model));
|
||||
|
||||
CANN_GRAPH_RETURN_IF_ERROR(ge::aclgrphSaveModel(file_name.c_str(), model));
|
||||
|
||||
ge::aclgrphBuildFinalize();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/providers/cann/cann_common.h"
|
||||
#include "core/providers/cann/cann_inc.h"
|
||||
#include "core/providers/cann/cann_utils.h"
|
||||
#include "core/providers/cann/cann_execution_provider_info.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cann {
|
||||
|
|
@ -72,7 +73,7 @@ struct CannModelPreparation {
|
|||
std::vector<NodeIndex> SupportONNXModel(const GraphViewer& graph_viewer);
|
||||
Status ParserONNXModel(std::string string_model, ge::Graph& graph);
|
||||
Status BuildONNXModel(ge::Graph& graph, std::string input_shape, const char* soc_name, std::string file_name,
|
||||
ge::ModelBufferData& model);
|
||||
CANNExecutionProviderInfo& info, ge::ModelBufferData& model);
|
||||
|
||||
} // namespace cann
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,11 +4,13 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/platform/ort_mutex.h"
|
||||
#include "core/providers/cann/cann_inc.h"
|
||||
#include "core/providers/cann/cann_call.h"
|
||||
#include "core/providers/cann/cann_execution_provider.h"
|
||||
#include "core/providers/cann/cann_fwd.h"
|
||||
#include "core/providers/cann/cann_utils.h"
|
||||
#include "core/providers/cann/cann_stream_handle.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cann {
|
||||
|
|
@ -35,11 +37,14 @@ class CannKernel : public OpKernel {
|
|||
|
||||
virtual Status ComputeInternal(OpKernelContext* p_op_kernel_context) const = 0;
|
||||
|
||||
inline aclrtStream Stream() const { return static_cast<aclrtStream>(provider_->GetComputeStream()); }
|
||||
inline aclrtStream Stream(OpKernelContext* ctx) const {
|
||||
auto* stream = ctx->GetComputeStream();
|
||||
return stream ? static_cast<aclrtStream>(stream->GetHandle()) : nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes) const {
|
||||
return provider_->GetScratchBuffer<T>(count_or_bytes);
|
||||
inline IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes, onnxruntime::Stream* stream) const {
|
||||
return provider_->GetScratchBuffer<T>(count_or_bytes, stream, WaitCannNotificationOnDevice);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -48,13 +53,13 @@ class CannKernel : public OpKernel {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
inline Status Fill(Tensor* y, void* addr) const {
|
||||
return provider_->Fill<T>(y, addr);
|
||||
inline Status Fill(Tensor* y, void* addr, aclrtStream stream) const {
|
||||
return provider_->Fill<T>(y, addr, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Status Broadcast(const Tensor* x, Tensor* y, void* addr) const {
|
||||
return provider_->Broadcast<T>(x, y, addr);
|
||||
inline Status Broadcast(const Tensor* x, Tensor* y, void* addr, aclrtStream stream) const {
|
||||
return provider_->Broadcast<T>(x, y, addr, stream);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
|||
|
|
@ -53,8 +53,11 @@ struct CANN_Provider : Provider {
|
|||
info.device_id = static_cast<OrtDevice::DeviceId>(params->device_id);
|
||||
info.npu_mem_limit = params->npu_mem_limit;
|
||||
info.arena_extend_strategy = params->arena_extend_strategy;
|
||||
info.do_copy_in_default_stream = params->do_copy_in_default_stream != 0;
|
||||
info.enable_cann_graph = params->enable_cann_graph != 0;
|
||||
info.dump_graphs = params->dump_graphs != 0;
|
||||
info.precision_mode = params->precision_mode;
|
||||
info.op_select_impl_mode = params->op_select_impl_mode;
|
||||
info.optypelist_for_implmode = params->optypelist_for_implmode;
|
||||
info.default_memory_arena_cfg = params->default_memory_arena_cfg;
|
||||
|
||||
return std::make_shared<CANNProviderFactory>(info);
|
||||
|
|
@ -67,8 +70,11 @@ struct CANN_Provider : Provider {
|
|||
cann_options.device_id = internal_options.device_id;
|
||||
cann_options.npu_mem_limit = internal_options.npu_mem_limit;
|
||||
cann_options.arena_extend_strategy = internal_options.arena_extend_strategy;
|
||||
cann_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream;
|
||||
cann_options.enable_cann_graph = internal_options.enable_cann_graph;
|
||||
cann_options.dump_graphs = internal_options.dump_graphs;
|
||||
cann_options.precision_mode = internal_options.precision_mode;
|
||||
cann_options.op_select_impl_mode = internal_options.op_select_impl_mode;
|
||||
cann_options.optypelist_for_implmode = internal_options.optypelist_for_implmode;
|
||||
cann_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg;
|
||||
}
|
||||
|
||||
|
|
|
|||
81
onnxruntime/core/providers/cann/cann_stream_handle.cc
Normal file
81
onnxruntime/core/providers/cann/cann_stream_handle.cc
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Huawei. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cann/cann_stream_handle.h"
|
||||
#include "core/common/spin_pause.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct CannNotification : public synchronize::Notification {
|
||||
explicit CannNotification(Stream& s) : Notification(s) {
|
||||
CANN_CALL_THROW(aclrtCreateEvent(&event_));
|
||||
}
|
||||
|
||||
~CannNotification() {
|
||||
if (event_)
|
||||
CANN_CALL_THROW(aclrtDestroyEvent(event_));
|
||||
}
|
||||
|
||||
void Activate() override {
|
||||
CANN_CALL_THROW(aclrtRecordEvent(event_, static_cast<aclrtStream>(stream_.GetHandle())));
|
||||
}
|
||||
|
||||
void wait_on_device(Stream& device_stream) {
|
||||
ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::NPU);
|
||||
CANN_CALL_THROW(aclrtStreamWaitEvent(static_cast<aclrtStream>(device_stream.GetHandle()), event_));
|
||||
}
|
||||
|
||||
void wait_on_host() {
|
||||
CANN_CALL_THROW(aclrtSynchronizeEvent(event_));
|
||||
}
|
||||
|
||||
aclrtEvent event_;
|
||||
};
|
||||
|
||||
CannStream::CannStream(aclrtStream stream,
|
||||
const OrtDevice& device,
|
||||
bool own_flag) : Stream(stream, device),
|
||||
own_stream_(own_flag) {}
|
||||
|
||||
CannStream::~CannStream() {
|
||||
ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd());
|
||||
if (own_stream_) {
|
||||
auto* handle = GetHandle();
|
||||
if (handle)
|
||||
aclrtDestroyStream(static_cast<aclrtStream>(handle));
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<synchronize::Notification> CannStream::CreateNotification(size_t /*num_consumers*/) {
|
||||
return std::make_unique<CannNotification>(*this);
|
||||
}
|
||||
|
||||
void CannStream::Flush() {
|
||||
if (own_stream_)
|
||||
CANN_CALL_THROW(aclrtSynchronizeStream(static_cast<aclrtStream>(GetHandle())));
|
||||
}
|
||||
|
||||
// CPU Stream command handles
|
||||
void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification) {
|
||||
static_cast<CannNotification*>(¬ification)->wait_on_device(stream);
|
||||
}
|
||||
|
||||
void WaitCannNotificationOnHost(Stream& /*stream*/, synchronize::Notification& notification) {
|
||||
static_cast<CannNotification*>(¬ification)->wait_on_host();
|
||||
}
|
||||
|
||||
void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
|
||||
const OrtDevice::DeviceType device_type) {
|
||||
// wait cann notification on cann ep
|
||||
stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitCannNotificationOnDevice);
|
||||
// wait cann notification on cpu ep
|
||||
stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitCannNotificationOnHost);
|
||||
stream_handle_registry.RegisterCreateStreamFn(device_type, [](const OrtDevice& device) {
|
||||
aclrtStream stream = nullptr;
|
||||
CANN_CALL_THROW(aclrtCreateStream(&stream));
|
||||
return std::make_unique<CannStream>(stream, device, true);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
32
onnxruntime/core/providers/cann/cann_stream_handle.h
Normal file
32
onnxruntime/core/providers/cann/cann_stream_handle.h
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Huawei. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "core/framework/stream_handles.h"
|
||||
#include "core/providers/cann/cann_inc.h"
|
||||
#include "core/providers/cann/cann_call.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct CannStream : Stream {
|
||||
CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag);
|
||||
|
||||
~CannStream();
|
||||
|
||||
std::unique_ptr<synchronize::Notification> CreateNotification(size_t /*num_consumers*/) override;
|
||||
|
||||
void Flush() override;
|
||||
|
||||
bool own_stream_{true};
|
||||
};
|
||||
|
||||
void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
|
||||
const OrtDevice::DeviceType device_type);
|
||||
|
||||
void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
|
||||
#include "core/providers/cann/cann_utils.h"
|
||||
|
||||
|
|
@ -223,5 +224,34 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) {
|
|||
hash_value = hash[0] | (uint64_t(hash[1]) << 32);
|
||||
}
|
||||
|
||||
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
|
||||
const TensorShape& rhs_shape, TensorShape& out_shape) {
|
||||
size_t lhs_rank = lhs_shape.NumDimensions();
|
||||
size_t rhs_rank = rhs_shape.NumDimensions();
|
||||
size_t out_rank = std::max(lhs_rank, rhs_rank);
|
||||
|
||||
std::vector<int64_t> output_dims(out_rank, 0);
|
||||
for (size_t i = 0; i < out_rank; ++i) {
|
||||
int64_t lhs_dim = 1;
|
||||
if (i < lhs_rank)
|
||||
lhs_dim = lhs_shape[lhs_rank - 1 - i];
|
||||
int64_t rhs_dim = 1;
|
||||
if (i < rhs_rank)
|
||||
rhs_dim = rhs_shape[rhs_rank - 1 - i];
|
||||
int64_t max = std::max(lhs_dim, rhs_dim);
|
||||
int64_t min = std::min(lhs_dim, rhs_dim);
|
||||
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
|
||||
if (lhs_dim != out_dim && lhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
if (rhs_dim != out_dim && rhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
output_dims[out_rank - 1 - i] = out_dim;
|
||||
}
|
||||
out_shape = TensorShape(output_dims);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cann
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -124,6 +124,9 @@ Status aclrtblasGemmEx(aclTransType transA,
|
|||
|
||||
bool FileExist(const std::string& file_name);
|
||||
void GenerateHashValue(const std::string string, HashValue& hash_value);
|
||||
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
|
||||
const TensorShape& rhs_shape, TensorShape& out_shape);
|
||||
|
||||
std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger);
|
||||
|
||||
} // namespace cann
|
||||
|
|
|
|||
|
|
@ -11,35 +11,6 @@ using onnxruntime::common::Status;
|
|||
namespace onnxruntime {
|
||||
namespace cann {
|
||||
|
||||
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
|
||||
const TensorShape& rhs_shape, TensorShape& out_shape) {
|
||||
size_t lhs_rank = lhs_shape.NumDimensions();
|
||||
size_t rhs_rank = rhs_shape.NumDimensions();
|
||||
size_t out_rank = std::max(lhs_rank, rhs_rank);
|
||||
|
||||
std::vector<int64_t> output_dims(out_rank, 0);
|
||||
for (size_t i = 0; i < out_rank; ++i) {
|
||||
int64_t lhs_dim = 1;
|
||||
if (i < lhs_rank)
|
||||
lhs_dim = lhs_shape[lhs_rank - 1 - i];
|
||||
int64_t rhs_dim = 1;
|
||||
if (i < rhs_rank)
|
||||
rhs_dim = rhs_shape[rhs_rank - 1 - i];
|
||||
int64_t max = std::max(lhs_dim, rhs_dim);
|
||||
int64_t min = std::min(lhs_dim, rhs_dim);
|
||||
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
|
||||
if (lhs_dim != out_dim && lhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
if (rhs_dim != out_dim && rhs_dim != 1)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
|
||||
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
|
||||
output_dims[out_rank - 1 - i] = out_dim;
|
||||
}
|
||||
out_shape = TensorShape(output_dims);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare) const {
|
||||
const aclDataType aclType = getACLType<T>();
|
||||
|
|
@ -56,14 +27,14 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
|
|||
void* B_data = const_cast<void*>(B->DataRaw());
|
||||
|
||||
if (A->Shape() != C->Shape()) {
|
||||
IAllocatorUniquePtr<void> pA = GetScratchBuffer<void>(C->SizeInBytes());
|
||||
ORT_RETURN_IF_ERROR(Broadcast<T>(A, C, pA.get()));
|
||||
IAllocatorUniquePtr<void> pA = GetScratchBuffer<void>(C->SizeInBytes(), ctx->GetComputeStream());
|
||||
ORT_RETURN_IF_ERROR(Broadcast<T>(A, C, pA.get(), Stream(ctx)));
|
||||
A_data = pA.get();
|
||||
}
|
||||
|
||||
if (B->Shape() != C->Shape()) {
|
||||
IAllocatorUniquePtr<void> pB = GetScratchBuffer<void>(C->SizeInBytes());
|
||||
ORT_RETURN_IF_ERROR(Broadcast<T>(B, C, pB.get()));
|
||||
IAllocatorUniquePtr<void> pB = GetScratchBuffer<void>(C->SizeInBytes(), ctx->GetComputeStream());
|
||||
ORT_RETURN_IF_ERROR(Broadcast<T>(B, C, pB.get(), Stream(ctx)));
|
||||
B_data = pB.get();
|
||||
}
|
||||
|
||||
|
|
@ -85,9 +56,9 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
|
|||
|
||||
#define REGISTER_ELEMENTWISE_TYPED_COMPUTE(x, T) \
|
||||
template <> \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* ctx) const { \
|
||||
CannPreparation prepare; \
|
||||
ORT_RETURN_IF_ERROR(Prepare<T>(context, prepare)); \
|
||||
ORT_RETURN_IF_ERROR(Prepare<T>(ctx, prepare)); \
|
||||
CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \
|
||||
prepare.inputDesc_.size(), \
|
||||
prepare.inputDesc_.data(), \
|
||||
|
|
@ -99,7 +70,7 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
|
|||
ACL_ENGINE_SYS, \
|
||||
ACL_COMPILE_SYS, \
|
||||
NULL, \
|
||||
Stream())); \
|
||||
Stream(ctx))); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,28 +24,28 @@ template <typename T>
|
|||
class Add final : public BinaryElementwise {
|
||||
public:
|
||||
Add(const OpKernelInfo& info) : BinaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Sub final : public BinaryElementwise {
|
||||
public:
|
||||
Sub(const OpKernelInfo& info) : BinaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Mul final : public BinaryElementwise {
|
||||
public:
|
||||
Mul(const OpKernelInfo& info) : BinaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Div final : public BinaryElementwise {
|
||||
public:
|
||||
Div(const OpKernelInfo& info) : BinaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
} // namespace cann
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ namespace onnxruntime {
|
|||
namespace cann {
|
||||
|
||||
template <typename T>
|
||||
Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const auto* A = context->Input<Tensor>(0);
|
||||
const auto* B = context->Input<Tensor>(1);
|
||||
const auto* C = context->Input<Tensor>(2);
|
||||
Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const auto* A = ctx->Input<Tensor>(0);
|
||||
const auto* B = ctx->Input<Tensor>(1);
|
||||
const auto* C = ctx->Input<Tensor>(2);
|
||||
|
||||
GemmHelper helper(A->Shape(), trans_A_, B->Shape(), trans_B_, C != nullptr ? C->Shape() : TensorShape({}));
|
||||
if (!helper.State().IsOK())
|
||||
|
|
@ -24,13 +24,13 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
int N = gsl::narrow_cast<int>(helper.N());
|
||||
int K = gsl::narrow_cast<int>(helper.K());
|
||||
|
||||
auto* Y = context->Output(0, {M, N});
|
||||
auto* Y = ctx->Output(0, {M, N});
|
||||
|
||||
// broadcast C if needed.
|
||||
if (beta_ != 0 && C != nullptr) {
|
||||
if (C->Shape().Size() == 1) {
|
||||
// C is (), (1,) or (1, 1), fill the scalar to Y
|
||||
ORT_RETURN_IF_ERROR(Fill<T>(Y, const_cast<void*>(C->DataRaw())));
|
||||
ORT_RETURN_IF_ERROR(Fill<T>(Y, const_cast<void*>(C->DataRaw()), Stream(ctx)));
|
||||
} else if (C->Shape() == Y->Shape()) {
|
||||
// C is (M, N), no broadcast needed.
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(Y->MutableDataRaw(),
|
||||
|
|
@ -38,10 +38,10 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const_cast<void*>(C->DataRaw()),
|
||||
Y->SizeInBytes(),
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
} else {
|
||||
// others, broadcast needed.
|
||||
ORT_RETURN_IF_ERROR(Broadcast<T>(C, Y, Y->MutableDataRaw()));
|
||||
ORT_RETURN_IF_ERROR(Broadcast<T>(C, Y, Y->MutableDataRaw(), Stream(ctx)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -49,8 +49,8 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
T alpha = ToCannType<T>::FromFloat(alpha_);
|
||||
T beta = ToCannType<T>::FromFloat(beta_);
|
||||
IAllocatorUniquePtr<void> pAlpha = GetScratchBuffer<void>(sizeof(T));
|
||||
IAllocatorUniquePtr<void> pBeta = GetScratchBuffer<void>(sizeof(T));
|
||||
IAllocatorUniquePtr<void> pAlpha = GetScratchBuffer<void>(sizeof(T), ctx->GetComputeStream());
|
||||
IAllocatorUniquePtr<void> pBeta = GetScratchBuffer<void>(sizeof(T), ctx->GetComputeStream());
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpy(pAlpha.get(), sizeof(T), &alpha, sizeof(T), ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpy(pBeta.get(), sizeof(T), &beta, sizeof(T), ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
pBeta.get(),
|
||||
Y->MutableDataRaw(), -1, aclType,
|
||||
ACL_COMPUTE_HIGH_PRECISION,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class Gemm final : public CannKernel {
|
|||
ORT_ENFORCE(info.GetAttr<float>("beta", &beta_).IsOK());
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
bool trans_A_;
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
ACL_ENGINE_SYS,
|
||||
ACL_COMPILE_SYS,
|
||||
NULL,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class MatMul final : public CannKernel {
|
|||
MatMul(const OpKernelInfo& info)
|
||||
: CannKernel(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
} // namespace cann
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ Status UnaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare)
|
|||
|
||||
#define REGISTER_ELEMENTWISE_TYPED_COMPUTE(x, T) \
|
||||
template <> \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* ctx) const { \
|
||||
CannPreparation prepare; \
|
||||
ORT_RETURN_IF_ERROR(Prepare<T>(context, prepare)); \
|
||||
ORT_RETURN_IF_ERROR(Prepare<T>(ctx, prepare)); \
|
||||
CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \
|
||||
prepare.inputDesc_.size(), \
|
||||
prepare.inputDesc_.data(), \
|
||||
|
|
@ -52,7 +52,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare)
|
|||
ACL_ENGINE_SYS, \
|
||||
ACL_COMPILE_SYS, \
|
||||
NULL, \
|
||||
Stream())); \
|
||||
Stream(ctx))); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,84 +29,84 @@ template <typename T>
|
|||
class Abs final : public UnaryElementwise {
|
||||
public:
|
||||
Abs(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Neg final : public UnaryElementwise {
|
||||
public:
|
||||
Neg(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Floor final : public UnaryElementwise {
|
||||
public:
|
||||
Floor(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Ceil final : public UnaryElementwise {
|
||||
public:
|
||||
Ceil(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Reciprocal final : public UnaryElementwise {
|
||||
public:
|
||||
Reciprocal(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Sqrt final : public UnaryElementwise {
|
||||
public:
|
||||
Sqrt(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Log final : public UnaryElementwise {
|
||||
public:
|
||||
Log(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Exp final : public UnaryElementwise {
|
||||
public:
|
||||
Exp(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Erf final : public UnaryElementwise {
|
||||
public:
|
||||
Erf(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Round final : public UnaryElementwise {
|
||||
public:
|
||||
Round(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Sin final : public UnaryElementwise {
|
||||
public:
|
||||
Sin(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Cos final : public UnaryElementwise {
|
||||
public:
|
||||
Cos(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
} // namespace cann
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ namespace onnxruntime {
|
|||
namespace cann {
|
||||
|
||||
template <typename T>
|
||||
Status AveragePool<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
Status AveragePool<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* X = ctx->Input<Tensor>(0);
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
const auto X_dims = X_shape.GetDims();
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ Status AveragePool<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
auto Y_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads);
|
||||
TensorShape Y_shape(Y_dims);
|
||||
Tensor* Y = context->Output(0, Y_shape);
|
||||
Tensor* Y = ctx->Output(0, Y_shape);
|
||||
if (Y_shape.Size() == 0)
|
||||
return Status::OK();
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ Status AveragePool<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
ACL_ENGINE_SYS,
|
||||
ACL_COMPILE_SYS,
|
||||
NULL,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class AveragePool : public CannKernel, public PoolBase {
|
|||
public:
|
||||
explicit AveragePool(const OpKernelInfo& info) : CannKernel(info), PoolBase(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
} // namespace cann
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@ Status BatchNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
// There is only one output in inference mode
|
||||
Tensor* Y = ctx->Output(0, X->Shape());
|
||||
|
||||
IAllocatorUniquePtr<void> pbatch_mean = GetScratchBuffer<void>(mean->SizeInBytes());
|
||||
IAllocatorUniquePtr<void> pbatch_variance = GetScratchBuffer<void>(var->SizeInBytes());
|
||||
IAllocatorUniquePtr<void> preserver_space_1 = GetScratchBuffer<void>(mean->SizeInBytes());
|
||||
IAllocatorUniquePtr<void> preserver_space_2 = GetScratchBuffer<void>(var->SizeInBytes());
|
||||
IAllocatorUniquePtr<void> pbatch_mean = GetScratchBuffer<void>(mean->SizeInBytes(), ctx->GetComputeStream());
|
||||
IAllocatorUniquePtr<void> pbatch_variance = GetScratchBuffer<void>(var->SizeInBytes(), ctx->GetComputeStream());
|
||||
IAllocatorUniquePtr<void> preserver_space_1 = GetScratchBuffer<void>(mean->SizeInBytes(), ctx->GetComputeStream());
|
||||
IAllocatorUniquePtr<void> preserver_space_2 = GetScratchBuffer<void>(var->SizeInBytes(), ctx->GetComputeStream());
|
||||
|
||||
const aclDataType aclType = getACLType<T>();
|
||||
aclFormat format = ACL_FORMAT_NCHW;
|
||||
|
|
@ -76,7 +76,7 @@ Status BatchNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
ACL_ENGINE_SYS,
|
||||
ACL_COMPILE_SYS,
|
||||
NULL,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class BatchNorm final : public CannKernel {
|
|||
ORT_ENFORCE(!is_training_mode_, "only supports inference mode");
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
float epsilon_;
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ Status Conv<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
ACL_ENGINE_SYS,
|
||||
ACL_COMPILE_SYS,
|
||||
NULL,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,17 +26,17 @@ float GetRatioOrDefault(const Tensor* ratio) {
|
|||
} // namespace
|
||||
|
||||
template <typename T1, typename T2>
|
||||
Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* X = ctx->Input<Tensor>(0);
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
|
||||
const Tensor* ratio = context->Input<Tensor>(1);
|
||||
const Tensor* ratio = ctx->Input<Tensor>(1);
|
||||
const float ratio_value = GetRatioOrDefault<T2>(ratio);
|
||||
|
||||
const Tensor* training_mode = context->Input<Tensor>(2);
|
||||
const Tensor* training_mode = ctx->Input<Tensor>(2);
|
||||
|
||||
auto Y = context->Output(0, X_shape);
|
||||
auto mask = context->Output(1, X_shape);
|
||||
auto Y = ctx->Output(0, X_shape);
|
||||
auto mask = ctx->Output(1, X_shape);
|
||||
|
||||
if (ratio_value == 0.f || !training_mode || !(*(training_mode->Data<bool>()))) {
|
||||
const void* X_data = X->DataRaw();
|
||||
|
|
@ -44,22 +44,22 @@ Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
if (Y_data != X_data) {
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(Y_data, Y->SizeInBytes(), X_data, Y->SizeInBytes(),
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream()));
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream(ctx)));
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
CANN_RETURN_IF_ERROR(aclrtMemsetAsync(mask->MutableData<bool>(), mask->SizeInBytes(), true,
|
||||
mask->SizeInBytes(), Stream()));
|
||||
mask->SizeInBytes(), Stream(ctx)));
|
||||
}
|
||||
} else {
|
||||
IAllocatorUniquePtr<void> pmask{};
|
||||
IAllocatorUniquePtr<void> pseed = GetScratchBuffer<void>(sizeof(float));
|
||||
IAllocatorUniquePtr<void> pseed = GetScratchBuffer<void>(sizeof(float), ctx->GetComputeStream());
|
||||
|
||||
void* mask_data = nullptr;
|
||||
if (mask) {
|
||||
mask_data = mask->MutableDataRaw();
|
||||
} else {
|
||||
pmask = GetScratchBuffer<void>(X_shape.Size() * sizeof(bool));
|
||||
pmask = GetScratchBuffer<void>(X_shape.Size() * sizeof(bool), ctx->GetComputeStream());
|
||||
mask_data = pmask.get();
|
||||
}
|
||||
|
||||
|
|
@ -107,7 +107,7 @@ Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
|
|||
ACL_ENGINE_SYS,
|
||||
ACL_COMPILE_SYS,
|
||||
NULL,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class Dropout final : public CannKernel {
|
|||
}
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
mutable std::unique_ptr<RandomGenerator> generator_;
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ namespace onnxruntime {
|
|||
namespace cann {
|
||||
|
||||
template <typename T>
|
||||
Status MaxPool<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
Status MaxPool<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* X = ctx->Input<Tensor>(0);
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
const auto X_dims = X_shape.GetDims();
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ Status MaxPool<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
auto Y_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads);
|
||||
TensorShape Y_shape(Y_dims);
|
||||
Tensor* Y = context->Output(0, Y_shape);
|
||||
Tensor* Y = ctx->Output(0, Y_shape);
|
||||
if (Y_shape.Size() == 0)
|
||||
return Status::OK();
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ Status MaxPool<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
ACL_ENGINE_SYS,
|
||||
ACL_COMPILE_SYS,
|
||||
NULL,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class MaxPool : public CannKernel, public PoolBase {
|
|||
public:
|
||||
explicit MaxPool(const OpKernelInfo& info) : CannKernel(info), PoolBase(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
||||
} // namespace cann
|
||||
|
|
|
|||
|
|
@ -4,35 +4,62 @@
|
|||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/providers/cann/npu_data_transfer.h"
|
||||
#include "core/providers/cann/cann_call.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
NPUDataTransfer::NPUDataTransfer(aclrtStream stream, bool do_copy_in_default_stream) {
|
||||
do_copy_in_default_stream_ = do_copy_in_default_stream;
|
||||
streams_[kCannStreamDefault] = stream;
|
||||
if (do_copy_in_default_stream) {
|
||||
streams_[kCannStreamCopyIn] = stream;
|
||||
streams_[kCannStreamCopyOut] = stream;
|
||||
} else {
|
||||
CANN_CALL_THROW(aclrtCreateStream(&streams_[kCannStreamCopyIn]));
|
||||
CANN_CALL_THROW(aclrtCreateStream(&streams_[kCannStreamCopyOut]));
|
||||
}
|
||||
}
|
||||
NPUDataTransfer::NPUDataTransfer() {}
|
||||
|
||||
NPUDataTransfer::~NPUDataTransfer() {
|
||||
if (!do_copy_in_default_stream_ && streams_[kCannStreamCopyIn] != nullptr) {
|
||||
CANN_CALL_THROW(aclrtDestroyStream(streams_[kCannStreamCopyIn]));
|
||||
}
|
||||
if (!do_copy_in_default_stream_ && streams_[kCannStreamCopyOut] != nullptr) {
|
||||
CANN_CALL_THROW(aclrtDestroyStream(streams_[kCannStreamCopyOut]));
|
||||
}
|
||||
}
|
||||
NPUDataTransfer::~NPUDataTransfer() {}
|
||||
|
||||
bool NPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
|
||||
return src_device.Type() == OrtDevice::NPU || dst_device.Type() == OrtDevice::NPU;
|
||||
}
|
||||
|
||||
common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const {
|
||||
common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
|
||||
size_t bytes = src.SizeInBytes();
|
||||
const void* src_data = src.DataRaw();
|
||||
void* dst_data = dst.MutableDataRaw();
|
||||
|
||||
auto& src_device = src.Location().device;
|
||||
auto& dst_device = dst.Location().device;
|
||||
|
||||
// for the sync version of memcpy, launch to cann default stream
|
||||
if (dst_device.Type() == OrtDevice::NPU) {
|
||||
if (src_device.Type() == OrtDevice::NPU) {
|
||||
// Copy only if the two addresses are different.
|
||||
if (dst_data != src_data) {
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data,
|
||||
bytes,
|
||||
src_data,
|
||||
bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE));
|
||||
CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr));
|
||||
}
|
||||
} else {
|
||||
// copy from other CPU memory to NPU, this is blocking
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data,
|
||||
bytes,
|
||||
src_data,
|
||||
bytes,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr));
|
||||
}
|
||||
} else if (src_device.Type() == OrtDevice::NPU) {
|
||||
// copying from NPU to CPU memory, this is blocking
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data,
|
||||
bytes,
|
||||
src_data,
|
||||
bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST));
|
||||
CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr));
|
||||
} else {
|
||||
// copying between cpu memory
|
||||
memcpy(dst_data, src_data, bytes);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status NPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const {
|
||||
size_t bytes = src.SizeInBytes();
|
||||
const void* src_data = src.DataRaw();
|
||||
void* dst_data = dst.MutableDataRaw();
|
||||
|
|
@ -41,30 +68,44 @@ common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int e
|
|||
auto& dst_device = dst.Location().device;
|
||||
|
||||
if (dst_device.Type() == OrtDevice::NPU) {
|
||||
if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::CANN_PINNED) {
|
||||
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE, GetStream(exec_queue_id)));
|
||||
if (src_device.Type() == OrtDevice::CPU) {
|
||||
// copy from pinned memory to NPU, this is non-blocking
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data,
|
||||
bytes,
|
||||
src_data,
|
||||
bytes,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE,
|
||||
static_cast<aclrtStream>(stream.GetHandle())));
|
||||
} else if (src_device.Type() == OrtDevice::NPU) {
|
||||
// copying between NPU, this is non-blocking
|
||||
if (dst_data != src_data) {
|
||||
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, GetStream(kCannStreamDefault)));
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data,
|
||||
bytes,
|
||||
src_data,
|
||||
bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE,
|
||||
static_cast<aclrtStream>(stream.GetHandle())));
|
||||
}
|
||||
} else {
|
||||
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE, GetStream(kCannStreamDefault)));
|
||||
CANN_CALL_THROW(aclrtSynchronizeStream(GetStream(kCannStreamDefault)));
|
||||
}
|
||||
} else if (src_device.Type() == OrtDevice::NPU) {
|
||||
if (dst_device.Type() == OrtDevice::CPU) {
|
||||
// copying from NPU to pinned memory, this is non-blocking
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data,
|
||||
bytes,
|
||||
src_data,
|
||||
bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST,
|
||||
static_cast<aclrtStream>(stream.GetHandle())));
|
||||
}
|
||||
} else {
|
||||
if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::CANN_PINNED) {
|
||||
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST, GetStream(exec_queue_id)));
|
||||
} else {
|
||||
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST, GetStream(kCannStreamDefault)));
|
||||
CANN_CALL_THROW(aclrtSynchronizeStream(GetStream(kCannStreamDefault)));
|
||||
if (src_device.MemType() == OrtDevice::MemType::CANN_PINNED) {
|
||||
// sync the stream first to make sure the data arrived
|
||||
CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(static_cast<aclrtStream>(stream.GetHandle())));
|
||||
}
|
||||
memcpy(dst_data, src_data, bytes);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,35 +4,23 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "cann_inc.h"
|
||||
#include "core/framework/data_transfer.h"
|
||||
#include "core/providers/cann/cann_inc.h"
|
||||
#include "core/providers/cann/cann_call.h"
|
||||
#include "core/providers/cann/cann_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
enum CANNStreamType : int {
|
||||
kCannStreamDefault = 0,
|
||||
kCannStreamCopyIn,
|
||||
kCannStreamCopyOut,
|
||||
kTotalCannStreams,
|
||||
};
|
||||
|
||||
class NPUDataTransfer : public IDataTransfer {
|
||||
public:
|
||||
explicit NPUDataTransfer(aclrtStream stream, bool do_copy_in_default_stream = true);
|
||||
NPUDataTransfer();
|
||||
~NPUDataTransfer();
|
||||
|
||||
bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;
|
||||
|
||||
common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const override;
|
||||
common::Status CopyTensor(const Tensor& src, Tensor& dst) const override;
|
||||
|
||||
aclrtStream GetStream(int queue_id) const {
|
||||
ORT_ENFORCE(queue_id >= 0 && queue_id < kTotalCannStreams);
|
||||
return streams_[queue_id];
|
||||
}
|
||||
|
||||
private:
|
||||
bool do_copy_in_default_stream_;
|
||||
aclrtStream streams_[kTotalCannStreams];
|
||||
common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -42,10 +42,10 @@ aclDataType getACLTypeByMap(ONNX_NAMESPACE::TensorProto_DataType type) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
Status Cast<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
Status Cast<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* X = ctx->Input<Tensor>(0);
|
||||
|
||||
Tensor* Y = context->Output(0, X->Shape());
|
||||
Tensor* Y = ctx->Output(0, X->Shape());
|
||||
|
||||
aclFormat format = ACL_FORMAT_ND;
|
||||
const aclDataType aclTypeX = getACLType<T>();
|
||||
|
|
@ -78,7 +78,7 @@ Status Cast<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
ACL_ENGINE_SYS,
|
||||
ACL_COMPILE_SYS,
|
||||
NULL,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class Cast final : public CannKernel {
|
|||
to_ = gsl::narrow_cast<ONNX_NAMESPACE::TensorProto_DataType>(to);
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
ONNX_NAMESPACE::TensorProto_DataType to_;
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ Status Flatten<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
ACL_ENGINE_SYS,
|
||||
ACL_COMPILE_SYS,
|
||||
NULL,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class Flatten final : public CannKernel {
|
|||
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK());
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
int64_t axis_;
|
||||
|
|
|
|||
|
|
@ -18,16 +18,16 @@ class IdentityOp final : public CannKernel {
|
|||
IdentityOp(const OpKernelInfo& info) : CannKernel(info) {
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
auto X_ml_type = context->InputType(0);
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
auto X_ml_type = ctx->InputType(0);
|
||||
if (X_ml_type->IsTensorType()) {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const Tensor* X = ctx->Input<Tensor>(0);
|
||||
if (nullptr == X) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"IdentityOp cann: input count mismatch.");
|
||||
}
|
||||
const TensorShape& shape = X->Shape();
|
||||
Tensor* Y = context->Output(0, shape);
|
||||
Tensor* Y = ctx->Output(0, shape);
|
||||
if (nullptr == Y) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"IdentityOp cann: failed to allocate output tensor.");
|
||||
|
|
@ -39,20 +39,20 @@ class IdentityOp final : public CannKernel {
|
|||
if (target != source) {
|
||||
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(target, Y->SizeInBytes(), source,
|
||||
X->Shape().Size() * X->DataType()->Size(),
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream()));
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream(ctx)));
|
||||
}
|
||||
|
||||
if (is_dropout) {
|
||||
Tensor* mask = context->Output(1, shape);
|
||||
Tensor* mask = ctx->Output(1, shape);
|
||||
if (mask != nullptr) {
|
||||
void* mask_data = mask->MutableDataRaw();
|
||||
CANN_RETURN_IF_ERROR(aclrtMemsetAsync(mask_data, mask->SizeInBytes(), 0, mask->SizeInBytes(), Stream()));
|
||||
CANN_RETURN_IF_ERROR(aclrtMemsetAsync(mask_data, mask->SizeInBytes(), 0, mask->SizeInBytes(), Stream(ctx)));
|
||||
}
|
||||
}
|
||||
} else if (X_ml_type->IsTensorSequenceType()) {
|
||||
const TensorSeq* X = context->Input<TensorSeq>(0);
|
||||
const TensorSeq* X = ctx->Input<TensorSeq>(0);
|
||||
ORT_ENFORCE(X != nullptr, "IdentityOp cann: input tensor is missing.");
|
||||
TensorSeq* Y = context->Output<TensorSeq>(0);
|
||||
TensorSeq* Y = ctx->Output<TensorSeq>(0);
|
||||
ORT_ENFORCE(Y != nullptr, "IdentityOp cann: failed to allocate output tensor sequence.");
|
||||
if (X == Y) {
|
||||
return Status::OK();
|
||||
|
|
@ -60,7 +60,7 @@ class IdentityOp final : public CannKernel {
|
|||
auto X_type = X->DataType();
|
||||
Y->SetType(X_type);
|
||||
AllocatorPtr alloc;
|
||||
auto status = context->GetTempSpaceAllocator(&alloc);
|
||||
auto status = ctx->GetTempSpaceAllocator(&alloc);
|
||||
if (!status.IsOK()) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"IdentityOp cann: unable to get an allocator.");
|
||||
|
|
@ -75,7 +75,7 @@ class IdentityOp final : public CannKernel {
|
|||
target_tensor->SizeInBytes(),
|
||||
source_tensor.DataRaw(),
|
||||
source_tensor.SizeInBytes(),
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream()));
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream(ctx)));
|
||||
Y->Add(std::move(*target_tensor));
|
||||
}
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ class Reshape final : public CannKernel {
|
|||
allow_zero_ = (info.GetAttrOrDefault("allowzero", static_cast<int64_t>(0)) == 1);
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
const Tensor* shapeTensor = context->Input<Tensor>(1);
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
const Tensor* shapeTensor = ctx->Input<Tensor>(1);
|
||||
if (shapeTensor == nullptr)
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "the 0th input is missing");
|
||||
if (shapeTensor->Shape().NumDimensions() != 1)
|
||||
|
|
@ -28,14 +28,14 @@ class Reshape final : public CannKernel {
|
|||
auto data_span = shapeTensor->template DataAsSpan<int64_t>();
|
||||
TensorShapeVector shape(data_span.begin(), data_span.end());
|
||||
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const Tensor* X = ctx->Input<Tensor>(0);
|
||||
if (X == nullptr)
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "the 1th input is missing");
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
|
||||
ReshapeHelper helper(X_shape, shape, allow_zero_);
|
||||
|
||||
Tensor* Y = context->Output(0, TensorShape(shape));
|
||||
Tensor* Y = ctx->Output(0, TensorShape(shape));
|
||||
const void* source = X->DataRaw();
|
||||
void* target = Y->MutableDataRaw();
|
||||
if (target != source) {
|
||||
|
|
@ -56,14 +56,14 @@ class Reshape_1 final : public CannKernel {
|
|||
ORT_ENFORCE(status.IsOK(), "Attribute shape is not set.");
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
TensorShapeVector shape = shape_;
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const Tensor* X = ctx->Input<Tensor>(0);
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
|
||||
ReshapeHelper helper(X_shape, shape);
|
||||
|
||||
Tensor* Y = context->Output(0, TensorShape(shape));
|
||||
Tensor* Y = ctx->Output(0, TensorShape(shape));
|
||||
const void* source = X->DataRaw();
|
||||
void* target = Y->MutableDataRaw();
|
||||
if (target != source) {
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
ACL_ENGINE_SYS,
|
||||
ACL_COMPILE_SYS,
|
||||
NULL,
|
||||
Stream()));
|
||||
Stream(ctx)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -690,6 +690,10 @@ void RefCountTracker::DumpDetails(const std::string& phase_name) const {
|
|||
|
||||
#if defined(USE_CANN)
|
||||
RandomGenerator& RandomGenerator::Default() { return g_host->RandomGenerator__Default(); }
|
||||
void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream,
|
||||
WaitNotificationFn wait_fn) {
|
||||
return g_host->Allocator__AllocateBufferWithOptions(allocator, size, use_reserve, stream, wait_fn);
|
||||
}
|
||||
|
||||
namespace cann {
|
||||
std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger) {
|
||||
|
|
|
|||
|
|
@ -1847,8 +1847,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCANNProviderOptions, _Outptr_ OrtCANNProvider
|
|||
(*out)->device_id = 0;
|
||||
(*out)->npu_mem_limit = SIZE_MAX;
|
||||
(*out)->arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(0);
|
||||
(*out)->do_copy_in_default_stream = 1;
|
||||
(*out)->enable_cann_graph = 1;
|
||||
(*out)->dump_graphs = 0;
|
||||
(*out)->default_memory_arena_cfg = nullptr;
|
||||
return nullptr;
|
||||
#else
|
||||
|
|
|
|||
Loading…
Reference in a new issue