diff --git a/onnxruntime/core/framework/kernel_def_builder.cc b/onnxruntime/core/framework/kernel_def_builder.cc index 0d475eab46..6270e30d3e 100644 --- a/onnxruntime/core/framework/kernel_def_builder.cc +++ b/onnxruntime/core/framework/kernel_def_builder.cc @@ -28,10 +28,18 @@ inline bool AreVectorsOverlap(const std::vector& v1, const std::vector& v2 bool KernelDef::IsConflict(const KernelDef& other) const { if (op_name_ != other.OpName() || provider_type_ != other.Provider()) return false; - int start = 0; - int end = 0; - other.SinceVersion(&start, &end); - if (!AreIntervalsOverlap(op_since_version_start_, op_since_version_end_, start, end)) + int other_since_version_start = 0; + int other_since_version_end = 0; + other.SinceVersion(&other_since_version_start, &other_since_version_end); + + //When max version is INT_MAX, it means that it should be determined based on the + //SinceVersion of schema from a higher version. Since this sometimes isn't known until + //all custom schema are available, make a conservative assumption here that the operator + //is valid for only one version. + int op_since_version_conservative_end = (op_since_version_end_ == INT_MAX) ? op_since_version_start_ : op_since_version_end_; + int other_conservative_since_version_end = (other_since_version_end == INT_MAX) ? other_since_version_start : other_since_version_end; + + if (!AreIntervalsOverlap(op_since_version_start_, op_since_version_conservative_end, other_since_version_start, other_conservative_since_version_end)) return false; //only one case they don't conflict: //There is a type_constraint, it exists in both hands, but they don't overlap diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index cdbfb759e6..5365d10c9f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -499,11 +499,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( registration->requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph; registration->requiredConstantCpuInputs = constantCpuInputCapture; - // TODO: Propagate errors here once the presence of overlapping built-in DML kernels is addressed - if (m_kernelRegistry->RegisterCustomKernel(create_info).IsOK()) - { - (*m_graphNodeFactoryMap)[create_info.kernel_def.get()] = registration; - } + onnxruntime::KernelDef* kernelDef = create_info.kernel_def.get(); + THROW_IF_NOT_OK(m_kernelRegistry->RegisterCustomKernel(create_info)); + (*m_graphNodeFactoryMap)[kernelDef] = registration; } else { diff --git a/onnxruntime/test/framework/kernel_registry_test.cc b/onnxruntime/test/framework/kernel_registry_test.cc index c186a0b035..d59de5cf6f 100644 --- a/onnxruntime/test/framework/kernel_registry_test.cc +++ b/onnxruntime/test/framework/kernel_registry_test.cc @@ -90,5 +90,15 @@ TEST(KernelRegistryTests, two_versions3) { function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(1).Provider(kCpuExecutionProvider).Build()); Status st; + ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage(); +} + +//One op two versions +TEST(KernelRegistryTests, two_versions4) { + KernelRegistry r; + std::vector > function_table; + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(5,6).Provider(kCpuExecutionProvider).Build()); + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6,7).Provider(kCpuExecutionProvider).Build()); + Status st; ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()); } \ No newline at end of file