Fix kernel registry validation to reenable DML kernels

This commit is contained in:
Jeff Bloomfield 2019-11-26 19:26:10 -08:00 committed by Changming Sun
parent ddaad86605
commit b9faa0b6fd
3 changed files with 25 additions and 9 deletions

View file

@ -28,10 +28,18 @@ inline bool AreVectorsOverlap(const std::vector<T>& v1, const std::vector<T>& v2
bool KernelDef::IsConflict(const KernelDef& other) const {
if (op_name_ != other.OpName() || provider_type_ != other.Provider())
return false;
int start = 0;
int end = 0;
other.SinceVersion(&start, &end);
if (!AreIntervalsOverlap(op_since_version_start_, op_since_version_end_, start, end))
int other_since_version_start = 0;
int other_since_version_end = 0;
other.SinceVersion(&other_since_version_start, &other_since_version_end);
//When max version is INT_MAX, it means that it should be determined based on the
//SinceVersion of schema from a higher version. Since this sometimes isn't known until
//all custom schema are available, make a conservative assumption here that the operator
//is valid for only one version.
int op_since_version_conservative_end = (op_since_version_end_ == INT_MAX) ? op_since_version_start_ : op_since_version_end_;
int other_conservative_since_version_end = (other_since_version_end == INT_MAX) ? other_since_version_start : other_since_version_end;
if (!AreIntervalsOverlap(op_since_version_start_, op_since_version_conservative_end, other_since_version_start, other_conservative_since_version_end))
return false;
//only one case they don't conflict:
//There is a type_constraint, it exists in both hands, but they don't overlap

View file

@ -499,11 +499,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
registration->requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph;
registration->requiredConstantCpuInputs = constantCpuInputCapture;
// TODO: Propagate errors here once the presence of overlapping built-in DML kernels is addressed
if (m_kernelRegistry->RegisterCustomKernel(create_info).IsOK())
{
(*m_graphNodeFactoryMap)[create_info.kernel_def.get()] = registration;
}
onnxruntime::KernelDef* kernelDef = create_info.kernel_def.get();
THROW_IF_NOT_OK(m_kernelRegistry->RegisterCustomKernel(create_info));
(*m_graphNodeFactoryMap)[kernelDef] = registration;
}
else
{

View file

@ -90,5 +90,15 @@ TEST(KernelRegistryTests, two_versions3) {
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(1).Provider(kCpuExecutionProvider).Build());
Status st;
ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
}
//One op two versions
TEST(KernelRegistryTests, two_versions4) {
KernelRegistry r;
std::vector<std::unique_ptr<KernelDef> > function_table;
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(5,6).Provider(kCpuExecutionProvider).Build());
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6,7).Provider(kCpuExecutionProvider).Build());
Status st;
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK());
}