[CANN] Multi-stream execution support for CANN EP. (#14058)

### Description
**Multi-stream** execution support for **CANN EP**.

### Motivation and Context
**CANN EP** is currently **unavailable** due to the introduction of a
new mechanism for multi-stream execution
[#13495](https://github.com/microsoft/onnxruntime/pull/13495), the
deletion of the Fence-based synchronization mechanism, and the failure
to update the relevant logic of **CANN EP** synchronously.

This PR is to fix it.
This commit is contained in:
FFFrog 2023-03-30 02:57:22 +08:00 committed by GitHub
parent febc69e1b2
commit ecb89ed752
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
49 changed files with 560 additions and 439 deletions

View file

@ -4,6 +4,8 @@
#pragma once
#include <string>
#include "onnxruntime_c_api.h"
#include "core/framework/arena_extend_strategy.h"
@ -11,9 +13,13 @@ struct OrtCANNProviderOptions {
int device_id; // CANN device id
size_t npu_mem_limit; // BFC Arena memory limit for CANN
onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena
int do_copy_in_default_stream; // Flag indicating if copying needs to take place on the
// same stream as the compute stream in the CANN EP
int enable_cann_graph; // Flag indicating if prioritizing the use of
// CANN's graph-running capabilities
int dump_graphs; // Flag indicating if dumping graphs
std::string precision_mode; // Operator Precision Mode
std::string op_select_impl_mode; // Operator-level model compilation options:
// Mode selection
std::string optypelist_for_implmode; // Operator-level model compilation options:
// Operator list
OrtArenaCfg* default_memory_arena_cfg; // CANN memory arena configuration parameters
};

View file

@ -181,6 +181,19 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess
return;
}
#ifdef USE_CANN
// For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool,
// which is different from CUDA Runtime API, but similar to CUDA Driver API.
auto& execution_providers = ctx.GetSessionState().GetExecutionProviders();
for (auto& xp : execution_providers) {
auto status = xp->OnRunStart();
if (!status.IsOK()) {
ctx.SetStatus(status);
return;
}
}
#endif
// get logic stream
auto& execution_plan = ctx.GetSessionState().GetExecutionPlan()->execution_plan;
auto& logic_stream = execution_plan[stream_idx];

View file

@ -32,9 +32,9 @@ Status Activations::Prepare(OpKernelContext* ctx, CannPreparation& prepare) cons
#define REGISTER_ACTIVATION_TYPED_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
Status x<T>::ComputeInternal(OpKernelContext* ctx) const { \
CannPreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare<T>(context, prepare)); \
ORT_RETURN_IF_ERROR(Prepare<T>(ctx, prepare)); \
CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \
prepare.inputDesc_.size(), \
prepare.inputDesc_.data(), \
@ -46,7 +46,7 @@ Status Activations::Prepare(OpKernelContext* ctx, CannPreparation& prepare) cons
ACL_ENGINE_SYS, \
ACL_COMPILE_SYS, \
NULL, \
Stream())); \
Stream(ctx))); \
return Status::OK(); \
}

View file

@ -23,7 +23,7 @@ template <typename T>
class Relu final : public Activations {
public:
Relu(const OpKernelInfo& info) : Activations(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace cann

View file

@ -7,18 +7,10 @@
#include "core/providers/cann/cann_call.h"
#include "core/providers/cann/cann_allocator.h"
#include "core/framework/allocatormgr.h"
#include "core/providers/cann/cann_fence.h"
#include "core/providers/cann/npu_data_transfer.h"
namespace onnxruntime {
static const NPUDataTransfer* GetNPUDataTransfer(const SessionState* session_state) {
OrtDevice npu_device(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, 0);
OrtDevice cpu_device;
return static_cast<const NPUDataTransfer*>(
session_state->GetDataTransferMgr().GetDataTransfer(npu_device, cpu_device));
}
void* CANNAllocator::Alloc(size_t size) {
void* p = nullptr;
aclrtMemMallocPolicy policy = ACL_MEM_MALLOC_HUGE_FIRST;
@ -32,10 +24,6 @@ void CANNAllocator::Free(void* p) {
aclrtFree(p);
}
FencePtr CANNAllocator::CreateFence(const SessionState* session_state) {
return std::make_shared<CANNFence>(GetNPUDataTransfer(session_state));
}
void* CANNPinnedAllocator::Alloc(size_t size) {
void* p = nullptr;
if (size > 0) {
@ -48,8 +36,4 @@ void CANNPinnedAllocator::Free(void* p) {
CANN_CALL_THROW(aclrtFreeHost(p));
}
FencePtr CANNPinnedAllocator::CreateFence(const SessionState* session_state) {
return std::make_shared<CANNFence>(GetNPUDataTransfer(session_state));
}
} // namespace onnxruntime

View file

@ -19,7 +19,6 @@ class CANNAllocator : public IAllocator {
device_id, OrtMemTypeDefault)) {}
void* Alloc(size_t size) override;
void Free(void* p) override;
FencePtr CreateFence(const SessionState* session_state) override;
};
class CANNPinnedAllocator : public IAllocator {
@ -32,7 +31,6 @@ class CANNPinnedAllocator : public IAllocator {
void* Alloc(size_t size) override;
void Free(void* p) override;
FencePtr CreateFence(const SessionState* session_state) override;
};
} // namespace onnxruntime

View file

@ -101,6 +101,8 @@ template <>
const char* CannErrString<ge::graphStatus>(ge::graphStatus e) {
using namespace ge;
aclrtSynchronizeDevice();
switch (e) {
CASE_ENUM_TO_STR(GRAPH_FAILED);
CASE_ENUM_TO_STR(GRAPH_SUCCESS);

View file

@ -16,14 +16,14 @@
#include "core/providers/cann/cann_inc.h"
#include "core/providers/cann/cann_call.h"
#include "core/providers/cann/cann_allocator.h"
#include "core/providers/cann/cann_fence.h"
#include "core/providers/cann/cann_fwd.h"
#include "core/providers/cann/cann_stream_handle.h"
#include "core/providers/cann/npu_data_transfer.h"
using onnxruntime::cann::BuildONNXModel;
using onnxruntime::cann::CannModelPreparation;
using onnxruntime::cann::ParserONNXModel;
using onnxruntime::cann::SupportONNXModel;
using onnxruntime::cann::CannModelPreparation;
using onnxruntime::common::Status;
namespace onnxruntime {
@ -42,49 +42,63 @@ class Memcpy final : public OpKernel {
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr.");
Tensor* Y = ctx->Output(0, X->Shape());
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor.");
return Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
} else if (X_type->IsSparseTensorType()) {
const auto* X = ctx->Input<SparseTensor>(0);
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr.");
SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape());
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor.");
return X->Copy(Info().GetDataTransferManager(), Info().GetKernelDef().ExecQueueId(), *Y);
} else if (X_type->IsTensorSequenceType()) {
const TensorSeq* X = ctx->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr.");
TensorSeq* Y = ctx->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence.");
auto X_dtype = X->DataType();
Y->SetType(X_dtype);
AllocatorPtr alloc;
if (Node().OpType() == "MemcpyFromHost") {
auto status = ctx->GetTempSpaceAllocator(&alloc);
if (!status.IsOK()) {
return Status(common::ONNXRUNTIME, common::FAIL,
"Memcpy cann: unable to get an allocator.");
}
} else {
auto status = ctx->GetTempSpaceCPUAllocator(&alloc);
if (!status.IsOK()) {
return Status(common::ONNXRUNTIME, common::FAIL,
"Memcpy cann: unable to get the CPU allocator.");
}
}
auto X_size = X->Size();
for (size_t i = 0; i < X_size; ++i) {
const Tensor& source_tensor = X->Get(i);
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc);
Status retval = Info().GetDataTransferManager().CopyTensor(source_tensor, *target_tensor,
Info().GetKernelDef().ExecQueueId());
if (!retval.IsOK()) {
return retval;
}
Y->Add(std::move(*target_tensor));
}
auto* npu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device,
Y->Location().device);
ORT_RETURN_IF_ERROR(npu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream()));
return Status::OK();
} else {
if (X_type->IsSparseTensorType()) {
aclrtSynchronizeStream(static_cast<aclrtStream>(ctx->GetComputeStream()->GetHandle()));
const auto* X = ctx->Input<SparseTensor>(0);
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr.");
SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape());
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor.");
return X->Copy(Info().GetDataTransferManager(), *Y);
} else if (X_type->IsTensorSequenceType()) {
const TensorSeq* X = ctx->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr.");
TensorSeq* Y = ctx->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence.");
auto X_dtype = X->DataType();
Y->SetType(X_dtype);
AllocatorPtr alloc;
// If we are copying contents to CANN, the allocator to use
// to allocate the buffers of the new tensors in the sequence
// can be temp space allocator associated with the CANN EP
if (Node().OpType() == "MemcpyFromHost") {
auto status = ctx->GetTempSpaceAllocator(&alloc);
if (!status.IsOK()) {
return Status(common::ONNXRUNTIME, common::FAIL,
"Memcpy cann: unable to get an allocator.");
}
} else {
// If we are copying contents to CPU (op type is "MemcpyToHost"),
// the allocator to use to allocate the buffers of the new tensors
// in the sequence will be the allocator from the CPU EP
auto status = ctx->GetTempSpaceCPUAllocator(&alloc);
if (!status.IsOK()) {
return Status(common::ONNXRUNTIME, common::FAIL,
"Memcpy cann: unable to get the CPU allocator.");
}
}
auto X_size = X->Size();
for (size_t i = 0; i < X_size; ++i) {
const Tensor& source_tensor = X->Get(i);
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(),
source_tensor.Shape(),
alloc);
auto* npu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device,
target_tensor->Location().device);
ORT_RETURN_IF_ERROR(npu_data_transfer->CopyTensorAsync(source_tensor,
*target_tensor,
*ctx->GetComputeStream()));
Y->Add(std::move(*target_tensor));
}
return Status::OK();
}
return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type.");
}
return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type.");
}
};
@ -97,7 +111,6 @@ ONNX_OPERATOR_KERNEL_EX(
kCannExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0)
.ExecQueueId(kCannStreamCopyIn)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()),
Memcpy);
@ -108,7 +121,6 @@ ONNX_OPERATOR_KERNEL_EX(
kCannExecutionProvider,
(*KernelDefBuilder::Create())
.OutputMemoryType(OrtMemTypeCPUOutput, 0)
.ExecQueueId(kCannStreamCopyOut)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()),
Memcpy);
@ -1013,27 +1025,20 @@ CANNExecutionProvider::CANNExecutionProvider(const CANNExecutionProviderInfo& in
InitProviderOrtApi();
CANN_CALL_THROW(aclrtSetDevice(info_.device_id));
CANN_CALL_THROW(aclrtCreateStream(&stream_));
soc_name_ = aclrtGetSocName();
ORT_ENFORCE(soc_name_ != nullptr, "aclrtGetSocName return nullptr");
}
CANNExecutionProvider::~CANNExecutionProvider() {
CANN_CALL_THROW(aclrtDestroyStream(stream_));
for (auto modelID : modelIDs_) {
CANN_CALL_THROW(aclmdlUnload(modelID.second));
}
}
// All threads share the same context and stream
Status CANNExecutionProvider::OnRunStart() {
CANN_CALL_THROW(aclrtSetDevice(info_.device_id));
return Status::OK();
}
Status CANNExecutionProvider::OnRunEnd(bool sync_stream) {
if (sync_stream) {
CANN_CALL_THROW(aclrtSynchronizeStream(stream_));
}
CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id));
return Status::OK();
}
@ -1050,6 +1055,8 @@ void InitializeRegistry() {
void DeleteRegistry() {
s_kernel_registry.reset();
ge::aclgrphBuildFinalize();
CANN_CALL_THROW(aclFinalize());
}
@ -1058,8 +1065,7 @@ std::shared_ptr<KernelRegistry> CANNExecutionProvider::GetKernelRegistry() const
}
std::unique_ptr<onnxruntime::IDataTransfer> CANNExecutionProvider::GetDataTransfer() const {
return std::make_unique<onnxruntime::NPUDataTransfer>(static_cast<aclrtStream>(GetComputeStream()),
info_.do_copy_in_default_stream);
return std::make_unique<onnxruntime::NPUDataTransfer>();
}
std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
@ -1072,9 +1078,9 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
}
// Get parent graph output names
std::vector<std::string> graph_output_names;
std::unordered_set<std::string> graph_output_names;
for (const auto* output_arg : graph_viewer.GetOutputs()) {
graph_output_names.push_back(output_arg->Name());
graph_output_names.insert(output_arg->Name());
}
// Find inputs and outputs of the subgraph
@ -1084,26 +1090,35 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
int input_order = 0;
int output_order = 0;
std::vector<std::string> initializers;
for (const auto& index : graph_nodes_index) {
sub_graph->Nodes().push_back(index);
const auto& node = graph_viewer.GetNode(index);
for (const auto& input : node->InputDefs()) {
if (graph_viewer.IsConstantInitializer(input->Name(), true)) {
initializers.push_back(input->Name());
continue;
}
const auto& it = fused_outputs.find(input);
if (it != fused_outputs.end()) {
fused_outputs.erase(it);
erased.insert(input);
} else if (erased.find(input) == erased.end() && !graph_viewer.GetAllInitializedTensors().count(input->Name())) {
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
}
}
for (const auto& input : node->ImplicitInputDefs()) {
if (graph_viewer.IsConstantInitializer(input->Name(), true)) {
initializers.push_back(input->Name());
continue;
}
const auto& it = fused_outputs.find(input);
if (it != fused_outputs.end()) {
fused_outputs.erase(it);
erased.insert(input);
} else if (erased.find(input) == erased.end() && !graph_viewer.GetAllInitializedTensors().count(input->Name())) {
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
}
@ -1118,15 +1133,20 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
if (node->GetOutputEdgesCount() > node->OutputDefs().size()) {
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
const auto& node_idx = it->GetNode().Index();
const auto& output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()];
const onnxruntime::NodeArg* output;
if (it->GetDstArgIndex() < static_cast<int>(it->GetNode().InputDefs().size())) {
output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()];
} else {
auto index = it->GetDstArgIndex() - static_cast<int>(it->GetNode().InputDefs().size());
output = (it->GetNode()).ImplicitInputDefs()[index];
}
if (node_set.find(node_idx) != node_set.end()) {
const auto& iter = fused_inputs.find(output);
if (iter != fused_inputs.end()) {
fused_inputs.erase(iter);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
auto it = std::find(graph_output_names.begin(), graph_output_names.end(), output->Name());
if (it != graph_output_names.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
fused_outputs[output] = output_order++;
@ -1144,8 +1164,7 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
} else {
// Only when output is neither in input list nor erased list, add the output to output list
if (erased.find(output) == erased.end()) {
auto it = std::find(graph_output_names.begin(), graph_output_names.end(), output->Name());
if (it != graph_output_names.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
fused_outputs[output] = output_order++;
@ -1168,27 +1187,6 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
outputs.insert(std::pair<int, const NodeArg*>(it->second, it->first));
}
// It is possible that an output of an node is put bebind the output of an later
// node in the graph output list. So we should sort the output name according
// to the graph output names
std::vector<std::string> output_names;
std::unordered_set<std::string> graph_out_names;
for (const auto& output : outputs) {
if (output.second->Exists()) {
auto name = output.second->Name();
if (std::find(graph_output_names.begin(), graph_output_names.end(), name) == graph_output_names.end()) {
output_names.push_back(name);
} else {
graph_out_names.insert(name);
}
}
}
for (auto& name : graph_output_names) {
if (std::find(graph_out_names.begin(), graph_out_names.end(), name) != graph_out_names.end())
output_names.push_back(name);
}
// Generate unique kernel name for CANN subgraph
HashValue model_hash = 0;
int id = GenerateMetaDefId(graph_viewer, model_hash);
@ -1202,8 +1200,14 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(
}
}
for (const auto& output : output_names) {
meta_def->outputs().push_back(output);
for (const auto& initializer : initializers) {
meta_def->constant_initializers().push_back(initializer);
}
for (const auto& output : outputs) {
if (output.second->Exists()) {
meta_def->outputs().push_back(output.second->Name());
}
}
meta_def->domain() = kMSDomain;
@ -1306,13 +1310,13 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
const std::string node_name = fused_node.Name();
std::unordered_map<size_t, std::string> names2index;
std::unordered_map<size_t, std::string> index2name;
const auto& input_defs = fused_node.InputDefs();
names2index.reserve(input_defs.size());
index2name.reserve(input_defs.size());
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
names2index[i] = input_defs[i]->Name();
index2name[i] = input_defs[i]->Name();
}
names_[node_name] = names2index;
names_[node_name] = index2name;
std::string string_model;
auto model = cann::CreateModel(graph_body_viewer, *GetLogger());
@ -1322,6 +1326,11 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
model_proto->SerializeToString(string_model);
models_[node_name] = string_model;
if (info_.dump_graphs) {
std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
}
NodeComputeInfo compute_info;
compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
std::unique_ptr<CannFuncState> p = std::make_unique<CannFuncState>();
@ -1335,20 +1344,19 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
delete static_cast<CannFuncState*>(state);
};
compute_info.compute_func = [this](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) {
compute_info.compute_func = [this](FunctionState state, const OrtApi*, OrtKernelContext* context) {
Ort::KernelContext ctx(context);
CannFuncState* cann_state = reinterpret_cast<CannFuncState*>(state);
std::string& string_model = models_[cann_state->node_name];
std::unordered_map<size_t, std::string>& names2index = names_[cann_state->node_name];
std::unordered_map<size_t, std::string>& index2name = names_[cann_state->node_name];
std::string input_shape = [&ctx, &names2index]() -> std::string {
std::string input_shape = [&ctx, &index2name]() -> std::string {
std::string res;
for (size_t i = 0; i < ctx.GetInputCount(); i++) {
auto&& shape = ctx.GetInput(i).GetTensorTypeAndShapeInfo().GetShape();
auto name = names2index[i];
std::string s = name + ":";
std::string s = index2name[i] + ":";
for (auto& d : shape) {
s += std::to_string(d) + ",";
}
@ -1366,8 +1374,13 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
std::string filename = cann_state->node_name + "_" + std::to_string(hash);
std::string filename_with_suffix = filename + ".om";
// TODO(FFFrog): Resource Management
// It is very necessary to provide a new mechanism for memory reclamation to avoid inference failure caused by
// device memory exhaustion
uint32_t modelID;
{
if (modelIDs_.find(filename) != modelIDs_.end()) {
modelID = modelIDs_[filename];
} else {
std::lock_guard<OrtMutex> lock(g_mutex);
if (cann::FileExist(filename_with_suffix)) {
@ -1377,10 +1390,12 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
ORT_RETURN_IF_ERROR(ParserONNXModel(string_model, graph));
ge::ModelBufferData model;
ORT_RETURN_IF_ERROR(BuildONNXModel(graph, input_shape, soc_name_, filename, model));
ORT_RETURN_IF_ERROR(BuildONNXModel(graph, input_shape, soc_name_, filename, info_, model));
CANN_RETURN_IF_ERROR(aclmdlLoadFromMem(model.data.get(), model.length, &modelID));
}
modelIDs_.emplace(filename, modelID);
}
CannModelPreparation prepare(modelID);
@ -1407,7 +1422,8 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what());
}
CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream_));
aclrtStream stream = static_cast<aclrtStream>(ctx.GetGPUComputeStream());
CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream));
return Status::OK();
};
@ -1436,7 +1452,12 @@ void CANNExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage
true,
{info_.default_memory_arena_cfg ? *info_.default_memory_arena_cfg
: OrtArenaCfg(info_.npu_mem_limit,
static_cast<int>(info_.arena_extend_strategy), -1, -1, -1)});
static_cast<int>(info_.arena_extend_strategy),
-1,
-1,
-1)},
true,
false);
cann_alloc = CreateAllocator(default_memory_info);
allocator_manager.InsertAllocator(cann_alloc);
@ -1484,4 +1505,8 @@ void CANNExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage
}
}
void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const {
RegisterCannStreamHandles(stream_handle_registry, OrtDevice::NPU);
}
} // namespace onnxruntime

