mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Implement new LabelEncoder in opset 2 in ML domain (#1393)
* Implement new LabelEncoder in opset 2 in ML domain * Fix compilation error * Fix tests * Include ONNX's fix * Formatting and addressing a comment * Address a minor comment
This commit is contained in:
parent
6d783e8a07
commit
0187d876cb
4 changed files with 311 additions and 5 deletions
|
|
@ -594,7 +594,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int64_t_double, DictVectorizer);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, FeatureVectorizer);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, Imputer);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, LabelEncoder);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 1, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, float, LinearClassifier);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, double, LinearClassifier);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int64_t, LinearClassifier);
|
||||
|
|
@ -621,6 +622,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1,
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, TreeEnsembleRegressor);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, ZipMap);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_string, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_float, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_float, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_int64, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder);
|
||||
|
||||
void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, float, ArrayFeatureExtractor)>,
|
||||
|
|
@ -639,7 +647,7 @@ void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int64_t_double, DictVectorizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, FeatureVectorizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, Imputer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 1, LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, float, LinearClassifier)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, double, LinearClassifier)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int64_t, LinearClassifier)>,
|
||||
|
|
@ -665,6 +673,13 @@ void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int32_t, TreeEnsembleClassifier)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, TreeEnsembleRegressor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, ZipMap)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_string, LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_float, LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_float, LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_int64, LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -9,15 +9,16 @@ using namespace ::onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
namespace ml {
|
||||
|
||||
ONNX_CPU_OPERATOR_ML_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL(
|
||||
LabelEncoder,
|
||||
1,
|
||||
1, 1,
|
||||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
DataTypeImpl::GetTensorType<int64_t>()})
|
||||
.SinceVersion(1, 2),
|
||||
LabelEncoder);
|
||||
|
||||
Status LabelEncoder::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -67,5 +68,107 @@ Status LabelEncoder::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
LabelEncoder,
|
||||
2,
|
||||
float_string,
|
||||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>()})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>()}),
|
||||
LabelEncoder_2<float, std::string>);
|
||||
|
||||
template <>
|
||||
void LabelEncoder_2<float, std::string>::InitializeSomeFields(const OpKernelInfo& info) {
|
||||
_key_field_name = "keys_floats";
|
||||
_value_field_name = "values_strings";
|
||||
info.GetAttrOrDefault<std::string>("default_string", &_default_value, std::string("_Unused"));
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
LabelEncoder,
|
||||
2,
|
||||
string_float,
|
||||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>()})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>()}),
|
||||
LabelEncoder_2<std::string, float>);
|
||||
|
||||
template <>
|
||||
void LabelEncoder_2<std::string, float>::InitializeSomeFields(const OpKernelInfo& info) {
|
||||
_key_field_name = "keys_strings";
|
||||
_value_field_name = "values_floats";
|
||||
info.GetAttrOrDefault<float>("default_float", &_default_value, -0.0f);
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
LabelEncoder,
|
||||
2,
|
||||
int64_float,
|
||||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::int64_t>()})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>()}),
|
||||
LabelEncoder_2<std::int64_t, float>);
|
||||
|
||||
template <>
|
||||
void LabelEncoder_2<std::int64_t, float>::InitializeSomeFields(const OpKernelInfo& info) {
|
||||
_key_field_name = "keys_int64s";
|
||||
_value_field_name = "values_floats";
|
||||
info.GetAttrOrDefault<float>("default_float", &_default_value, -0.0f);
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
LabelEncoder,
|
||||
2,
|
||||
float_int64,
|
||||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>()})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::int64_t>()}),
|
||||
LabelEncoder_2<float, std::int64_t>);
|
||||
|
||||
template <>
|
||||
void LabelEncoder_2<float, std::int64_t>::InitializeSomeFields(const OpKernelInfo& info) {
|
||||
_key_field_name = "keys_floats";
|
||||
_value_field_name = "values_int64s";
|
||||
info.GetAttrOrDefault<std::int64_t>("default_int64", &_default_value, (std::int64_t)-1);
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
LabelEncoder,
|
||||
2,
|
||||
int64_string,
|
||||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::int64_t>()})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>()}),
|
||||
LabelEncoder_2<std::int64_t, std::string>)
|
||||
|
||||
template <>
|
||||
void LabelEncoder_2<std::int64_t, std::string>::InitializeSomeFields(const OpKernelInfo& info) {
|
||||
_key_field_name = "keys_int64s";
|
||||
_value_field_name = "values_strings";
|
||||
info.GetAttrOrDefault<std::string>("default_string", &_default_value, std::string("_Unused"));
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
LabelEncoder,
|
||||
2,
|
||||
string_int64,
|
||||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::string>()})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<std::int64_t>()}),
|
||||
LabelEncoder_2<std::string, std::int64_t>)
|
||||
|
||||
template <>
|
||||
void LabelEncoder_2<std::string, std::int64_t>::InitializeSomeFields(const OpKernelInfo& info) {
|
||||
_key_field_name = "keys_strings";
|
||||
_value_field_name = "values_int64s";
|
||||
info.GetAttrOrDefault<std::int64_t>("default_int64", &_default_value, (std::int64_t)-1);
|
||||
};
|
||||
|
||||
} // namespace ml
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -43,5 +43,67 @@ class LabelEncoder final : public OpKernel {
|
|||
int64_t default_int_;
|
||||
};
|
||||
|
||||
template <typename TKey, typename TValue>
|
||||
class LabelEncoder_2 final : public OpKernel {
|
||||
public:
|
||||
LabelEncoder_2(const OpKernelInfo& info) : OpKernel(info) {
|
||||
// Let the specialized member function to tell which fields to load.
|
||||
InitializeSomeFields(info);
|
||||
|
||||
std::vector<TKey> keys;
|
||||
std::vector<TValue> values;
|
||||
|
||||
ORT_ENFORCE(info.GetAttrs<TKey>(_key_field_name, keys).IsOK());
|
||||
ORT_ENFORCE(info.GetAttrs<TValue>(_value_field_name, values).IsOK());
|
||||
|
||||
auto num_keys = keys.size();
|
||||
auto num_values = values.size();
|
||||
ORT_ENFORCE(num_keys == num_values,
|
||||
"The ", _key_field_name, " and ", _value_field_name, " attribtues in LabelEncoder ",
|
||||
"(name: ", info.node().Name(), ") must have the same length. ",
|
||||
"However, the number of key is ", num_keys, " and the number of ",
|
||||
"values is ", num_values, ".");
|
||||
|
||||
for (size_t i = 0; i < num_keys; ++i)
|
||||
_map[keys[i]] = values[i];
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
const auto* tensor_pointer = context->Input<Tensor>(0);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
const Tensor& X = *tensor_pointer;
|
||||
const TensorShape& shape = X.Shape();
|
||||
Tensor& Y = *context->Output(0, TensorShape(shape));
|
||||
|
||||
auto input = X.template DataAsSpan<TKey>();
|
||||
auto output = Y.template MutableDataAsSpan<TValue>();
|
||||
|
||||
for (int64_t i = 0; i < shape.Size(); ++i) {
|
||||
const auto found = _map.find(input[i]);
|
||||
if (found == _map.end())
|
||||
output[i] = _default_value;
|
||||
else
|
||||
output[i] = found->second;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Specialize this method to set attribute names. For example, if keys' type
|
||||
// is 64-bit integer, _key_field_name should be "keys_int64s". Field names
|
||||
// for other types can be found in ONNX spec.
|
||||
void InitializeSomeFields(const OpKernelInfo& info);
|
||||
|
||||
// A collection of key-value pairs. Each (a_key, a_value) pair
|
||||
// means that the "a_key" in the input would be mapped to "a_value".
|
||||
// If _map doesn't contain "a_key", we use _default_value as its output.
|
||||
std::unordered_map<TKey, TValue> _map;
|
||||
TValue _default_value;
|
||||
// ONNX attribute name to load keys.
|
||||
std::string _key_field_name;
|
||||
// ONNX attribute name to load values.
|
||||
std::string _value_field_name;
|
||||
};
|
||||
} // namespace ml
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -42,5 +42,131 @@ TEST(LabelEncoder, IntToString) {
|
|||
RunTest(dims, input, output);
|
||||
}
|
||||
|
||||
TEST(LabelEncoder, StringToIntOpset2) {
|
||||
std::vector<std::int64_t> dims{1, 5};
|
||||
|
||||
std::vector<std::string> input{"AA", "BB", "CC", "DD", "AA"};
|
||||
std::vector<std::int64_t> output{9, 1, 5566, 4, 9};
|
||||
|
||||
OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain);
|
||||
|
||||
const std::vector<std::string> keys{"AA", "BB", "DD"};
|
||||
const std::vector<std::int64_t> values{9, 1, 4};
|
||||
|
||||
test.AddAttribute("keys_strings", keys);
|
||||
test.AddAttribute("values_int64s", values);
|
||||
test.AddAttribute("default_int64", (std::int64_t)5566);
|
||||
|
||||
test.AddInput<std::string>("X", dims, input);
|
||||
test.AddOutput<std::int64_t>("Y", dims, output);
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(LabelEncoder, IntToStringOpset2) {
|
||||
std::vector<std::int64_t> dims{1, 5};
|
||||
|
||||
std::vector<std::int64_t> input{9, 1, 5566, 4, 9};
|
||||
std::vector<std::string> output{"AA", "BB", "CC", "DD", "AA"};
|
||||
|
||||
OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain);
|
||||
|
||||
const std::vector<std::int64_t> keys{9, 1, 4};
|
||||
const std::vector<std::string> values{"AA", "BB", "DD"};
|
||||
|
||||
test.AddAttribute("keys_int64s", keys);
|
||||
test.AddAttribute("values_strings", values);
|
||||
test.AddAttribute<std::string>("default_string", "CC");
|
||||
|
||||
test.AddInput<std::int64_t>("X", dims, input);
|
||||
test.AddOutput<std::string>("Y", dims, output);
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(LabelEncoder, FloatToStringOpset2) {
|
||||
std::vector<std::int64_t> dims{5, 1};
|
||||
|
||||
std::vector<float> input{9.4f, 1.7f, 3.6f, 1.2f, 2.8f};
|
||||
std::vector<std::string> output{"AA", "BB", "DD", "CC", "CC"};
|
||||
|
||||
OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain);
|
||||
|
||||
const std::vector<float> keys{9.4f, 1.7f, 3.6f};
|
||||
const std::vector<std::string> values{"AA", "BB", "DD"};
|
||||
|
||||
test.AddAttribute("keys_floats", keys);
|
||||
test.AddAttribute("values_strings", values);
|
||||
test.AddAttribute<std::string>("default_string", "CC");
|
||||
|
||||
test.AddInput<float>("X", dims, input);
|
||||
test.AddOutput<std::string>("Y", dims, output);
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(LabelEncoder, StringToFloatOpset2) {
|
||||
std::vector<std::int64_t> dims{5, 1};
|
||||
|
||||
std::vector<std::string> input{"AA", "BB", "DD", "CC", "CC"};
|
||||
std::vector<float> output{9.4f, 1.7f, 3.6f, 55.66f, 55.66f};
|
||||
|
||||
OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain);
|
||||
|
||||
const std::vector<std::string> keys{"AA", "BB", "DD"};
|
||||
const std::vector<float> values{9.4f, 1.7f, 3.6f};
|
||||
|
||||
test.AddAttribute("keys_strings", keys);
|
||||
test.AddAttribute("values_floats", values);
|
||||
test.AddAttribute("default_float", 55.66f);
|
||||
|
||||
test.AddInput<std::string>("X", dims, input);
|
||||
test.AddOutput<float>("Y", dims, output);
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(LabelEncoder, FloatToInt64Opset2) {
|
||||
std::vector<std::int64_t> dims{5};
|
||||
|
||||
std::vector<float> input{9.4f, 1.7f, 3.6f, 55.66f, 55.66f};
|
||||
std::vector<std::int64_t> output{1, 9, 3, -8, -8};
|
||||
|
||||
OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain);
|
||||
|
||||
const std::vector<float> keys{9.4f, 1.7f, 3.6f};
|
||||
const std::vector<std::int64_t> values{1, 9, 3};
|
||||
|
||||
test.AddAttribute("keys_floats", keys);
|
||||
test.AddAttribute("values_int64s", values);
|
||||
test.AddAttribute("default_int64", (std::int64_t)-8);
|
||||
|
||||
test.AddInput<float>("X", dims, input);
|
||||
test.AddOutput<std::int64_t>("Y", dims, output);
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(LabelEncoder, Int64ToFloatOpset2) {
|
||||
std::vector<std::int64_t> dims{5};
|
||||
|
||||
std::vector<std::int64_t> input{3, 1, 9, -8, -8};
|
||||
std::vector<float> output{3.6f, 9.4f, 1.7f, 55.66f, 55.66f};
|
||||
|
||||
OpTester test("LabelEncoder", 2, onnxruntime::kMLDomain);
|
||||
|
||||
const std::vector<std::int64_t> keys{1, 9, 3};
|
||||
const std::vector<float> values{9.4f, 1.7f, 3.6f};
|
||||
|
||||
test.AddAttribute("keys_int64s", keys);
|
||||
test.AddAttribute("values_floats", values);
|
||||
test.AddAttribute("default_float", 55.66f);
|
||||
|
||||
test.AddInput<std::int64_t>("X", dims, input);
|
||||
test.AddOutput<float>("Y", dims, output);
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue