diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index cabf6d0553..3aeb6f651d 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2337,6 +2337,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_int64, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_string, LabelEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleClassifier); @@ -2431,6 +2432,8 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { LabelEncoder)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo::InitializeSomeFields(const OpKernelInf info.GetAttrOrDefault("default_int64", &_default_value, (std::int64_t)-1); }; +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + string_string, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2) + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_strings"; + _value_field_name = "values_strings"; + info.GetAttrOrDefault("default_string", &_default_value, std::string("_Unused")); +}; + ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( LabelEncoder, 2, diff --git a/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc b/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc index 3f637317c4..4f04cd1d45 100644 --- a/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc +++ b/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc @@ -24,6 +24,7 @@ static void RunTest(const std::vector& dims, const std::vector& test.Run(); } + TEST(LabelEncoder, StringToInt) { std::vector dims{2, 2, 2}; @@ -189,5 +190,26 @@ TEST(LabelEncoder, Int64ToInt64Opset2) { test.Run(); } +TEST(LabelEncoder, StringToStringOpset2) { + std::vector dims{1, 5}; + + std::vector input{"A", "A", "C", "D", "E"}; + std::vector output{"X", "X", "Z", "!", "!"}; + + OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain); + + const std::vector keys{"A", "B", "C"}; + const std::vector values{"X", "Y", "Z"}; + + test.AddAttribute("keys_strings", keys); + test.AddAttribute("values_strings", values); + test.AddAttribute("default_string", "!"); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + } // namespace test } // namespace onnxruntime