View file

@ -36,36 +36,33 @@ class CANNExecutionProvider : public IExecutionProvider {
Status OnRunStart() override;
Status OnRunEnd(bool sync_stream) override;
void* GetComputeStream() const override { return static_cast<void*>(stream_); }
template <typename T>
IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes) const {
IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes, Stream* stream, WaitNotificationFn wait_fn) const {
if (count_or_bytes == 0)
return nullptr;
return IAllocator::MakeUniquePtr<T>(GetAllocator(OrtMemTypeDefault), count_or_bytes);
return IAllocator::MakeUniquePtr<T>(GetAllocator(OrtMemTypeDefault), count_or_bytes, false, stream, wait_fn);
}
template <typename T>
IAllocatorUniquePtr<T> GetScratchBufferOnCANNPinned(size_t count_or_bytes) const {
if (count_or_bytes == 0)
return nullptr;
return IAllocator::MakeUniquePtr<T>(GetAllocator(OrtMemTypeCPU),
count_or_bytes);
return IAllocator::MakeUniquePtr<T>(GetAllocator(OrtMemTypeCPU), count_or_bytes);
}
template <typename T>
Status Fill(Tensor* y, void* addr) const {
return cann::Fill<T>(y, addr, stream_);
Status Fill(Tensor* y, void* addr, aclrtStream stream) const {
return cann::Fill<T>(y, addr, stream);
}
template <typename T>
Status Broadcast(const Tensor* x, Tensor* y, void* addr) const {
return cann::Broadcast<T>(x, y, addr, stream_);
Status Broadcast(const Tensor* x, Tensor* y, void* addr, aclrtStream stream) const {
return cann::Broadcast<T>(x, y, addr, stream);
}
int GetDeviceId() const override { return info_.device_id; }
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;
@ -86,11 +83,13 @@ class CANNExecutionProvider : public IExecutionProvider {
void RegisterAllocator(AllocatorManager& allocator_manager) override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const override;
private:
CANNExecutionProviderInfo info_;
aclrtStream stream_ = nullptr;
const char* soc_name_ = nullptr;
std::unordered_map<std::string, uint32_t> modelIDs_;
std::unordered_map<std::string, std::string> models_;
std::unordered_map<std::string, std::unordered_map<std::size_t, std::string>> names_;
};

View file

@ -19,8 +19,11 @@ namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kMemLimit = "npu_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kDoCopyInDefaultStream = "do_copy_in_default_stream";
constexpr const char* kEnableCannGraph = "enable_cann_graph";
constexpr const char* kDumpGraphs = "dump_graphs";
constexpr const char* kPrecisionMode = "precision_mode";
constexpr const char* kOpSelectImplMode = "op_select_impl_mode";
constexpr const char* kOpTypeListForImplMode = "optypelist_for_implmode";
} // namespace provider_option_names
} // namespace cann
@ -53,8 +56,11 @@ CANNExecutionProviderInfo CANNExecutionProviderInfo::FromProviderOptions(const P
.AddAssignmentToEnumReference(
cann::provider_option_names::kArenaExtendStrategy,
arena_extend_strategy_mapping, info.arena_extend_strategy)
.AddAssignmentToReference(cann::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream)
.AddAssignmentToReference(cann::provider_option_names::kEnableCannGraph, info.enable_cann_graph)
.AddAssignmentToReference(cann::provider_option_names::kDumpGraphs, info.dump_graphs)
.AddAssignmentToReference(cann::provider_option_names::kPrecisionMode, info.precision_mode)
.AddAssignmentToReference(cann::provider_option_names::kOpSelectImplMode, info.op_select_impl_mode)
.AddAssignmentToReference(cann::provider_option_names::kOpTypeListForImplMode, info.optypelist_for_implmode)
.Parse(options));
return info;
}
@ -65,9 +71,11 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const CANNExecution
{cann::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.npu_mem_limit)},
{cann::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{cann::provider_option_names::kDoCopyInDefaultStream,
MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}};
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)},
{cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)},
{cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)},
{cann::provider_option_names::kOpSelectImplMode, MakeStringWithClassicLocale(info.op_select_impl_mode)},
{cann::provider_option_names::kOpTypeListForImplMode, MakeStringWithClassicLocale(info.optypelist_for_implmode)}};
return options;
}
@ -77,9 +85,11 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const OrtCANNProvid
{cann::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.npu_mem_limit)},
{cann::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, ArenaExtendStrategy(info.arena_extend_strategy))},
{cann::provider_option_names::kDoCopyInDefaultStream,
MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}};
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)},
{cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)},
{cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)},
{cann::provider_option_names::kOpSelectImplMode, MakeStringWithClassicLocale(info.op_select_impl_mode)},
{cann::provider_option_names::kOpTypeListForImplMode, MakeStringWithClassicLocale(info.optypelist_for_implmode)}};
return options;
}
} // namespace onnxruntime

