Reduce memory usage for TreeEnsemble operators (#14670)

### Description
The onnx file is about 5Mb for a lightgbm model with 500 trees.
onnxruntime uses additional 10Mb to compute the inference and keeps the
onnx structure. This PR reduces the memory usage by almost 50%. The
memory used by the onnx node could be freed if there is no optimized
graph to save but that's not covered by this PR.

### Motivation and Context
Reduce memory usage.
This commit is contained in:
Xavier Dupré 2023-02-22 12:20:02 +01:00 committed by GitHub
parent 262e46e8ce
commit df45b12fdf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 219 additions and 153 deletions

View file

@ -20,14 +20,14 @@ enum class OUTPUT_MODE {
ALL_SCORES
};
enum class NODE_MODE {
BRANCH_LEQ,
BRANCH_LT,
BRANCH_GTE,
BRANCH_GT,
BRANCH_EQ,
BRANCH_NEQ,
LEAF
enum NODE_MODE : uint8_t {
LEAF = 1,
BRANCH_LEQ = 2,
BRANCH_LT = 4,
BRANCH_GTE = 6,
BRANCH_GT = 8,
BRANCH_EQ = 10,
BRANCH_NEQ = 12
};
static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {

View file

@ -13,8 +13,8 @@ namespace ml {
namespace detail {
struct TreeNodeElementId {
int tree_id;
int node_id;
int64_t tree_id;
int64_t node_id;
bool operator==(const TreeNodeElementId& xyz) const {
return (tree_id == xyz.tree_id) && (node_id == xyz.node_id);
}
@ -23,8 +23,8 @@ struct TreeNodeElementId {
}
struct hash_fn {
std::size_t operator()(const TreeNodeElementId& key) const {
std::size_t h1 = std::hash<int>()(key.tree_id);
std::size_t h2 = std::hash<int>()(key.node_id);
std::size_t h1 = std::hash<int64_t>()(key.tree_id);
std::size_t h2 = std::hash<int64_t>()(key.node_id);
return h1 ^ h2;
}
};
@ -61,26 +61,39 @@ struct ScoreValue {
}
};
enum MissingTrack {
kNone,
kTrue,
kFalse
enum MissingTrack : uint8_t {
kTrue = 16,
kFalse = 0
};
template <typename T>
struct TreeNodeElement {
TreeNodeElementId id;
int feature_id;
T value;
T hitrates;
NODE_MODE mode;
TreeNodeElement<T>* truenode;
TreeNodeElement<T>* falsenode;
MissingTrack missing_tracks;
std::vector<SparseValue<T>> weights;
bool is_not_leaf;
bool is_missing_track_true;
// Stores the node threshold or the weights if the tree has one target.
T value_or_unique_weight;
// onnx specification says hitrates is used to store information about the node,
// but this information is not used for inference.
// T hitrates;
// True node, false node are obtained by computing `this + truenode_inc_or_first_weight`,
// `this + falsenode_inc_or_n_weights` if the node is not a leaf.
// In case of a leaf, these attributes are used to indicate the position of the weight
// in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one,
// the weight is also stored in `value_or_unique_weight`.
// This implementation assumes a tree has less than 2^21 nodes,
// and the total number of leave in the set of trees is below 2^21.
uint32_t truenode_inc_or_first_weight;
// In case of a leaf, the following attribute indicates the number of weights
// in array `TreeEnsembleCommon::weights_`. If not a leaf, it indicates
// `this + falsenode_inc_or_n_weights` is the false node.
uint32_t falsenode_inc_or_n_weights;
uint8_t flags;
inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); }
inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); }
inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; }
};
template <typename InputType, typename ThresholdType, typename OutputType>
@ -121,7 +134,8 @@ class TreeAggregator {
// N outputs
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& /*predictions*/,
const TreeNodeElement<ThresholdType>& /*root*/) const {}
const TreeNodeElement<ThresholdType>& /*root*/,
gsl::span<const SparseValue<ThresholdType>> /*weights*/) const {}
void MergePrediction(InlinedVector<ScoreValue<ThresholdType>>& /*predictions*/,
const InlinedVector<ScoreValue<ThresholdType>>& /*predictions2*/) const {}
@ -158,7 +172,7 @@ class TreeAggregatorSum : public TreeAggregator<InputType, ThresholdType, Output
void ProcessTreeNodePrediction1(ScoreValue<ThresholdType>& prediction,
const TreeNodeElement<ThresholdType>& root) const {
prediction.score += root.weights[0].value;
prediction.score += root.value_or_unique_weight;
}
void MergePrediction1(ScoreValue<ThresholdType>& prediction,
@ -176,8 +190,10 @@ class TreeAggregatorSum : public TreeAggregator<InputType, ThresholdType, Output
// N outputs
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& predictions,
const TreeNodeElement<ThresholdType>& root) const {
for (auto it = root.weights.cbegin(); it != root.weights.cend(); ++it) {
const TreeNodeElement<ThresholdType>& root,
gsl::span<const SparseValue<ThresholdType>> weights) const {
auto it = weights.begin() + root.truenode_inc_or_first_weight;
for (uint32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) {
ORT_ENFORCE(it->i < (int64_t)predictions.size());
predictions[onnxruntime::narrow<size_t>(it->i)].score += it->value;
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
@ -260,8 +276,8 @@ class TreeAggregatorMin : public TreeAggregator<InputType, ThresholdType, Output
void ProcessTreeNodePrediction1(ScoreValue<ThresholdType>& prediction,
const TreeNodeElement<ThresholdType>& root) const {
prediction.score = (!(prediction.has_score) || root.weights[0].value < prediction.score)
? root.weights[0].value
prediction.score = (!(prediction.has_score) || root.value_or_unique_weight < prediction.score)
? root.value_or_unique_weight
: prediction.score;
prediction.has_score = 1;
}
@ -279,11 +295,14 @@ class TreeAggregatorMin : public TreeAggregator<InputType, ThresholdType, Output
// N outputs
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& predictions,
const TreeNodeElement<ThresholdType>& root) const {
for (auto it = root.weights.begin(); it != root.weights.end(); ++it) {
predictions[onnxruntime::narrow<size_t>(it->i)].score = (!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value < predictions[onnxruntime::narrow<size_t>(it->i)].score)
? it->value
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
const TreeNodeElement<ThresholdType>& root,
gsl::span<const SparseValue<ThresholdType>> weights) const {
auto it = weights.begin() + root.truenode_inc_or_first_weight;
for (uint32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) {
predictions[onnxruntime::narrow<size_t>(it->i)].score =
(!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value < predictions[onnxruntime::narrow<size_t>(it->i)].score)
? it->value
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
}
}
@ -316,8 +335,8 @@ class TreeAggregatorMax : public TreeAggregator<InputType, ThresholdType, Output
void ProcessTreeNodePrediction1(ScoreValue<ThresholdType>& prediction,
const TreeNodeElement<ThresholdType>& root) const {
prediction.score = (!(prediction.has_score) || root.weights[0].value > prediction.score)
? root.weights[0].value
prediction.score = (!(prediction.has_score) || root.value_or_unique_weight > prediction.score)
? root.value_or_unique_weight
: prediction.score;
prediction.has_score = 1;
}
@ -334,11 +353,14 @@ class TreeAggregatorMax : public TreeAggregator<InputType, ThresholdType, Output
// N outputs
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& predictions,
const TreeNodeElement<ThresholdType>& root) const {
for (auto it = root.weights.begin(); it != root.weights.end(); ++it) {
predictions[onnxruntime::narrow<size_t>(it->i)].score = (!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value > predictions[onnxruntime::narrow<size_t>(it->i)].score)
? it->value
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
const TreeNodeElement<ThresholdType>& root,
gsl::span<const SparseValue<ThresholdType>> weights) const {
auto it = weights.begin() + root.truenode_inc_or_first_weight;
for (uint32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) {
predictions[onnxruntime::narrow<size_t>(it->i)].score =
(!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value > predictions[onnxruntime::narrow<size_t>(it->i)].score)
? it->value
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
}
}

View file

@ -42,6 +42,10 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
protected:
std::vector<ThresholdType> base_values_;
std::vector<TreeNodeElement<ThresholdType>> nodes_;
// Type of weights should be a vector of OutputType. Onnx specifications says it must be float.
// Lightgbm requires a double to do the summation of all trees predictions. That's why
// `ThresholdType` is used as well for output type (double as well for lightgbm) and not `OutputType`.
std::vector<SparseValue<ThresholdType>> weights_;
std::vector<TreeNodeElement<ThresholdType>*> roots_;
public:
@ -180,14 +184,17 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}
n_targets_or_classes_ = n_targets_or_classes;
max_tree_depth_ = 1000;
ORT_ENFORCE(nodes_modes.size() < std::numeric_limits<uint32_t>::max());
// additional members
size_t i, limit;
std::vector<NODE_MODE> cmodes(nodes_modes.size());
size_t limit;
uint32_t i;
InlinedVector<NODE_MODE> cmodes;
cmodes.reserve(nodes_modes.size());
same_mode_ = true;
int fpos = -1;
for (i = 0, limit = nodes_modes.size(); i < limit; ++i) {
cmodes[i] = MakeTreeNodeMode(nodes_modes[i]);
cmodes.push_back(MakeTreeNodeMode(nodes_modes[i]));
if (cmodes[i] == NODE_MODE::LEAF)
continue;
if (fpos == -1) {
@ -201,103 +208,138 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
// filling nodes
n_nodes_ = nodes_treeids.size();
nodes_.resize(nodes_treeids.size());
limit = static_cast<size_t>(n_nodes_);
InlinedVector<TreeNodeElementId> node_tree_ids;
node_tree_ids.reserve(limit);
nodes_.clear();
nodes_.reserve(limit);
roots_.clear();
std::unordered_map<TreeNodeElementId, TreeNodeElement<ThresholdType>*, TreeNodeElementId::hash_fn> idi;
std::unordered_map<TreeNodeElementId, uint32_t, TreeNodeElementId::hash_fn> idi;
idi.reserve(limit);
max_feature_id_ = 0;
for (i = 0, limit = nodes_treeids.size(); i < limit; ++i) {
TreeNodeElement<ThresholdType>& node = nodes_[i];
node.id.tree_id = static_cast<int>(nodes_treeids[i]);
node.id.node_id = static_cast<int>(nodes_nodeids[i]);
for (i = 0; i < limit; ++i) {
TreeNodeElementId node_tree_id{static_cast<int>(nodes_treeids[i]),
static_cast<int>(nodes_nodeids[i])};
TreeNodeElement<ThresholdType> node;
node.feature_id = static_cast<int>(nodes_featureids[i]);
if (node.feature_id > max_feature_id_) {
max_feature_id_ = node.feature_id;
}
if (nodes_values_as_tensor.empty()) {
node.value = static_cast<ThresholdType>(nodes_values[i]);
} else {
node.value = nodes_values_as_tensor[i];
}
node.value_or_unique_weight = nodes_values_as_tensor.empty()
? static_cast<ThresholdType>(nodes_values[i])
: nodes_values_as_tensor[i];
/* hitrates is not used for inference, they are ignored.
if (nodes_hitrates_as_tensor.empty()) {
node.hitrates = static_cast<ThresholdType>(i < nodes_hitrates.size() ? nodes_hitrates[i] : -1);
} else {
node.hitrates = i < nodes_hitrates_as_tensor.size() ? nodes_hitrates_as_tensor[i] : -1;
} */
node.flags = static_cast<uint8_t>(cmodes[i]);
node.truenode_inc_or_first_weight = 0; // nodes_truenodeids[i] if not a leaf
node.falsenode_inc_or_n_weights = 0; // nodes_falsenodeids[i] if not a leaf
if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
}
node.mode = cmodes[i];
node.is_not_leaf = node.mode != NODE_MODE::LEAF;
node.truenode = nullptr; // nodes_truenodeids[i];
node.falsenode = nullptr; // nodes_falsenodeids[i];
node.missing_tracks = i < static_cast<size_t>(nodes_missing_value_tracks_true.size())
? (nodes_missing_value_tracks_true[i] == 1
? MissingTrack::kTrue
: MissingTrack::kFalse)
: MissingTrack::kNone;
node.is_missing_track_true = node.missing_tracks == MissingTrack::kTrue;
if (idi.find(node.id) != idi.end()) {
ORT_THROW("Node ", node.id.node_id, " in tree ", node.id.tree_id, " is already there.");
auto p = idi.insert(std::pair<TreeNodeElementId, uint32_t>(node_tree_id, i));
if (!p.second) {
ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there.");
}
idi.insert(std::pair<TreeNodeElementId, TreeNodeElement<ThresholdType>*>(node.id, &node));
nodes_.emplace_back(node);
node_tree_ids.emplace_back(node_tree_id);
}
InlinedVector<int64_t> truenode_ids, falsenode_ids;
truenode_ids.reserve(limit);
falsenode_ids.reserve(limit);
TreeNodeElementId coor;
i = 0;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it, ++i) {
if (!it->is_not_leaf)
if (!it->is_not_leaf()) {
truenode_ids.push_back(0);
falsenode_ids.push_back(0);
continue;
i = std::distance(nodes_.begin(), it);
coor.tree_id = it->id.tree_id;
}
TreeNodeElementId& node_tree_id = node_tree_ids[i];
coor.tree_id = node_tree_id.tree_id;
coor.node_id = static_cast<int>(nodes_truenodeids[i]);
auto found = idi.find(coor);
if (found == idi.end()) {
ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (truenode).");
}
if (coor.node_id >= 0 && coor.node_id < n_nodes_) {
it->truenode = found->second;
if ((it->truenode->id.tree_id != it->id.tree_id) ||
(it->truenode->id.node_id == it->id.node_id)) {
ORT_THROW("One falsenode is pointing either to itself, either to another tree.");
}
} else
it->truenode = nullptr;
truenode_ids.emplace_back((coor.node_id >= 0 && coor.node_id < n_nodes_) ? found->second : 0);
coor.node_id = static_cast<int>(nodes_falsenodeids[i]);
found = idi.find(coor);
if (found == idi.end()) {
ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (falsenode).");
}
if (coor.node_id >= 0 && coor.node_id < n_nodes_) {
it->falsenode = found->second;
if ((it->falsenode->id.tree_id != it->id.tree_id) ||
(it->falsenode->id.node_id == it->id.node_id)) {
ORT_THROW("One falsenode is pointing either to itself, either to another tree.");
}
} else
it->falsenode = nullptr;
falsenode_ids.emplace_back((coor.node_id >= 0 && coor.node_id < n_nodes_) ? found->second : 0);
}
int64_t previous = -1;
for (i = 0; i < static_cast<size_t>(n_nodes_); ++i) {
if ((previous == -1) || (previous != nodes_[i].id.tree_id))
roots_.push_back(&(nodes_[i]));
previous = nodes_[i].id.tree_id;
// sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(std::pair<TreeNodeElementId, uint32_t>(
TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]},
i));
}
std::sort(indices.begin(), indices.end());
// Initialize the leaves.
TreeNodeElementId ind;
SparseValue<ThresholdType> w;
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
ind.tree_id = static_cast<int>(target_class_treeids[i]);
ind.node_id = static_cast<int>(target_class_nodeids[i]);
if (idi.find(ind) == idi.end()) {
ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (weights).");
size_t indi;
for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) {
ind = indices[indi].first;
i = indices[indi].second;
auto found = idi.find(ind);
if (found == idi.end()) {
ORT_THROW("Unable to find node ", ind.tree_id, "-", ind.node_id, " (weights).");
}
TreeNodeElement<ThresholdType>& leaf = nodes_[found->second];
if (leaf.is_not_leaf()) {
// An exception should be raised in that case. But this case may happen in
// models converted with an old version of onnxmltools. There weights are ignored.
// ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf.");
continue;
}
w.i = target_class_ids[i];
if (target_class_weights_as_tensor.empty()) {
w.value = static_cast<ThresholdType>(target_class_weights[i]);
} else {
w.value = target_class_weights_as_tensor[i];
w.value = target_class_weights_as_tensor.empty()
? static_cast<ThresholdType>(target_class_weights[i])
: target_class_weights_as_tensor[i];
if (leaf.falsenode_inc_or_n_weights == 0) {
leaf.truenode_inc_or_first_weight = static_cast<uint32_t>(weights_.size());
leaf.value_or_unique_weight = w.value;
}
idi[ind]->weights.push_back(w);
++leaf.falsenode_inc_or_n_weights;
weights_.push_back(w);
}
// Initialize all the nodes but the leaves.
int64_t previous = -1;
for (i = 0, limit = static_cast<uint32_t>(n_nodes_); i < limit; ++i) {
if ((previous == -1) || (previous != node_tree_ids[i].tree_id))
roots_.push_back(&(nodes_[idi[node_tree_ids[i]]]));
previous = node_tree_ids[i].tree_id;
if (!nodes_[i].is_not_leaf()) {
if (nodes_[i].falsenode_inc_or_n_weights == 0) {
ORT_THROW("Target is missing for leaf ", ind.tree_id, "-", ind.node_id, ".");
}
continue;
}
ORT_ENFORCE(truenode_ids[i] == 0 || truenode_ids[i] > i);
nodes_[i].truenode_inc_or_first_weight = truenode_ids[i] == 0 ? 0 : static_cast<uint32_t>(truenode_ids[i] - i);
ORT_ENFORCE(falsenode_ids[i] == 0 || falsenode_ids[i] > i);
nodes_[i].falsenode_inc_or_n_weights = falsenode_ids[i] == 0 ? 0 : static_cast<uint32_t>(falsenode_ids[i] - i);
}
n_trees_ = roots_.size();
@ -479,7 +521,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
if (n_trees_ <= parallel_tree_ || max_num_threads == 1) { /* section A2 */
InlinedVector<ScoreValue<ThresholdType>> scores(onnxruntime::narrow<size_t>(n_targets_or_classes_), {0, 0});
for (int64_t j = 0; j < n_trees_; ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[onnxruntime::narrow<size_t>(j)], x_data));
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[onnxruntime::narrow<size_t>(j)], x_data), weights_);
}
agg.FinalizeScores(scores, z_data, -1, label_data);
} else { /* section B2: 2+ outputs, 1 row, enough trees to parallelize */
@ -492,7 +534,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
scores[batch_num].resize(onnxruntime::narrow<size_t>(n_targets_or_classes_), {0, 0});
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, onnxruntime::narrow<size_t>(n_trees_));
for (auto j = work.start; j < work.end; ++j) {
agg.ProcessTreeNodePrediction(scores[batch_num], *ProcessTreeNodeLeave(roots_[j], x_data));
agg.ProcessTreeNodePrediction(scores[batch_num], *ProcessTreeNodeLeave(roots_[j], x_data), weights_);
}
});
for (size_t i = 1, limit = scores.size(); i < limit; ++i) {
@ -515,7 +557,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
}
for (j = 0, limit = roots_.size(); j < limit; ++j) {
for (i = batch; i < batch_end; ++i) {
agg.ProcessTreeNodePrediction(scores[SafeInt<ptrdiff_t>(i - batch)], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
agg.ProcessTreeNodePrediction(scores[SafeInt<ptrdiff_t>(i - batch)], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride), weights_);
}
}
for (i = batch; i < batch_end; ++i) {
@ -541,7 +583,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
for (auto j = work.start; j < work.end; ++j) {
for (int64_t i = begin_n; i < end_n; ++i) {
agg.ProcessTreeNodePrediction(scores[batch_num * SafeInt<ptrdiff_t>(N) + i],
*ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
*ProcessTreeNodeLeave(roots_[j], x_data + i * stride), weights_);
}
}
});
@ -573,7 +615,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
for (auto i = work.start; i < work.end; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<ThresholdType>({0, 0}));
for (j = 0, limit = roots_.size(); j < limit; ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride), weights_);
}
agg.FinalizeScores(scores,
@ -587,17 +629,19 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
#define TREE_FIND_VALUE(CMP) \
if (has_missing_tracks_) { \
while (root->is_not_leaf) { \
while (root->is_not_leaf()) { \
val = x_data[root->feature_id]; \
root = (val CMP root->value || \
(root->is_missing_track_true && _isnan_(val))) \
? root->truenode \
: root->falsenode; \
root += (val CMP root->value_or_unique_weight || \
(root->is_missing_track_true() && _isnan_(val))) \
? root->truenode_inc_or_first_weight \
: root->falsenode_inc_or_n_weights; \
} \
} else { \
while (root->is_not_leaf) { \
while (root->is_not_leaf()) { \
val = x_data[root->feature_id]; \
root = val CMP root->value ? root->truenode : root->falsenode; \
root += val CMP root->value_or_unique_weight \
? root->truenode_inc_or_first_weight \
: root->falsenode_inc_or_n_weights; \
} \
}
@ -612,20 +656,20 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
TreeNodeElement<ThresholdType>* root, const InputType* x_data) const {
InputType val;
if (same_mode_) {
switch (root->mode) {
switch (root->mode()) {
case NODE_MODE::BRANCH_LEQ:
if (has_missing_tracks_) {
while (root->is_not_leaf) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = (val <= root->value ||
(root->is_missing_track_true && _isnan_(val)))
? root->truenode
: root->falsenode;
root += (val <= root->value_or_unique_weight ||
(root->is_missing_track_true() && _isnan_(val)))
? root->truenode_inc_or_first_weight
: root->falsenode_inc_or_n_weights;
}
} else {
while (root->is_not_leaf) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = val <= root->value ? root->truenode : root->falsenode;
root += val <= root->value_or_unique_weight ? root->truenode_inc_or_first_weight : root->falsenode_inc_or_n_weights;
}
}
break;
@ -649,39 +693,39 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
}
} else { // Different rules to compare to node thresholds.
ThresholdType threshold;
while (root->is_not_leaf) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
threshold = root->value;
switch (root->mode) {
threshold = root->value_or_unique_weight;
switch (root->mode()) {
case NODE_MODE::BRANCH_LEQ:
root = val <= threshold || (root->is_missing_track_true && _isnan_(val))
? root->truenode
: root->falsenode;
root += val <= threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_inc_or_first_weight
: root->falsenode_inc_or_n_weights;
break;
case NODE_MODE::BRANCH_LT:
root = val < threshold || (root->is_missing_track_true && _isnan_(val))
? root->truenode
: root->falsenode;
root += val < threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_inc_or_first_weight
: root->falsenode_inc_or_n_weights;
break;
case NODE_MODE::BRANCH_GTE:
root = val >= threshold || (root->is_missing_track_true && _isnan_(val))
? root->truenode
: root->falsenode;
root += val >= threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_inc_or_first_weight
: root->falsenode_inc_or_n_weights;
break;
case NODE_MODE::BRANCH_GT:
root = val > threshold || (root->is_missing_track_true && _isnan_(val))
? root->truenode
: root->falsenode;
root += val > threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_inc_or_first_weight
: root->falsenode_inc_or_n_weights;
break;
case NODE_MODE::BRANCH_EQ:
root = val == threshold || (root->is_missing_track_true && _isnan_(val))
? root->truenode
: root->falsenode;
root += val == threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_inc_or_first_weight
: root->falsenode_inc_or_n_weights;
break;
case NODE_MODE::BRANCH_NEQ:
root = val != threshold || (root->is_missing_track_true && _isnan_(val))
? root->truenode
: root->falsenode;
root += val != threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_inc_or_first_weight
: root->falsenode_inc_or_n_weights;
break;
case NODE_MODE::LEAF:
break;

View file

@ -322,8 +322,8 @@ TEST(MLOpTest, TreeEnsembleClassifierBinaryProbabilities) {
std::vector<std::string> modes = {"BRANCH_LEQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF", "BRANCH_LEQ",
"LEAF", "BRANCH_LEQ", "BRANCH_LEQ", "LEAF", "LEAF", "LEAF",
"BRANCH_LEQ", "BRANCH_LEQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF", "LEAF"};
//std::vector<int64_t> classes = {0, 1, 2, 3};
std::vector<int64_t> class_treeids = {0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2};
//std::vector<int64_t> classes = {0, 1};
std::vector<int64_t> class_treeids = {0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2};
std::vector<int64_t> class_nodeids = {1, 3, 4, 1, 4, 5, 6, 2, 4, 5, 6};
std::vector<int64_t> class_classids = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1};
std::vector<float> class_weights = {-1.f, 4.f, -1.f, 2.f, -1.f, +1.f, -2.f, 1.f, -1.f, 2.f, -3.f};
@ -335,13 +335,13 @@ TEST(MLOpTest, TreeEnsembleClassifierBinaryProbabilities) {
std::vector<float> probs = {};
std::vector<float> log_probs = {};
std::vector<float> scores{
0.2689414f, 0.73105859f,
0.00669282f, 0.99330717f,
0.04742586f, 0.88079702f,
0.73105859f, 0.26894140f,
0.73105859f, 0.26894140f,
0.73105859f, 0.26894140f,
0.73105859f, 0.26894142f,
0.73105859f, 0.26894142f,
0.73105859f, 0.26894142f,
0.26894140f, 0.73105859f,
0.73105859f, 0.26894140f,
0.73105859f, 0.26894142f,
0.5f, 0.04742586f};
//define the context of the operator call