mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
update kernel memory type interface (#225)
* refactor the kernel memory type interface * remove useless change * fix comments in PR
This commit is contained in:
parent
a43382e390
commit
c453b48b71
5 changed files with 59 additions and 37 deletions
|
|
@ -20,13 +20,15 @@ class KernelDefBuilder;
|
|||
typedef std::map<size_t, OrtMemType> MemTypeMap;
|
||||
|
||||
// note that input/output might be on CPU implicitly when the node is from CPU execution provider
|
||||
inline bool MemTypeOnCpuExplicitly(const MemTypeMap& mem_type_map, size_t index) {
|
||||
auto iter = mem_type_map.find(index);
|
||||
return iter != mem_type_map.end() && (iter->second == OrtMemTypeCPUInput || iter->second == OrtMemTypeCPUOutput);
|
||||
inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
|
||||
return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
|
||||
}
|
||||
|
||||
class KernelDef {
|
||||
public:
|
||||
explicit KernelDef() : default_inputs_mem_type_(OrtMemTypeDefault), default_outputs_mem_type_(OrtMemTypeDefault) {
|
||||
}
|
||||
|
||||
const std::string& OpName() const {
|
||||
return op_name_;
|
||||
}
|
||||
|
|
@ -56,17 +58,20 @@ class KernelDef {
|
|||
return alias_map_;
|
||||
}
|
||||
|
||||
const MemTypeMap& InputMemoryType() const {
|
||||
return input_memory_type_args_;
|
||||
OrtMemType InputMemoryType(size_t input_index) const {
|
||||
auto it = input_memory_type_args_.find(input_index);
|
||||
if (it == input_memory_type_args_.end())
|
||||
return default_inputs_mem_type_;
|
||||
else
|
||||
return it->second;
|
||||
}
|
||||
|
||||
const MemTypeMap& OutputMemoryType() const {
|
||||
return output_memory_type_args_;
|
||||
}
|
||||
|
||||
// legacy interface for winml, should not be used in onnxruntime
|
||||
const MemTypeMap& MemoryType() const {
|
||||
return output_memory_type_args_;
|
||||
OrtMemType OutputMemoryType(size_t output_index) const {
|
||||
auto it = output_memory_type_args_.find(output_index);
|
||||
if (it == output_memory_type_args_.end())
|
||||
return default_outputs_mem_type_;
|
||||
else
|
||||
return it->second;
|
||||
}
|
||||
|
||||
int ExecQueueId() const {
|
||||
|
|
@ -111,6 +116,10 @@ class KernelDef {
|
|||
|
||||
// execution command queue id, 0 for default queue in execution provider
|
||||
int exec_queue_id_ = 0;
|
||||
// Default memory type for all inputs
|
||||
OrtMemType default_inputs_mem_type_;
|
||||
// Default memory type for all outputs
|
||||
OrtMemType default_outputs_mem_type_;
|
||||
};
|
||||
|
||||
class KernelDefBuilder {
|
||||
|
|
@ -212,6 +221,22 @@ class KernelDefBuilder {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
Specify the default inputs memory type, if not specified, it is DefaultMemory
|
||||
*/
|
||||
KernelDefBuilder& SetDefaultInputsMemoryType(OrtMemType mem_type) {
|
||||
kernel_def_->default_inputs_mem_type_ = mem_type;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
Specify the default outputs memory type, if not specified, it is DefaultMemory
|
||||
*/
|
||||
KernelDefBuilder& SetDefaultOutputMemoryType(OrtMemType mem_type) {
|
||||
kernel_def_->default_outputs_mem_type_ = mem_type;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
Return the kernel definition, passing ownership of the KernelDef to the caller
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -380,7 +380,6 @@ class PlannerImpl {
|
|||
ORT_ENFORCE(exec_provider);
|
||||
|
||||
auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
|
||||
auto& mem_type_allocated_args = p_kernelDef->OutputMemoryType();
|
||||
auto& outputs = pnode->OutputDefs();
|
||||
auto num_outputs = outputs.size();
|
||||
|
||||
|
|
@ -393,11 +392,11 @@ class PlannerImpl {
|
|||
if (strcmp(default_allocator_info.name, CPU) != 0) {
|
||||
// By default, outputs of this node are allocated on the default device allocator,
|
||||
// except for outputs marked for allocation in MemoryType:
|
||||
auto memory_type_iter = mem_type_allocated_args.find(i);
|
||||
if (memory_type_iter == mem_type_allocated_args.end()) {
|
||||
auto memory_type = p_kernelDef->OutputMemoryType(i);
|
||||
if (memory_type == OrtMemTypeDefault) {
|
||||
AllocPlan(index).location = default_allocator_info;
|
||||
} else {
|
||||
AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type_iter->second)->Info();
|
||||
AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type)->Info();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -438,7 +437,7 @@ class PlannerImpl {
|
|||
|
||||
thisplan.alloc_kind = AllocKind::kAllocateStatically;
|
||||
auto p_opkernelDef = utils::GetKernelDef(kernel_registry_, node);
|
||||
if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(), index))
|
||||
if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(index)))
|
||||
// weights are not output from any node, so it's OK to put its location on CPU provider
|
||||
thisplan.location = execution_providers_.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault)->Info();
|
||||
else
|
||||
|
|
|
|||
|
|
@ -66,20 +66,20 @@ bool KernelDef::IsConflict(const KernelDef& other) const {
|
|||
return false;
|
||||
|
||||
//check memory type
|
||||
auto other_input_mem_types = other.InputMemoryType();
|
||||
auto& other_input_mem_types = other.input_memory_type_args_;
|
||||
for (auto it : input_memory_type_args_) {
|
||||
if (other_input_mem_types.count(it.first) && other_input_mem_types[it.first] == it.second)
|
||||
if (other_input_mem_types.count(it.first) && other_input_mem_types.find(it.first)->second == it.second)
|
||||
return false;
|
||||
}
|
||||
if (input_memory_type_args_.empty() && !other.InputMemoryType().empty())
|
||||
if (input_memory_type_args_.empty() && !other.input_memory_type_args_.empty())
|
||||
return false;
|
||||
|
||||
auto other_output_mem_types = other.OutputMemoryType();
|
||||
auto& other_output_mem_types = other.output_memory_type_args_;
|
||||
for (auto it : output_memory_type_args_) {
|
||||
if (other_output_mem_types.count(it.first) && other_output_mem_types[it.first] == it.second)
|
||||
if (other_output_mem_types.count(it.first) && other_output_mem_types.find(it.second)->second == it.second)
|
||||
return false;
|
||||
}
|
||||
return !(output_memory_type_args_.empty() && !other.OutputMemoryType().empty());
|
||||
return !(output_memory_type_args_.empty() && !other.output_memory_type_args_.empty());
|
||||
}
|
||||
|
||||
KernelDefBuilder& KernelDefBuilder::SetName(const std::string& op_name) {
|
||||
|
|
|
|||
|
|
@ -68,25 +68,24 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
|
|||
// note KernelCreateInfo might be nullptr for custom kernel
|
||||
const KernelCreateInfo* kci = nullptr;
|
||||
kernel_registries.SearchKernelRegistry(node, &kci);
|
||||
const auto* input_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr;
|
||||
const auto* output_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr;
|
||||
|
||||
ORT_ENFORCE(onnxruntime::Node::ForEachWithIndex(
|
||||
node.InputDefs(),
|
||||
[this, &input_mem_types](const onnxruntime::NodeArg& arg, size_t index) {
|
||||
if (input_mem_types && MemTypeOnCpuExplicitly(*input_mem_types, index))
|
||||
non_provider_input_defs_.insert(&arg);
|
||||
else
|
||||
provider_input_defs_.insert(&arg);
|
||||
return Status::OK();
|
||||
})
|
||||
.IsOK());
|
||||
node.InputDefs(),
|
||||
[this, &kci](const onnxruntime::NodeArg& arg, size_t index) {
|
||||
if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index)))
|
||||
non_provider_input_defs_.insert(&arg);
|
||||
else
|
||||
provider_input_defs_.insert(&arg);
|
||||
return Status::OK();
|
||||
})
|
||||
.IsOK());
|
||||
auto& output_defs = node.MutableOutputDefs();
|
||||
for (size_t i = 0; i < output_defs.size(); ++i) {
|
||||
auto arg = output_defs[i];
|
||||
if (!arg->Exists())
|
||||
continue;
|
||||
|
||||
if (output_mem_types && MemTypeOnCpuExplicitly(*output_mem_types, i))
|
||||
if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->OutputMemoryType(i)))
|
||||
non_provider_output_defs_.insert(arg);
|
||||
else
|
||||
provider_output_defs_.insert(arg);
|
||||
|
|
|
|||
|
|
@ -60,10 +60,9 @@ common::Status IOBinding::CopyOneInputAcrossDevices(const SessionState& session_
|
|||
size_t index = node_info.index;
|
||||
auto& node = *node_info.p_node;
|
||||
const KernelCreateInfo* kci = node_info.kci;
|
||||
const auto* node_input_mem_types = (kci != nullptr) ? &kci->kernel_def->InputMemoryType() : nullptr;
|
||||
|
||||
// node may declare input_mem_type to be on CPU explicitly
|
||||
bool node_input_on_cpu = node_input_mem_types && MemTypeOnCpuExplicitly(*node_input_mem_types, index);
|
||||
bool node_input_on_cpu = kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index));
|
||||
auto& required_provider_type = node_input_on_cpu ? onnxruntime::kCpuExecutionProvider : node.GetExecutionProviderType();
|
||||
if (!orig_mlvalue.IsTensor()) {
|
||||
// copying not supported for non-tensor types
|
||||
|
|
|
|||
Loading…
Reference in a new issue