View file

@ -5,6 +5,7 @@
#pragma once
#include <limits>
#include <string>
#include "core/framework/arena_extend_strategy.h"
#include "core/framework/ortdevice.h"
@ -16,8 +17,11 @@ struct CANNExecutionProviderInfo {
OrtDevice::DeviceId device_id{0};
size_t npu_mem_limit{std::numeric_limits<size_t>::max()};
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
bool do_copy_in_default_stream{true};
bool enable_cann_graph{true};
bool dump_graphs{false};
std::string precision_mode;
std::string op_select_impl_mode;
std::string optypelist_for_implmode;
OrtArenaCfg* default_memory_arena_cfg{nullptr};
static CANNExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);

View file

@ -1,62 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Huawei. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cann/cann_call.h"
#include "core/providers/cann/npu_data_transfer.h"
#include "core/providers/cann/cann_fence.h"
#include <iostream>
namespace onnxruntime {
CANNFence::CANNFence(const NPUDataTransfer* data_transfer) : data_transfer_(data_transfer) {
CANN_CALL_THROW(aclrtCreateEvent(&read_event_));
CANN_CALL_THROW(aclrtCreateEvent(&write_event_));
}
CANNFence::~CANNFence() {
CANN_CALL_THROW(aclrtDestroyEvent(read_event_));
CANN_CALL_THROW(aclrtDestroyEvent(write_event_));
}
void CANNFence::BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int async_queue_id) {
if (provider_type == onnxruntime::kCannExecutionProvider) {
CANN_CALL_THROW(aclrtStreamWaitEvent(data_transfer_->GetStream(async_queue_id), write_event_));
} else {
CANN_CALL_THROW(aclrtSynchronizeEvent(write_event_));
}
}
void CANNFence::BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) {
if (provider_type == onnxruntime::kCannExecutionProvider) {
aclrtStream stream = data_transfer_->GetStream(queue_id);
CANN_CALL_THROW(aclrtStreamWaitEvent(stream, read_event_));
CANN_CALL_THROW(aclrtStreamWaitEvent(stream, write_event_));
} else {
CANN_CALL_THROW(aclrtSynchronizeEvent(read_event_));
CANN_CALL_THROW(aclrtSynchronizeEvent(write_event_));
}
}
bool CANNFence::CanRelease() {
aclrtEventRecordedStatus read_status;
aclrtEventRecordedStatus write_status;
return aclrtQueryEventStatus(read_event_, &read_status) == ACL_SUCCESS &&
aclrtQueryEventStatus(write_event_, &write_status) == ACL_SUCCESS &&
read_status == ACL_EVENT_RECORDED_STATUS_COMPLETE &&
write_status == ACL_EVENT_RECORDED_STATUS_COMPLETE;
}
void CANNFence::AfterUsedAsInput(int queue_id) {
aclrtStream stream = data_transfer_->GetStream(queue_id);
CANN_CALL_THROW(aclrtRecordEvent(read_event_, stream));
}
void CANNFence::AfterUsedAsOutput(int queue_id) {
aclrtStream stream = data_transfer_->GetStream(queue_id);
CANN_CALL_THROW(aclrtRecordEvent(write_event_, stream));
}
} // namespace onnxruntime

View file

@ -1,29 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Huawei. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/fence.h"
#include "core/providers/cann/cann_inc.h"
namespace onnxruntime {
class NPUDataTransfer;
class CANNFence : public IFence {
public:
explicit CANNFence(const NPUDataTransfer* data_transfer);
virtual ~CANNFence();
void BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int queue_id) override;
void BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) override;
void AfterUsedAsInput(int queue_id) override;
void AfterUsedAsOutput(int queue_id) override;
bool CanRelease() override;
private:
aclrtEvent read_event_;
aclrtEvent write_event_;
const NPUDataTransfer* data_transfer_;
};
} // namespace onnxruntime

View file

@ -13,6 +13,8 @@ namespace cann {
static int lower_bound = 8; // Supported domain version lower bounds
static int upper_bound = 15; // Supported domain version upper bounds
std::once_flag flag;
/**
* This function will been changed with the evolution of ONNX and CANN
* and will be replaced by the corresponding API provided by CANN in the future, probably.
@ -32,7 +34,7 @@ std::vector<NodeIndex> SupportONNXModel(const GraphViewer& graph_viewer) {
"GatherElements", "GatherND", "Gemm", "GlobalAveragePool",
"GlobalLpPool", "GlobalMaxPool", "Greater", "GreaterOrEqual",
"Hardmax", "HardSigmoid", "HardSwish", "Identity",
"If", "InstanceNormalization", "LeakyRelu", "Less",
"InstanceNormalization", "LeakyRelu", "Less",
"LessOrEqual", "Log", "LogSoftmax", "LpNormalization",
"LpPool", "LRN", "LSTM", "MatMul",
"Max", "MaxPool", "MaxRoiPool", "MaxUnpool",
@ -94,20 +96,27 @@ Status ParserONNXModel(std::string string_model, ge::Graph& graph) {
}
Status BuildONNXModel(ge::Graph& graph, std::string input_shape, const char* soc_name, std::string file_name,
ge::ModelBufferData& model) {
CANNExecutionProviderInfo& info, ge::ModelBufferData& model) {
std::call_once(flag, [&soc_name, &info]() {
std::map<ge::AscendString, ge::AscendString> options;
options.emplace(ge::ir_option::SOC_VERSION, soc_name);
if (!info.precision_mode.empty())
options.emplace(ge::ir_option::PRECISION_MODE, info.precision_mode.c_str());
if (!info.op_select_impl_mode.empty())
options.emplace(ge::ir_option::OP_SELECT_IMPL_MODE, info.op_select_impl_mode.c_str());
if (!info.optypelist_for_implmode.empty())
options.emplace(ge::ir_option::OPTYPELIST_FOR_IMPLMODE, info.optypelist_for_implmode.c_str());
CANN_CALL_THROW(ge::aclgrphBuildInitialize(options));
});
std::map<ge::AscendString, ge::AscendString> options;
options.emplace(ge::ir_option::SOC_VERSION, soc_name);
CANN_GRAPH_RETURN_IF_ERROR(ge::aclgrphBuildInitialize(options));
options.clear();
options.emplace(ge::ir_option::INPUT_SHAPE, input_shape.c_str());
CANN_GRAPH_RETURN_IF_ERROR(ge::aclgrphBuildModel(graph, options, model));
CANN_GRAPH_RETURN_IF_ERROR(ge::aclgrphSaveModel(file_name.c_str(), model));
ge::aclgrphBuildFinalize();
return Status::OK();
}

View file

@ -12,6 +12,7 @@
#include "core/providers/cann/cann_common.h"
#include "core/providers/cann/cann_inc.h"
#include "core/providers/cann/cann_utils.h"
#include "core/providers/cann/cann_execution_provider_info.h"
namespace onnxruntime {
namespace cann {
@ -72,7 +73,7 @@ struct CannModelPreparation {
std::vector<NodeIndex> SupportONNXModel(const GraphViewer& graph_viewer);
Status ParserONNXModel(std::string string_model, ge::Graph& graph);
Status BuildONNXModel(ge::Graph& graph, std::string input_shape, const char* soc_name, std::string file_name,
ge::ModelBufferData& model);
CANNExecutionProviderInfo& info, ge::ModelBufferData& model);
} // namespace cann
} // namespace onnxruntime

View file

@ -4,11 +4,13 @@
#pragma once
#include "core/platform/ort_mutex.h"
#include "core/providers/cann/cann_inc.h"
#include "core/providers/cann/cann_call.h"
#include "core/providers/cann/cann_execution_provider.h"
#include "core/providers/cann/cann_fwd.h"
#include "core/providers/cann/cann_utils.h"
#include "core/providers/cann/cann_stream_handle.h"
namespace onnxruntime {
namespace cann {
@ -35,11 +37,14 @@ class CannKernel : public OpKernel {
virtual Status ComputeInternal(OpKernelContext* p_op_kernel_context) const = 0;
inline aclrtStream Stream() const { return static_cast<aclrtStream>(provider_->GetComputeStream()); }
inline aclrtStream Stream(OpKernelContext* ctx) const {
auto* stream = ctx->GetComputeStream();
return stream ? static_cast<aclrtStream>(stream->GetHandle()) : nullptr;
}
template <typename T>
inline IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes) const {
return provider_->GetScratchBuffer<T>(count_or_bytes);
inline IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes, onnxruntime::Stream* stream) const {
return provider_->GetScratchBuffer<T>(count_or_bytes, stream, WaitCannNotificationOnDevice);
}
template <typename T>
@ -48,13 +53,13 @@ class CannKernel : public OpKernel {
}
template <typename T>
inline Status Fill(Tensor* y, void* addr) const {
return provider_->Fill<T>(y, addr);
inline Status Fill(Tensor* y, void* addr, aclrtStream stream) const {
return provider_->Fill<T>(y, addr, stream);
}
template <typename T>
inline Status Broadcast(const Tensor* x, Tensor* y, void* addr) const {
return provider_->Broadcast<T>(x, y, addr);
inline Status Broadcast(const Tensor* x, Tensor* y, void* addr, aclrtStream stream) const {
return provider_->Broadcast<T>(x, y, addr, stream);
}
protected:

View file

@ -53,8 +53,11 @@ struct CANN_Provider : Provider {
info.device_id = static_cast<OrtDevice::DeviceId>(params->device_id);
info.npu_mem_limit = params->npu_mem_limit;
info.arena_extend_strategy = params->arena_extend_strategy;
info.do_copy_in_default_stream = params->do_copy_in_default_stream != 0;
info.enable_cann_graph = params->enable_cann_graph != 0;
info.dump_graphs = params->dump_graphs != 0;
info.precision_mode = params->precision_mode;
info.op_select_impl_mode = params->op_select_impl_mode;
info.optypelist_for_implmode = params->optypelist_for_implmode;
info.default_memory_arena_cfg = params->default_memory_arena_cfg;
return std::make_shared<CANNProviderFactory>(info);
@ -67,8 +70,11 @@ struct CANN_Provider : Provider {
cann_options.device_id = internal_options.device_id;
cann_options.npu_mem_limit = internal_options.npu_mem_limit;
cann_options.arena_extend_strategy = internal_options.arena_extend_strategy;
cann_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream;
cann_options.enable_cann_graph = internal_options.enable_cann_graph;
cann_options.dump_graphs = internal_options.dump_graphs;
cann_options.precision_mode = internal_options.precision_mode;
cann_options.op_select_impl_mode = internal_options.op_select_impl_mode;
cann_options.optypelist_for_implmode = internal_options.optypelist_for_implmode;
cann_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg;
}

View file

@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Huawei. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cann/cann_stream_handle.h"
#include "core/common/spin_pause.h"
namespace onnxruntime {
struct CannNotification : public synchronize::Notification {
explicit CannNotification(Stream& s) : Notification(s) {
CANN_CALL_THROW(aclrtCreateEvent(&event_));
}
~CannNotification() {
if (event_)
CANN_CALL_THROW(aclrtDestroyEvent(event_));
}
void Activate() override {
CANN_CALL_THROW(aclrtRecordEvent(event_, static_cast<aclrtStream>(stream_.GetHandle())));
}
void wait_on_device(Stream& device_stream) {
ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::NPU);
CANN_CALL_THROW(aclrtStreamWaitEvent(static_cast<aclrtStream>(device_stream.GetHandle()), event_));
}
void wait_on_host() {
CANN_CALL_THROW(aclrtSynchronizeEvent(event_));
}
aclrtEvent event_;
};
CannStream::CannStream(aclrtStream stream,
const OrtDevice& device,
bool own_flag) : Stream(stream, device),
own_stream_(own_flag) {}
CannStream::~CannStream() {
ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd());
if (own_stream_) {
auto* handle = GetHandle();
if (handle)
aclrtDestroyStream(static_cast<aclrtStream>(handle));
}
}
std::unique_ptr<synchronize::Notification> CannStream::CreateNotification(size_t /*num_consumers*/) {
return std::make_unique<CannNotification>(*this);
}
void CannStream::Flush() {
if (own_stream_)
CANN_CALL_THROW(aclrtSynchronizeStream(static_cast<aclrtStream>(GetHandle())));
}
// CPU Stream command handles
void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification) {
static_cast<CannNotification*>(&notification)->wait_on_device(stream);
}
void WaitCannNotificationOnHost(Stream& /*stream*/, synchronize::Notification& notification) {
static_cast<CannNotification*>(&notification)->wait_on_host();
}
void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
const OrtDevice::DeviceType device_type) {
// wait cann notification on cann ep
stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitCannNotificationOnDevice);
// wait cann notification on cpu ep
stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitCannNotificationOnHost);
stream_handle_registry.RegisterCreateStreamFn(device_type, [](const OrtDevice& device) {
aclrtStream stream = nullptr;
CANN_CALL_THROW(aclrtCreateStream(&stream));
return std::make_unique<CannStream>(stream, device, true);
});
}
} // namespace onnxruntime

View file

@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Huawei. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include <vector>
#include "core/framework/stream_handles.h"
#include "core/providers/cann/cann_inc.h"
#include "core/providers/cann/cann_call.h"
namespace onnxruntime {
struct CannStream : Stream {
CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag);
~CannStream();
std::unique_ptr<synchronize::Notification> CreateNotification(size_t /*num_consumers*/) override;
void Flush() override;
bool own_stream_{true};
};
void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
const OrtDevice::DeviceType device_type);
void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime

View file

@ -3,6 +3,7 @@
// Licensed under the MIT License.
#include <unistd.h>
#include <algorithm>
#include "core/providers/cann/cann_utils.h"
@ -223,5 +224,34 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) {
hash_value = hash[0] | (uint64_t(hash[1]) << 32);
}
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);
std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}
} // namespace cann
} // namespace onnxruntime

View file

@ -124,6 +124,9 @@ Status aclrtblasGemmEx(aclTransType transA,
bool FileExist(const std::string& file_name);
void GenerateHashValue(const std::string string, HashValue& hash_value);
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
const TensorShape& rhs_shape, TensorShape& out_shape);
std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger);
} // namespace cann

View file

@ -11,35 +11,6 @@ using onnxruntime::common::Status;
namespace onnxruntime {
namespace cann {
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape,
const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);
std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}
template <typename T>
Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare) const {
const aclDataType aclType = getACLType<T>();
@ -56,14 +27,14 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
void* B_data = const_cast<void*>(B->DataRaw());
if (A->Shape() != C->Shape()) {
IAllocatorUniquePtr<void> pA = GetScratchBuffer<void>(C->SizeInBytes());
ORT_RETURN_IF_ERROR(Broadcast<T>(A, C, pA.get()));
IAllocatorUniquePtr<void> pA = GetScratchBuffer<void>(C->SizeInBytes(), ctx->GetComputeStream());
ORT_RETURN_IF_ERROR(Broadcast<T>(A, C, pA.get(), Stream(ctx)));
A_data = pA.get();
}
if (B->Shape() != C->Shape()) {
IAllocatorUniquePtr<void> pB = GetScratchBuffer<void>(C->SizeInBytes());
ORT_RETURN_IF_ERROR(Broadcast<T>(B, C, pB.get()));
IAllocatorUniquePtr<void> pB = GetScratchBuffer<void>(C->SizeInBytes(), ctx->GetComputeStream());
ORT_RETURN_IF_ERROR(Broadcast<T>(B, C, pB.get(), Stream(ctx)));
B_data = pB.get();
}
@ -85,9 +56,9 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
#define REGISTER_ELEMENTWISE_TYPED_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
Status x<T>::ComputeInternal(OpKernelContext* ctx) const { \
CannPreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare<T>(context, prepare)); \
ORT_RETURN_IF_ERROR(Prepare<T>(ctx, prepare)); \
CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \
prepare.inputDesc_.size(), \
prepare.inputDesc_.data(), \
@ -99,7 +70,7 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare
ACL_ENGINE_SYS, \
ACL_COMPILE_SYS, \
NULL, \
Stream())); \
Stream(ctx))); \
return Status::OK(); \
}

View file

