From df45b12fdf8fbeb68c314b2475eabbc7c2ea3af8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 22 Feb 2023 12:20:02 +0100 Subject: [PATCH] Reduce memory usage for TreeEnsemble operators (#14670) ### Description The onnx file is about 5Mb for a lightgbm model with 500 trees. onnxruntime uses additional 10Mb to compute the inference and keeps the onnx structure. This PR reduces the memory usage by almost 50%. The memory used by the onnx node could be freed if there is no optimized graph to save but that's not covered by this PR. ### Motivation and Context Reduce memory usage. --- onnxruntime/core/providers/cpu/ml/ml_common.h | 16 +- .../cpu/ml/tree_ensemble_aggregator.h | 94 ++++--- .../providers/cpu/ml/tree_ensemble_common.h | 248 +++++++++++------- .../cpu/ml/tree_ensembler_classifier_test.cc | 14 +- 4 files changed, 219 insertions(+), 153 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 8660af0f2e..89fb556ca6 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -20,14 +20,14 @@ enum class OUTPUT_MODE { ALL_SCORES }; -enum class NODE_MODE { - BRANCH_LEQ, - BRANCH_LT, - BRANCH_GTE, - BRANCH_GT, - BRANCH_EQ, - BRANCH_NEQ, - LEAF +enum NODE_MODE : uint8_t { + LEAF = 1, + BRANCH_LEQ = 2, + BRANCH_LT = 4, + BRANCH_GTE = 6, + BRANCH_GT = 8, + BRANCH_EQ = 10, + BRANCH_NEQ = 12 }; static inline NODE_MODE MakeTreeNodeMode(const std::string& input) { diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index ae5f33616c..6edb5b8497 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -13,8 +13,8 @@ namespace ml { namespace detail { struct TreeNodeElementId { - int tree_id; - int node_id; + int64_t tree_id; + int64_t node_id; bool operator==(const TreeNodeElementId& xyz) const { return (tree_id == xyz.tree_id) && (node_id == xyz.node_id); } @@ -23,8 +23,8 @@ struct TreeNodeElementId { } struct hash_fn { std::size_t operator()(const TreeNodeElementId& key) const { - std::size_t h1 = std::hash()(key.tree_id); - std::size_t h2 = std::hash()(key.node_id); + std::size_t h1 = std::hash()(key.tree_id); + std::size_t h2 = std::hash()(key.node_id); return h1 ^ h2; } }; @@ -61,26 +61,39 @@ struct ScoreValue { } }; -enum MissingTrack { - kNone, - kTrue, - kFalse +enum MissingTrack : uint8_t { + kTrue = 16, + kFalse = 0 }; template struct TreeNodeElement { - TreeNodeElementId id; int feature_id; - T value; - T hitrates; - NODE_MODE mode; - TreeNodeElement* truenode; - TreeNodeElement* falsenode; - MissingTrack missing_tracks; - std::vector> weights; - bool is_not_leaf; - bool is_missing_track_true; + // Stores the node threshold or the weights if the tree has one target. + T value_or_unique_weight; + + // onnx specification says hitrates is used to store information about the node, + // but this information is not used for inference. + // T hitrates; + + // True node, false node are obtained by computing `this + truenode_inc_or_first_weight`, + // `this + falsenode_inc_or_n_weights` if the node is not a leaf. + // In case of a leaf, these attributes are used to indicate the position of the weight + // in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, + // the weight is also stored in `value_or_unique_weight`. + // This implementation assumes a tree has less than 2^21 nodes, + // and the total number of leave in the set of trees is below 2^21. + uint32_t truenode_inc_or_first_weight; + // In case of a leaf, the following attribute indicates the number of weights + // in array `TreeEnsembleCommon::weights_`. If not a leaf, it indicates + // `this + falsenode_inc_or_n_weights` is the false node. + uint32_t falsenode_inc_or_n_weights; + uint8_t flags; + + inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); } + inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); } + inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; } }; template @@ -121,7 +134,8 @@ class TreeAggregator { // N outputs void ProcessTreeNodePrediction(InlinedVector>& /*predictions*/, - const TreeNodeElement& /*root*/) const {} + const TreeNodeElement& /*root*/, + gsl::span> /*weights*/) const {} void MergePrediction(InlinedVector>& /*predictions*/, const InlinedVector>& /*predictions2*/) const {} @@ -158,7 +172,7 @@ class TreeAggregatorSum : public TreeAggregator& prediction, const TreeNodeElement& root) const { - prediction.score += root.weights[0].value; + prediction.score += root.value_or_unique_weight; } void MergePrediction1(ScoreValue& prediction, @@ -176,8 +190,10 @@ class TreeAggregatorSum : public TreeAggregator>& predictions, - const TreeNodeElement& root) const { - for (auto it = root.weights.cbegin(); it != root.weights.cend(); ++it) { + const TreeNodeElement& root, + gsl::span> weights) const { + auto it = weights.begin() + root.truenode_inc_or_first_weight; + for (uint32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { ORT_ENFORCE(it->i < (int64_t)predictions.size()); predictions[onnxruntime::narrow(it->i)].score += it->value; predictions[onnxruntime::narrow(it->i)].has_score = 1; @@ -260,8 +276,8 @@ class TreeAggregatorMin : public TreeAggregator& prediction, const TreeNodeElement& root) const { - prediction.score = (!(prediction.has_score) || root.weights[0].value < prediction.score) - ? root.weights[0].value + prediction.score = (!(prediction.has_score) || root.value_or_unique_weight < prediction.score) + ? root.value_or_unique_weight : prediction.score; prediction.has_score = 1; } @@ -279,11 +295,14 @@ class TreeAggregatorMin : public TreeAggregator>& predictions, - const TreeNodeElement& root) const { - for (auto it = root.weights.begin(); it != root.weights.end(); ++it) { - predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value < predictions[onnxruntime::narrow(it->i)].score) - ? it->value - : predictions[onnxruntime::narrow(it->i)].score; + const TreeNodeElement& root, + gsl::span> weights) const { + auto it = weights.begin() + root.truenode_inc_or_first_weight; + for (uint32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + predictions[onnxruntime::narrow(it->i)].score = + (!predictions[onnxruntime::narrow(it->i)].has_score || it->value < predictions[onnxruntime::narrow(it->i)].score) + ? it->value + : predictions[onnxruntime::narrow(it->i)].score; predictions[onnxruntime::narrow(it->i)].has_score = 1; } } @@ -316,8 +335,8 @@ class TreeAggregatorMax : public TreeAggregator& prediction, const TreeNodeElement& root) const { - prediction.score = (!(prediction.has_score) || root.weights[0].value > prediction.score) - ? root.weights[0].value + prediction.score = (!(prediction.has_score) || root.value_or_unique_weight > prediction.score) + ? root.value_or_unique_weight : prediction.score; prediction.has_score = 1; } @@ -334,11 +353,14 @@ class TreeAggregatorMax : public TreeAggregator>& predictions, - const TreeNodeElement& root) const { - for (auto it = root.weights.begin(); it != root.weights.end(); ++it) { - predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value > predictions[onnxruntime::narrow(it->i)].score) - ? it->value - : predictions[onnxruntime::narrow(it->i)].score; + const TreeNodeElement& root, + gsl::span> weights) const { + auto it = weights.begin() + root.truenode_inc_or_first_weight; + for (uint32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + predictions[onnxruntime::narrow(it->i)].score = + (!predictions[onnxruntime::narrow(it->i)].has_score || it->value > predictions[onnxruntime::narrow(it->i)].score) + ? it->value + : predictions[onnxruntime::narrow(it->i)].score; predictions[onnxruntime::narrow(it->i)].has_score = 1; } } diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 3d016f62c6..a7a3925457 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -42,6 +42,10 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { protected: std::vector base_values_; std::vector> nodes_; + // Type of weights should be a vector of OutputType. Onnx specifications says it must be float. + // Lightgbm requires a double to do the summation of all trees predictions. That's why + // `ThresholdType` is used as well for output type (double as well for lightgbm) and not `OutputType`. + std::vector> weights_; std::vector*> roots_; public: @@ -180,14 +184,17 @@ Status TreeEnsembleCommon::Init( } n_targets_or_classes_ = n_targets_or_classes; max_tree_depth_ = 1000; + ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max()); // additional members - size_t i, limit; - std::vector cmodes(nodes_modes.size()); + size_t limit; + uint32_t i; + InlinedVector cmodes; + cmodes.reserve(nodes_modes.size()); same_mode_ = true; int fpos = -1; for (i = 0, limit = nodes_modes.size(); i < limit; ++i) { - cmodes[i] = MakeTreeNodeMode(nodes_modes[i]); + cmodes.push_back(MakeTreeNodeMode(nodes_modes[i])); if (cmodes[i] == NODE_MODE::LEAF) continue; if (fpos == -1) { @@ -201,103 +208,138 @@ Status TreeEnsembleCommon::Init( // filling nodes n_nodes_ = nodes_treeids.size(); - nodes_.resize(nodes_treeids.size()); + limit = static_cast(n_nodes_); + InlinedVector node_tree_ids; + node_tree_ids.reserve(limit); + nodes_.clear(); + nodes_.reserve(limit); roots_.clear(); - std::unordered_map*, TreeNodeElementId::hash_fn> idi; + std::unordered_map idi; + idi.reserve(limit); max_feature_id_ = 0; - for (i = 0, limit = nodes_treeids.size(); i < limit; ++i) { - TreeNodeElement& node = nodes_[i]; - node.id.tree_id = static_cast(nodes_treeids[i]); - node.id.node_id = static_cast(nodes_nodeids[i]); + for (i = 0; i < limit; ++i) { + TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), + static_cast(nodes_nodeids[i])}; + TreeNodeElement node; node.feature_id = static_cast(nodes_featureids[i]); if (node.feature_id > max_feature_id_) { max_feature_id_ = node.feature_id; } - if (nodes_values_as_tensor.empty()) { - node.value = static_cast(nodes_values[i]); - } else { - node.value = nodes_values_as_tensor[i]; - } + node.value_or_unique_weight = nodes_values_as_tensor.empty() + ? static_cast(nodes_values[i]) + : nodes_values_as_tensor[i]; + + /* hitrates is not used for inference, they are ignored. if (nodes_hitrates_as_tensor.empty()) { node.hitrates = static_cast(i < nodes_hitrates.size() ? nodes_hitrates[i] : -1); } else { node.hitrates = i < nodes_hitrates_as_tensor.size() ? nodes_hitrates_as_tensor[i] : -1; + } */ + + node.flags = static_cast(cmodes[i]); + node.truenode_inc_or_first_weight = 0; // nodes_truenodeids[i] if not a leaf + node.falsenode_inc_or_n_weights = 0; // nodes_falsenodeids[i] if not a leaf + + if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { + node.flags |= static_cast(MissingTrack::kTrue); } - node.mode = cmodes[i]; - node.is_not_leaf = node.mode != NODE_MODE::LEAF; - node.truenode = nullptr; // nodes_truenodeids[i]; - node.falsenode = nullptr; // nodes_falsenodeids[i]; - node.missing_tracks = i < static_cast(nodes_missing_value_tracks_true.size()) - ? (nodes_missing_value_tracks_true[i] == 1 - ? MissingTrack::kTrue - : MissingTrack::kFalse) - : MissingTrack::kNone; - node.is_missing_track_true = node.missing_tracks == MissingTrack::kTrue; - if (idi.find(node.id) != idi.end()) { - ORT_THROW("Node ", node.id.node_id, " in tree ", node.id.tree_id, " is already there."); + auto p = idi.insert(std::pair(node_tree_id, i)); + if (!p.second) { + ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there."); } - idi.insert(std::pair*>(node.id, &node)); + nodes_.emplace_back(node); + node_tree_ids.emplace_back(node_tree_id); } + InlinedVector truenode_ids, falsenode_ids; + truenode_ids.reserve(limit); + falsenode_ids.reserve(limit); TreeNodeElementId coor; + i = 0; for (auto it = nodes_.begin(); it != nodes_.end(); ++it, ++i) { - if (!it->is_not_leaf) + if (!it->is_not_leaf()) { + truenode_ids.push_back(0); + falsenode_ids.push_back(0); continue; - i = std::distance(nodes_.begin(), it); - coor.tree_id = it->id.tree_id; + } + + TreeNodeElementId& node_tree_id = node_tree_ids[i]; + coor.tree_id = node_tree_id.tree_id; coor.node_id = static_cast(nodes_truenodeids[i]); auto found = idi.find(coor); if (found == idi.end()) { ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (truenode)."); } - if (coor.node_id >= 0 && coor.node_id < n_nodes_) { - it->truenode = found->second; - if ((it->truenode->id.tree_id != it->id.tree_id) || - (it->truenode->id.node_id == it->id.node_id)) { - ORT_THROW("One falsenode is pointing either to itself, either to another tree."); - } - } else - it->truenode = nullptr; + truenode_ids.emplace_back((coor.node_id >= 0 && coor.node_id < n_nodes_) ? found->second : 0); coor.node_id = static_cast(nodes_falsenodeids[i]); found = idi.find(coor); if (found == idi.end()) { ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (falsenode)."); } - if (coor.node_id >= 0 && coor.node_id < n_nodes_) { - it->falsenode = found->second; - if ((it->falsenode->id.tree_id != it->id.tree_id) || - (it->falsenode->id.node_id == it->id.node_id)) { - ORT_THROW("One falsenode is pointing either to itself, either to another tree."); - } - } else - it->falsenode = nullptr; + falsenode_ids.emplace_back((coor.node_id >= 0 && coor.node_id < n_nodes_) ? found->second : 0); } - int64_t previous = -1; - for (i = 0; i < static_cast(n_nodes_); ++i) { - if ((previous == -1) || (previous != nodes_[i].id.tree_id)) - roots_.push_back(&(nodes_[i])); - previous = nodes_[i].id.tree_id; + // sort targets + InlinedVector> indices; + indices.reserve(target_class_nodeids.size()); + for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) { + indices.emplace_back(std::pair( + TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, + i)); } + std::sort(indices.begin(), indices.end()); + // Initialize the leaves. TreeNodeElementId ind; SparseValue w; - for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) { - ind.tree_id = static_cast(target_class_treeids[i]); - ind.node_id = static_cast(target_class_nodeids[i]); - if (idi.find(ind) == idi.end()) { - ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (weights)."); + size_t indi; + for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) { + ind = indices[indi].first; + i = indices[indi].second; + auto found = idi.find(ind); + if (found == idi.end()) { + ORT_THROW("Unable to find node ", ind.tree_id, "-", ind.node_id, " (weights)."); } + + TreeNodeElement& leaf = nodes_[found->second]; + if (leaf.is_not_leaf()) { + // An exception should be raised in that case. But this case may happen in + // models converted with an old version of onnxmltools. There weights are ignored. + // ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf."); + continue; + } + w.i = target_class_ids[i]; - if (target_class_weights_as_tensor.empty()) { - w.value = static_cast(target_class_weights[i]); - } else { - w.value = target_class_weights_as_tensor[i]; + w.value = target_class_weights_as_tensor.empty() + ? static_cast(target_class_weights[i]) + : target_class_weights_as_tensor[i]; + if (leaf.falsenode_inc_or_n_weights == 0) { + leaf.truenode_inc_or_first_weight = static_cast(weights_.size()); + leaf.value_or_unique_weight = w.value; } - idi[ind]->weights.push_back(w); + ++leaf.falsenode_inc_or_n_weights; + weights_.push_back(w); + } + + // Initialize all the nodes but the leaves. + int64_t previous = -1; + for (i = 0, limit = static_cast(n_nodes_); i < limit; ++i) { + if ((previous == -1) || (previous != node_tree_ids[i].tree_id)) + roots_.push_back(&(nodes_[idi[node_tree_ids[i]]])); + previous = node_tree_ids[i].tree_id; + if (!nodes_[i].is_not_leaf()) { + if (nodes_[i].falsenode_inc_or_n_weights == 0) { + ORT_THROW("Target is missing for leaf ", ind.tree_id, "-", ind.node_id, "."); + } + continue; + } + ORT_ENFORCE(truenode_ids[i] == 0 || truenode_ids[i] > i); + nodes_[i].truenode_inc_or_first_weight = truenode_ids[i] == 0 ? 0 : static_cast(truenode_ids[i] - i); + ORT_ENFORCE(falsenode_ids[i] == 0 || falsenode_ids[i] > i); + nodes_[i].falsenode_inc_or_n_weights = falsenode_ids[i] == 0 ? 0 : static_cast(falsenode_ids[i] - i); } n_trees_ = roots_.size(); @@ -479,7 +521,7 @@ void TreeEnsembleCommon::ComputeAgg(concur if (n_trees_ <= parallel_tree_ || max_num_threads == 1) { /* section A2 */ InlinedVector> scores(onnxruntime::narrow(n_targets_or_classes_), {0, 0}); for (int64_t j = 0; j < n_trees_; ++j) { - agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[onnxruntime::narrow(j)], x_data)); + agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[onnxruntime::narrow(j)], x_data), weights_); } agg.FinalizeScores(scores, z_data, -1, label_data); } else { /* section B2: 2+ outputs, 1 row, enough trees to parallelize */ @@ -492,7 +534,7 @@ void TreeEnsembleCommon::ComputeAgg(concur scores[batch_num].resize(onnxruntime::narrow(n_targets_or_classes_), {0, 0}); auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, onnxruntime::narrow(n_trees_)); for (auto j = work.start; j < work.end; ++j) { - agg.ProcessTreeNodePrediction(scores[batch_num], *ProcessTreeNodeLeave(roots_[j], x_data)); + agg.ProcessTreeNodePrediction(scores[batch_num], *ProcessTreeNodeLeave(roots_[j], x_data), weights_); } }); for (size_t i = 1, limit = scores.size(); i < limit; ++i) { @@ -515,7 +557,7 @@ void TreeEnsembleCommon::ComputeAgg(concur } for (j = 0, limit = roots_.size(); j < limit; ++j) { for (i = batch; i < batch_end; ++i) { - agg.ProcessTreeNodePrediction(scores[SafeInt(i - batch)], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + agg.ProcessTreeNodePrediction(scores[SafeInt(i - batch)], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride), weights_); } } for (i = batch; i < batch_end; ++i) { @@ -541,7 +583,7 @@ void TreeEnsembleCommon::ComputeAgg(concur for (auto j = work.start; j < work.end; ++j) { for (int64_t i = begin_n; i < end_n; ++i) { agg.ProcessTreeNodePrediction(scores[batch_num * SafeInt(N) + i], - *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + *ProcessTreeNodeLeave(roots_[j], x_data + i * stride), weights_); } } }); @@ -573,7 +615,7 @@ void TreeEnsembleCommon::ComputeAgg(concur for (auto i = work.start; i < work.end; ++i) { std::fill(scores.begin(), scores.end(), ScoreValue({0, 0})); for (j = 0, limit = roots_.size(); j < limit; ++j) { - agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride), weights_); } agg.FinalizeScores(scores, @@ -587,17 +629,19 @@ void TreeEnsembleCommon::ComputeAgg(concur #define TREE_FIND_VALUE(CMP) \ if (has_missing_tracks_) { \ - while (root->is_not_leaf) { \ + while (root->is_not_leaf()) { \ val = x_data[root->feature_id]; \ - root = (val CMP root->value || \ - (root->is_missing_track_true && _isnan_(val))) \ - ? root->truenode \ - : root->falsenode; \ + root += (val CMP root->value_or_unique_weight || \ + (root->is_missing_track_true() && _isnan_(val))) \ + ? root->truenode_inc_or_first_weight \ + : root->falsenode_inc_or_n_weights; \ } \ } else { \ - while (root->is_not_leaf) { \ + while (root->is_not_leaf()) { \ val = x_data[root->feature_id]; \ - root = val CMP root->value ? root->truenode : root->falsenode; \ + root += val CMP root->value_or_unique_weight \ + ? root->truenode_inc_or_first_weight \ + : root->falsenode_inc_or_n_weights; \ } \ } @@ -612,20 +656,20 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( TreeNodeElement* root, const InputType* x_data) const { InputType val; if (same_mode_) { - switch (root->mode) { + switch (root->mode()) { case NODE_MODE::BRANCH_LEQ: if (has_missing_tracks_) { - while (root->is_not_leaf) { + while (root->is_not_leaf()) { val = x_data[root->feature_id]; - root = (val <= root->value || - (root->is_missing_track_true && _isnan_(val))) - ? root->truenode - : root->falsenode; + root += (val <= root->value_or_unique_weight || + (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_inc_or_first_weight + : root->falsenode_inc_or_n_weights; } } else { - while (root->is_not_leaf) { + while (root->is_not_leaf()) { val = x_data[root->feature_id]; - root = val <= root->value ? root->truenode : root->falsenode; + root += val <= root->value_or_unique_weight ? root->truenode_inc_or_first_weight : root->falsenode_inc_or_n_weights; } } break; @@ -649,39 +693,39 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( } } else { // Different rules to compare to node thresholds. ThresholdType threshold; - while (root->is_not_leaf) { + while (root->is_not_leaf()) { val = x_data[root->feature_id]; - threshold = root->value; - switch (root->mode) { + threshold = root->value_or_unique_weight; + switch (root->mode()) { case NODE_MODE::BRANCH_LEQ: - root = val <= threshold || (root->is_missing_track_true && _isnan_(val)) - ? root->truenode - : root->falsenode; + root += val <= threshold || (root->is_missing_track_true() && _isnan_(val)) + ? root->truenode_inc_or_first_weight + : root->falsenode_inc_or_n_weights; break; case NODE_MODE::BRANCH_LT: - root = val < threshold || (root->is_missing_track_true && _isnan_(val)) - ? root->truenode - : root->falsenode; + root += val < threshold || (root->is_missing_track_true() && _isnan_(val)) + ? root->truenode_inc_or_first_weight + : root->falsenode_inc_or_n_weights; break; case NODE_MODE::BRANCH_GTE: - root = val >= threshold || (root->is_missing_track_true && _isnan_(val)) - ? root->truenode - : root->falsenode; + root += val >= threshold || (root->is_missing_track_true() && _isnan_(val)) + ? root->truenode_inc_or_first_weight + : root->falsenode_inc_or_n_weights; break; case NODE_MODE::BRANCH_GT: - root = val > threshold || (root->is_missing_track_true && _isnan_(val)) - ? root->truenode - : root->falsenode; + root += val > threshold || (root->is_missing_track_true() && _isnan_(val)) + ? root->truenode_inc_or_first_weight + : root->falsenode_inc_or_n_weights; break; case NODE_MODE::BRANCH_EQ: - root = val == threshold || (root->is_missing_track_true && _isnan_(val)) - ? root->truenode - : root->falsenode; + root += val == threshold || (root->is_missing_track_true() && _isnan_(val)) + ? root->truenode_inc_or_first_weight + : root->falsenode_inc_or_n_weights; break; case NODE_MODE::BRANCH_NEQ: - root = val != threshold || (root->is_missing_track_true && _isnan_(val)) - ? root->truenode - : root->falsenode; + root += val != threshold || (root->is_missing_track_true() && _isnan_(val)) + ? root->truenode_inc_or_first_weight + : root->falsenode_inc_or_n_weights; break; case NODE_MODE::LEAF: break; diff --git a/onnxruntime/test/providers/cpu/ml/tree_ensembler_classifier_test.cc b/onnxruntime/test/providers/cpu/ml/tree_ensembler_classifier_test.cc index 2a47b678e6..fad78d2337 100644 --- a/onnxruntime/test/providers/cpu/ml/tree_ensembler_classifier_test.cc +++ b/onnxruntime/test/providers/cpu/ml/tree_ensembler_classifier_test.cc @@ -322,8 +322,8 @@ TEST(MLOpTest, TreeEnsembleClassifierBinaryProbabilities) { std::vector modes = {"BRANCH_LEQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF", "BRANCH_LEQ", "LEAF", "BRANCH_LEQ", "BRANCH_LEQ", "LEAF", "LEAF", "LEAF", "BRANCH_LEQ", "BRANCH_LEQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF", "LEAF"}; - //std::vector classes = {0, 1, 2, 3}; - std::vector class_treeids = {0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2}; + //std::vector classes = {0, 1}; + std::vector class_treeids = {0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2}; std::vector class_nodeids = {1, 3, 4, 1, 4, 5, 6, 2, 4, 5, 6}; std::vector class_classids = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1}; std::vector class_weights = {-1.f, 4.f, -1.f, 2.f, -1.f, +1.f, -2.f, 1.f, -1.f, 2.f, -3.f}; @@ -335,13 +335,13 @@ TEST(MLOpTest, TreeEnsembleClassifierBinaryProbabilities) { std::vector probs = {}; std::vector log_probs = {}; std::vector scores{ - 0.2689414f, 0.73105859f, + 0.00669282f, 0.99330717f, 0.04742586f, 0.88079702f, - 0.73105859f, 0.26894140f, - 0.73105859f, 0.26894140f, - 0.73105859f, 0.26894140f, + 0.73105859f, 0.26894142f, + 0.73105859f, 0.26894142f, + 0.73105859f, 0.26894142f, 0.26894140f, 0.73105859f, - 0.73105859f, 0.26894140f, + 0.73105859f, 0.26894142f, 0.5f, 0.04742586f}; //define the context of the operator call