diff --git a/onnxruntime/contrib_ops/contrib_kernels.cc b/onnxruntime/contrib_ops/contrib_kernels.cc index 98aaef9c0f..47ccb903af 100644 --- a/onnxruntime/contrib_ops/contrib_kernels.cc +++ b/onnxruntime/contrib_ops/contrib_kernels.cc @@ -12,6 +12,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, Ngram); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int64_t, Ngram); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear); @@ -32,6 +35,9 @@ void RegisterContribKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); diff --git a/onnxruntime/contrib_ops/cpu/ngram.cc b/onnxruntime/contrib_ops/cpu/ngram.cc new file mode 100644 index 0000000000..74117806f5 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/ngram.cc @@ -0,0 +1,582 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ngram.h" +#include "onnx/defs/schema.h" +#include "core/common/common.h" +#include "core/framework/tensor.h" + +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + Ngram, + 1, + string, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + contrib::Ngram); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + Ngram, + 1, + int32_t, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + contrib::Ngram); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + Ngram, + 1, + int64_t, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + contrib::Ngram); + +namespace ngram_details { + +class NgramEntryBase { + size_t id_; // Id in the pool + protected: + NgramEntryBase(size_t id) : id_(id) {} + ~NgramEntryBase() = default; + + public: + size_t Id() const { return id_; } +}; + +template +class NgramEntry; + +template <> +class NgramEntry : public NgramEntryBase { + std::vector items_; + size_t hash_ = 0; + + void RunningHash(int64_t v) { + std::hash hf{}; + hash_ ^= hf(v) + 0x9e3779b9 + (hash_ << 6) + (hash_ >> 2); + } + + public: + template + explicit NgramEntry(size_t id, ForwardIter first, ForwardIter last) : NgramEntryBase(id) { + while (first != last) { + RunningHash(*first); + items_.push_back(*first); + ++first; + } + assert(!items_.empty()); + } + // For sampling + explicit NgramEntry() : NgramEntryBase(0) {} + void AddItem(int64_t v) { + items_.push_back(v); + RunningHash(v); + } + void DebugPrint() const { + std::copy(items_.cbegin(), items_.cend(), std::ostream_iterator(std::cout, ",")); + std::cout << std::endl; + } + void Clear() { + items_.clear(); + hash_ = 0; + } + bool operator==(const NgramEntry& o) const { + return items_ == o.items_; + } + size_t Hash() const { + return hash_; + } +}; + +template <> +class NgramEntry : public NgramEntry { + public: + template + explicit NgramEntry(size_t id, ForwardIter first, ForwardIter last) : NgramEntry(id, first, last) {} + explicit NgramEntry() = default; +}; + +template <> +class NgramEntry : public NgramEntryBase { + private: + std::vector> items_; + size_t hash_ = 0; + + void RunningHash(const std::string& s) { + std::hash hf{}; + hash_ ^= hf(s) + 0x9e3779b9 + (hash_ << 6) + (hash_ >> 2); + } + + public: + template + explicit NgramEntry(size_t id, ForwardIter first, ForwardIter last) : NgramEntryBase(id) { + while (first != last) { + RunningHash(*first); + items_.push_back(std::cref(*first)); + ++first; + } + assert(!items_.empty()); + } + explicit NgramEntry() : NgramEntryBase(0) {} + void AddItem(const std::string& s) { + items_.push_back(std::cref(s)); + RunningHash(s); + } + void DebugPrint() const { + std::copy(items_.cbegin(), items_.cend(), std::ostream_iterator(std::cout, ",")); + std::cout << std::endl; + } + void Clear() { + items_.clear(); + hash_ = 0; + } + + bool operator==(const NgramEntry& o) const { + if (items_.size() == o.items_.size()) { + return std::equal(items_.cbegin(), items_.cend(), + o.items_.cbegin(), o.items_.cend(), + std::equal_to()); + } + return false; + } + size_t Hash() const { + return hash_; + } +}; + +using IntegerPoolSet = std::unordered_set>; +// Does not own strings, contains references to them. This helps +// to search by string references that point to the current input. +using StringPoolSet = std::unordered_set>; + +template +inline void Emplace(ForwardIter first, size_t ngrams, size_t ngram_size, size_t& ngram_id, Cont& c) { + for (; ngrams > 0; --ngrams) { + c.emplace(ngram_id, first, first + ngram_size); + first += ngram_size; + ++ngram_id; + } +} + +} // namespace ngram_details +} // namespace contrib +} // namespace onnxruntime + +using namespace onnxruntime::contrib::ngram_details; + +namespace std { +template +struct hash> { + typedef NgramEntry argument_type; + typedef size_t result_type; + result_type operator()(const argument_type& a) const { + return a.Hash(); + } +}; +} // namespace std + +namespace onnxruntime { +namespace contrib { + +// The weighting criteria. +// "TF"(term frequency), +// the counts are propagated to output +// "IDF"(inverse document frequency), +// all the counts larger than 1 +// would be truncated to 1 and the i-th element +// in weights would be used to scale (by multiplication) +// the count of the i-th n-gram in pool +// "TFIDF" (the combination of TF and IDF). +// counts are scaled by the associated values in the weights attribute. + +enum WeightingCriteria { + kNone = 0, + kTF = 1, + kIDF = 2, + kTFIDF = 3 +}; + +struct Ngram::Impl { + WeightingCriteria weighting_criteria_ = kNone; + int64_t max_gram_length_ = 0; + int64_t min_gram_length_ = 0; + int64_t max_skip_count_ = 0; + // This is the content of ngram_counts attribute. + // The starting indexes of 1-grams, 2-grams, + // and so on in pool. For example, if ngram_counts is [0, 17, 36], + // the first index (zero-based) of 1-gram/2-gram/3-gram + // in pool are 0/17/36. + std::vector ngram_counts_; + // Contains output indexes + // represents ngram_indexes output + std::vector ngram_indexes_; + std::vector weights_; + + std::vector pool_strings_; + // This set contains references to pool_string_ entries + // of pool_strings attribute + StringPoolSet str_set_; + // This set contains pool_int64s entries + IntegerPoolSet int64_set_; + size_t output_size_ = 0; + + Impl() = default; + ~Impl() = default; + Impl(const Impl&) = delete; + Impl& operator=(const Impl&) = delete; + + template + auto PoolEnd() const; + + template + auto PoolFind(const ngram_details::NgramEntry&) const; + + void IncrementCount(size_t ngram_id, size_t row_num, + std::vector& frequencies) const { + assert(ngram_id < ngram_indexes_.size()); + auto output_idx = row_num * output_size_ + ngram_indexes_[ngram_id]; + assert(static_cast(output_idx) < frequencies.size()); + ++frequencies[output_idx]; + } +}; + +template <> +inline auto Ngram::Impl::PoolEnd() const { + return int64_set_.cend(); +} + +template <> +inline auto Ngram::Impl::PoolEnd() const { + return PoolEnd(); +} + +template <> +inline auto Ngram::Impl::PoolEnd() const { + return str_set_.cend(); +} + +template <> +inline auto Ngram::Impl::PoolFind(const NgramEntry& i) const { + return int64_set_.find(i); +} + +template <> +inline auto Ngram::Impl::PoolFind(const NgramEntry& i) const { + return int64_set_.find(i); +} + +template <> +inline auto Ngram::Impl::PoolFind(const NgramEntry& i) const { + return str_set_.find(i); +} + +Ngram::Ngram(const OpKernelInfo& info) : OpKernel(info), impl_(new Impl) { + std::string mode; + Status status = info.GetAttr("mode", &mode); + ORT_ENFORCE(status.IsOK(), "mode is required"); + if (mode == "TF") { + impl_->weighting_criteria_ = kTF; + } else if (mode == "IDF") { + impl_->weighting_criteria_ = kIDF; + } else if (mode == "TFIDF") { + impl_->weighting_criteria_ = kTFIDF; + } + ORT_ENFORCE(impl_->weighting_criteria_ != kNone, "mode: ", mode, " is unrecognized, acceptable values are TF,IDF,TFIDF"); + + status = info.GetAttr("min_gram_length", &impl_->min_gram_length_); + ORT_ENFORCE(status.IsOK(), "min_gram_length is required"); + ORT_ENFORCE(impl_->min_gram_length_ > 0, "Required min_gram_length must be positive: ", std::to_string(impl_->min_gram_length_)); + + status = info.GetAttr("max_gram_length", &impl_->max_gram_length_); + ORT_ENFORCE(status.IsOK(), "min_gram_length is required"); + ORT_ENFORCE(impl_->max_gram_length_ >= impl_->min_gram_length_, + "min_gram_length >= max_gram_length required: ", + std::to_string(impl_->max_gram_length_), " >= ", std::to_string(impl_->min_gram_length_)); + + status = info.GetAttr("max_skip_count", &impl_->max_skip_count_); + ORT_ENFORCE(status.IsOK(), "max_skip_count is required"); + ORT_ENFORCE(impl_->max_skip_count_ >= 0, "max_skip_count must be non-negative: ", std::to_string(impl_->max_skip_count_)); + + status = info.GetAttrs(std::string("ngram_counts"), impl_->ngram_counts_); + ORT_ENFORCE(status.IsOK() && !impl_->ngram_counts_.empty(), "Non-empty ngram_counts is required"); + ORT_ENFORCE(size_t(impl_->min_gram_length_) <= impl_->ngram_counts_.size(), + "min_gram_length must be inbounds of ngram_counts: ", + std::to_string(impl_->min_gram_length_), " <= ", std::to_string(impl_->ngram_counts_.size())); + ORT_ENFORCE(size_t(impl_->max_gram_length_) <= impl_->ngram_counts_.size(), + "max_gram_length must be inbounds of ngram_counts: ", + std::to_string(impl_->max_gram_length_), " <= ", std::to_string(impl_->ngram_counts_.size())); + + status = info.GetAttrs("ngram_indexes", impl_->ngram_indexes_); + ORT_ENFORCE(status.IsOK() && !impl_->ngram_indexes_.empty(), "Non-empty ngram_indexes is required"); + { + // Check that all are positive + ORT_ENFORCE(std::all_of(impl_->ngram_indexes_.cbegin(), impl_->ngram_indexes_.cend(), + [](int64_t i) { return i >= 0; }), + "Negative ngram_indexes values are not allowed"); + // Set output size to max output index + 1; + auto greatest_hit = std::max_element(impl_->ngram_indexes_.cbegin(), impl_->ngram_indexes_.cend()); + impl_->output_size_ = *greatest_hit + 1; + } + + status = info.GetAttrs("weights", impl_->weights_); + if (status.IsOK()) { + ORT_ENFORCE(impl_->weights_.size() == impl_->ngram_indexes_.size(), + "Got weights of size: ", std::to_string(impl_->weights_.size()), + " but ngram_indexes size: ", std::to_string(impl_->ngram_indexes_.size()), + " must be of equal size"); + } + + std::vector pool_int64s; + status = info.GetAttrs("pool_strings", impl_->pool_strings_); + if (status.IsOK()) { + ORT_ENFORCE(!impl_->pool_strings_.empty(), "pool_strings must not be empty if specified"); + } else { + status = info.GetAttrs("pool_int64s", pool_int64s); + ORT_ENFORCE(status.IsOK() && !pool_int64s.empty(), "non-empty pool_int64s is required if pool_strings not provided"); + } + + // Iterator via the pool. Insert 1 item for 1-grams, 2 items for 2-grams, etc. + const auto total_items = (impl_->pool_strings_.empty()) ? pool_int64s.size() : impl_->pool_strings_.size(); + size_t ngram_id = 0; + // Load into dictionary only required gram sizes + const size_t min_gram_length = impl_->min_gram_length_; + const size_t max_gram_length = impl_->max_gram_length_; + size_t ngram_size = 1; + for (size_t i = 0; i < impl_->ngram_counts_.size(); ++i) { + size_t start_idx = impl_->ngram_counts_[i]; + size_t end_idx = ((i + 1) < impl_->ngram_counts_.size()) ? impl_->ngram_counts_[i + 1] : total_items; + ORT_ENFORCE(end_idx >= start_idx && end_idx <= total_items, + "n-gram counts out of bounds for ", std::to_string(ngram_size), "-grams"); + auto items = end_idx - start_idx; + if (items > 0) { + ORT_ENFORCE((items % ngram_size == 0), + "Number of items must compose whole ", std::to_string(ngram_size), "-grams"); + auto ngrams = items / ngram_size; + // Skip loading into hash_set ngrams that are not in the range of [min_gram_length-max_gram_length] + if (ngram_size >= min_gram_length && ngram_size <= max_gram_length) { + if (impl_->pool_strings_.empty()) { + auto before_insert = impl_->int64_set_.size(); + Emplace(pool_int64s.begin() + start_idx, ngrams, ngram_size, ngram_id, impl_->int64_set_); + ORT_ENFORCE((before_insert + ngrams) == impl_->int64_set_.size(), "pool_int64s duplicate ", std::to_string(ngram_size), "-grams detected"); + } else { + auto before_insert = impl_->str_set_.size(); + Emplace(impl_->pool_strings_.begin() + start_idx, ngrams, ngram_size, ngram_id, impl_->str_set_); + ORT_ENFORCE((before_insert + ngrams) == impl_->str_set_.size(), "poll_strings duplicate ", std::to_string(ngram_size), "-grams detected"); + } + } else { + ngram_id += ngrams; + } + } + ++ngram_size; + } +} + +Ngram::~Ngram() { +} + +void Ngram::OutputResult(OpKernelContext* ctx, size_t B, const std::vector& frequences) const { + const Impl& impl = *impl_; + std::vector output_dims; + if (B == 0) { + output_dims.push_back(impl.output_size_); + } else { + output_dims.push_back(B); + output_dims.push_back(impl.output_size_); + } + + TensorShape output_shape(output_dims); + assert(frequences.size() == static_cast(output_shape.Size())); + + auto Y = ctx->Output(0, output_shape); + auto output_data = Y->MutableData(); + const auto& w = impl.weights_; + switch (impl.weighting_criteria_) { + case kTF: { + for (auto f : frequences) { + *output_data++ = static_cast(f); + } + } break; + case kIDF: { + if (!w.empty()) { + assert(frequences.size() == w.size()); + for (size_t i = 0; i < frequences.size(); ++i) { + *output_data++ = (frequences[i] > 0) ? w[i] : 0; + } + } else { + for (auto f : frequences) { + *output_data++ = (f > 0) ? 1.0f : 0; + } + } + } break; + case kTFIDF: { + if (!w.empty()) { + assert(frequences.size() == w.size()); + for (size_t i = 0; i < frequences.size(); ++i) { + *output_data++ = frequences[i] * w[i]; + } + } else { + for (auto f : frequences) { + *output_data++ = static_cast(f); + } + } + } break; + case kNone: // fall-through + default: + assert(false); + } +} + +template +Status Ngram::ComputeImpl(OpKernelContext* ctx) const { + const auto& impl = *impl_; + auto const set_end = impl.PoolEnd(); + + auto X = ctx->Input(0); + auto& input_shape = X->Shape(); + const size_t total_items = input_shape.Size(); + + size_t b_dim = 0; + size_t B = 0; + size_t C = 0; + auto& input_dims = input_shape.GetDims(); + if (input_dims.empty()) { + b_dim = 1; + C = 1; + assert(total_items == 1); + } else if (input_dims.size() == 1) { + b_dim = 1; + C = input_dims[0]; + if (C < 1) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input shape must have either [C] or [B,C] dimensions where C > 0 and B > 0"); + } + } else if (input_dims.size() == 2) { + B = input_dims[0]; + C = input_dims[1]; + b_dim = B; + if (B < 1 || C < 1) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input shape must have either [C] or [B,C] dimensions where C > 0 and B > 0"); + } + } else { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input shape must have either [C] or [B,C] dimensions where C > 0 and B > 0"); + } + + assert((b_dim * C) == total_items); + + // Frequency holder allocate [B..output_size_] + // and init all to zero + std::vector frequencies; + frequencies.resize(b_dim * impl.output_size_, 0); + + const auto max_gram_length = impl.max_gram_length_; + const auto max_skip_distance = impl.max_skip_count_ + 1; // Convert to distance + auto start_ngram_size = impl.min_gram_length_; + auto const input_data = X->template Data(); + auto const end_data = input_data + total_items; + NgramEntry sample; + + // Treat 1-grams in a special way + if (start_ngram_size == 1) { + size_t row_num = 0; + auto ngram_start = input_data; + while (ngram_start < end_data) { + auto const ngram_row_end = ngram_start + C; + while (ngram_start < ngram_row_end) { + sample.Clear(); + sample.AddItem(*ngram_start); + auto hit = impl.PoolFind(sample); + if (hit != set_end) { + // record frequency + auto ngram_id = hit->Id(); + impl.IncrementCount(ngram_id, row_num, frequencies); + } + ++ngram_start; + } + ++row_num; + ngram_start = ngram_row_end; + } + if (++start_ngram_size > max_gram_length) { + OutputResult(ctx, B, frequencies); + return Status::OK(); + } + } + + for (auto skip_distance = 1; skip_distance <= max_skip_distance; ++skip_distance) { + auto ngram_start = input_data; + size_t row_num = 0; + while (ngram_start < end_data) { + assert((B == 0) || (row_num < B)); + auto const ngram_row_end = ngram_start + C; + assert(ngram_row_end <= end_data); + while (ngram_start < ngram_row_end) { + // Check if any n-gram size in [start_ngram_size..max_gram_length] range + // fit before the end of the row so we do not waste time adding [1..start_ngram_size) + // At least items of start_ngram_size should fit + // last row should match end_data + auto at_least_this = ngram_start + skip_distance * (start_ngram_size - 1); + if (at_least_this >= ngram_row_end) { + break; + } + sample.Clear(); + auto ngram_item = ngram_start; + for (auto ngram_size = 1; + ngram_size <= max_gram_length && + ngram_item < ngram_row_end; + ++ngram_size, ngram_item += skip_distance) { + sample.AddItem(*ngram_item); + + // Do not test anything before start_ngram_size + if (ngram_size >= start_ngram_size) { + auto hit = impl.PoolFind(sample); + if (hit != set_end) { + // record frequency + auto ngram_id = hit->Id(); + impl.IncrementCount(ngram_id, row_num, frequencies); + } + } + } + // Sliding window shift + ++ngram_start; + } + // Next row + ngram_start = ngram_row_end; + ++row_num; + } + } + OutputResult(ctx, B, frequencies); + return Status::OK(); +} + +Status Ngram::Compute(OpKernelContext* ctx) const { + Status s; + + auto X = ctx->Input(0); + + if (X->DataType() == DataTypeImpl::GetType()) { + s = ComputeImpl(ctx); + } else if (X->DataType() == DataTypeImpl::GetType()) { + s = ComputeImpl(ctx); + } else if (X->DataType() == DataTypeImpl::GetType()) { + s = ComputeImpl(ctx); + } else { + s = Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Invalid type of the input argument"); + } + + return s; +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/ngram.h b/onnxruntime/contrib_ops/cpu/ngram.h new file mode 100644 index 0000000000..adf95ebcb8 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/ngram.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include +#include + +namespace onnxruntime { +namespace contrib { + +class Ngram final : public OpKernel { + public: + explicit Ngram(const OpKernelInfo& info); + ~Ngram(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Ngram); + + Status Compute(OpKernelContext* ctx) const override; + + private: + template + Status ComputeImpl(OpKernelContext* ctx) const; + + // Apply weighing criteria and output + void OutputResult(OpKernelContext* ctx, size_t b_dim, const std::vector& frequences) const; + + struct Impl; + std::unique_ptr impl_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index b87775277a..c14e06ed3f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -214,6 +214,139 @@ activation.)DOC") }) .SetDoc(R"DOC(Tokenizer divides each string in X into a vector of strings along the last axis. All input strings including attributes are UTF-8 encoded.)DOC"); + ONNX_CONTRIB_OPERATOR_SCHEMA(Ngram) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "X", "Input for n-gram extraction", "T") + .Output(0, "Y", "Ngram results", "T1") + .TypeConstraint( + "T", + {"tensor(string)", "tensor(int32)", "tensor(int64)"}, + "Input is ether string UTF-8 or int32/int64") + .TypeConstraint( + "T1", + {"tensor(float)"}, + "1-D tensor of floats") + .Attr( + "max_gram_length", + "Maximum n-gram length. If this value is 3, 3-grams will be used to generate the output.", + AttributeProto::INT) + .Attr( + "min_gram_length", + "Minimum n-gram length. If this value is 2 and max_gram_length is 3, output may contain counts of 2-grams and 3-grams.", + AttributeProto::INT) + .Attr( + "max_skip_count", + "Maximum number of items (integers/strings) to be skipped when constructing an n-gram from X." + "If max_skip_count=1, min_gram_length=2, max_gram_length=3, this operator may generate 2-grams" + "with skip_count=0 and skip_count=1, and 3-grams with skip_count=0 and skip_count=1", + AttributeProto::INT) + .Attr( + "pool_strings", + "List of strings n-grams learned from the training set. Either this or pool_int64s attributes must be present but not both." + "It's an 1-D tensor starting with the collections of all 1-grams and ending with the collections of n-grams." + "The i-th element in pool stores the n-gram that should be mapped to index ngram_indexes[i] in the output vector.", + AttributeProto::STRINGS, + OPTIONAL) + .Attr( + "pool_int64s", + "List of int64 n-grams learned from the training set. Either this or pool_strings attributes must be present but not both." + "It's an 1-D tensor starting with the collections of all 1-grams and ending with the collections of n-grams." + "The i-th element in pool stores the n-gram that should be mapped to index ngram_indexes[i] in the output vector.", + AttributeProto::INTS, + OPTIONAL) + .Attr( + "ngram_counts", + "The starting indexes of 1-grams, 2-grams, and so on in pool." + "It is useful when determining the boundary between two consecutive collections of n-grams." + "For example, if ngram_counts is [0, 17, 36], the first index (zero-based) of 1-gram/2-gram/3-gram" + "in pool are 0/17/36. This format is essentially identical to CSR (or CSC) sparse matrix format, " + "and we choose to keep this due to its popularity.", + AttributeProto::INTS) + .Attr( + "ngram_indexes", + "list of int64s (type: AttributeProto::INTS). This list is parallel to the specified 'pool_*' attribute." + "The i-th element in ngram_indexes indicate the coordinate of the i-th n-gram in the output tensor.", + AttributeProto::INTS) + .Attr( + "weights", + "list of floats. This attribute stores the weight of each n-gram in pool. The i-th element in weights" + "is the weight of the i-th n-gram in pool. Its length equals to the size of ngram_indexes." + "By default, weights is an all-one tensor.This attribute is used when mode is \"IDF\" or \"TFIDF\"" + "to scale the associated word counts.", + AttributeProto::FLOATS, + OPTIONAL) + .Attr( + "mode", + "The weighting criteria. It can be one of \"TF\" (term frequency)," + "\"IDF\" (inverse document frequency), and \"TFIDF\" (the combination of TF and IDF)", + AttributeProto::STRING) + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type(); + output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); + + if (hasInputShape(ctx, 0)) { + std::vector ngram_indexes; + ONNX_NAMESPACE::getRepeatedAttribute(ctx, "ngram_indexes", ngram_indexes); + if (ngram_indexes.empty() || !std::all_of(ngram_indexes.cbegin(), ngram_indexes.cend(), + [](int64_t i) { return i >= 0; })) { + fail_shape_inference( + "ngram_indexes must be non-empty with no negative values"); + } + + auto greatest_hit = std::max_element(ngram_indexes.cbegin(), ngram_indexes.cend()); + auto max_last_axis = *greatest_hit + 1; + + ONNX_NAMESPACE::TensorShapeProto output_shape; + auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + auto dim_size = input_shape.dim_size(); + if (dim_size == 0 || dim_size == 1) { + output_shape.add_dim()->set_dim_value(max_last_axis); + } else if (dim_size == 2) { + auto& B_dim = input_shape.dim(0); + if (!B_dim.has_dim_value()) { + fail_shape_inference( + "Input shape does not have first dimension value"); + } + output_shape.add_dim()->set_dim_value(B_dim.dim_value()); + output_shape.add_dim()->set_dim_value(max_last_axis); + } else { + fail_shape_inference( + "Input shape must have either [C] or [B,C] dimensions where C > 0 and B > 0"); + } + updateOutputShape(ctx, 0, output_shape); + } + }) + .SetDoc(R"DOC( +This transform extracts n-grams from the input sequence and save them as a vector. Input can +be either a 1-D or 2-D tensor. For 1-D input, output is the n-gram representation of that input. +For 2-D input, the output is also a 2-D tensor whose i-th row is the n-gram representation of the i-th input row. +More specifically, if input shape is [C], the corresponding output shape would be [max(ngram_indexes) + 1]. +If input shape is [N, C], this operator produces a [N, max(ngram_indexes) + 1]-tensor. + +In contrast to standard n-gram extraction, here, the indexes of extracting an n-gram from the original +sequence are not necessarily consecutive numbers. The discontinuity between indexes are controlled by the number of skips. +If the number of skips is 2, we should skip two tokens when scanning through the original sequence. +Let's consider an example. Assume that input sequence is [94, 17, 36, 12, 28] and the number of skips is 2. +The associated 2-grams are [94, 12] and [17, 28] respectively indexed by [0, 3] and [1, 4]. +If the number of skips becomes 0, the 2-grams generated are [94, 17], [17, 36], [36, 12], [12, 28] +indexed by [0, 1], [1, 2], [2, 3], [3, 4], respectively. + +The output vector stores the count of each n-gram; +Y[i] indicates the times that the i-th n-gram is found. The attribute ngram_indexes is used to determine the mapping +between index i and the corresponding n-gram. If pool_int64s is [94 , 17 ,17, 36], ngram_indexes is [1, 0], +ngram_counts=[0, 0], then the Y[0] (first element in Y) and Y[1] (second element in Y) are the counts of [17, 36] and [94, 17], +respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output. +Note that we may consider all skips up to S when generating the n-grams. + +The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and +the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF", +this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute. + +Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor. +If pool_strings is set, the input must be a string tensor. +)DOC"); + // Operators for linear 8 bit quanitzation support. ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear) .SetDomain(kMSDomain) diff --git a/onnxruntime/test/contrib_ops/ngram_test.cc b/onnxruntime/test/contrib_ops/ngram_test.cc new file mode 100644 index 0000000000..727fdd4521 --- /dev/null +++ b/onnxruntime/test/contrib_ops/ngram_test.cc @@ -0,0 +1,559 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +#include + +namespace onnxruntime { +namespace test { +namespace ngram_test { + +constexpr const char* domain = onnxruntime::kMSDomain; +const int opset_ver = 1; + +void InitTestAttr(OpTester& test, const std::string& mode, + int64_t min_gram_length, int64_t max_gram_length, int64_t max_skip_count, + const std::vector& ngram_counts, + const std::vector& ngram_indexes, + const std::vector& weights, + const std::vector& pool_int64s, + const std::vector& pool_strings) { + test.AddAttribute("mode", mode); + test.AddAttribute("min_gram_length", min_gram_length); + test.AddAttribute("max_gram_length", max_gram_length); + test.AddAttribute("max_skip_count", max_skip_count); + test.AddAttribute("ngram_counts", ngram_counts); + test.AddAttribute("ngram_indexes", ngram_indexes); + // optional + if (!weights.empty()) { + test.AddAttribute("weights", weights); + } + if (!pool_int64s.empty()) { + test.AddAttribute("pool_int64s", pool_int64s); + } else { + test.AddAttribute("pool_strings", pool_strings); + } +} +} // namespace ngram_test + +using namespace ngram_test; + +// Here is what takes place in general and in particular +// in this unit test.There are 7 n - grams : 4 unigrams and 3 bigrams +// that are expressed as 10 items(integers in this case) contained within pool_int64 attribute. +// We only count and then optionally scale those ngrams that appear in the supplied pool parameter(either int64 or string). +// M = 1 and N = 2 in this case. +// However, attribute all controls whether we consider all of the supplied ngram[M..N] sizes +// into consideration or not.With all = false, we only consider N - grams. + +TEST(ContribOpNgramTest, Int32_TF_onlyBigrams_Skip0) { + OpTester test("Ngram", opset_ver, domain); + // s=0, Min=Max=2, weights empty, int32 + InitTestAttr(test, "TF", 2, 2, 0, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{12}; + std::vector input = {1, 1, 3, 3, 3, 7, 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 1, 1, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_TF_BatchOnlyBigrams_Skip0) { + OpTester test("Ngram", opset_ver, domain); + // s=0, Min=Max=2, weights empty, int32 + InitTestAttr(test, "TF", 2, 2, 0, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + // Tow batches by six + std::vector dims{2, 6}; + std::vector input = {1, 1, 3, 3, 3, 7, + 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{2, 7}; + std::vector output = {0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_TF_OnlyBigrams_Skip0) { + OpTester test("Ngram", opset_ver, domain); + // s=0, Min=Max=2, weights empty, string + InitTestAttr(test, "TF", 2, 2, 0, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{12}; + std::vector input{"one", "one", "three", "three", "three", "seven", "eight", + "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 1, 1, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_TF_BatchOnlyBigrams_Skip0) { + OpTester test("Ngram", opset_ver, domain); + // s=0, Min=Max=2, weights empty, string + InitTestAttr(test, "TF", 2, 2, 0, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{2, 6}; + std::vector input{"one", "one", "three", "three", "three", "seven", + "eight", "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{2, 7}; + // ["seven", "eight"] can not be found due to batch boundary and s=0 + // bigram elements have to be next to each other + std::vector output = {0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1}; + + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_TF_onlyBigrams_LevelEmpty) { + OpTester test("Ngram", opset_ver, domain); + // s=0, Min=Max=2, weights empty, int32 + InitTestAttr(test, "TF", 2, 2, 0, + {0, 0}, // no unigrams, bi-grams start immediately + { + 0, + 1, + 2, + }, //7 output indexes + {}, + { //1-grams none + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{12}; + std::vector input = {1, 1, 3, 3, 3, 7, 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{3}; + std::vector output = {1, 1, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_TF_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=Max=2, weights empty, int32 + InitTestAttr(test, "TF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{12}; + std::vector input = {1, 1, 3, 3, 3, 7, 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + // No 1-grams but Skip is 5 so we manage to count 3 + // occurrences of [7,8] + std::vector output = {0, 0, 0, 0, 1, 3, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_TF_BatchOnlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, , Min=Max=2, weights empty, int32 + InitTestAttr(test, "TF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{2, 6}; + std::vector input = {1, 1, 3, 3, 3, 7, + 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{2, 7}; + // Skip is 5 but we are constraint by row boundaries + // so count only 1 of each + std::vector output = {0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_TF_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, , Min=Max=2, weights empty, string + InitTestAttr(test, "TF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{12}; + std::vector input{"one", "one", "three", "three", "three", "seven", "eight", + "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + // No 1-grams but Skip is 5 so we manage to count 3 + // occurrences of [7,8] in one batch (row) + std::vector output = {0, 0, 0, 0, 1, 3, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_TF_BatchOnlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, , Min=Max=2, weights empty, string + InitTestAttr(test, "TF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{2, 6}; + std::vector input{"one", "one", "three", "three", "three", "seven", "eight", + "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{2, 7}; + std::vector output = {0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_TF_UniAndBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, , Min=1, Max=2, weights empty, int32 + InitTestAttr(test, "TF", 1, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{12}; + std::vector input = {1, 1, 3, 3, 3, 7, 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + // We consider both 1-grams and 2-grams so get all the counts here + std::vector output = {0, 3, 1, 0, 1, 3, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_TF_BatchUniAndBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=1, Max=2, weights empty, int32 + InitTestAttr(test, "TF", 1, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{2, 6}; + std::vector input = {1, 1, 3, 3, 3, 7, + 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{2, 7}; + // Counts are now per row (batch) + std::vector output = {0, 3, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 1, 1, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_TF_UniAndBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=1, Max=2, weights empty, string + InitTestAttr(test, "TF", 1, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{12}; + std::vector input{"one", "one", "three", "three", "three", "seven", "eight", + "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 3, 1, 0, 1, 3, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_TF_BatchUniAndBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=1, Max=2, weights empty, string + InitTestAttr(test, "TF", 1, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{2, 6}; + std::vector input{"one", "one", "three", "three", "three", "seven", "eight", + "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{2, 7}; + std::vector output = {0, 3, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 1, 1, 1}; + + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_IDF_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=Max=2, weights empty, int32 + // We change to IDF but do not supply weights so + // we should get all 1.0f where count is not zero + InitTestAttr(test, "IDF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{12}; + std::vector input = {1, 1, 3, 3, 3, 7, 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 1, 1, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_IDF_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=Max=2, weights empty, string + InitTestAttr(test, "IDF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{12}; + std::vector input{"one", "one", "three", "three", "three", "seven", "eight", + "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 1, 1, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_TFIDF_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=Max=2, weights empty, int32 + // We change to TFIDF but do not supply weights so + // we should all get the original values as weights are 1.0f by + // default + InitTestAttr(test, "TFIDF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{12}; + std::vector input = {1, 1, 3, 3, 3, 7, 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 1, 3, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_TFIDF_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=Max=2, weights empty, string + InitTestAttr(test, "TFIDF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {}, + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{12}; + std::vector input{"one", "one", "three", "three", "three", "seven", "eight", + "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 1, 3, 1}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_IDFWeights_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=Max=2, weights specified, int32 + // We change to IDF with supplied weights. All + // with non-zero counts must be replaced with the supplied weights + InitTestAttr(test, "IDF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 2.0}, // weights + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{12}; + std::vector input = {1, 1, 3, 3, 3, 7, 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 2, 3, 2}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_IDFWeights_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=Max=2, weights specified, string + InitTestAttr(test, "IDF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 2.0}, // weights + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{12}; + std::vector input{"one", "one", "three", "three", "three", "seven", "eight", + "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 2, 3, 2}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, Int32_TFIDFWeights_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=Max=2, weights specified, int32 + // We change to TFIDF with supplied weights. + // We should have all counts scaled by weights + InitTestAttr(test, "TFIDF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 2.0}, // weights + {2, 3, 5, 4, //1-grams + 5, 6, 7, 8, 6, 7}, //bi-grams + {}); + + std::vector dims{12}; + std::vector input = {1, 1, 3, 3, 3, 7, 8, 6, 7, 5, 6, 8}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 2, 9, 2}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpNgramTest, String_TFIDFWeights_onlyBigrams_Skip5) { + OpTester test("Ngram", opset_ver, domain); + // s=5, Min=Max=2, weights specified, string + InitTestAttr(test, "TFIDF", 2, 2, 5, + {0, 4}, + {0, 1, 2, 3, 4, 5, 6}, //7 output indexes + {2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 2.0}, // weights + {}, + {"two", "three", "five", "four", //1-grams + "five", "six", "seven", "eight", "six", "seven"}); //bi-grams + + std::vector dims{12}; + std::vector input{"one", "one", "three", "three", "three", "seven", "eight", + "six", "seven", "five", "six", "eight"}; + test.AddInput("T", dims, input); + + std::vector out_dims{7}; + std::vector output = {0, 0, 0, 0, 2, 9, 2}; + test.AddOutput("Y", out_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +} // namespace test +} // namespace onnxruntime