@ -24,28 +24,28 @@ template <typename T>
class Add final : public BinaryElementwise {
public:
Add(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Sub final : public BinaryElementwise {
public:
Sub(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Mul final : public BinaryElementwise {
public:
Mul(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Div final : public BinaryElementwise {
public:
Div(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace cann

View file

@ -11,10 +11,10 @@ namespace onnxruntime {
namespace cann {
template <typename T>
Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
const auto* A = context->Input<Tensor>(0);
const auto* B = context->Input<Tensor>(1);
const auto* C = context->Input<Tensor>(2);
Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
const auto* A = ctx->Input<Tensor>(0);
const auto* B = ctx->Input<Tensor>(1);
const auto* C = ctx->Input<Tensor>(2);
GemmHelper helper(A->Shape(), trans_A_, B->Shape(), trans_B_, C != nullptr ? C->Shape() : TensorShape({}));
if (!helper.State().IsOK())
@ -24,13 +24,13 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
int N = gsl::narrow_cast<int>(helper.N());
int K = gsl::narrow_cast<int>(helper.K());
auto* Y = context->Output(0, {M, N});
auto* Y = ctx->Output(0, {M, N});
// broadcast C if needed.
if (beta_ != 0 && C != nullptr) {
if (C->Shape().Size() == 1) {
// C is (), (1,) or (1, 1), fill the scalar to Y
ORT_RETURN_IF_ERROR(Fill<T>(Y, const_cast<void*>(C->DataRaw())));
ORT_RETURN_IF_ERROR(Fill<T>(Y, const_cast<void*>(C->DataRaw()), Stream(ctx)));
} else if (C->Shape() == Y->Shape()) {
// C is (M, N), no broadcast needed.
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(Y->MutableDataRaw(),
@ -38,10 +38,10 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
const_cast<void*>(C->DataRaw()),
Y->SizeInBytes(),
ACL_MEMCPY_DEVICE_TO_DEVICE,
Stream()));
Stream(ctx)));
} else {
// others, broadcast needed.
ORT_RETURN_IF_ERROR(Broadcast<T>(C, Y, Y->MutableDataRaw()));
ORT_RETURN_IF_ERROR(Broadcast<T>(C, Y, Y->MutableDataRaw(), Stream(ctx)));
}
}
@ -49,8 +49,8 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
T alpha = ToCannType<T>::FromFloat(alpha_);
T beta = ToCannType<T>::FromFloat(beta_);
IAllocatorUniquePtr<void> pAlpha = GetScratchBuffer<void>(sizeof(T));
IAllocatorUniquePtr<void> pBeta = GetScratchBuffer<void>(sizeof(T));
IAllocatorUniquePtr<void> pAlpha = GetScratchBuffer<void>(sizeof(T), ctx->GetComputeStream());
IAllocatorUniquePtr<void> pBeta = GetScratchBuffer<void>(sizeof(T), ctx->GetComputeStream());
CANN_RETURN_IF_ERROR(aclrtMemcpy(pAlpha.get(), sizeof(T), &alpha, sizeof(T), ACL_MEMCPY_HOST_TO_DEVICE));
CANN_RETURN_IF_ERROR(aclrtMemcpy(pBeta.get(), sizeof(T), &beta, sizeof(T), ACL_MEMCPY_HOST_TO_DEVICE));
@ -67,7 +67,7 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* context) const {
pBeta.get(),
Y->MutableDataRaw(), -1, aclType,
ACL_COMPUTE_HIGH_PRECISION,
Stream()));
Stream(ctx)));
return Status::OK();
}

View file

@ -25,7 +25,7 @@ class Gemm final : public CannKernel {
ORT_ENFORCE(info.GetAttr<float>("beta", &beta_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
bool trans_A_;

View file

@ -52,7 +52,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
ACL_ENGINE_SYS,
ACL_COMPILE_SYS,
NULL,
Stream()));
Stream(ctx)));
return Status::OK();
}

View file

@ -15,7 +15,7 @@ class MatMul final : public CannKernel {
MatMul(const OpKernelInfo& info)
: CannKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace cann

View file

@ -38,9 +38,9 @@ Status UnaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare)
#define REGISTER_ELEMENTWISE_TYPED_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
Status x<T>::ComputeInternal(OpKernelContext* ctx) const { \
CannPreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare<T>(context, prepare)); \
ORT_RETURN_IF_ERROR(Prepare<T>(ctx, prepare)); \
CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \
prepare.inputDesc_.size(), \
prepare.inputDesc_.data(), \
@ -52,7 +52,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare)
ACL_ENGINE_SYS, \
ACL_COMPILE_SYS, \
NULL, \
Stream())); \
Stream(ctx))); \
return Status::OK(); \
}

View file

