String to string label encoder (#15379)

This commit is contained in:
Aditya Goel 2023-04-05 22:04:34 +01:00 committed by GitHub
parent ea6b32fea8
commit a7d321e9dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 0 deletions

View file

@ -2337,6 +2337,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_int64, LabelEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_string, LabelEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleClassifier);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleClassifier);
@ -2431,6 +2432,8 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
LabelEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_int64,
LabelEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_string,
LabelEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float,
TreeEnsembleClassifier)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double,

View file

@ -134,6 +134,23 @@ void LabelEncoder_2<float, std::int64_t>::InitializeSomeFields(const OpKernelInf
info.GetAttrOrDefault<std::int64_t>("default_int64", &_default_value, (std::int64_t)-1);
};
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
LabelEncoder,
2,
string_string,
KernelDefBuilder().TypeConstraint("T1",
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>()})
.TypeConstraint("T2",
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>()}),
LabelEncoder_2<std::string, std::string>)
template <>
void LabelEncoder_2<std::string, std::string>::InitializeSomeFields(const OpKernelInfo& info) {
_key_field_name = "keys_strings";
_value_field_name = "values_strings";
info.GetAttrOrDefault<std::string>("default_string", &_default_value, std::string("_Unused"));
};
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
LabelEncoder,
2,

View file

@ -24,6 +24,7 @@ static void RunTest(const std::vector<int64_t>& dims, const std::vector<TInput>&
test.Run();
}
TEST(LabelEncoder, StringToInt) {
std::vector<int64_t> dims{2, 2, 2};
@ -189,5 +190,26 @@ TEST(LabelEncoder, Int64ToInt64Opset2) {
test.Run();
}
TEST(LabelEncoder, StringToStringOpset2) {
std::vector<std::int64_t> dims{1, 5};
std::vector<std::string> input{"A", "A", "C", "D", "E"};
std::vector<std::string> output{"X", "X", "Z", "!", "!"};
OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain);
const std::vector<std::string> keys{"A", "B", "C"};
const std::vector<std::string> values{"X", "Y", "Z"};
test.AddAttribute("keys_strings", keys);
test.AddAttribute("values_strings", values);
test.AddAttribute("default_string", "!");
test.AddInput<std::string>("X", dims, input);
test.AddOutput<std::string>("Y", dims, output);
test.Run();
}
} // namespace test
} // namespace onnxruntime