update kernel memory type interface (#225)

* refactor the kernel memory type interface

* remove useless change

* fix comments in PR
This commit is contained in:
Tang, Cheng 2018-12-20 11:11:50 -08:00 committed by GitHub
parent a43382e390
commit c453b48b71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 37 deletions

View file

@ -20,13 +20,15 @@ class KernelDefBuilder;
typedef std::map<size_t, OrtMemType> MemTypeMap;
// note that input/output might be on CPU implicitly when the node is from CPU execution provider
inline bool MemTypeOnCpuExplicitly(const MemTypeMap& mem_type_map, size_t index) {
auto iter = mem_type_map.find(index);
return iter != mem_type_map.end() && (iter->second == OrtMemTypeCPUInput || iter->second == OrtMemTypeCPUOutput);
inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
}
class KernelDef {
public:
explicit KernelDef() : default_inputs_mem_type_(OrtMemTypeDefault), default_outputs_mem_type_(OrtMemTypeDefault) {
}
const std::string& OpName() const {
return op_name_;
}
@ -56,17 +58,20 @@ class KernelDef {
return alias_map_;
}
const MemTypeMap& InputMemoryType() const {
return input_memory_type_args_;
OrtMemType InputMemoryType(size_t input_index) const {
auto it = input_memory_type_args_.find(input_index);
if (it == input_memory_type_args_.end())
return default_inputs_mem_type_;
else
return it->second;
}
const MemTypeMap& OutputMemoryType() const {
return output_memory_type_args_;
}
// legacy interface for winml, should not be used in onnxruntime
const MemTypeMap& MemoryType() const {
return output_memory_type_args_;
OrtMemType OutputMemoryType(size_t output_index) const {
auto it = output_memory_type_args_.find(output_index);
if (it == output_memory_type_args_.end())
return default_outputs_mem_type_;
else
return it->second;
}
int ExecQueueId() const {
@ -111,6 +116,10 @@ class KernelDef {
// execution command queue id, 0 for default queue in execution provider
int exec_queue_id_ = 0;
// Default memory type for all inputs
OrtMemType default_inputs_mem_type_;
// Default memory type for all outputs
OrtMemType default_outputs_mem_type_;
};
class KernelDefBuilder {
@ -212,6 +221,22 @@ class KernelDefBuilder {
return *this;
}
/**
Specify the default inputs memory type, if not specified, it is DefaultMemory
*/
KernelDefBuilder& SetDefaultInputsMemoryType(OrtMemType mem_type) {
kernel_def_->default_inputs_mem_type_ = mem_type;
return *this;
}
/**
Specify the default outputs memory type, if not specified, it is DefaultMemory
*/
KernelDefBuilder& SetDefaultOutputMemoryType(OrtMemType mem_type) {
kernel_def_->default_outputs_mem_type_ = mem_type;
return *this;
}
/**
Return the kernel definition, passing ownership of the KernelDef to the caller
*/

View file

@ -380,7 +380,6 @@ class PlannerImpl {
ORT_ENFORCE(exec_provider);
auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
auto& mem_type_allocated_args = p_kernelDef->OutputMemoryType();
auto& outputs = pnode->OutputDefs();
auto num_outputs = outputs.size();
@ -393,11 +392,11 @@ class PlannerImpl {
if (strcmp(default_allocator_info.name, CPU) != 0) {
// By default, outputs of this node are allocated on the default device allocator,
// except for outputs marked for allocation in MemoryType:
auto memory_type_iter = mem_type_allocated_args.find(i);
if (memory_type_iter == mem_type_allocated_args.end()) {
auto memory_type = p_kernelDef->OutputMemoryType(i);
if (memory_type == OrtMemTypeDefault) {
AllocPlan(index).location = default_allocator_info;
} else {
AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type_iter->second)->Info();
AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type)->Info();
}
}
}
@ -438,7 +437,7 @@ class PlannerImpl {
thisplan.alloc_kind = AllocKind::kAllocateStatically;
auto p_opkernelDef = utils::GetKernelDef(kernel_registry_, node);
if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(), index))
if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(index)))
// weights are not output from any node, so it's OK to put its location on CPU provider
thisplan.location = execution_providers_.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault)->Info();
else

View file

@ -66,20 +66,20 @@ bool KernelDef::IsConflict(const KernelDef& other) const {
return false;
//check memory type
auto other_input_mem_types = other.InputMemoryType();
auto& other_input_mem_types = other.input_memory_type_args_;
for (auto it : input_memory_type_args_) {
if (other_input_mem_types.count(it.first) && other_input_mem_types[it.first] == it.second)
if (other_input_mem_types.count(it.first) && other_input_mem_types.find(it.first)->second == it.second)
return false;
}
if (input_memory_type_args_.empty() && !other.InputMemoryType().empty())
if (input_memory_type_args_.empty() && !other.input_memory_type_args_.empty())
return false;
auto other_output_mem_types = other.OutputMemoryType();
auto& other_output_mem_types = other.output_memory_type_args_;
for (auto it : output_memory_type_args_) {
if (other_output_mem_types.count(it.first) && other_output_mem_types[it.first] == it.second)
if (other_output_mem_types.count(it.first) && other_output_mem_types.find(it.second)->second == it.second)
return false;
}
return !(output_memory_type_args_.empty() && !other.OutputMemoryType().empty());
return !(output_memory_type_args_.empty() && !other.output_memory_type_args_.empty());
}
KernelDefBuilder& KernelDefBuilder::SetName(const std::string& op_name) {

View file

@ -68,25 +68,24 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
// note KernelCreateInfo might be nullptr for custom kernel
const KernelCreateInfo* kci = nullptr;
kernel_registries.SearchKernelRegistry(node, &kci);
const auto* input_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr;
const auto* output_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr;
ORT_ENFORCE(onnxruntime::Node::ForEachWithIndex(
node.InputDefs(),
[this, &input_mem_types](const onnxruntime::NodeArg& arg, size_t index) {
if (input_mem_types && MemTypeOnCpuExplicitly(*input_mem_types, index))
non_provider_input_defs_.insert(&arg);
else
provider_input_defs_.insert(&arg);
return Status::OK();
})
.IsOK());
node.InputDefs(),
[this, &kci](const onnxruntime::NodeArg& arg, size_t index) {
if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index)))
non_provider_input_defs_.insert(&arg);
else
provider_input_defs_.insert(&arg);
return Status::OK();
})
.IsOK());
auto& output_defs = node.MutableOutputDefs();
for (size_t i = 0; i < output_defs.size(); ++i) {
auto arg = output_defs[i];
if (!arg->Exists())
continue;
if (output_mem_types && MemTypeOnCpuExplicitly(*output_mem_types, i))
if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->OutputMemoryType(i)))
non_provider_output_defs_.insert(arg);
else
provider_output_defs_.insert(arg);

View file

@ -60,10 +60,9 @@ common::Status IOBinding::CopyOneInputAcrossDevices(const SessionState& session_
size_t index = node_info.index;
auto& node = *node_info.p_node;
const KernelCreateInfo* kci = node_info.kci;
const auto* node_input_mem_types = (kci != nullptr) ? &kci->kernel_def->InputMemoryType() : nullptr;
// node may declare input_mem_type to be on CPU explicitly
bool node_input_on_cpu = node_input_mem_types && MemTypeOnCpuExplicitly(*node_input_mem_types, index);
bool node_input_on_cpu = kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index));
auto& required_provider_type = node_input_on_cpu ? onnxruntime::kCpuExecutionProvider : node.GetExecutionProviderType();
if (!orig_mlvalue.IsTensor()) {
// copying not supported for non-tensor types