@ -29,84 +29,84 @@ template <typename T>
class Abs final : public UnaryElementwise {
public:
Abs(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Neg final : public UnaryElementwise {
public:
Neg(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Floor final : public UnaryElementwise {
public:
Floor(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Ceil final : public UnaryElementwise {
public:
Ceil(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Reciprocal final : public UnaryElementwise {
public:
Reciprocal(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Sqrt final : public UnaryElementwise {
public:
Sqrt(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Log final : public UnaryElementwise {
public:
Log(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Exp final : public UnaryElementwise {
public:
Exp(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Erf final : public UnaryElementwise {
public:
Erf(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Round final : public UnaryElementwise {
public:
Round(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Sin final : public UnaryElementwise {
public:
Sin(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
template <typename T>
class Cos final : public UnaryElementwise {
public:
Cos(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace cann

View file

@ -11,8 +11,8 @@ namespace onnxruntime {
namespace cann {
template <typename T>
Status AveragePool<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
Status AveragePool<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* X = ctx->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
const auto X_dims = X_shape.GetDims();
@ -35,7 +35,7 @@ Status AveragePool<T>::ComputeInternal(OpKernelContext* context) const {
auto Y_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads);
TensorShape Y_shape(Y_dims);
Tensor* Y = context->Output(0, Y_shape);
Tensor* Y = ctx->Output(0, Y_shape);
if (Y_shape.Size() == 0)
return Status::OK();
@ -86,7 +86,7 @@ Status AveragePool<T>::ComputeInternal(OpKernelContext* context) const {
ACL_ENGINE_SYS,
ACL_COMPILE_SYS,
NULL,
Stream()));
Stream(ctx)));
return Status::OK();
}

View file

@ -15,7 +15,7 @@ class AveragePool : public CannKernel, public PoolBase {
public:
explicit AveragePool(const OpKernelInfo& info) : CannKernel(info), PoolBase(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace cann

View file

@ -22,10 +22,10 @@ Status BatchNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
// There is only one output in inference mode
Tensor* Y = ctx->Output(0, X->Shape());
IAllocatorUniquePtr<void> pbatch_mean = GetScratchBuffer<void>(mean->SizeInBytes());
IAllocatorUniquePtr<void> pbatch_variance = GetScratchBuffer<void>(var->SizeInBytes());
IAllocatorUniquePtr<void> preserver_space_1 = GetScratchBuffer<void>(mean->SizeInBytes());
IAllocatorUniquePtr<void> preserver_space_2 = GetScratchBuffer<void>(var->SizeInBytes());
IAllocatorUniquePtr<void> pbatch_mean = GetScratchBuffer<void>(mean->SizeInBytes(), ctx->GetComputeStream());
IAllocatorUniquePtr<void> pbatch_variance = GetScratchBuffer<void>(var->SizeInBytes(), ctx->GetComputeStream());
IAllocatorUniquePtr<void> preserver_space_1 = GetScratchBuffer<void>(mean->SizeInBytes(), ctx->GetComputeStream());
IAllocatorUniquePtr<void> preserver_space_2 = GetScratchBuffer<void>(var->SizeInBytes(), ctx->GetComputeStream());
const aclDataType aclType = getACLType<T>();
aclFormat format = ACL_FORMAT_NCHW;
@ -76,7 +76,7 @@ Status BatchNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
ACL_ENGINE_SYS,
ACL_COMPILE_SYS,
NULL,
Stream()));
Stream(ctx)));
return Status::OK();
}

View file

@ -20,7 +20,7 @@ class BatchNorm final : public CannKernel {
ORT_ENFORCE(!is_training_mode_, "only supports inference mode");
}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
float epsilon_;

View file

@ -106,7 +106,7 @@ Status Conv<T>::ComputeInternal(OpKernelContext* ctx) const {
ACL_ENGINE_SYS,
ACL_COMPILE_SYS,
NULL,
Stream()));
Stream(ctx)));
return Status::OK();
}

View file

@ -26,17 +26,17 @@ float GetRatioOrDefault(const Tensor* ratio) {
} // namespace
template <typename T1, typename T2>
Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* X = ctx->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
const Tensor* ratio = context->Input<Tensor>(1);
const Tensor* ratio = ctx->Input<Tensor>(1);
const float ratio_value = GetRatioOrDefault<T2>(ratio);
const Tensor* training_mode = context->Input<Tensor>(2);
const Tensor* training_mode = ctx->Input<Tensor>(2);
auto Y = context->Output(0, X_shape);
auto mask = context->Output(1, X_shape);
auto Y = ctx->Output(0, X_shape);
auto mask = ctx->Output(1, X_shape);
if (ratio_value == 0.f || !training_mode || !(*(training_mode->Data<bool>()))) {
const void* X_data = X->DataRaw();
@ -44,22 +44,22 @@ Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
if (Y_data != X_data) {
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(Y_data, Y->SizeInBytes(), X_data, Y->SizeInBytes(),
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream()));
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream(ctx)));
}
if (mask) {
CANN_RETURN_IF_ERROR(aclrtMemsetAsync(mask->MutableData<bool>(), mask->SizeInBytes(), true,
mask->SizeInBytes(), Stream()));
mask->SizeInBytes(), Stream(ctx)));
}
} else {
IAllocatorUniquePtr<void> pmask{};
IAllocatorUniquePtr<void> pseed = GetScratchBuffer<void>(sizeof(float));
IAllocatorUniquePtr<void> pseed = GetScratchBuffer<void>(sizeof(float), ctx->GetComputeStream());
void* mask_data = nullptr;
if (mask) {
mask_data = mask->MutableDataRaw();
} else {
pmask = GetScratchBuffer<void>(X_shape.Size() * sizeof(bool));
pmask = GetScratchBuffer<void>(X_shape.Size() * sizeof(bool), ctx->GetComputeStream());
mask_data = pmask.get();
}
@ -107,7 +107,7 @@ Status Dropout<T1, T2>::ComputeInternal(OpKernelContext* context) const {
ACL_ENGINE_SYS,
ACL_COMPILE_SYS,
NULL,
Stream()));
Stream(ctx)));
}
return Status::OK();

View file

@ -21,7 +21,7 @@ class Dropout final : public CannKernel {
}
}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
mutable std::unique_ptr<RandomGenerator> generator_;

View file

@ -11,8 +11,8 @@ namespace onnxruntime {
namespace cann {
template <typename T>
Status MaxPool<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
Status MaxPool<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* X = ctx->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
const auto X_dims = X_shape.GetDims();
@ -35,7 +35,7 @@ Status MaxPool<T>::ComputeInternal(OpKernelContext* context) const {
auto Y_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads);
TensorShape Y_shape(Y_dims);
Tensor* Y = context->Output(0, Y_shape);
Tensor* Y = ctx->Output(0, Y_shape);
if (Y_shape.Size() == 0)
return Status::OK();
@ -85,7 +85,7 @@ Status MaxPool<T>::ComputeInternal(OpKernelContext* context) const {
ACL_ENGINE_SYS,
ACL_COMPILE_SYS,
NULL,
Stream()));
Stream(ctx)));
return Status::OK();
}

View file

@ -15,7 +15,7 @@ class MaxPool : public CannKernel, public PoolBase {
public:
explicit MaxPool(const OpKernelInfo& info) : CannKernel(info), PoolBase(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace cann

View file

@ -4,35 +4,62 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cann/npu_data_transfer.h"
#include "core/providers/cann/cann_call.h"
namespace onnxruntime {
NPUDataTransfer::NPUDataTransfer(aclrtStream stream, bool do_copy_in_default_stream) {
do_copy_in_default_stream_ = do_copy_in_default_stream;
streams_[kCannStreamDefault] = stream;
if (do_copy_in_default_stream) {
streams_[kCannStreamCopyIn] = stream;
streams_[kCannStreamCopyOut] = stream;
} else {
CANN_CALL_THROW(aclrtCreateStream(&streams_[kCannStreamCopyIn]));
CANN_CALL_THROW(aclrtCreateStream(&streams_[kCannStreamCopyOut]));
}
}
NPUDataTransfer::NPUDataTransfer() {}
NPUDataTransfer::~NPUDataTransfer() {
if (!do_copy_in_default_stream_ && streams_[kCannStreamCopyIn] != nullptr) {
CANN_CALL_THROW(aclrtDestroyStream(streams_[kCannStreamCopyIn]));
}
if (!do_copy_in_default_stream_ && streams_[kCannStreamCopyOut] != nullptr) {
CANN_CALL_THROW(aclrtDestroyStream(streams_[kCannStreamCopyOut]));
}
}
NPUDataTransfer::~NPUDataTransfer() {}
bool NPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
return src_device.Type() == OrtDevice::NPU || dst_device.Type() == OrtDevice::NPU;
}
common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const {
common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
size_t bytes = src.SizeInBytes();
const void* src_data = src.DataRaw();
void* dst_data = dst.MutableDataRaw();
auto& src_device = src.Location().device;
auto& dst_device = dst.Location().device;
// for the sync version of memcpy, launch to cann default stream
if (dst_device.Type() == OrtDevice::NPU) {
if (src_device.Type() == OrtDevice::NPU) {
// Copy only if the two addresses are different.
if (dst_data != src_data) {
CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data,
bytes,
src_data,
bytes,
ACL_MEMCPY_DEVICE_TO_DEVICE));
CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr));
}
} else {
// copy from other CPU memory to NPU, this is blocking
CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data,
bytes,
src_data,
bytes,
ACL_MEMCPY_HOST_TO_DEVICE));
CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr));
}
} else if (src_device.Type() == OrtDevice::NPU) {
// copying from NPU to CPU memory, this is blocking
CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data,
bytes,
src_data,
bytes,
ACL_MEMCPY_DEVICE_TO_HOST));
CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr));
} else {
// copying between cpu memory
memcpy(dst_data, src_data, bytes);
}
return Status::OK();
}
common::Status NPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const {
size_t bytes = src.SizeInBytes();
const void* src_data = src.DataRaw();
void* dst_data = dst.MutableDataRaw();
@ -41,30 +68,44 @@ common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int e
auto& dst_device = dst.Location().device;
if (dst_device.Type() == OrtDevice::NPU) {
if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::CANN_PINNED) {
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
ACL_MEMCPY_HOST_TO_DEVICE, GetStream(exec_queue_id)));
if (src_device.Type() == OrtDevice::CPU) {
// copy from pinned memory to NPU, this is non-blocking
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data,
bytes,
src_data,
bytes,
ACL_MEMCPY_HOST_TO_DEVICE,
static_cast<aclrtStream>(stream.GetHandle())));
} else if (src_device.Type() == OrtDevice::NPU) {
// copying between NPU, this is non-blocking
if (dst_data != src_data) {
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
ACL_MEMCPY_DEVICE_TO_DEVICE, GetStream(kCannStreamDefault)));
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data,
bytes,
src_data,
bytes,
ACL_MEMCPY_DEVICE_TO_DEVICE,
static_cast<aclrtStream>(stream.GetHandle())));
}
} else {
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
ACL_MEMCPY_HOST_TO_DEVICE, GetStream(kCannStreamDefault)));
CANN_CALL_THROW(aclrtSynchronizeStream(GetStream(kCannStreamDefault)));
}
} else if (src_device.Type() == OrtDevice::NPU) {
if (dst_device.Type() == OrtDevice::CPU) {
// copying from NPU to pinned memory, this is non-blocking
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data,
bytes,
src_data,
bytes,
ACL_MEMCPY_DEVICE_TO_HOST,
static_cast<aclrtStream>(stream.GetHandle())));
}
} else {
if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::CANN_PINNED) {
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
ACL_MEMCPY_DEVICE_TO_HOST, GetStream(exec_queue_id)));
} else {
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes,
ACL_MEMCPY_DEVICE_TO_HOST, GetStream(kCannStreamDefault)));
CANN_CALL_THROW(aclrtSynchronizeStream(GetStream(kCannStreamDefault)));
if (src_device.MemType() == OrtDevice::MemType::CANN_PINNED) {
// sync the stream first to make sure the data arrived
CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(static_cast<aclrtStream>(stream.GetHandle())));
}
memcpy(dst_data, src_data, bytes);
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -4,35 +4,23 @@
#pragma once
#include "cann_inc.h"
#include "core/framework/data_transfer.h"
#include "core/providers/cann/cann_inc.h"
#include "core/providers/cann/cann_call.h"
#include "core/providers/cann/cann_common.h"
namespace onnxruntime {
enum CANNStreamType : int {
kCannStreamDefault = 0,
kCannStreamCopyIn,
kCannStreamCopyOut,
kTotalCannStreams,
};
class NPUDataTransfer : public IDataTransfer {
public:
explicit NPUDataTransfer(aclrtStream stream, bool do_copy_in_default_stream = true);
NPUDataTransfer();
~NPUDataTransfer();
bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;
common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const override;
common::Status CopyTensor(const Tensor& src, Tensor& dst) const override;
aclrtStream GetStream(int queue_id) const {
ORT_ENFORCE(queue_id >= 0 && queue_id < kTotalCannStreams);
return streams_[queue_id];
}
private:
bool do_copy_in_default_stream_;
aclrtStream streams_[kTotalCannStreams];
common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override;
};
} // namespace onnxruntime

