mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
String to string label encoder (#15379)
This commit is contained in:
parent
ea6b32fea8
commit
a7d321e9dc
3 changed files with 42 additions and 0 deletions
|
|
@ -2337,6 +2337,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_int64, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_string, LabelEncoder);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleClassifier);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleClassifier);
|
||||
|
|
@ -2431,6 +2432,8 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_int64,
|
||||
LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_string,
|
||||
LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float,
|
||||
TreeEnsembleClassifier)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double,
|
||||
|
|
|
|||
|
|
@ -134,6 +134,23 @@ void LabelEncoder_2<float, std::int64_t>::InitializeSomeFields(const OpKernelInf
|
|||
info.GetAttrOrDefault<std::int64_t>("default_int64", &_default_value, (std::int64_t)-1);
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
LabelEncoder,
|
||||
2,
|
||||
string_string,
|
||||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>()})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>()}),
|
||||
LabelEncoder_2<std::string, std::string>)
|
||||
|
||||
template <>
|
||||
void LabelEncoder_2<std::string, std::string>::InitializeSomeFields(const OpKernelInfo& info) {
|
||||
_key_field_name = "keys_strings";
|
||||
_value_field_name = "values_strings";
|
||||
info.GetAttrOrDefault<std::string>("default_string", &_default_value, std::string("_Unused"));
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
LabelEncoder,
|
||||
2,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ static void RunTest(const std::vector<int64_t>& dims, const std::vector<TInput>&
|
|||
test.Run();
|
||||
}
|
||||
|
||||
|
||||
TEST(LabelEncoder, StringToInt) {
|
||||
std::vector<int64_t> dims{2, 2, 2};
|
||||
|
||||
|
|
@ -189,5 +190,26 @@ TEST(LabelEncoder, Int64ToInt64Opset2) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(LabelEncoder, StringToStringOpset2) {
|
||||
std::vector<std::int64_t> dims{1, 5};
|
||||
|
||||
std::vector<std::string> input{"A", "A", "C", "D", "E"};
|
||||
std::vector<std::string> output{"X", "X", "Z", "!", "!"};
|
||||
|
||||
OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain);
|
||||
|
||||
const std::vector<std::string> keys{"A", "B", "C"};
|
||||
const std::vector<std::string> values{"X", "Y", "Z"};
|
||||
|
||||
test.AddAttribute("keys_strings", keys);
|
||||
test.AddAttribute("values_strings", values);
|
||||
test.AddAttribute("default_string", "!");
|
||||
|
||||
test.AddInput<std::string>("X", dims, input);
|
||||
test.AddOutput<std::string>("Y", dims, output);
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue