diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index d08b82c014..fbb19fe332 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -594,7 +594,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int64_t_double, DictVectorizer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, FeatureVectorizer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, Imputer); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, LabelEncoder); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 1, LabelEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, float, LinearClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, double, LinearClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int64_t, LinearClassifier); @@ -621,6 +622,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, TreeEnsembleRegressor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, ZipMap); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_string, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_float, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_float, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_int64, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder); + void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, @@ -639,7 +647,7 @@ void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -665,6 +673,13 @@ void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/ml/label_encoder.cc b/onnxruntime/core/providers/cpu/ml/label_encoder.cc index 4a2ac686b4..b497300a72 100644 --- a/onnxruntime/core/providers/cpu/ml/label_encoder.cc +++ b/onnxruntime/core/providers/cpu/ml/label_encoder.cc @@ -9,15 +9,16 @@ using namespace ::onnxruntime::common; namespace onnxruntime { namespace ml { -ONNX_CPU_OPERATOR_ML_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL( LabelEncoder, - 1, + 1, 1, KernelDefBuilder().TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType()}) + .SinceVersion(1, 2), LabelEncoder); Status LabelEncoder::Compute(OpKernelContext* context) const { @@ -67,5 +68,107 @@ Status LabelEncoder::Compute(OpKernelContext* context) const { return Status::OK(); } +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + float_string, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2); + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_floats"; + _value_field_name = "values_strings"; + info.GetAttrOrDefault("default_string", &_default_value, std::string("_Unused")); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + string_float, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2); + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_strings"; + _value_field_name = "values_floats"; + info.GetAttrOrDefault("default_float", &_default_value, -0.0f); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + int64_float, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2); + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_int64s"; + _value_field_name = "values_floats"; + info.GetAttrOrDefault("default_float", &_default_value, -0.0f); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + float_int64, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2); + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_floats"; + _value_field_name = "values_int64s"; + info.GetAttrOrDefault("default_int64", &_default_value, (std::int64_t)-1); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + int64_string, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2) + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_int64s"; + _value_field_name = "values_strings"; + info.GetAttrOrDefault("default_string", &_default_value, std::string("_Unused")); +}; + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + LabelEncoder, + 2, + string_int64, + KernelDefBuilder().TypeConstraint("T1", + std::vector{DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", + std::vector{DataTypeImpl::GetTensorType()}), + LabelEncoder_2) + +template <> +void LabelEncoder_2::InitializeSomeFields(const OpKernelInfo& info) { + _key_field_name = "keys_strings"; + _value_field_name = "values_int64s"; + info.GetAttrOrDefault("default_int64", &_default_value, (std::int64_t)-1); +}; + } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/label_encoder.h b/onnxruntime/core/providers/cpu/ml/label_encoder.h index 597cf240c6..0f7c59b574 100644 --- a/onnxruntime/core/providers/cpu/ml/label_encoder.h +++ b/onnxruntime/core/providers/cpu/ml/label_encoder.h @@ -43,5 +43,67 @@ class LabelEncoder final : public OpKernel { int64_t default_int_; }; +template +class LabelEncoder_2 final : public OpKernel { + public: + LabelEncoder_2(const OpKernelInfo& info) : OpKernel(info) { + // Let the specialized member function to tell which fields to load. + InitializeSomeFields(info); + + std::vector keys; + std::vector values; + + ORT_ENFORCE(info.GetAttrs(_key_field_name, keys).IsOK()); + ORT_ENFORCE(info.GetAttrs(_value_field_name, values).IsOK()); + + auto num_keys = keys.size(); + auto num_values = values.size(); + ORT_ENFORCE(num_keys == num_values, + "The ", _key_field_name, " and ", _value_field_name, " attribtues in LabelEncoder ", + "(name: ", info.node().Name(), ") must have the same length. ", + "However, the number of key is ", num_keys, " and the number of ", + "values is ", num_values, "."); + + for (size_t i = 0; i < num_keys; ++i) + _map[keys[i]] = values[i]; + } + + Status Compute(OpKernelContext* context) const override { + const auto* tensor_pointer = context->Input(0); + if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const Tensor& X = *tensor_pointer; + const TensorShape& shape = X.Shape(); + Tensor& Y = *context->Output(0, TensorShape(shape)); + + auto input = X.template DataAsSpan(); + auto output = Y.template MutableDataAsSpan(); + + for (int64_t i = 0; i < shape.Size(); ++i) { + const auto found = _map.find(input[i]); + if (found == _map.end()) + output[i] = _default_value; + else + output[i] = found->second; + } + + return Status::OK(); + } + + private: + // Specialize this method to set attribute names. For example, if keys' type + // is 64-bit integer, _key_field_name should be "keys_int64s". Field names + // for other types can be found in ONNX spec. + void InitializeSomeFields(const OpKernelInfo& info); + + // A collection of key-value pairs. Each (a_key, a_value) pair + // means that the "a_key" in the input would be mapped to "a_value". + // If _map doesn't contain "a_key", we use _default_value as its output. + std::unordered_map _map; + TValue _default_value; + // ONNX attribute name to load keys. + std::string _key_field_name; + // ONNX attribute name to load values. + std::string _value_field_name; +}; } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc b/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc index c52f43d682..05abc8a917 100644 --- a/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc +++ b/onnxruntime/test/providers/cpu/ml/label_encoder_test.cc @@ -42,5 +42,131 @@ TEST(LabelEncoder, IntToString) { RunTest(dims, input, output); } +TEST(LabelEncoder, StringToIntOpset2) { + std::vector dims{1, 5}; + + std::vector input{"AA", "BB", "CC", "DD", "AA"}; + std::vector output{9, 1, 5566, 4, 9}; + + OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain); + + const std::vector keys{"AA", "BB", "DD"}; + const std::vector values{9, 1, 4}; + + test.AddAttribute("keys_strings", keys); + test.AddAttribute("values_int64s", values); + test.AddAttribute("default_int64", (std::int64_t)5566); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, IntToStringOpset2) { + std::vector dims{1, 5}; + + std::vector input{9, 1, 5566, 4, 9}; + std::vector output{"AA", "BB", "CC", "DD", "AA"}; + + OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain); + + const std::vector keys{9, 1, 4}; + const std::vector values{"AA", "BB", "DD"}; + + test.AddAttribute("keys_int64s", keys); + test.AddAttribute("values_strings", values); + test.AddAttribute("default_string", "CC"); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, FloatToStringOpset2) { + std::vector dims{5, 1}; + + std::vector input{9.4f, 1.7f, 3.6f, 1.2f, 2.8f}; + std::vector output{"AA", "BB", "DD", "CC", "CC"}; + + OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain); + + const std::vector keys{9.4f, 1.7f, 3.6f}; + const std::vector values{"AA", "BB", "DD"}; + + test.AddAttribute("keys_floats", keys); + test.AddAttribute("values_strings", values); + test.AddAttribute("default_string", "CC"); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, StringToFloatOpset2) { + std::vector dims{5, 1}; + + std::vector input{"AA", "BB", "DD", "CC", "CC"}; + std::vector output{9.4f, 1.7f, 3.6f, 55.66f, 55.66f}; + + OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain); + + const std::vector keys{"AA", "BB", "DD"}; + const std::vector values{9.4f, 1.7f, 3.6f}; + + test.AddAttribute("keys_strings", keys); + test.AddAttribute("values_floats", values); + test.AddAttribute("default_float", 55.66f); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, FloatToInt64Opset2) { + std::vector dims{5}; + + std::vector input{9.4f, 1.7f, 3.6f, 55.66f, 55.66f}; + std::vector output{1, 9, 3, -8, -8}; + + OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain); + + const std::vector keys{9.4f, 1.7f, 3.6f}; + const std::vector values{1, 9, 3}; + + test.AddAttribute("keys_floats", keys); + test.AddAttribute("values_int64s", values); + test.AddAttribute("default_int64", (std::int64_t)-8); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + +TEST(LabelEncoder, Int64ToFloatOpset2) { + std::vector dims{5}; + + std::vector input{3, 1, 9, -8, -8}; + std::vector output{3.6f, 9.4f, 1.7f, 55.66f, 55.66f}; + + OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain); + + const std::vector keys{1, 9, 3}; + const std::vector values{9.4f, 1.7f, 3.6f}; + + test.AddAttribute("keys_int64s", keys); + test.AddAttribute("values_floats", values); + test.AddAttribute("default_float", 55.66f); + + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); + + test.Run(); +} + } // namespace test } // namespace onnxruntime