View file

@ -42,10 +42,10 @@ aclDataType getACLTypeByMap(ONNX_NAMESPACE::TensorProto_DataType type) {
}
template <typename T>
Status Cast<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
Status Cast<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* X = ctx->Input<Tensor>(0);
Tensor* Y = context->Output(0, X->Shape());
Tensor* Y = ctx->Output(0, X->Shape());
aclFormat format = ACL_FORMAT_ND;
const aclDataType aclTypeX = getACLType<T>();
@ -78,7 +78,7 @@ Status Cast<T>::ComputeInternal(OpKernelContext* context) const {
ACL_ENGINE_SYS,
ACL_COMPILE_SYS,
NULL,
Stream()));
Stream(ctx)));
return Status::OK();
}

View file

@ -20,7 +20,7 @@ class Cast final : public CannKernel {
to_ = gsl::narrow_cast<ONNX_NAMESPACE::TensorProto_DataType>(to);
}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
ONNX_NAMESPACE::TensorProto_DataType to_;

View file

@ -53,7 +53,7 @@ Status Flatten<T>::ComputeInternal(OpKernelContext* ctx) const {
ACL_ENGINE_SYS,
ACL_COMPILE_SYS,
NULL,
Stream()));
Stream(ctx)));
}
return Status::OK();

View file

@ -16,7 +16,7 @@ class Flatten final : public CannKernel {
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
int64_t axis_;

View file

@ -18,16 +18,16 @@ class IdentityOp final : public CannKernel {
IdentityOp(const OpKernelInfo& info) : CannKernel(info) {
}
Status ComputeInternal(OpKernelContext* context) const override {
auto X_ml_type = context->InputType(0);
Status ComputeInternal(OpKernelContext* ctx) const override {
auto X_ml_type = ctx->InputType(0);
if (X_ml_type->IsTensorType()) {
const Tensor* X = context->Input<Tensor>(0);
const Tensor* X = ctx->Input<Tensor>(0);
if (nullptr == X) {
return Status(common::ONNXRUNTIME, common::FAIL,
"IdentityOp cann: input count mismatch.");
}
const TensorShape& shape = X->Shape();
Tensor* Y = context->Output(0, shape);
Tensor* Y = ctx->Output(0, shape);
if (nullptr == Y) {
return Status(common::ONNXRUNTIME, common::FAIL,
"IdentityOp cann: failed to allocate output tensor.");
@ -39,20 +39,20 @@ class IdentityOp final : public CannKernel {
if (target != source) {
CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(target, Y->SizeInBytes(), source,
X->Shape().Size() * X->DataType()->Size(),
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream()));
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream(ctx)));
}
if (is_dropout) {
Tensor* mask = context->Output(1, shape);
Tensor* mask = ctx->Output(1, shape);
if (mask != nullptr) {
void* mask_data = mask->MutableDataRaw();
CANN_RETURN_IF_ERROR(aclrtMemsetAsync(mask_data, mask->SizeInBytes(), 0, mask->SizeInBytes(), Stream()));
CANN_RETURN_IF_ERROR(aclrtMemsetAsync(mask_data, mask->SizeInBytes(), 0, mask->SizeInBytes(), Stream(ctx)));
}
}
} else if (X_ml_type->IsTensorSequenceType()) {
const TensorSeq* X = context->Input<TensorSeq>(0);
const TensorSeq* X = ctx->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "IdentityOp cann: input tensor is missing.");
TensorSeq* Y = context->Output<TensorSeq>(0);
TensorSeq* Y = ctx->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "IdentityOp cann: failed to allocate output tensor sequence.");
if (X == Y) {
return Status::OK();
@ -60,7 +60,7 @@ class IdentityOp final : public CannKernel {
auto X_type = X->DataType();
Y->SetType(X_type);
AllocatorPtr alloc;
auto status = context->GetTempSpaceAllocator(&alloc);
auto status = ctx->GetTempSpaceAllocator(&alloc);
if (!status.IsOK()) {
return Status(common::ONNXRUNTIME, common::FAIL,
"IdentityOp cann: unable to get an allocator.");
@ -75,7 +75,7 @@ class IdentityOp final : public CannKernel {
target_tensor->SizeInBytes(),
source_tensor.DataRaw(),
source_tensor.SizeInBytes(),
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream()));
ACL_MEMCPY_DEVICE_TO_DEVICE, Stream(ctx)));
Y->Add(std::move(*target_tensor));
}
} else {

View file

@ -18,8 +18,8 @@ class Reshape final : public CannKernel {
allow_zero_ = (info.GetAttrOrDefault("allowzero", static_cast<int64_t>(0)) == 1);
}
Status ComputeInternal(OpKernelContext* context) const override {
const Tensor* shapeTensor = context->Input<Tensor>(1);
Status ComputeInternal(OpKernelContext* ctx) const override {
const Tensor* shapeTensor = ctx->Input<Tensor>(1);
if (shapeTensor == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "the 0th input is missing");
if (shapeTensor->Shape().NumDimensions() != 1)
@ -28,14 +28,14 @@ class Reshape final : public CannKernel {
auto data_span = shapeTensor->template DataAsSpan<int64_t>();
TensorShapeVector shape(data_span.begin(), data_span.end());
const Tensor* X = context->Input<Tensor>(0);
const Tensor* X = ctx->Input<Tensor>(0);
if (X == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "the 1th input is missing");
const TensorShape& X_shape = X->Shape();
ReshapeHelper helper(X_shape, shape, allow_zero_);
Tensor* Y = context->Output(0, TensorShape(shape));
Tensor* Y = ctx->Output(0, TensorShape(shape));
const void* source = X->DataRaw();
void* target = Y->MutableDataRaw();
if (target != source) {
@ -56,14 +56,14 @@ class Reshape_1 final : public CannKernel {
ORT_ENFORCE(status.IsOK(), "Attribute shape is not set.");
}
Status ComputeInternal(OpKernelContext* context) const override {
Status ComputeInternal(OpKernelContext* ctx) const override {
TensorShapeVector shape = shape_;
const Tensor* X = context->Input<Tensor>(0);
const Tensor* X = ctx->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
ReshapeHelper helper(X_shape, shape);
Tensor* Y = context->Output(0, TensorShape(shape));
Tensor* Y = ctx->Output(0, TensorShape(shape));
const void* source = X->DataRaw();
void* target = Y->MutableDataRaw();
if (target != source) {

View file

@ -55,7 +55,7 @@ Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
ACL_ENGINE_SYS,
ACL_COMPILE_SYS,
NULL,
Stream()));
Stream(ctx)));
return Status::OK();
}

View file

@ -690,6 +690,10 @@ void RefCountTracker::DumpDetails(const std::string& phase_name) const {
#if defined(USE_CANN)
RandomGenerator& RandomGenerator::Default() { return g_host->RandomGenerator__Default(); }
void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream,
WaitNotificationFn wait_fn) {
return g_host->Allocator__AllocateBufferWithOptions(allocator, size, use_reserve, stream, wait_fn);
}
namespace cann {
std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger) {

View file

@ -1847,8 +1847,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCANNProviderOptions, _Outptr_ OrtCANNProvider
(*out)->device_id = 0;
(*out)->npu_mem_limit = SIZE_MAX;
(*out)->arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(0);
(*out)->do_copy_in_default_stream = 1;
(*out)->enable_cann_graph = 1;
(*out)->dump_graphs = 0;
(*out)->default_memory_arena_cfg = nullptr;
return nullptr;
#else