mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Fix kernel registry validation to reenable DML kernels
This commit is contained in:
parent
ddaad86605
commit
b9faa0b6fd
3 changed files with 25 additions and 9 deletions
|
|
@ -28,10 +28,18 @@ inline bool AreVectorsOverlap(const std::vector<T>& v1, const std::vector<T>& v2
|
|||
bool KernelDef::IsConflict(const KernelDef& other) const {
|
||||
if (op_name_ != other.OpName() || provider_type_ != other.Provider())
|
||||
return false;
|
||||
int start = 0;
|
||||
int end = 0;
|
||||
other.SinceVersion(&start, &end);
|
||||
if (!AreIntervalsOverlap(op_since_version_start_, op_since_version_end_, start, end))
|
||||
int other_since_version_start = 0;
|
||||
int other_since_version_end = 0;
|
||||
other.SinceVersion(&other_since_version_start, &other_since_version_end);
|
||||
|
||||
//When max version is INT_MAX, it means that it should be determined based on the
|
||||
//SinceVersion of schema from a higher version. Since this sometimes isn't known until
|
||||
//all custom schema are available, make a conservative assumption here that the operator
|
||||
//is valid for only one version.
|
||||
int op_since_version_conservative_end = (op_since_version_end_ == INT_MAX) ? op_since_version_start_ : op_since_version_end_;
|
||||
int other_conservative_since_version_end = (other_since_version_end == INT_MAX) ? other_since_version_start : other_since_version_end;
|
||||
|
||||
if (!AreIntervalsOverlap(op_since_version_start_, op_since_version_conservative_end, other_since_version_start, other_conservative_since_version_end))
|
||||
return false;
|
||||
//only one case they don't conflict:
|
||||
//There is a type_constraint, it exists in both hands, but they don't overlap
|
||||
|
|
|
|||
|
|
@ -499,11 +499,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
registration->requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph;
|
||||
registration->requiredConstantCpuInputs = constantCpuInputCapture;
|
||||
|
||||
// TODO: Propagate errors here once the presence of overlapping built-in DML kernels is addressed
|
||||
if (m_kernelRegistry->RegisterCustomKernel(create_info).IsOK())
|
||||
{
|
||||
(*m_graphNodeFactoryMap)[create_info.kernel_def.get()] = registration;
|
||||
}
|
||||
onnxruntime::KernelDef* kernelDef = create_info.kernel_def.get();
|
||||
THROW_IF_NOT_OK(m_kernelRegistry->RegisterCustomKernel(create_info));
|
||||
(*m_graphNodeFactoryMap)[kernelDef] = registration;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
|||
|
|
@ -90,5 +90,15 @@ TEST(KernelRegistryTests, two_versions3) {
|
|||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(1).Provider(kCpuExecutionProvider).Build());
|
||||
Status st;
|
||||
ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
|
||||
}
|
||||
|
||||
//One op two versions
|
||||
TEST(KernelRegistryTests, two_versions4) {
|
||||
KernelRegistry r;
|
||||
std::vector<std::unique_ptr<KernelDef> > function_table;
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(5,6).Provider(kCpuExecutionProvider).Build());
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6,7).Provider(kCpuExecutionProvider).Build());
|
||||
Status st;
|
||||
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK());
|
||||
}
|
||||
Loading…
Reference in a new issue