diff --git a/include/onnxruntime/core/providers/cann/cann_provider_options.h b/include/onnxruntime/core/providers/cann/cann_provider_options.h index f36d4c0161..ac60fbe4a2 100644 --- a/include/onnxruntime/core/providers/cann/cann_provider_options.h +++ b/include/onnxruntime/core/providers/cann/cann_provider_options.h @@ -4,6 +4,8 @@ #pragma once +#include + #include "onnxruntime_c_api.h" #include "core/framework/arena_extend_strategy.h" @@ -11,9 +13,13 @@ struct OrtCANNProviderOptions { int device_id; // CANN device id size_t npu_mem_limit; // BFC Arena memory limit for CANN onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena - int do_copy_in_default_stream; // Flag indicating if copying needs to take place on the - // same stream as the compute stream in the CANN EP int enable_cann_graph; // Flag indicating if prioritizing the use of // CANN's graph-running capabilities + int dump_graphs; // Flag indicating if dumping graphs + std::string precision_mode; // Operator Precision Mode + std::string op_select_impl_mode; // Operator-level model compilation options: + // Mode selection + std::string optypelist_for_implmode; // Operator-level model compilation options: + // Operator list OrtArenaCfg* default_memory_arena_cfg; // CANN memory arena configuration parameters }; diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc index d47f861c29..db6a6cf731 100644 --- a/onnxruntime/core/framework/stream_execution_context.cc +++ b/onnxruntime/core/framework/stream_execution_context.cc @@ -181,6 +181,19 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess return; } +#ifdef USE_CANN + // For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool, + // which is different from CUDA Runtime API, but similar to CUDA Driver API. + auto& execution_providers = ctx.GetSessionState().GetExecutionProviders(); + for (auto& xp : execution_providers) { + auto status = xp->OnRunStart(); + if (!status.IsOK()) { + ctx.SetStatus(status); + return; + } + } +#endif + // get logic stream auto& execution_plan = ctx.GetSessionState().GetExecutionPlan()->execution_plan; auto& logic_stream = execution_plan[stream_idx]; diff --git a/onnxruntime/core/providers/cann/activation/activations.cc b/onnxruntime/core/providers/cann/activation/activations.cc index 327efa7c6e..1c8363e8b5 100644 --- a/onnxruntime/core/providers/cann/activation/activations.cc +++ b/onnxruntime/core/providers/cann/activation/activations.cc @@ -32,9 +32,9 @@ Status Activations::Prepare(OpKernelContext* ctx, CannPreparation& prepare) cons #define REGISTER_ACTIVATION_TYPED_COMPUTE(x, T) \ template <> \ - Status x::ComputeInternal(OpKernelContext* context) const { \ + Status x::ComputeInternal(OpKernelContext* ctx) const { \ CannPreparation prepare; \ - ORT_RETURN_IF_ERROR(Prepare(context, prepare)); \ + ORT_RETURN_IF_ERROR(Prepare(ctx, prepare)); \ CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \ prepare.inputDesc_.size(), \ prepare.inputDesc_.data(), \ @@ -46,7 +46,7 @@ Status Activations::Prepare(OpKernelContext* ctx, CannPreparation& prepare) cons ACL_ENGINE_SYS, \ ACL_COMPILE_SYS, \ NULL, \ - Stream())); \ + Stream(ctx))); \ return Status::OK(); \ } diff --git a/onnxruntime/core/providers/cann/activation/activations.h b/onnxruntime/core/providers/cann/activation/activations.h index 425a66c1ca..5e20980c89 100644 --- a/onnxruntime/core/providers/cann/activation/activations.h +++ b/onnxruntime/core/providers/cann/activation/activations.h @@ -23,7 +23,7 @@ template class Relu final : public Activations { public: Relu(const OpKernelInfo& info) : Activations(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; } // namespace cann diff --git a/onnxruntime/core/providers/cann/cann_allocator.cc b/onnxruntime/core/providers/cann/cann_allocator.cc index 479207082d..7f9f9e2ffd 100644 --- a/onnxruntime/core/providers/cann/cann_allocator.cc +++ b/onnxruntime/core/providers/cann/cann_allocator.cc @@ -7,18 +7,10 @@ #include "core/providers/cann/cann_call.h" #include "core/providers/cann/cann_allocator.h" #include "core/framework/allocatormgr.h" -#include "core/providers/cann/cann_fence.h" #include "core/providers/cann/npu_data_transfer.h" namespace onnxruntime { -static const NPUDataTransfer* GetNPUDataTransfer(const SessionState* session_state) { - OrtDevice npu_device(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, 0); - OrtDevice cpu_device; - return static_cast( - session_state->GetDataTransferMgr().GetDataTransfer(npu_device, cpu_device)); -} - void* CANNAllocator::Alloc(size_t size) { void* p = nullptr; aclrtMemMallocPolicy policy = ACL_MEM_MALLOC_HUGE_FIRST; @@ -32,10 +24,6 @@ void CANNAllocator::Free(void* p) { aclrtFree(p); } -FencePtr CANNAllocator::CreateFence(const SessionState* session_state) { - return std::make_shared(GetNPUDataTransfer(session_state)); -} - void* CANNPinnedAllocator::Alloc(size_t size) { void* p = nullptr; if (size > 0) { @@ -48,8 +36,4 @@ void CANNPinnedAllocator::Free(void* p) { CANN_CALL_THROW(aclrtFreeHost(p)); } -FencePtr CANNPinnedAllocator::CreateFence(const SessionState* session_state) { - return std::make_shared(GetNPUDataTransfer(session_state)); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_allocator.h b/onnxruntime/core/providers/cann/cann_allocator.h index 0d607fd76a..15fa7b1779 100644 --- a/onnxruntime/core/providers/cann/cann_allocator.h +++ b/onnxruntime/core/providers/cann/cann_allocator.h @@ -19,7 +19,6 @@ class CANNAllocator : public IAllocator { device_id, OrtMemTypeDefault)) {} void* Alloc(size_t size) override; void Free(void* p) override; - FencePtr CreateFence(const SessionState* session_state) override; }; class CANNPinnedAllocator : public IAllocator { @@ -32,7 +31,6 @@ class CANNPinnedAllocator : public IAllocator { void* Alloc(size_t size) override; void Free(void* p) override; - FencePtr CreateFence(const SessionState* session_state) override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_call.cc b/onnxruntime/core/providers/cann/cann_call.cc index dee7cf2ed7..396d97792d 100644 --- a/onnxruntime/core/providers/cann/cann_call.cc +++ b/onnxruntime/core/providers/cann/cann_call.cc @@ -101,6 +101,8 @@ template <> const char* CannErrString(ge::graphStatus e) { using namespace ge; + aclrtSynchronizeDevice(); + switch (e) { CASE_ENUM_TO_STR(GRAPH_FAILED); CASE_ENUM_TO_STR(GRAPH_SUCCESS); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index d4428a4c09..25f5bfadb3 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -16,14 +16,14 @@ #include "core/providers/cann/cann_inc.h" #include "core/providers/cann/cann_call.h" #include "core/providers/cann/cann_allocator.h" -#include "core/providers/cann/cann_fence.h" #include "core/providers/cann/cann_fwd.h" +#include "core/providers/cann/cann_stream_handle.h" #include "core/providers/cann/npu_data_transfer.h" using onnxruntime::cann::BuildONNXModel; +using onnxruntime::cann::CannModelPreparation; using onnxruntime::cann::ParserONNXModel; using onnxruntime::cann::SupportONNXModel; -using onnxruntime::cann::CannModelPreparation; using onnxruntime::common::Status; namespace onnxruntime { @@ -42,49 +42,63 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - return Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId()); - } else if (X_type->IsSparseTensorType()) { - const auto* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); - SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape()); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor."); - return X->Copy(Info().GetDataTransferManager(), Info().GetKernelDef().ExecQueueId(), *Y); - } else if (X_type->IsTensorSequenceType()) { - const TensorSeq* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); - TensorSeq* Y = ctx->Output(0); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); - auto X_dtype = X->DataType(); - Y->SetType(X_dtype); - AllocatorPtr alloc; - - if (Node().OpType() == "MemcpyFromHost") { - auto status = ctx->GetTempSpaceAllocator(&alloc); - if (!status.IsOK()) { - return Status(common::ONNXRUNTIME, common::FAIL, - "Memcpy cann: unable to get an allocator."); - } - } else { - auto status = ctx->GetTempSpaceCPUAllocator(&alloc); - if (!status.IsOK()) { - return Status(common::ONNXRUNTIME, common::FAIL, - "Memcpy cann: unable to get the CPU allocator."); - } - } - auto X_size = X->Size(); - for (size_t i = 0; i < X_size; ++i) { - const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); - Status retval = Info().GetDataTransferManager().CopyTensor(source_tensor, *target_tensor, - Info().GetKernelDef().ExecQueueId()); - if (!retval.IsOK()) { - return retval; - } - Y->Add(std::move(*target_tensor)); - } + auto* npu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, + Y->Location().device); + ORT_RETURN_IF_ERROR(npu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); return Status::OK(); + } else { + if (X_type->IsSparseTensorType()) { + aclrtSynchronizeStream(static_cast(ctx->GetComputeStream()->GetHandle())); + const auto* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); + SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape()); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor."); + return X->Copy(Info().GetDataTransferManager(), *Y); + } else if (X_type->IsTensorSequenceType()) { + const TensorSeq* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); + TensorSeq* Y = ctx->Output(0); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); + auto X_dtype = X->DataType(); + Y->SetType(X_dtype); + AllocatorPtr alloc; + + // If we are copying contents to CANN, the allocator to use + // to allocate the buffers of the new tensors in the sequence + // can be temp space allocator associated with the CANN EP + if (Node().OpType() == "MemcpyFromHost") { + auto status = ctx->GetTempSpaceAllocator(&alloc); + if (!status.IsOK()) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Memcpy cann: unable to get an allocator."); + } + } else { + // If we are copying contents to CPU (op type is "MemcpyToHost"), + // the allocator to use to allocate the buffers of the new tensors + // in the sequence will be the allocator from the CPU EP + auto status = ctx->GetTempSpaceCPUAllocator(&alloc); + if (!status.IsOK()) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Memcpy cann: unable to get the CPU allocator."); + } + } + auto X_size = X->Size(); + for (size_t i = 0; i < X_size; ++i) { + const Tensor& source_tensor = X->Get(i); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), + source_tensor.Shape(), + alloc); + auto* npu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, + target_tensor->Location().device); + ORT_RETURN_IF_ERROR(npu_data_transfer->CopyTensorAsync(source_tensor, + *target_tensor, + *ctx->GetComputeStream())); + Y->Add(std::move(*target_tensor)); + } + return Status::OK(); + } + return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); } - return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); } }; @@ -97,7 +111,6 @@ ONNX_OPERATOR_KERNEL_EX( kCannExecutionProvider, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) - .ExecQueueId(kCannStreamCopyIn) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()), Memcpy); @@ -108,7 +121,6 @@ ONNX_OPERATOR_KERNEL_EX( kCannExecutionProvider, (*KernelDefBuilder::Create()) .OutputMemoryType(OrtMemTypeCPUOutput, 0) - .ExecQueueId(kCannStreamCopyOut) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()), Memcpy); @@ -1013,27 +1025,20 @@ CANNExecutionProvider::CANNExecutionProvider(const CANNExecutionProviderInfo& in InitProviderOrtApi(); CANN_CALL_THROW(aclrtSetDevice(info_.device_id)); - CANN_CALL_THROW(aclrtCreateStream(&stream_)); soc_name_ = aclrtGetSocName(); ORT_ENFORCE(soc_name_ != nullptr, "aclrtGetSocName return nullptr"); } CANNExecutionProvider::~CANNExecutionProvider() { - CANN_CALL_THROW(aclrtDestroyStream(stream_)); + for (auto modelID : modelIDs_) { + CANN_CALL_THROW(aclmdlUnload(modelID.second)); + } } // All threads share the same context and stream Status CANNExecutionProvider::OnRunStart() { - CANN_CALL_THROW(aclrtSetDevice(info_.device_id)); - - return Status::OK(); -} - -Status CANNExecutionProvider::OnRunEnd(bool sync_stream) { - if (sync_stream) { - CANN_CALL_THROW(aclrtSynchronizeStream(stream_)); - } + CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id)); return Status::OK(); } @@ -1050,6 +1055,8 @@ void InitializeRegistry() { void DeleteRegistry() { s_kernel_registry.reset(); + ge::aclgrphBuildFinalize(); + CANN_CALL_THROW(aclFinalize()); } @@ -1058,8 +1065,7 @@ std::shared_ptr CANNExecutionProvider::GetKernelRegistry() const } std::unique_ptr CANNExecutionProvider::GetDataTransfer() const { - return std::make_unique(static_cast(GetComputeStream()), - info_.do_copy_in_default_stream); + return std::make_unique(); } std::unique_ptr CANNExecutionProvider::GetSubGraph( @@ -1072,9 +1078,9 @@ std::unique_ptr CANNExecutionProvider::GetSubGraph( } // Get parent graph output names - std::vector graph_output_names; + std::unordered_set graph_output_names; for (const auto* output_arg : graph_viewer.GetOutputs()) { - graph_output_names.push_back(output_arg->Name()); + graph_output_names.insert(output_arg->Name()); } // Find inputs and outputs of the subgraph @@ -1084,26 +1090,35 @@ std::unique_ptr CANNExecutionProvider::GetSubGraph( int input_order = 0; int output_order = 0; + std::vector initializers; for (const auto& index : graph_nodes_index) { sub_graph->Nodes().push_back(index); const auto& node = graph_viewer.GetNode(index); for (const auto& input : node->InputDefs()) { + if (graph_viewer.IsConstantInitializer(input->Name(), true)) { + initializers.push_back(input->Name()); + continue; + } const auto& it = fused_outputs.find(input); if (it != fused_outputs.end()) { fused_outputs.erase(it); erased.insert(input); - } else if (erased.find(input) == erased.end() && !graph_viewer.GetAllInitializedTensors().count(input->Name())) { + } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list fused_inputs[input] = input_order++; } } for (const auto& input : node->ImplicitInputDefs()) { + if (graph_viewer.IsConstantInitializer(input->Name(), true)) { + initializers.push_back(input->Name()); + continue; + } const auto& it = fused_outputs.find(input); if (it != fused_outputs.end()) { fused_outputs.erase(it); erased.insert(input); - } else if (erased.find(input) == erased.end() && !graph_viewer.GetAllInitializedTensors().count(input->Name())) { + } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list fused_inputs[input] = input_order++; } @@ -1118,15 +1133,20 @@ std::unique_ptr CANNExecutionProvider::GetSubGraph( if (node->GetOutputEdgesCount() > node->OutputDefs().size()) { for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { const auto& node_idx = it->GetNode().Index(); - const auto& output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()]; + const onnxruntime::NodeArg* output; + if (it->GetDstArgIndex() < static_cast(it->GetNode().InputDefs().size())) { + output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()]; + } else { + auto index = it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size()); + output = (it->GetNode()).ImplicitInputDefs()[index]; + } if (node_set.find(node_idx) != node_set.end()) { const auto& iter = fused_inputs.find(output); if (iter != fused_inputs.end()) { fused_inputs.erase(iter); erased.insert(output); } else if (erased.find(output) == erased.end()) { - auto it = std::find(graph_output_names.begin(), graph_output_names.end(), output->Name()); - if (it != graph_output_names.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } fused_outputs[output] = output_order++; @@ -1144,8 +1164,7 @@ std::unique_ptr CANNExecutionProvider::GetSubGraph( } else { // Only when output is neither in input list nor erased list, add the output to output list if (erased.find(output) == erased.end()) { - auto it = std::find(graph_output_names.begin(), graph_output_names.end(), output->Name()); - if (it != graph_output_names.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } fused_outputs[output] = output_order++; @@ -1168,27 +1187,6 @@ std::unique_ptr CANNExecutionProvider::GetSubGraph( outputs.insert(std::pair(it->second, it->first)); } - // It is possible that an output of an node is put bebind the output of an later - // node in the graph output list. So we should sort the output name according - // to the graph output names - std::vector output_names; - std::unordered_set graph_out_names; - for (const auto& output : outputs) { - if (output.second->Exists()) { - auto name = output.second->Name(); - if (std::find(graph_output_names.begin(), graph_output_names.end(), name) == graph_output_names.end()) { - output_names.push_back(name); - } else { - graph_out_names.insert(name); - } - } - } - - for (auto& name : graph_output_names) { - if (std::find(graph_out_names.begin(), graph_out_names.end(), name) != graph_out_names.end()) - output_names.push_back(name); - } - // Generate unique kernel name for CANN subgraph HashValue model_hash = 0; int id = GenerateMetaDefId(graph_viewer, model_hash); @@ -1202,8 +1200,14 @@ std::unique_ptr CANNExecutionProvider::GetSubGraph( } } - for (const auto& output : output_names) { - meta_def->outputs().push_back(output); + for (const auto& initializer : initializers) { + meta_def->constant_initializers().push_back(initializer); + } + + for (const auto& output : outputs) { + if (output.second->Exists()) { + meta_def->outputs().push_back(output.second->Name()); + } } meta_def->domain() = kMSDomain; @@ -1306,13 +1310,13 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse const std::string node_name = fused_node.Name(); - std::unordered_map names2index; + std::unordered_map index2name; const auto& input_defs = fused_node.InputDefs(); - names2index.reserve(input_defs.size()); + index2name.reserve(input_defs.size()); for (size_t i = 0, end = input_defs.size(); i < end; ++i) { - names2index[i] = input_defs[i]->Name(); + index2name[i] = input_defs[i]->Name(); } - names_[node_name] = names2index; + names_[node_name] = index2name; std::string string_model; auto model = cann::CreateModel(graph_body_viewer, *GetLogger()); @@ -1322,6 +1326,11 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse model_proto->SerializeToString(string_model); models_[node_name] = string_model; + if (info_.dump_graphs) { + std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); + model_proto->SerializeToOstream(dump); + } + NodeComputeInfo compute_info; compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); @@ -1335,20 +1344,19 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse delete static_cast(state); }; - compute_info.compute_func = [this](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { + compute_info.compute_func = [this](FunctionState state, const OrtApi*, OrtKernelContext* context) { Ort::KernelContext ctx(context); CannFuncState* cann_state = reinterpret_cast(state); std::string& string_model = models_[cann_state->node_name]; - std::unordered_map& names2index = names_[cann_state->node_name]; + std::unordered_map& index2name = names_[cann_state->node_name]; - std::string input_shape = [&ctx, &names2index]() -> std::string { + std::string input_shape = [&ctx, &index2name]() -> std::string { std::string res; for (size_t i = 0; i < ctx.GetInputCount(); i++) { auto&& shape = ctx.GetInput(i).GetTensorTypeAndShapeInfo().GetShape(); - auto name = names2index[i]; - std::string s = name + ":"; + std::string s = index2name[i] + ":"; for (auto& d : shape) { s += std::to_string(d) + ","; } @@ -1366,8 +1374,13 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse std::string filename = cann_state->node_name + "_" + std::to_string(hash); std::string filename_with_suffix = filename + ".om"; + // TODO(FFFrog): Resource Management + // It is very necessary to provide a new mechanism for memory reclamation to avoid inference failure caused by + // device memory exhaustion uint32_t modelID; - { + if (modelIDs_.find(filename) != modelIDs_.end()) { + modelID = modelIDs_[filename]; + } else { std::lock_guard lock(g_mutex); if (cann::FileExist(filename_with_suffix)) { @@ -1377,10 +1390,12 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse ORT_RETURN_IF_ERROR(ParserONNXModel(string_model, graph)); ge::ModelBufferData model; - ORT_RETURN_IF_ERROR(BuildONNXModel(graph, input_shape, soc_name_, filename, model)); + ORT_RETURN_IF_ERROR(BuildONNXModel(graph, input_shape, soc_name_, filename, info_, model)); CANN_RETURN_IF_ERROR(aclmdlLoadFromMem(model.data.get(), model.length, &modelID)); } + + modelIDs_.emplace(filename, modelID); } CannModelPreparation prepare(modelID); @@ -1407,7 +1422,8 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); } - CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream_)); + aclrtStream stream = static_cast(ctx.GetGPUComputeStream()); + CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream)); return Status::OK(); }; @@ -1436,7 +1452,12 @@ void CANNExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage true, {info_.default_memory_arena_cfg ? *info_.default_memory_arena_cfg : OrtArenaCfg(info_.npu_mem_limit, - static_cast(info_.arena_extend_strategy), -1, -1, -1)}); + static_cast(info_.arena_extend_strategy), + -1, + -1, + -1)}, + true, + false); cann_alloc = CreateAllocator(default_memory_info); allocator_manager.InsertAllocator(cann_alloc); @@ -1484,4 +1505,8 @@ void CANNExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manage } } +void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const { + RegisterCannStreamHandles(stream_handle_registry, OrtDevice::NPU); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 2fe4024487..9391842541 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -36,36 +36,33 @@ class CANNExecutionProvider : public IExecutionProvider { Status OnRunStart() override; - Status OnRunEnd(bool sync_stream) override; - - void* GetComputeStream() const override { return static_cast(stream_); } - template - IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes) const { + IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, Stream* stream, WaitNotificationFn wait_fn) const { if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(GetAllocator(OrtMemTypeDefault), count_or_bytes); + return IAllocator::MakeUniquePtr(GetAllocator(OrtMemTypeDefault), count_or_bytes, false, stream, wait_fn); } template IAllocatorUniquePtr GetScratchBufferOnCANNPinned(size_t count_or_bytes) const { if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(GetAllocator(OrtMemTypeCPU), - count_or_bytes); + + return IAllocator::MakeUniquePtr(GetAllocator(OrtMemTypeCPU), count_or_bytes); } template - Status Fill(Tensor* y, void* addr) const { - return cann::Fill(y, addr, stream_); + Status Fill(Tensor* y, void* addr, aclrtStream stream) const { + return cann::Fill(y, addr, stream); } template - Status Broadcast(const Tensor* x, Tensor* y, void* addr) const { - return cann::Broadcast(x, y, addr, stream_); + Status Broadcast(const Tensor* x, Tensor* y, void* addr, aclrtStream stream) const { + return cann::Broadcast(x, y, addr, stream); } + int GetDeviceId() const override { return info_.device_id; } std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; @@ -86,11 +83,13 @@ class CANNExecutionProvider : public IExecutionProvider { void RegisterAllocator(AllocatorManager& allocator_manager) override; + void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const override; + private: CANNExecutionProviderInfo info_; - aclrtStream stream_ = nullptr; const char* soc_name_ = nullptr; + std::unordered_map modelIDs_; std::unordered_map models_; std::unordered_map> names_; }; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc index e1fb6c5018..5f1a6d8f1b 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc @@ -19,8 +19,11 @@ namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kMemLimit = "npu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; -constexpr const char* kDoCopyInDefaultStream = "do_copy_in_default_stream"; constexpr const char* kEnableCannGraph = "enable_cann_graph"; +constexpr const char* kDumpGraphs = "dump_graphs"; +constexpr const char* kPrecisionMode = "precision_mode"; +constexpr const char* kOpSelectImplMode = "op_select_impl_mode"; +constexpr const char* kOpTypeListForImplMode = "optypelist_for_implmode"; } // namespace provider_option_names } // namespace cann @@ -53,8 +56,11 @@ CANNExecutionProviderInfo CANNExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToEnumReference( cann::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) - .AddAssignmentToReference(cann::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) .AddAssignmentToReference(cann::provider_option_names::kEnableCannGraph, info.enable_cann_graph) + .AddAssignmentToReference(cann::provider_option_names::kDumpGraphs, info.dump_graphs) + .AddAssignmentToReference(cann::provider_option_names::kPrecisionMode, info.precision_mode) + .AddAssignmentToReference(cann::provider_option_names::kOpSelectImplMode, info.op_select_impl_mode) + .AddAssignmentToReference(cann::provider_option_names::kOpTypeListForImplMode, info.optypelist_for_implmode) .Parse(options)); return info; } @@ -65,9 +71,11 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const CANNExecution {cann::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.npu_mem_limit)}, {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, - {cann::provider_option_names::kDoCopyInDefaultStream, - MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, - {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}}; + {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, + {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, + {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, + {cann::provider_option_names::kOpSelectImplMode, MakeStringWithClassicLocale(info.op_select_impl_mode)}, + {cann::provider_option_names::kOpTypeListForImplMode, MakeStringWithClassicLocale(info.optypelist_for_implmode)}}; return options; } @@ -77,9 +85,11 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const OrtCANNProvid {cann::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.npu_mem_limit)}, {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, ArenaExtendStrategy(info.arena_extend_strategy))}, - {cann::provider_option_names::kDoCopyInDefaultStream, - MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, - {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}}; + {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, + {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, + {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, + {cann::provider_option_names::kOpSelectImplMode, MakeStringWithClassicLocale(info.op_select_impl_mode)}, + {cann::provider_option_names::kOpTypeListForImplMode, MakeStringWithClassicLocale(info.optypelist_for_implmode)}}; return options; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.h b/onnxruntime/core/providers/cann/cann_execution_provider_info.h index 0091e80768..b5c022c9e9 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/ortdevice.h" @@ -16,8 +17,11 @@ struct CANNExecutionProviderInfo { OrtDevice::DeviceId device_id{0}; size_t npu_mem_limit{std::numeric_limits::max()}; ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; - bool do_copy_in_default_stream{true}; bool enable_cann_graph{true}; + bool dump_graphs{false}; + std::string precision_mode; + std::string op_select_impl_mode; + std::string optypelist_for_implmode; OrtArenaCfg* default_memory_arena_cfg{nullptr}; static CANNExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/cann/cann_fence.cc b/onnxruntime/core/providers/cann/cann_fence.cc deleted file mode 100644 index 9a1a0cb710..0000000000 --- a/onnxruntime/core/providers/cann/cann_fence.cc +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) Huawei. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/cann/cann_call.h" -#include "core/providers/cann/npu_data_transfer.h" -#include "core/providers/cann/cann_fence.h" -#include - -namespace onnxruntime { - -CANNFence::CANNFence(const NPUDataTransfer* data_transfer) : data_transfer_(data_transfer) { - CANN_CALL_THROW(aclrtCreateEvent(&read_event_)); - CANN_CALL_THROW(aclrtCreateEvent(&write_event_)); -} - -CANNFence::~CANNFence() { - CANN_CALL_THROW(aclrtDestroyEvent(read_event_)); - CANN_CALL_THROW(aclrtDestroyEvent(write_event_)); -} - -void CANNFence::BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int async_queue_id) { - if (provider_type == onnxruntime::kCannExecutionProvider) { - CANN_CALL_THROW(aclrtStreamWaitEvent(data_transfer_->GetStream(async_queue_id), write_event_)); - } else { - CANN_CALL_THROW(aclrtSynchronizeEvent(write_event_)); - } -} - -void CANNFence::BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) { - if (provider_type == onnxruntime::kCannExecutionProvider) { - aclrtStream stream = data_transfer_->GetStream(queue_id); - CANN_CALL_THROW(aclrtStreamWaitEvent(stream, read_event_)); - CANN_CALL_THROW(aclrtStreamWaitEvent(stream, write_event_)); - } else { - CANN_CALL_THROW(aclrtSynchronizeEvent(read_event_)); - CANN_CALL_THROW(aclrtSynchronizeEvent(write_event_)); - } -} - -bool CANNFence::CanRelease() { - aclrtEventRecordedStatus read_status; - aclrtEventRecordedStatus write_status; - - return aclrtQueryEventStatus(read_event_, &read_status) == ACL_SUCCESS && - aclrtQueryEventStatus(write_event_, &write_status) == ACL_SUCCESS && - read_status == ACL_EVENT_RECORDED_STATUS_COMPLETE && - write_status == ACL_EVENT_RECORDED_STATUS_COMPLETE; -} - -void CANNFence::AfterUsedAsInput(int queue_id) { - aclrtStream stream = data_transfer_->GetStream(queue_id); - CANN_CALL_THROW(aclrtRecordEvent(read_event_, stream)); -} - -void CANNFence::AfterUsedAsOutput(int queue_id) { - aclrtStream stream = data_transfer_->GetStream(queue_id); - CANN_CALL_THROW(aclrtRecordEvent(write_event_, stream)); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_fence.h b/onnxruntime/core/providers/cann/cann_fence.h deleted file mode 100644 index ad49041be7..0000000000 --- a/onnxruntime/core/providers/cann/cann_fence.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) Huawei. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/framework/fence.h" -#include "core/providers/cann/cann_inc.h" - -namespace onnxruntime { -class NPUDataTransfer; - -class CANNFence : public IFence { - public: - explicit CANNFence(const NPUDataTransfer* data_transfer); - virtual ~CANNFence(); - void BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int queue_id) override; - void BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) override; - void AfterUsedAsInput(int queue_id) override; - void AfterUsedAsOutput(int queue_id) override; - bool CanRelease() override; - - private: - aclrtEvent read_event_; - aclrtEvent write_event_; - const NPUDataTransfer* data_transfer_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_graph.cc b/onnxruntime/core/providers/cann/cann_graph.cc index c74fdf4f20..c680e36380 100644 --- a/onnxruntime/core/providers/cann/cann_graph.cc +++ b/onnxruntime/core/providers/cann/cann_graph.cc @@ -13,6 +13,8 @@ namespace cann { static int lower_bound = 8; // Supported domain version lower bounds static int upper_bound = 15; // Supported domain version upper bounds +std::once_flag flag; + /** * This function will been changed with the evolution of ONNX and CANN * and will be replaced by the corresponding API provided by CANN in the future, probably. @@ -32,7 +34,7 @@ std::vector SupportONNXModel(const GraphViewer& graph_viewer) { "GatherElements", "GatherND", "Gemm", "GlobalAveragePool", "GlobalLpPool", "GlobalMaxPool", "Greater", "GreaterOrEqual", "Hardmax", "HardSigmoid", "HardSwish", "Identity", - "If", "InstanceNormalization", "LeakyRelu", "Less", + "InstanceNormalization", "LeakyRelu", "Less", "LessOrEqual", "Log", "LogSoftmax", "LpNormalization", "LpPool", "LRN", "LSTM", "MatMul", "Max", "MaxPool", "MaxRoiPool", "MaxUnpool", @@ -94,20 +96,27 @@ Status ParserONNXModel(std::string string_model, ge::Graph& graph) { } Status BuildONNXModel(ge::Graph& graph, std::string input_shape, const char* soc_name, std::string file_name, - ge::ModelBufferData& model) { + CANNExecutionProviderInfo& info, ge::ModelBufferData& model) { + std::call_once(flag, [&soc_name, &info]() { + std::map options; + options.emplace(ge::ir_option::SOC_VERSION, soc_name); + + if (!info.precision_mode.empty()) + options.emplace(ge::ir_option::PRECISION_MODE, info.precision_mode.c_str()); + if (!info.op_select_impl_mode.empty()) + options.emplace(ge::ir_option::OP_SELECT_IMPL_MODE, info.op_select_impl_mode.c_str()); + if (!info.optypelist_for_implmode.empty()) + options.emplace(ge::ir_option::OPTYPELIST_FOR_IMPLMODE, info.optypelist_for_implmode.c_str()); + + CANN_CALL_THROW(ge::aclgrphBuildInitialize(options)); + }); + std::map options; - - options.emplace(ge::ir_option::SOC_VERSION, soc_name); - CANN_GRAPH_RETURN_IF_ERROR(ge::aclgrphBuildInitialize(options)); - - options.clear(); options.emplace(ge::ir_option::INPUT_SHAPE, input_shape.c_str()); CANN_GRAPH_RETURN_IF_ERROR(ge::aclgrphBuildModel(graph, options, model)); CANN_GRAPH_RETURN_IF_ERROR(ge::aclgrphSaveModel(file_name.c_str(), model)); - ge::aclgrphBuildFinalize(); - return Status::OK(); } diff --git a/onnxruntime/core/providers/cann/cann_graph.h b/onnxruntime/core/providers/cann/cann_graph.h index 0946db438c..95686df8d6 100644 --- a/onnxruntime/core/providers/cann/cann_graph.h +++ b/onnxruntime/core/providers/cann/cann_graph.h @@ -12,6 +12,7 @@ #include "core/providers/cann/cann_common.h" #include "core/providers/cann/cann_inc.h" #include "core/providers/cann/cann_utils.h" +#include "core/providers/cann/cann_execution_provider_info.h" namespace onnxruntime { namespace cann { @@ -72,7 +73,7 @@ struct CannModelPreparation { std::vector SupportONNXModel(const GraphViewer& graph_viewer); Status ParserONNXModel(std::string string_model, ge::Graph& graph); Status BuildONNXModel(ge::Graph& graph, std::string input_shape, const char* soc_name, std::string file_name, - ge::ModelBufferData& model); + CANNExecutionProviderInfo& info, ge::ModelBufferData& model); } // namespace cann } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_kernel.h b/onnxruntime/core/providers/cann/cann_kernel.h index e0ccc47426..cd2998795d 100644 --- a/onnxruntime/core/providers/cann/cann_kernel.h +++ b/onnxruntime/core/providers/cann/cann_kernel.h @@ -4,11 +4,13 @@ #pragma once +#include "core/platform/ort_mutex.h" #include "core/providers/cann/cann_inc.h" #include "core/providers/cann/cann_call.h" #include "core/providers/cann/cann_execution_provider.h" #include "core/providers/cann/cann_fwd.h" #include "core/providers/cann/cann_utils.h" +#include "core/providers/cann/cann_stream_handle.h" namespace onnxruntime { namespace cann { @@ -35,11 +37,14 @@ class CannKernel : public OpKernel { virtual Status ComputeInternal(OpKernelContext* p_op_kernel_context) const = 0; - inline aclrtStream Stream() const { return static_cast(provider_->GetComputeStream()); } + inline aclrtStream Stream(OpKernelContext* ctx) const { + auto* stream = ctx->GetComputeStream(); + return stream ? static_cast(stream->GetHandle()) : nullptr; + } template - inline IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes) const { - return provider_->GetScratchBuffer(count_or_bytes); + inline IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, onnxruntime::Stream* stream) const { + return provider_->GetScratchBuffer(count_or_bytes, stream, WaitCannNotificationOnDevice); } template @@ -48,13 +53,13 @@ class CannKernel : public OpKernel { } template - inline Status Fill(Tensor* y, void* addr) const { - return provider_->Fill(y, addr); + inline Status Fill(Tensor* y, void* addr, aclrtStream stream) const { + return provider_->Fill(y, addr, stream); } template - inline Status Broadcast(const Tensor* x, Tensor* y, void* addr) const { - return provider_->Broadcast(x, y, addr); + inline Status Broadcast(const Tensor* x, Tensor* y, void* addr, aclrtStream stream) const { + return provider_->Broadcast(x, y, addr, stream); } protected: diff --git a/onnxruntime/core/providers/cann/cann_provider_factory.cc b/onnxruntime/core/providers/cann/cann_provider_factory.cc index d3ddc6c574..636ca22591 100644 --- a/onnxruntime/core/providers/cann/cann_provider_factory.cc +++ b/onnxruntime/core/providers/cann/cann_provider_factory.cc @@ -53,8 +53,11 @@ struct CANN_Provider : Provider { info.device_id = static_cast(params->device_id); info.npu_mem_limit = params->npu_mem_limit; info.arena_extend_strategy = params->arena_extend_strategy; - info.do_copy_in_default_stream = params->do_copy_in_default_stream != 0; info.enable_cann_graph = params->enable_cann_graph != 0; + info.dump_graphs = params->dump_graphs != 0; + info.precision_mode = params->precision_mode; + info.op_select_impl_mode = params->op_select_impl_mode; + info.optypelist_for_implmode = params->optypelist_for_implmode; info.default_memory_arena_cfg = params->default_memory_arena_cfg; return std::make_shared(info); @@ -67,8 +70,11 @@ struct CANN_Provider : Provider { cann_options.device_id = internal_options.device_id; cann_options.npu_mem_limit = internal_options.npu_mem_limit; cann_options.arena_extend_strategy = internal_options.arena_extend_strategy; - cann_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream; cann_options.enable_cann_graph = internal_options.enable_cann_graph; + cann_options.dump_graphs = internal_options.dump_graphs; + cann_options.precision_mode = internal_options.precision_mode; + cann_options.op_select_impl_mode = internal_options.op_select_impl_mode; + cann_options.optypelist_for_implmode = internal_options.optypelist_for_implmode; cann_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; } diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.cc b/onnxruntime/core/providers/cann/cann_stream_handle.cc new file mode 100644 index 0000000000..bcb5a62cf6 --- /dev/null +++ b/onnxruntime/core/providers/cann/cann_stream_handle.cc @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Huawei. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cann/cann_stream_handle.h" +#include "core/common/spin_pause.h" + +namespace onnxruntime { + +struct CannNotification : public synchronize::Notification { + explicit CannNotification(Stream& s) : Notification(s) { + CANN_CALL_THROW(aclrtCreateEvent(&event_)); + } + + ~CannNotification() { + if (event_) + CANN_CALL_THROW(aclrtDestroyEvent(event_)); + } + + void Activate() override { + CANN_CALL_THROW(aclrtRecordEvent(event_, static_cast(stream_.GetHandle()))); + } + + void wait_on_device(Stream& device_stream) { + ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::NPU); + CANN_CALL_THROW(aclrtStreamWaitEvent(static_cast(device_stream.GetHandle()), event_)); + } + + void wait_on_host() { + CANN_CALL_THROW(aclrtSynchronizeEvent(event_)); + } + + aclrtEvent event_; +}; + +CannStream::CannStream(aclrtStream stream, + const OrtDevice& device, + bool own_flag) : Stream(stream, device), + own_stream_(own_flag) {} + +CannStream::~CannStream() { + ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); + if (own_stream_) { + auto* handle = GetHandle(); + if (handle) + aclrtDestroyStream(static_cast(handle)); + } +} + +std::unique_ptr CannStream::CreateNotification(size_t /*num_consumers*/) { + return std::make_unique(*this); +} + +void CannStream::Flush() { + if (own_stream_) + CANN_CALL_THROW(aclrtSynchronizeStream(static_cast(GetHandle()))); +} + +// CPU Stream command handles +void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { + static_cast(¬ification)->wait_on_device(stream); +} + +void WaitCannNotificationOnHost(Stream& /*stream*/, synchronize::Notification& notification) { + static_cast(¬ification)->wait_on_host(); +} + +void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, + const OrtDevice::DeviceType device_type) { + // wait cann notification on cann ep + stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitCannNotificationOnDevice); + // wait cann notification on cpu ep + stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitCannNotificationOnHost); + stream_handle_registry.RegisterCreateStreamFn(device_type, [](const OrtDevice& device) { + aclrtStream stream = nullptr; + CANN_CALL_THROW(aclrtCreateStream(&stream)); + return std::make_unique(stream, device, true); + }); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.h b/onnxruntime/core/providers/cann/cann_stream_handle.h new file mode 100644 index 0000000000..4d03fe5201 --- /dev/null +++ b/onnxruntime/core/providers/cann/cann_stream_handle.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Huawei. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/framework/stream_handles.h" +#include "core/providers/cann/cann_inc.h" +#include "core/providers/cann/cann_call.h" + +namespace onnxruntime { + +struct CannStream : Stream { + CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag); + + ~CannStream(); + + std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; + + void Flush() override; + + bool own_stream_{true}; +}; + +void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, + const OrtDevice::DeviceType device_type); + +void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification); +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_utils.cc b/onnxruntime/core/providers/cann/cann_utils.cc index 8448f7c60f..b0e61848ba 100644 --- a/onnxruntime/core/providers/cann/cann_utils.cc +++ b/onnxruntime/core/providers/cann/cann_utils.cc @@ -3,6 +3,7 @@ // Licensed under the MIT License. #include +#include #include "core/providers/cann/cann_utils.h" @@ -223,5 +224,34 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) { hash_value = hash[0] | (uint64_t(hash[1]) << 32); } +Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, + const TensorShape& rhs_shape, TensorShape& out_shape) { + size_t lhs_rank = lhs_shape.NumDimensions(); + size_t rhs_rank = rhs_shape.NumDimensions(); + size_t out_rank = std::max(lhs_rank, rhs_rank); + + std::vector output_dims(out_rank, 0); + for (size_t i = 0; i < out_rank; ++i) { + int64_t lhs_dim = 1; + if (i < lhs_rank) + lhs_dim = lhs_shape[lhs_rank - 1 - i]; + int64_t rhs_dim = 1; + if (i < rhs_rank) + rhs_dim = rhs_shape[rhs_rank - 1 - i]; + int64_t max = std::max(lhs_dim, rhs_dim); + int64_t min = std::min(lhs_dim, rhs_dim); + int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0. + if (lhs_dim != out_dim && lhs_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i, + " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); + if (rhs_dim != out_dim && rhs_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i, + " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); + output_dims[out_rank - 1 - i] = out_dim; + } + out_shape = TensorShape(output_dims); + return Status::OK(); +} + } // namespace cann } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_utils.h b/onnxruntime/core/providers/cann/cann_utils.h index 41e0eeb65e..5eb1873ae3 100644 --- a/onnxruntime/core/providers/cann/cann_utils.h +++ b/onnxruntime/core/providers/cann/cann_utils.h @@ -124,6 +124,9 @@ Status aclrtblasGemmEx(aclTransType transA, bool FileExist(const std::string& file_name); void GenerateHashValue(const std::string string, HashValue& hash_value); +Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, + const TensorShape& rhs_shape, TensorShape& out_shape); + std::unique_ptr CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger); } // namespace cann diff --git a/onnxruntime/core/providers/cann/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cann/math/binary_elementwise_ops.cc index 1e013608b7..d8911a4caa 100644 --- a/onnxruntime/core/providers/cann/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cann/math/binary_elementwise_ops.cc @@ -11,35 +11,6 @@ using onnxruntime::common::Status; namespace onnxruntime { namespace cann { -Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, - const TensorShape& rhs_shape, TensorShape& out_shape) { - size_t lhs_rank = lhs_shape.NumDimensions(); - size_t rhs_rank = rhs_shape.NumDimensions(); - size_t out_rank = std::max(lhs_rank, rhs_rank); - - std::vector output_dims(out_rank, 0); - for (size_t i = 0; i < out_rank; ++i) { - int64_t lhs_dim = 1; - if (i < lhs_rank) - lhs_dim = lhs_shape[lhs_rank - 1 - i]; - int64_t rhs_dim = 1; - if (i < rhs_rank) - rhs_dim = rhs_shape[rhs_rank - 1 - i]; - int64_t max = std::max(lhs_dim, rhs_dim); - int64_t min = std::min(lhs_dim, rhs_dim); - int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0. - if (lhs_dim != out_dim && lhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - if (rhs_dim != out_dim && rhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - output_dims[out_rank - 1 - i] = out_dim; - } - out_shape = TensorShape(output_dims); - return Status::OK(); -} - template Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare) const { const aclDataType aclType = getACLType(); @@ -56,14 +27,14 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare void* B_data = const_cast(B->DataRaw()); if (A->Shape() != C->Shape()) { - IAllocatorUniquePtr pA = GetScratchBuffer(C->SizeInBytes()); - ORT_RETURN_IF_ERROR(Broadcast(A, C, pA.get())); + IAllocatorUniquePtr pA = GetScratchBuffer(C->SizeInBytes(), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(Broadcast(A, C, pA.get(), Stream(ctx))); A_data = pA.get(); } if (B->Shape() != C->Shape()) { - IAllocatorUniquePtr pB = GetScratchBuffer(C->SizeInBytes()); - ORT_RETURN_IF_ERROR(Broadcast(B, C, pB.get())); + IAllocatorUniquePtr pB = GetScratchBuffer(C->SizeInBytes(), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(Broadcast(B, C, pB.get(), Stream(ctx))); B_data = pB.get(); } @@ -85,9 +56,9 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare #define REGISTER_ELEMENTWISE_TYPED_COMPUTE(x, T) \ template <> \ - Status x::ComputeInternal(OpKernelContext* context) const { \ + Status x::ComputeInternal(OpKernelContext* ctx) const { \ CannPreparation prepare; \ - ORT_RETURN_IF_ERROR(Prepare(context, prepare)); \ + ORT_RETURN_IF_ERROR(Prepare(ctx, prepare)); \ CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \ prepare.inputDesc_.size(), \ prepare.inputDesc_.data(), \ @@ -99,7 +70,7 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare ACL_ENGINE_SYS, \ ACL_COMPILE_SYS, \ NULL, \ - Stream())); \ + Stream(ctx))); \ return Status::OK(); \ } diff --git a/onnxruntime/core/providers/cann/math/binary_elementwise_ops.h b/onnxruntime/core/providers/cann/math/binary_elementwise_ops.h index 4aeb38bcce..41922d4bbf 100644 --- a/onnxruntime/core/providers/cann/math/binary_elementwise_ops.h +++ b/onnxruntime/core/providers/cann/math/binary_elementwise_ops.h @@ -24,28 +24,28 @@ template class Add final : public BinaryElementwise { public: Add(const OpKernelInfo& info) : BinaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Sub final : public BinaryElementwise { public: Sub(const OpKernelInfo& info) : BinaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Mul final : public BinaryElementwise { public: Mul(const OpKernelInfo& info) : BinaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Div final : public BinaryElementwise { public: Div(const OpKernelInfo& info) : BinaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; } // namespace cann diff --git a/onnxruntime/core/providers/cann/math/gemm.cc b/onnxruntime/core/providers/cann/math/gemm.cc index d8ea11d8f3..00551de1e0 100644 --- a/onnxruntime/core/providers/cann/math/gemm.cc +++ b/onnxruntime/core/providers/cann/math/gemm.cc @@ -11,10 +11,10 @@ namespace onnxruntime { namespace cann { template -Status Gemm::ComputeInternal(OpKernelContext* context) const { - const auto* A = context->Input(0); - const auto* B = context->Input(1); - const auto* C = context->Input(2); +Status Gemm::ComputeInternal(OpKernelContext* ctx) const { + const auto* A = ctx->Input(0); + const auto* B = ctx->Input(1); + const auto* C = ctx->Input(2); GemmHelper helper(A->Shape(), trans_A_, B->Shape(), trans_B_, C != nullptr ? C->Shape() : TensorShape({})); if (!helper.State().IsOK()) @@ -24,13 +24,13 @@ Status Gemm::ComputeInternal(OpKernelContext* context) const { int N = gsl::narrow_cast(helper.N()); int K = gsl::narrow_cast(helper.K()); - auto* Y = context->Output(0, {M, N}); + auto* Y = ctx->Output(0, {M, N}); // broadcast C if needed. if (beta_ != 0 && C != nullptr) { if (C->Shape().Size() == 1) { // C is (), (1,) or (1, 1), fill the scalar to Y - ORT_RETURN_IF_ERROR(Fill(Y, const_cast(C->DataRaw()))); + ORT_RETURN_IF_ERROR(Fill(Y, const_cast(C->DataRaw()), Stream(ctx))); } else if (C->Shape() == Y->Shape()) { // C is (M, N), no broadcast needed. CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(Y->MutableDataRaw(), @@ -38,10 +38,10 @@ Status Gemm::ComputeInternal(OpKernelContext* context) const { const_cast(C->DataRaw()), Y->SizeInBytes(), ACL_MEMCPY_DEVICE_TO_DEVICE, - Stream())); + Stream(ctx))); } else { // others, broadcast needed. - ORT_RETURN_IF_ERROR(Broadcast(C, Y, Y->MutableDataRaw())); + ORT_RETURN_IF_ERROR(Broadcast(C, Y, Y->MutableDataRaw(), Stream(ctx))); } } @@ -49,8 +49,8 @@ Status Gemm::ComputeInternal(OpKernelContext* context) const { T alpha = ToCannType::FromFloat(alpha_); T beta = ToCannType::FromFloat(beta_); - IAllocatorUniquePtr pAlpha = GetScratchBuffer(sizeof(T)); - IAllocatorUniquePtr pBeta = GetScratchBuffer(sizeof(T)); + IAllocatorUniquePtr pAlpha = GetScratchBuffer(sizeof(T), ctx->GetComputeStream()); + IAllocatorUniquePtr pBeta = GetScratchBuffer(sizeof(T), ctx->GetComputeStream()); CANN_RETURN_IF_ERROR(aclrtMemcpy(pAlpha.get(), sizeof(T), &alpha, sizeof(T), ACL_MEMCPY_HOST_TO_DEVICE)); CANN_RETURN_IF_ERROR(aclrtMemcpy(pBeta.get(), sizeof(T), &beta, sizeof(T), ACL_MEMCPY_HOST_TO_DEVICE)); @@ -67,7 +67,7 @@ Status Gemm::ComputeInternal(OpKernelContext* context) const { pBeta.get(), Y->MutableDataRaw(), -1, aclType, ACL_COMPUTE_HIGH_PRECISION, - Stream())); + Stream(ctx))); return Status::OK(); } diff --git a/onnxruntime/core/providers/cann/math/gemm.h b/onnxruntime/core/providers/cann/math/gemm.h index e4312d315d..8a1b7b4cfc 100644 --- a/onnxruntime/core/providers/cann/math/gemm.h +++ b/onnxruntime/core/providers/cann/math/gemm.h @@ -25,7 +25,7 @@ class Gemm final : public CannKernel { ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK()); } - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; private: bool trans_A_; diff --git a/onnxruntime/core/providers/cann/math/matmul.cc b/onnxruntime/core/providers/cann/math/matmul.cc index 9103cd2fb3..2b02e52d0f 100644 --- a/onnxruntime/core/providers/cann/math/matmul.cc +++ b/onnxruntime/core/providers/cann/math/matmul.cc @@ -52,7 +52,7 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, - Stream())); + Stream(ctx))); return Status::OK(); } diff --git a/onnxruntime/core/providers/cann/math/matmul.h b/onnxruntime/core/providers/cann/math/matmul.h index f7504b7a15..99e655e3c8 100644 --- a/onnxruntime/core/providers/cann/math/matmul.h +++ b/onnxruntime/core/providers/cann/math/matmul.h @@ -15,7 +15,7 @@ class MatMul final : public CannKernel { MatMul(const OpKernelInfo& info) : CannKernel(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; } // namespace cann diff --git a/onnxruntime/core/providers/cann/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cann/math/unary_elementwise_ops.cc index 6c1a670d73..4ac524b063 100644 --- a/onnxruntime/core/providers/cann/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cann/math/unary_elementwise_ops.cc @@ -38,9 +38,9 @@ Status UnaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare) #define REGISTER_ELEMENTWISE_TYPED_COMPUTE(x, T) \ template <> \ - Status x::ComputeInternal(OpKernelContext* context) const { \ + Status x::ComputeInternal(OpKernelContext* ctx) const { \ CannPreparation prepare; \ - ORT_RETURN_IF_ERROR(Prepare(context, prepare)); \ + ORT_RETURN_IF_ERROR(Prepare(ctx, prepare)); \ CANN_RETURN_IF_ERROR(aclopCompileAndExecute(#x, \ prepare.inputDesc_.size(), \ prepare.inputDesc_.data(), \ @@ -52,7 +52,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare) ACL_ENGINE_SYS, \ ACL_COMPILE_SYS, \ NULL, \ - Stream())); \ + Stream(ctx))); \ return Status::OK(); \ } diff --git a/onnxruntime/core/providers/cann/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cann/math/unary_elementwise_ops.h index 7cd03a5a9b..18467c772b 100644 --- a/onnxruntime/core/providers/cann/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cann/math/unary_elementwise_ops.h @@ -29,84 +29,84 @@ template class Abs final : public UnaryElementwise { public: Abs(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Neg final : public UnaryElementwise { public: Neg(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Floor final : public UnaryElementwise { public: Floor(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Ceil final : public UnaryElementwise { public: Ceil(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Reciprocal final : public UnaryElementwise { public: Reciprocal(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Sqrt final : public UnaryElementwise { public: Sqrt(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Log final : public UnaryElementwise { public: Log(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Exp final : public UnaryElementwise { public: Exp(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Erf final : public UnaryElementwise { public: Erf(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Round final : public UnaryElementwise { public: Round(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Sin final : public UnaryElementwise { public: Sin(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; template class Cos final : public UnaryElementwise { public: Cos(const OpKernelInfo& info) : UnaryElementwise(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; } // namespace cann diff --git a/onnxruntime/core/providers/cann/nn/average_pool.cc b/onnxruntime/core/providers/cann/nn/average_pool.cc index 78c5323040..8adcdfb057 100644 --- a/onnxruntime/core/providers/cann/nn/average_pool.cc +++ b/onnxruntime/core/providers/cann/nn/average_pool.cc @@ -11,8 +11,8 @@ namespace onnxruntime { namespace cann { template -Status AveragePool::ComputeInternal(OpKernelContext* context) const { - const Tensor* X = context->Input(0); +Status AveragePool::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* X = ctx->Input(0); const TensorShape& X_shape = X->Shape(); const auto X_dims = X_shape.GetDims(); @@ -35,7 +35,7 @@ Status AveragePool::ComputeInternal(OpKernelContext* context) const { auto Y_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads); TensorShape Y_shape(Y_dims); - Tensor* Y = context->Output(0, Y_shape); + Tensor* Y = ctx->Output(0, Y_shape); if (Y_shape.Size() == 0) return Status::OK(); @@ -86,7 +86,7 @@ Status AveragePool::ComputeInternal(OpKernelContext* context) const { ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, - Stream())); + Stream(ctx))); return Status::OK(); } diff --git a/onnxruntime/core/providers/cann/nn/average_pool.h b/onnxruntime/core/providers/cann/nn/average_pool.h index 4ffa6842d9..9c833d93ac 100644 --- a/onnxruntime/core/providers/cann/nn/average_pool.h +++ b/onnxruntime/core/providers/cann/nn/average_pool.h @@ -15,7 +15,7 @@ class AveragePool : public CannKernel, public PoolBase { public: explicit AveragePool(const OpKernelInfo& info) : CannKernel(info), PoolBase(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; } // namespace cann diff --git a/onnxruntime/core/providers/cann/nn/batch_norm.cc b/onnxruntime/core/providers/cann/nn/batch_norm.cc index a3966cc2fe..f922365045 100644 --- a/onnxruntime/core/providers/cann/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cann/nn/batch_norm.cc @@ -22,10 +22,10 @@ Status BatchNorm::ComputeInternal(OpKernelContext* ctx) const { // There is only one output in inference mode Tensor* Y = ctx->Output(0, X->Shape()); - IAllocatorUniquePtr pbatch_mean = GetScratchBuffer(mean->SizeInBytes()); - IAllocatorUniquePtr pbatch_variance = GetScratchBuffer(var->SizeInBytes()); - IAllocatorUniquePtr preserver_space_1 = GetScratchBuffer(mean->SizeInBytes()); - IAllocatorUniquePtr preserver_space_2 = GetScratchBuffer(var->SizeInBytes()); + IAllocatorUniquePtr pbatch_mean = GetScratchBuffer(mean->SizeInBytes(), ctx->GetComputeStream()); + IAllocatorUniquePtr pbatch_variance = GetScratchBuffer(var->SizeInBytes(), ctx->GetComputeStream()); + IAllocatorUniquePtr preserver_space_1 = GetScratchBuffer(mean->SizeInBytes(), ctx->GetComputeStream()); + IAllocatorUniquePtr preserver_space_2 = GetScratchBuffer(var->SizeInBytes(), ctx->GetComputeStream()); const aclDataType aclType = getACLType(); aclFormat format = ACL_FORMAT_NCHW; @@ -76,7 +76,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* ctx) const { ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, - Stream())); + Stream(ctx))); return Status::OK(); } diff --git a/onnxruntime/core/providers/cann/nn/batch_norm.h b/onnxruntime/core/providers/cann/nn/batch_norm.h index 1f11c61d3e..b92c8b5cdd 100644 --- a/onnxruntime/core/providers/cann/nn/batch_norm.h +++ b/onnxruntime/core/providers/cann/nn/batch_norm.h @@ -20,7 +20,7 @@ class BatchNorm final : public CannKernel { ORT_ENFORCE(!is_training_mode_, "only supports inference mode"); } - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; private: float epsilon_; diff --git a/onnxruntime/core/providers/cann/nn/conv.cc b/onnxruntime/core/providers/cann/nn/conv.cc index 05676e77f3..aad6f971e5 100644 --- a/onnxruntime/core/providers/cann/nn/conv.cc +++ b/onnxruntime/core/providers/cann/nn/conv.cc @@ -106,7 +106,7 @@ Status Conv::ComputeInternal(OpKernelContext* ctx) const { ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, - Stream())); + Stream(ctx))); return Status::OK(); } diff --git a/onnxruntime/core/providers/cann/nn/dropout.cc b/onnxruntime/core/providers/cann/nn/dropout.cc index 5bd09354df..1af36b0aaa 100644 --- a/onnxruntime/core/providers/cann/nn/dropout.cc +++ b/onnxruntime/core/providers/cann/nn/dropout.cc @@ -26,17 +26,17 @@ float GetRatioOrDefault(const Tensor* ratio) { } // namespace template -Status Dropout::ComputeInternal(OpKernelContext* context) const { - const Tensor* X = context->Input(0); +Status Dropout::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* X = ctx->Input(0); const TensorShape& X_shape = X->Shape(); - const Tensor* ratio = context->Input(1); + const Tensor* ratio = ctx->Input(1); const float ratio_value = GetRatioOrDefault(ratio); - const Tensor* training_mode = context->Input(2); + const Tensor* training_mode = ctx->Input(2); - auto Y = context->Output(0, X_shape); - auto mask = context->Output(1, X_shape); + auto Y = ctx->Output(0, X_shape); + auto mask = ctx->Output(1, X_shape); if (ratio_value == 0.f || !training_mode || !(*(training_mode->Data()))) { const void* X_data = X->DataRaw(); @@ -44,22 +44,22 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { if (Y_data != X_data) { CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(Y_data, Y->SizeInBytes(), X_data, Y->SizeInBytes(), - ACL_MEMCPY_DEVICE_TO_DEVICE, Stream())); + ACL_MEMCPY_DEVICE_TO_DEVICE, Stream(ctx))); } if (mask) { CANN_RETURN_IF_ERROR(aclrtMemsetAsync(mask->MutableData(), mask->SizeInBytes(), true, - mask->SizeInBytes(), Stream())); + mask->SizeInBytes(), Stream(ctx))); } } else { IAllocatorUniquePtr pmask{}; - IAllocatorUniquePtr pseed = GetScratchBuffer(sizeof(float)); + IAllocatorUniquePtr pseed = GetScratchBuffer(sizeof(float), ctx->GetComputeStream()); void* mask_data = nullptr; if (mask) { mask_data = mask->MutableDataRaw(); } else { - pmask = GetScratchBuffer(X_shape.Size() * sizeof(bool)); + pmask = GetScratchBuffer(X_shape.Size() * sizeof(bool), ctx->GetComputeStream()); mask_data = pmask.get(); } @@ -107,7 +107,7 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, - Stream())); + Stream(ctx))); } return Status::OK(); diff --git a/onnxruntime/core/providers/cann/nn/dropout.h b/onnxruntime/core/providers/cann/nn/dropout.h index a589b49568..f085b89442 100644 --- a/onnxruntime/core/providers/cann/nn/dropout.h +++ b/onnxruntime/core/providers/cann/nn/dropout.h @@ -21,7 +21,7 @@ class Dropout final : public CannKernel { } } - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; private: mutable std::unique_ptr generator_; diff --git a/onnxruntime/core/providers/cann/nn/max_pool.cc b/onnxruntime/core/providers/cann/nn/max_pool.cc index e53b05f6c9..8d6637e5e0 100644 --- a/onnxruntime/core/providers/cann/nn/max_pool.cc +++ b/onnxruntime/core/providers/cann/nn/max_pool.cc @@ -11,8 +11,8 @@ namespace onnxruntime { namespace cann { template -Status MaxPool::ComputeInternal(OpKernelContext* context) const { - const Tensor* X = context->Input(0); +Status MaxPool::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* X = ctx->Input(0); const TensorShape& X_shape = X->Shape(); const auto X_dims = X_shape.GetDims(); @@ -35,7 +35,7 @@ Status MaxPool::ComputeInternal(OpKernelContext* context) const { auto Y_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads); TensorShape Y_shape(Y_dims); - Tensor* Y = context->Output(0, Y_shape); + Tensor* Y = ctx->Output(0, Y_shape); if (Y_shape.Size() == 0) return Status::OK(); @@ -85,7 +85,7 @@ Status MaxPool::ComputeInternal(OpKernelContext* context) const { ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, - Stream())); + Stream(ctx))); return Status::OK(); } diff --git a/onnxruntime/core/providers/cann/nn/max_pool.h b/onnxruntime/core/providers/cann/nn/max_pool.h index 09893fda01..9679e756a5 100644 --- a/onnxruntime/core/providers/cann/nn/max_pool.h +++ b/onnxruntime/core/providers/cann/nn/max_pool.h @@ -15,7 +15,7 @@ class MaxPool : public CannKernel, public PoolBase { public: explicit MaxPool(const OpKernelInfo& info) : CannKernel(info), PoolBase(info) {} - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; }; } // namespace cann diff --git a/onnxruntime/core/providers/cann/npu_data_transfer.cc b/onnxruntime/core/providers/cann/npu_data_transfer.cc index e7cd8e751c..2f51c550b2 100644 --- a/onnxruntime/core/providers/cann/npu_data_transfer.cc +++ b/onnxruntime/core/providers/cann/npu_data_transfer.cc @@ -4,35 +4,62 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/cann/npu_data_transfer.h" -#include "core/providers/cann/cann_call.h" namespace onnxruntime { -NPUDataTransfer::NPUDataTransfer(aclrtStream stream, bool do_copy_in_default_stream) { - do_copy_in_default_stream_ = do_copy_in_default_stream; - streams_[kCannStreamDefault] = stream; - if (do_copy_in_default_stream) { - streams_[kCannStreamCopyIn] = stream; - streams_[kCannStreamCopyOut] = stream; - } else { - CANN_CALL_THROW(aclrtCreateStream(&streams_[kCannStreamCopyIn])); - CANN_CALL_THROW(aclrtCreateStream(&streams_[kCannStreamCopyOut])); - } -} +NPUDataTransfer::NPUDataTransfer() {} -NPUDataTransfer::~NPUDataTransfer() { - if (!do_copy_in_default_stream_ && streams_[kCannStreamCopyIn] != nullptr) { - CANN_CALL_THROW(aclrtDestroyStream(streams_[kCannStreamCopyIn])); - } - if (!do_copy_in_default_stream_ && streams_[kCannStreamCopyOut] != nullptr) { - CANN_CALL_THROW(aclrtDestroyStream(streams_[kCannStreamCopyOut])); - } -} +NPUDataTransfer::~NPUDataTransfer() {} bool NPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::NPU || dst_device.Type() == OrtDevice::NPU; } -common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const { +common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + size_t bytes = src.SizeInBytes(); + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + // for the sync version of memcpy, launch to cann default stream + if (dst_device.Type() == OrtDevice::NPU) { + if (src_device.Type() == OrtDevice::NPU) { + // Copy only if the two addresses are different. + if (dst_data != src_data) { + CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data, + bytes, + src_data, + bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE)); + CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr)); + } + } else { + // copy from other CPU memory to NPU, this is blocking + CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data, + bytes, + src_data, + bytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr)); + } + } else if (src_device.Type() == OrtDevice::NPU) { + // copying from NPU to CPU memory, this is blocking + CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data, + bytes, + src_data, + bytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr)); + } else { + // copying between cpu memory + memcpy(dst_data, src_data, bytes); + } + + return Status::OK(); +} + +common::Status NPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const { size_t bytes = src.SizeInBytes(); const void* src_data = src.DataRaw(); void* dst_data = dst.MutableDataRaw(); @@ -41,30 +68,44 @@ common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int e auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::NPU) { - if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::CANN_PINNED) { - CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes, - ACL_MEMCPY_HOST_TO_DEVICE, GetStream(exec_queue_id))); + if (src_device.Type() == OrtDevice::CPU) { + // copy from pinned memory to NPU, this is non-blocking + CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data, + bytes, + src_data, + bytes, + ACL_MEMCPY_HOST_TO_DEVICE, + static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::NPU) { + // copying between NPU, this is non-blocking if (dst_data != src_data) { - CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes, - ACL_MEMCPY_DEVICE_TO_DEVICE, GetStream(kCannStreamDefault))); + CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data, + bytes, + src_data, + bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE, + static_cast(stream.GetHandle()))); } - } else { - CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes, - ACL_MEMCPY_HOST_TO_DEVICE, GetStream(kCannStreamDefault))); - CANN_CALL_THROW(aclrtSynchronizeStream(GetStream(kCannStreamDefault))); + } + } else if (src_device.Type() == OrtDevice::NPU) { + if (dst_device.Type() == OrtDevice::CPU) { + // copying from NPU to pinned memory, this is non-blocking + CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data, + bytes, + src_data, + bytes, + ACL_MEMCPY_DEVICE_TO_HOST, + static_cast(stream.GetHandle()))); } } else { - if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::CANN_PINNED) { - CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes, - ACL_MEMCPY_DEVICE_TO_HOST, GetStream(exec_queue_id))); - } else { - CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes, - ACL_MEMCPY_DEVICE_TO_HOST, GetStream(kCannStreamDefault))); - CANN_CALL_THROW(aclrtSynchronizeStream(GetStream(kCannStreamDefault))); + if (src_device.MemType() == OrtDevice::MemType::CANN_PINNED) { + // sync the stream first to make sure the data arrived + CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(static_cast(stream.GetHandle()))); } + memcpy(dst_data, src_data, bytes); } return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/npu_data_transfer.h b/onnxruntime/core/providers/cann/npu_data_transfer.h index 40a4dace48..d86e3fc406 100644 --- a/onnxruntime/core/providers/cann/npu_data_transfer.h +++ b/onnxruntime/core/providers/cann/npu_data_transfer.h @@ -4,35 +4,23 @@ #pragma once -#include "cann_inc.h" #include "core/framework/data_transfer.h" +#include "core/providers/cann/cann_inc.h" +#include "core/providers/cann/cann_call.h" +#include "core/providers/cann/cann_common.h" namespace onnxruntime { -enum CANNStreamType : int { - kCannStreamDefault = 0, - kCannStreamCopyIn, - kCannStreamCopyOut, - kTotalCannStreams, -}; - class NPUDataTransfer : public IDataTransfer { public: - explicit NPUDataTransfer(aclrtStream stream, bool do_copy_in_default_stream = true); + NPUDataTransfer(); ~NPUDataTransfer(); bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; - common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const override; + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; - aclrtStream GetStream(int queue_id) const { - ORT_ENFORCE(queue_id >= 0 && queue_id < kTotalCannStreams); - return streams_[queue_id]; - } - - private: - bool do_copy_in_default_stream_; - aclrtStream streams_[kTotalCannStreams]; + common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/tensor/cast.cc b/onnxruntime/core/providers/cann/tensor/cast.cc index 07601580bf..8711bee410 100644 --- a/onnxruntime/core/providers/cann/tensor/cast.cc +++ b/onnxruntime/core/providers/cann/tensor/cast.cc @@ -42,10 +42,10 @@ aclDataType getACLTypeByMap(ONNX_NAMESPACE::TensorProto_DataType type) { } template -Status Cast::ComputeInternal(OpKernelContext* context) const { - const Tensor* X = context->Input(0); +Status Cast::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* X = ctx->Input(0); - Tensor* Y = context->Output(0, X->Shape()); + Tensor* Y = ctx->Output(0, X->Shape()); aclFormat format = ACL_FORMAT_ND; const aclDataType aclTypeX = getACLType(); @@ -78,7 +78,7 @@ Status Cast::ComputeInternal(OpKernelContext* context) const { ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, - Stream())); + Stream(ctx))); return Status::OK(); } diff --git a/onnxruntime/core/providers/cann/tensor/cast.h b/onnxruntime/core/providers/cann/tensor/cast.h index 54f7188920..e565aae408 100644 --- a/onnxruntime/core/providers/cann/tensor/cast.h +++ b/onnxruntime/core/providers/cann/tensor/cast.h @@ -20,7 +20,7 @@ class Cast final : public CannKernel { to_ = gsl::narrow_cast(to); } - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; private: ONNX_NAMESPACE::TensorProto_DataType to_; diff --git a/onnxruntime/core/providers/cann/tensor/flatten.cc b/onnxruntime/core/providers/cann/tensor/flatten.cc index c3348601b7..4a2e122cb2 100644 --- a/onnxruntime/core/providers/cann/tensor/flatten.cc +++ b/onnxruntime/core/providers/cann/tensor/flatten.cc @@ -53,7 +53,7 @@ Status Flatten::ComputeInternal(OpKernelContext* ctx) const { ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, - Stream())); + Stream(ctx))); } return Status::OK(); diff --git a/onnxruntime/core/providers/cann/tensor/flatten.h b/onnxruntime/core/providers/cann/tensor/flatten.h index 59759cfa96..0029a413c9 100644 --- a/onnxruntime/core/providers/cann/tensor/flatten.h +++ b/onnxruntime/core/providers/cann/tensor/flatten.h @@ -16,7 +16,7 @@ class Flatten final : public CannKernel { ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK()); } - Status ComputeInternal(OpKernelContext* context) const override; + Status ComputeInternal(OpKernelContext* ctx) const override; private: int64_t axis_; diff --git a/onnxruntime/core/providers/cann/tensor/identity_op.h b/onnxruntime/core/providers/cann/tensor/identity_op.h index 8d9802ed8b..203bb43eb9 100644 --- a/onnxruntime/core/providers/cann/tensor/identity_op.h +++ b/onnxruntime/core/providers/cann/tensor/identity_op.h @@ -18,16 +18,16 @@ class IdentityOp final : public CannKernel { IdentityOp(const OpKernelInfo& info) : CannKernel(info) { } - Status ComputeInternal(OpKernelContext* context) const override { - auto X_ml_type = context->InputType(0); + Status ComputeInternal(OpKernelContext* ctx) const override { + auto X_ml_type = ctx->InputType(0); if (X_ml_type->IsTensorType()) { - const Tensor* X = context->Input(0); + const Tensor* X = ctx->Input(0); if (nullptr == X) { return Status(common::ONNXRUNTIME, common::FAIL, "IdentityOp cann: input count mismatch."); } const TensorShape& shape = X->Shape(); - Tensor* Y = context->Output(0, shape); + Tensor* Y = ctx->Output(0, shape); if (nullptr == Y) { return Status(common::ONNXRUNTIME, common::FAIL, "IdentityOp cann: failed to allocate output tensor."); @@ -39,20 +39,20 @@ class IdentityOp final : public CannKernel { if (target != source) { CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(target, Y->SizeInBytes(), source, X->Shape().Size() * X->DataType()->Size(), - ACL_MEMCPY_DEVICE_TO_DEVICE, Stream())); + ACL_MEMCPY_DEVICE_TO_DEVICE, Stream(ctx))); } if (is_dropout) { - Tensor* mask = context->Output(1, shape); + Tensor* mask = ctx->Output(1, shape); if (mask != nullptr) { void* mask_data = mask->MutableDataRaw(); - CANN_RETURN_IF_ERROR(aclrtMemsetAsync(mask_data, mask->SizeInBytes(), 0, mask->SizeInBytes(), Stream())); + CANN_RETURN_IF_ERROR(aclrtMemsetAsync(mask_data, mask->SizeInBytes(), 0, mask->SizeInBytes(), Stream(ctx))); } } } else if (X_ml_type->IsTensorSequenceType()) { - const TensorSeq* X = context->Input(0); + const TensorSeq* X = ctx->Input(0); ORT_ENFORCE(X != nullptr, "IdentityOp cann: input tensor is missing."); - TensorSeq* Y = context->Output(0); + TensorSeq* Y = ctx->Output(0); ORT_ENFORCE(Y != nullptr, "IdentityOp cann: failed to allocate output tensor sequence."); if (X == Y) { return Status::OK(); @@ -60,7 +60,7 @@ class IdentityOp final : public CannKernel { auto X_type = X->DataType(); Y->SetType(X_type); AllocatorPtr alloc; - auto status = context->GetTempSpaceAllocator(&alloc); + auto status = ctx->GetTempSpaceAllocator(&alloc); if (!status.IsOK()) { return Status(common::ONNXRUNTIME, common::FAIL, "IdentityOp cann: unable to get an allocator."); @@ -75,7 +75,7 @@ class IdentityOp final : public CannKernel { target_tensor->SizeInBytes(), source_tensor.DataRaw(), source_tensor.SizeInBytes(), - ACL_MEMCPY_DEVICE_TO_DEVICE, Stream())); + ACL_MEMCPY_DEVICE_TO_DEVICE, Stream(ctx))); Y->Add(std::move(*target_tensor)); } } else { diff --git a/onnxruntime/core/providers/cann/tensor/reshape.h b/onnxruntime/core/providers/cann/tensor/reshape.h index b9dae46e13..48a92aee91 100644 --- a/onnxruntime/core/providers/cann/tensor/reshape.h +++ b/onnxruntime/core/providers/cann/tensor/reshape.h @@ -18,8 +18,8 @@ class Reshape final : public CannKernel { allow_zero_ = (info.GetAttrOrDefault("allowzero", static_cast(0)) == 1); } - Status ComputeInternal(OpKernelContext* context) const override { - const Tensor* shapeTensor = context->Input(1); + Status ComputeInternal(OpKernelContext* ctx) const override { + const Tensor* shapeTensor = ctx->Input(1); if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "the 0th input is missing"); if (shapeTensor->Shape().NumDimensions() != 1) @@ -28,14 +28,14 @@ class Reshape final : public CannKernel { auto data_span = shapeTensor->template DataAsSpan(); TensorShapeVector shape(data_span.begin(), data_span.end()); - const Tensor* X = context->Input(0); + const Tensor* X = ctx->Input(0); if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "the 1th input is missing"); const TensorShape& X_shape = X->Shape(); ReshapeHelper helper(X_shape, shape, allow_zero_); - Tensor* Y = context->Output(0, TensorShape(shape)); + Tensor* Y = ctx->Output(0, TensorShape(shape)); const void* source = X->DataRaw(); void* target = Y->MutableDataRaw(); if (target != source) { @@ -56,14 +56,14 @@ class Reshape_1 final : public CannKernel { ORT_ENFORCE(status.IsOK(), "Attribute shape is not set."); } - Status ComputeInternal(OpKernelContext* context) const override { + Status ComputeInternal(OpKernelContext* ctx) const override { TensorShapeVector shape = shape_; - const Tensor* X = context->Input(0); + const Tensor* X = ctx->Input(0); const TensorShape& X_shape = X->Shape(); ReshapeHelper helper(X_shape, shape); - Tensor* Y = context->Output(0, TensorShape(shape)); + Tensor* Y = ctx->Output(0, TensorShape(shape)); const void* source = X->DataRaw(); void* target = Y->MutableDataRaw(); if (target != source) { diff --git a/onnxruntime/core/providers/cann/tensor/transpose.cc b/onnxruntime/core/providers/cann/tensor/transpose.cc index 869128a3d5..d91bfe38e1 100644 --- a/onnxruntime/core/providers/cann/tensor/transpose.cc +++ b/onnxruntime/core/providers/cann/tensor/transpose.cc @@ -55,7 +55,7 @@ Status Transpose::ComputeInternal(OpKernelContext* ctx) const { ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, - Stream())); + Stream(ctx))); return Status::OK(); } diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index bd24fad008..817ae047da 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -690,6 +690,10 @@ void RefCountTracker::DumpDetails(const std::string& phase_name) const { #if defined(USE_CANN) RandomGenerator& RandomGenerator::Default() { return g_host->RandomGenerator__Default(); } +void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream, + WaitNotificationFn wait_fn) { + return g_host->Allocator__AllocateBufferWithOptions(allocator, size, use_reserve, stream, wait_fn); +} namespace cann { std::unique_ptr CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger) { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 8c266e34d0..faa25c39f5 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1847,8 +1847,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCANNProviderOptions, _Outptr_ OrtCANNProvider (*out)->device_id = 0; (*out)->npu_mem_limit = SIZE_MAX; (*out)->arena_extend_strategy = static_cast(0); - (*out)->do_copy_in_default_stream = 1; (*out)->enable_cann_graph = 1; + (*out)->dump_graphs = 0; (*out)->default_memory_arena_cfg = nullptr; return nullptr; #else