mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Reduce memory usage for TreeEnsemble operators (#14670)
### Description The onnx file is about 5Mb for a lightgbm model with 500 trees. onnxruntime uses additional 10Mb to compute the inference and keeps the onnx structure. This PR reduces the memory usage by almost 50%. The memory used by the onnx node could be freed if there is no optimized graph to save but that's not covered by this PR. ### Motivation and Context Reduce memory usage.
This commit is contained in:
parent
262e46e8ce
commit
df45b12fdf
4 changed files with 219 additions and 153 deletions
|
|
@ -20,14 +20,14 @@ enum class OUTPUT_MODE {
|
|||
ALL_SCORES
|
||||
};
|
||||
|
||||
enum class NODE_MODE {
|
||||
BRANCH_LEQ,
|
||||
BRANCH_LT,
|
||||
BRANCH_GTE,
|
||||
BRANCH_GT,
|
||||
BRANCH_EQ,
|
||||
BRANCH_NEQ,
|
||||
LEAF
|
||||
enum NODE_MODE : uint8_t {
|
||||
LEAF = 1,
|
||||
BRANCH_LEQ = 2,
|
||||
BRANCH_LT = 4,
|
||||
BRANCH_GTE = 6,
|
||||
BRANCH_GT = 8,
|
||||
BRANCH_EQ = 10,
|
||||
BRANCH_NEQ = 12
|
||||
};
|
||||
|
||||
static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ namespace ml {
|
|||
namespace detail {
|
||||
|
||||
struct TreeNodeElementId {
|
||||
int tree_id;
|
||||
int node_id;
|
||||
int64_t tree_id;
|
||||
int64_t node_id;
|
||||
bool operator==(const TreeNodeElementId& xyz) const {
|
||||
return (tree_id == xyz.tree_id) && (node_id == xyz.node_id);
|
||||
}
|
||||
|
|
@ -23,8 +23,8 @@ struct TreeNodeElementId {
|
|||
}
|
||||
struct hash_fn {
|
||||
std::size_t operator()(const TreeNodeElementId& key) const {
|
||||
std::size_t h1 = std::hash<int>()(key.tree_id);
|
||||
std::size_t h2 = std::hash<int>()(key.node_id);
|
||||
std::size_t h1 = std::hash<int64_t>()(key.tree_id);
|
||||
std::size_t h2 = std::hash<int64_t>()(key.node_id);
|
||||
return h1 ^ h2;
|
||||
}
|
||||
};
|
||||
|
|
@ -61,26 +61,39 @@ struct ScoreValue {
|
|||
}
|
||||
};
|
||||
|
||||
enum MissingTrack {
|
||||
kNone,
|
||||
kTrue,
|
||||
kFalse
|
||||
enum MissingTrack : uint8_t {
|
||||
kTrue = 16,
|
||||
kFalse = 0
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TreeNodeElement {
|
||||
TreeNodeElementId id;
|
||||
int feature_id;
|
||||
T value;
|
||||
T hitrates;
|
||||
NODE_MODE mode;
|
||||
TreeNodeElement<T>* truenode;
|
||||
TreeNodeElement<T>* falsenode;
|
||||
MissingTrack missing_tracks;
|
||||
std::vector<SparseValue<T>> weights;
|
||||
|
||||
bool is_not_leaf;
|
||||
bool is_missing_track_true;
|
||||
// Stores the node threshold or the weights if the tree has one target.
|
||||
T value_or_unique_weight;
|
||||
|
||||
// onnx specification says hitrates is used to store information about the node,
|
||||
// but this information is not used for inference.
|
||||
// T hitrates;
|
||||
|
||||
// True node, false node are obtained by computing `this + truenode_inc_or_first_weight`,
|
||||
// `this + falsenode_inc_or_n_weights` if the node is not a leaf.
|
||||
// In case of a leaf, these attributes are used to indicate the position of the weight
|
||||
// in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one,
|
||||
// the weight is also stored in `value_or_unique_weight`.
|
||||
// This implementation assumes a tree has less than 2^21 nodes,
|
||||
// and the total number of leave in the set of trees is below 2^21.
|
||||
uint32_t truenode_inc_or_first_weight;
|
||||
// In case of a leaf, the following attribute indicates the number of weights
|
||||
// in array `TreeEnsembleCommon::weights_`. If not a leaf, it indicates
|
||||
// `this + falsenode_inc_or_n_weights` is the false node.
|
||||
uint32_t falsenode_inc_or_n_weights;
|
||||
uint8_t flags;
|
||||
|
||||
inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); }
|
||||
inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); }
|
||||
inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; }
|
||||
};
|
||||
|
||||
template <typename InputType, typename ThresholdType, typename OutputType>
|
||||
|
|
@ -121,7 +134,8 @@ class TreeAggregator {
|
|||
// N outputs
|
||||
|
||||
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& /*predictions*/,
|
||||
const TreeNodeElement<ThresholdType>& /*root*/) const {}
|
||||
const TreeNodeElement<ThresholdType>& /*root*/,
|
||||
gsl::span<const SparseValue<ThresholdType>> /*weights*/) const {}
|
||||
|
||||
void MergePrediction(InlinedVector<ScoreValue<ThresholdType>>& /*predictions*/,
|
||||
const InlinedVector<ScoreValue<ThresholdType>>& /*predictions2*/) const {}
|
||||
|
|
@ -158,7 +172,7 @@ class TreeAggregatorSum : public TreeAggregator<InputType, ThresholdType, Output
|
|||
|
||||
void ProcessTreeNodePrediction1(ScoreValue<ThresholdType>& prediction,
|
||||
const TreeNodeElement<ThresholdType>& root) const {
|
||||
prediction.score += root.weights[0].value;
|
||||
prediction.score += root.value_or_unique_weight;
|
||||
}
|
||||
|
||||
void MergePrediction1(ScoreValue<ThresholdType>& prediction,
|
||||
|
|
@ -176,8 +190,10 @@ class TreeAggregatorSum : public TreeAggregator<InputType, ThresholdType, Output
|
|||
// N outputs
|
||||
|
||||
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& predictions,
|
||||
const TreeNodeElement<ThresholdType>& root) const {
|
||||
for (auto it = root.weights.cbegin(); it != root.weights.cend(); ++it) {
|
||||
const TreeNodeElement<ThresholdType>& root,
|
||||
gsl::span<const SparseValue<ThresholdType>> weights) const {
|
||||
auto it = weights.begin() + root.truenode_inc_or_first_weight;
|
||||
for (uint32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) {
|
||||
ORT_ENFORCE(it->i < (int64_t)predictions.size());
|
||||
predictions[onnxruntime::narrow<size_t>(it->i)].score += it->value;
|
||||
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
|
||||
|
|
@ -260,8 +276,8 @@ class TreeAggregatorMin : public TreeAggregator<InputType, ThresholdType, Output
|
|||
|
||||
void ProcessTreeNodePrediction1(ScoreValue<ThresholdType>& prediction,
|
||||
const TreeNodeElement<ThresholdType>& root) const {
|
||||
prediction.score = (!(prediction.has_score) || root.weights[0].value < prediction.score)
|
||||
? root.weights[0].value
|
||||
prediction.score = (!(prediction.has_score) || root.value_or_unique_weight < prediction.score)
|
||||
? root.value_or_unique_weight
|
||||
: prediction.score;
|
||||
prediction.has_score = 1;
|
||||
}
|
||||
|
|
@ -279,11 +295,14 @@ class TreeAggregatorMin : public TreeAggregator<InputType, ThresholdType, Output
|
|||
// N outputs
|
||||
|
||||
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& predictions,
|
||||
const TreeNodeElement<ThresholdType>& root) const {
|
||||
for (auto it = root.weights.begin(); it != root.weights.end(); ++it) {
|
||||
predictions[onnxruntime::narrow<size_t>(it->i)].score = (!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value < predictions[onnxruntime::narrow<size_t>(it->i)].score)
|
||||
? it->value
|
||||
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
|
||||
const TreeNodeElement<ThresholdType>& root,
|
||||
gsl::span<const SparseValue<ThresholdType>> weights) const {
|
||||
auto it = weights.begin() + root.truenode_inc_or_first_weight;
|
||||
for (uint32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) {
|
||||
predictions[onnxruntime::narrow<size_t>(it->i)].score =
|
||||
(!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value < predictions[onnxruntime::narrow<size_t>(it->i)].score)
|
||||
? it->value
|
||||
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
|
||||
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
|
||||
}
|
||||
}
|
||||
|
|
@ -316,8 +335,8 @@ class TreeAggregatorMax : public TreeAggregator<InputType, ThresholdType, Output
|
|||
|
||||
void ProcessTreeNodePrediction1(ScoreValue<ThresholdType>& prediction,
|
||||
const TreeNodeElement<ThresholdType>& root) const {
|
||||
prediction.score = (!(prediction.has_score) || root.weights[0].value > prediction.score)
|
||||
? root.weights[0].value
|
||||
prediction.score = (!(prediction.has_score) || root.value_or_unique_weight > prediction.score)
|
||||
? root.value_or_unique_weight
|
||||
: prediction.score;
|
||||
prediction.has_score = 1;
|
||||
}
|
||||
|
|
@ -334,11 +353,14 @@ class TreeAggregatorMax : public TreeAggregator<InputType, ThresholdType, Output
|
|||
// N outputs
|
||||
|
||||
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& predictions,
|
||||
const TreeNodeElement<ThresholdType>& root) const {
|
||||
for (auto it = root.weights.begin(); it != root.weights.end(); ++it) {
|
||||
predictions[onnxruntime::narrow<size_t>(it->i)].score = (!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value > predictions[onnxruntime::narrow<size_t>(it->i)].score)
|
||||
? it->value
|
||||
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
|
||||
const TreeNodeElement<ThresholdType>& root,
|
||||
gsl::span<const SparseValue<ThresholdType>> weights) const {
|
||||
auto it = weights.begin() + root.truenode_inc_or_first_weight;
|
||||
for (uint32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) {
|
||||
predictions[onnxruntime::narrow<size_t>(it->i)].score =
|
||||
(!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value > predictions[onnxruntime::narrow<size_t>(it->i)].score)
|
||||
? it->value
|
||||
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
|
||||
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,6 +42,10 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
|
|||
protected:
|
||||
std::vector<ThresholdType> base_values_;
|
||||
std::vector<TreeNodeElement<ThresholdType>> nodes_;
|
||||
// Type of weights should be a vector of OutputType. Onnx specifications says it must be float.
|
||||
// Lightgbm requires a double to do the summation of all trees predictions. That's why
|
||||
// `ThresholdType` is used as well for output type (double as well for lightgbm) and not `OutputType`.
|
||||
std::vector<SparseValue<ThresholdType>> weights_;
|
||||
std::vector<TreeNodeElement<ThresholdType>*> roots_;
|
||||
|
||||
public:
|
||||
|
|
@ -180,14 +184,17 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
}
|
||||
n_targets_or_classes_ = n_targets_or_classes;
|
||||
max_tree_depth_ = 1000;
|
||||
ORT_ENFORCE(nodes_modes.size() < std::numeric_limits<uint32_t>::max());
|
||||
|
||||
// additional members
|
||||
size_t i, limit;
|
||||
std::vector<NODE_MODE> cmodes(nodes_modes.size());
|
||||
size_t limit;
|
||||
uint32_t i;
|
||||
InlinedVector<NODE_MODE> cmodes;
|
||||
cmodes.reserve(nodes_modes.size());
|
||||
same_mode_ = true;
|
||||
int fpos = -1;
|
||||
for (i = 0, limit = nodes_modes.size(); i < limit; ++i) {
|
||||
cmodes[i] = MakeTreeNodeMode(nodes_modes[i]);
|
||||
cmodes.push_back(MakeTreeNodeMode(nodes_modes[i]));
|
||||
if (cmodes[i] == NODE_MODE::LEAF)
|
||||
continue;
|
||||
if (fpos == -1) {
|
||||
|
|
@ -201,103 +208,138 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
// filling nodes
|
||||
|
||||
n_nodes_ = nodes_treeids.size();
|
||||
nodes_.resize(nodes_treeids.size());
|
||||
limit = static_cast<size_t>(n_nodes_);
|
||||
InlinedVector<TreeNodeElementId> node_tree_ids;
|
||||
node_tree_ids.reserve(limit);
|
||||
nodes_.clear();
|
||||
nodes_.reserve(limit);
|
||||
roots_.clear();
|
||||
std::unordered_map<TreeNodeElementId, TreeNodeElement<ThresholdType>*, TreeNodeElementId::hash_fn> idi;
|
||||
std::unordered_map<TreeNodeElementId, uint32_t, TreeNodeElementId::hash_fn> idi;
|
||||
idi.reserve(limit);
|
||||
max_feature_id_ = 0;
|
||||
|
||||
for (i = 0, limit = nodes_treeids.size(); i < limit; ++i) {
|
||||
TreeNodeElement<ThresholdType>& node = nodes_[i];
|
||||
node.id.tree_id = static_cast<int>(nodes_treeids[i]);
|
||||
node.id.node_id = static_cast<int>(nodes_nodeids[i]);
|
||||
for (i = 0; i < limit; ++i) {
|
||||
TreeNodeElementId node_tree_id{static_cast<int>(nodes_treeids[i]),
|
||||
static_cast<int>(nodes_nodeids[i])};
|
||||
TreeNodeElement<ThresholdType> node;
|
||||
node.feature_id = static_cast<int>(nodes_featureids[i]);
|
||||
if (node.feature_id > max_feature_id_) {
|
||||
max_feature_id_ = node.feature_id;
|
||||
}
|
||||
if (nodes_values_as_tensor.empty()) {
|
||||
node.value = static_cast<ThresholdType>(nodes_values[i]);
|
||||
} else {
|
||||
node.value = nodes_values_as_tensor[i];
|
||||
}
|
||||
node.value_or_unique_weight = nodes_values_as_tensor.empty()
|
||||
? static_cast<ThresholdType>(nodes_values[i])
|
||||
: nodes_values_as_tensor[i];
|
||||
|
||||
/* hitrates is not used for inference, they are ignored.
|
||||
if (nodes_hitrates_as_tensor.empty()) {
|
||||
node.hitrates = static_cast<ThresholdType>(i < nodes_hitrates.size() ? nodes_hitrates[i] : -1);
|
||||
} else {
|
||||
node.hitrates = i < nodes_hitrates_as_tensor.size() ? nodes_hitrates_as_tensor[i] : -1;
|
||||
} */
|
||||
|
||||
node.flags = static_cast<uint8_t>(cmodes[i]);
|
||||
node.truenode_inc_or_first_weight = 0; // nodes_truenodeids[i] if not a leaf
|
||||
node.falsenode_inc_or_n_weights = 0; // nodes_falsenodeids[i] if not a leaf
|
||||
|
||||
if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
|
||||
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
|
||||
}
|
||||
node.mode = cmodes[i];
|
||||
node.is_not_leaf = node.mode != NODE_MODE::LEAF;
|
||||
node.truenode = nullptr; // nodes_truenodeids[i];
|
||||
node.falsenode = nullptr; // nodes_falsenodeids[i];
|
||||
node.missing_tracks = i < static_cast<size_t>(nodes_missing_value_tracks_true.size())
|
||||
? (nodes_missing_value_tracks_true[i] == 1
|
||||
? MissingTrack::kTrue
|
||||
: MissingTrack::kFalse)
|
||||
: MissingTrack::kNone;
|
||||
node.is_missing_track_true = node.missing_tracks == MissingTrack::kTrue;
|
||||
if (idi.find(node.id) != idi.end()) {
|
||||
ORT_THROW("Node ", node.id.node_id, " in tree ", node.id.tree_id, " is already there.");
|
||||
auto p = idi.insert(std::pair<TreeNodeElementId, uint32_t>(node_tree_id, i));
|
||||
if (!p.second) {
|
||||
ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there.");
|
||||
}
|
||||
idi.insert(std::pair<TreeNodeElementId, TreeNodeElement<ThresholdType>*>(node.id, &node));
|
||||
nodes_.emplace_back(node);
|
||||
node_tree_ids.emplace_back(node_tree_id);
|
||||
}
|
||||
|
||||
InlinedVector<int64_t> truenode_ids, falsenode_ids;
|
||||
truenode_ids.reserve(limit);
|
||||
falsenode_ids.reserve(limit);
|
||||
TreeNodeElementId coor;
|
||||
i = 0;
|
||||
for (auto it = nodes_.begin(); it != nodes_.end(); ++it, ++i) {
|
||||
if (!it->is_not_leaf)
|
||||
if (!it->is_not_leaf()) {
|
||||
truenode_ids.push_back(0);
|
||||
falsenode_ids.push_back(0);
|
||||
continue;
|
||||
i = std::distance(nodes_.begin(), it);
|
||||
coor.tree_id = it->id.tree_id;
|
||||
}
|
||||
|
||||
TreeNodeElementId& node_tree_id = node_tree_ids[i];
|
||||
coor.tree_id = node_tree_id.tree_id;
|
||||
coor.node_id = static_cast<int>(nodes_truenodeids[i]);
|
||||
|
||||
auto found = idi.find(coor);
|
||||
if (found == idi.end()) {
|
||||
ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (truenode).");
|
||||
}
|
||||
if (coor.node_id >= 0 && coor.node_id < n_nodes_) {
|
||||
it->truenode = found->second;
|
||||
if ((it->truenode->id.tree_id != it->id.tree_id) ||
|
||||
(it->truenode->id.node_id == it->id.node_id)) {
|
||||
ORT_THROW("One falsenode is pointing either to itself, either to another tree.");
|
||||
}
|
||||
} else
|
||||
it->truenode = nullptr;
|
||||
truenode_ids.emplace_back((coor.node_id >= 0 && coor.node_id < n_nodes_) ? found->second : 0);
|
||||
|
||||
coor.node_id = static_cast<int>(nodes_falsenodeids[i]);
|
||||
found = idi.find(coor);
|
||||
if (found == idi.end()) {
|
||||
ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (falsenode).");
|
||||
}
|
||||
if (coor.node_id >= 0 && coor.node_id < n_nodes_) {
|
||||
it->falsenode = found->second;
|
||||
if ((it->falsenode->id.tree_id != it->id.tree_id) ||
|
||||
(it->falsenode->id.node_id == it->id.node_id)) {
|
||||
ORT_THROW("One falsenode is pointing either to itself, either to another tree.");
|
||||
}
|
||||
} else
|
||||
it->falsenode = nullptr;
|
||||
falsenode_ids.emplace_back((coor.node_id >= 0 && coor.node_id < n_nodes_) ? found->second : 0);
|
||||
}
|
||||
|
||||
int64_t previous = -1;
|
||||
for (i = 0; i < static_cast<size_t>(n_nodes_); ++i) {
|
||||
if ((previous == -1) || (previous != nodes_[i].id.tree_id))
|
||||
roots_.push_back(&(nodes_[i]));
|
||||
previous = nodes_[i].id.tree_id;
|
||||
// sort targets
|
||||
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
|
||||
indices.reserve(target_class_nodeids.size());
|
||||
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
|
||||
indices.emplace_back(std::pair<TreeNodeElementId, uint32_t>(
|
||||
TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]},
|
||||
i));
|
||||
}
|
||||
std::sort(indices.begin(), indices.end());
|
||||
|
||||
// Initialize the leaves.
|
||||
TreeNodeElementId ind;
|
||||
SparseValue<ThresholdType> w;
|
||||
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
|
||||
ind.tree_id = static_cast<int>(target_class_treeids[i]);
|
||||
ind.node_id = static_cast<int>(target_class_nodeids[i]);
|
||||
if (idi.find(ind) == idi.end()) {
|
||||
ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (weights).");
|
||||
size_t indi;
|
||||
for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) {
|
||||
ind = indices[indi].first;
|
||||
i = indices[indi].second;
|
||||
auto found = idi.find(ind);
|
||||
if (found == idi.end()) {
|
||||
ORT_THROW("Unable to find node ", ind.tree_id, "-", ind.node_id, " (weights).");
|
||||
}
|
||||
|
||||
TreeNodeElement<ThresholdType>& leaf = nodes_[found->second];
|
||||
if (leaf.is_not_leaf()) {
|
||||
// An exception should be raised in that case. But this case may happen in
|
||||
// models converted with an old version of onnxmltools. There weights are ignored.
|
||||
// ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf.");
|
||||
continue;
|
||||
}
|
||||
|
||||
w.i = target_class_ids[i];
|
||||
if (target_class_weights_as_tensor.empty()) {
|
||||
w.value = static_cast<ThresholdType>(target_class_weights[i]);
|
||||
} else {
|
||||
w.value = target_class_weights_as_tensor[i];
|
||||
w.value = target_class_weights_as_tensor.empty()
|
||||
? static_cast<ThresholdType>(target_class_weights[i])
|
||||
: target_class_weights_as_tensor[i];
|
||||
if (leaf.falsenode_inc_or_n_weights == 0) {
|
||||
leaf.truenode_inc_or_first_weight = static_cast<uint32_t>(weights_.size());
|
||||
leaf.value_or_unique_weight = w.value;
|
||||
}
|
||||
idi[ind]->weights.push_back(w);
|
||||
++leaf.falsenode_inc_or_n_weights;
|
||||
weights_.push_back(w);
|
||||
}
|
||||
|
||||
// Initialize all the nodes but the leaves.
|
||||
int64_t previous = -1;
|
||||
for (i = 0, limit = static_cast<uint32_t>(n_nodes_); i < limit; ++i) {
|
||||
if ((previous == -1) || (previous != node_tree_ids[i].tree_id))
|
||||
roots_.push_back(&(nodes_[idi[node_tree_ids[i]]]));
|
||||
previous = node_tree_ids[i].tree_id;
|
||||
if (!nodes_[i].is_not_leaf()) {
|
||||
if (nodes_[i].falsenode_inc_or_n_weights == 0) {
|
||||
ORT_THROW("Target is missing for leaf ", ind.tree_id, "-", ind.node_id, ".");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
ORT_ENFORCE(truenode_ids[i] == 0 || truenode_ids[i] > i);
|
||||
nodes_[i].truenode_inc_or_first_weight = truenode_ids[i] == 0 ? 0 : static_cast<uint32_t>(truenode_ids[i] - i);
|
||||
ORT_ENFORCE(falsenode_ids[i] == 0 || falsenode_ids[i] > i);
|
||||
nodes_[i].falsenode_inc_or_n_weights = falsenode_ids[i] == 0 ? 0 : static_cast<uint32_t>(falsenode_ids[i] - i);
|
||||
}
|
||||
|
||||
n_trees_ = roots_.size();
|
||||
|
|
@ -479,7 +521,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
|
|||
if (n_trees_ <= parallel_tree_ || max_num_threads == 1) { /* section A2 */
|
||||
InlinedVector<ScoreValue<ThresholdType>> scores(onnxruntime::narrow<size_t>(n_targets_or_classes_), {0, 0});
|
||||
for (int64_t j = 0; j < n_trees_; ++j) {
|
||||
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[onnxruntime::narrow<size_t>(j)], x_data));
|
||||
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[onnxruntime::narrow<size_t>(j)], x_data), weights_);
|
||||
}
|
||||
agg.FinalizeScores(scores, z_data, -1, label_data);
|
||||
} else { /* section B2: 2+ outputs, 1 row, enough trees to parallelize */
|
||||
|
|
@ -492,7 +534,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
|
|||
scores[batch_num].resize(onnxruntime::narrow<size_t>(n_targets_or_classes_), {0, 0});
|
||||
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, onnxruntime::narrow<size_t>(n_trees_));
|
||||
for (auto j = work.start; j < work.end; ++j) {
|
||||
agg.ProcessTreeNodePrediction(scores[batch_num], *ProcessTreeNodeLeave(roots_[j], x_data));
|
||||
agg.ProcessTreeNodePrediction(scores[batch_num], *ProcessTreeNodeLeave(roots_[j], x_data), weights_);
|
||||
}
|
||||
});
|
||||
for (size_t i = 1, limit = scores.size(); i < limit; ++i) {
|
||||
|
|
@ -515,7 +557,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
|
|||
}
|
||||
for (j = 0, limit = roots_.size(); j < limit; ++j) {
|
||||
for (i = batch; i < batch_end; ++i) {
|
||||
agg.ProcessTreeNodePrediction(scores[SafeInt<ptrdiff_t>(i - batch)], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
|
||||
agg.ProcessTreeNodePrediction(scores[SafeInt<ptrdiff_t>(i - batch)], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride), weights_);
|
||||
}
|
||||
}
|
||||
for (i = batch; i < batch_end; ++i) {
|
||||
|
|
@ -541,7 +583,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
|
|||
for (auto j = work.start; j < work.end; ++j) {
|
||||
for (int64_t i = begin_n; i < end_n; ++i) {
|
||||
agg.ProcessTreeNodePrediction(scores[batch_num * SafeInt<ptrdiff_t>(N) + i],
|
||||
*ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
|
||||
*ProcessTreeNodeLeave(roots_[j], x_data + i * stride), weights_);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
@ -573,7 +615,7 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
|
|||
for (auto i = work.start; i < work.end; ++i) {
|
||||
std::fill(scores.begin(), scores.end(), ScoreValue<ThresholdType>({0, 0}));
|
||||
for (j = 0, limit = roots_.size(); j < limit; ++j) {
|
||||
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
|
||||
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride), weights_);
|
||||
}
|
||||
|
||||
agg.FinalizeScores(scores,
|
||||
|
|
@ -587,17 +629,19 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
|
|||
|
||||
#define TREE_FIND_VALUE(CMP) \
|
||||
if (has_missing_tracks_) { \
|
||||
while (root->is_not_leaf) { \
|
||||
while (root->is_not_leaf()) { \
|
||||
val = x_data[root->feature_id]; \
|
||||
root = (val CMP root->value || \
|
||||
(root->is_missing_track_true && _isnan_(val))) \
|
||||
? root->truenode \
|
||||
: root->falsenode; \
|
||||
root += (val CMP root->value_or_unique_weight || \
|
||||
(root->is_missing_track_true() && _isnan_(val))) \
|
||||
? root->truenode_inc_or_first_weight \
|
||||
: root->falsenode_inc_or_n_weights; \
|
||||
} \
|
||||
} else { \
|
||||
while (root->is_not_leaf) { \
|
||||
while (root->is_not_leaf()) { \
|
||||
val = x_data[root->feature_id]; \
|
||||
root = val CMP root->value ? root->truenode : root->falsenode; \
|
||||
root += val CMP root->value_or_unique_weight \
|
||||
? root->truenode_inc_or_first_weight \
|
||||
: root->falsenode_inc_or_n_weights; \
|
||||
} \
|
||||
}
|
||||
|
||||
|
|
@ -612,20 +656,20 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
|
|||
TreeNodeElement<ThresholdType>* root, const InputType* x_data) const {
|
||||
InputType val;
|
||||
if (same_mode_) {
|
||||
switch (root->mode) {
|
||||
switch (root->mode()) {
|
||||
case NODE_MODE::BRANCH_LEQ:
|
||||
if (has_missing_tracks_) {
|
||||
while (root->is_not_leaf) {
|
||||
while (root->is_not_leaf()) {
|
||||
val = x_data[root->feature_id];
|
||||
root = (val <= root->value ||
|
||||
(root->is_missing_track_true && _isnan_(val)))
|
||||
? root->truenode
|
||||
: root->falsenode;
|
||||
root += (val <= root->value_or_unique_weight ||
|
||||
(root->is_missing_track_true() && _isnan_(val)))
|
||||
? root->truenode_inc_or_first_weight
|
||||
: root->falsenode_inc_or_n_weights;
|
||||
}
|
||||
} else {
|
||||
while (root->is_not_leaf) {
|
||||
while (root->is_not_leaf()) {
|
||||
val = x_data[root->feature_id];
|
||||
root = val <= root->value ? root->truenode : root->falsenode;
|
||||
root += val <= root->value_or_unique_weight ? root->truenode_inc_or_first_weight : root->falsenode_inc_or_n_weights;
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
|
@ -649,39 +693,39 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
|
|||
}
|
||||
} else { // Different rules to compare to node thresholds.
|
||||
ThresholdType threshold;
|
||||
while (root->is_not_leaf) {
|
||||
while (root->is_not_leaf()) {
|
||||
val = x_data[root->feature_id];
|
||||
threshold = root->value;
|
||||
switch (root->mode) {
|
||||
threshold = root->value_or_unique_weight;
|
||||
switch (root->mode()) {
|
||||
case NODE_MODE::BRANCH_LEQ:
|
||||
root = val <= threshold || (root->is_missing_track_true && _isnan_(val))
|
||||
? root->truenode
|
||||
: root->falsenode;
|
||||
root += val <= threshold || (root->is_missing_track_true() && _isnan_(val))
|
||||
? root->truenode_inc_or_first_weight
|
||||
: root->falsenode_inc_or_n_weights;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_LT:
|
||||
root = val < threshold || (root->is_missing_track_true && _isnan_(val))
|
||||
? root->truenode
|
||||
: root->falsenode;
|
||||
root += val < threshold || (root->is_missing_track_true() && _isnan_(val))
|
||||
? root->truenode_inc_or_first_weight
|
||||
: root->falsenode_inc_or_n_weights;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_GTE:
|
||||
root = val >= threshold || (root->is_missing_track_true && _isnan_(val))
|
||||
? root->truenode
|
||||
: root->falsenode;
|
||||
root += val >= threshold || (root->is_missing_track_true() && _isnan_(val))
|
||||
? root->truenode_inc_or_first_weight
|
||||
: root->falsenode_inc_or_n_weights;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_GT:
|
||||
root = val > threshold || (root->is_missing_track_true && _isnan_(val))
|
||||
? root->truenode
|
||||
: root->falsenode;
|
||||
root += val > threshold || (root->is_missing_track_true() && _isnan_(val))
|
||||
? root->truenode_inc_or_first_weight
|
||||
: root->falsenode_inc_or_n_weights;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_EQ:
|
||||
root = val == threshold || (root->is_missing_track_true && _isnan_(val))
|
||||
? root->truenode
|
||||
: root->falsenode;
|
||||
root += val == threshold || (root->is_missing_track_true() && _isnan_(val))
|
||||
? root->truenode_inc_or_first_weight
|
||||
: root->falsenode_inc_or_n_weights;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_NEQ:
|
||||
root = val != threshold || (root->is_missing_track_true && _isnan_(val))
|
||||
? root->truenode
|
||||
: root->falsenode;
|
||||
root += val != threshold || (root->is_missing_track_true() && _isnan_(val))
|
||||
? root->truenode_inc_or_first_weight
|
||||
: root->falsenode_inc_or_n_weights;
|
||||
break;
|
||||
case NODE_MODE::LEAF:
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -322,8 +322,8 @@ TEST(MLOpTest, TreeEnsembleClassifierBinaryProbabilities) {
|
|||
std::vector<std::string> modes = {"BRANCH_LEQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF", "BRANCH_LEQ",
|
||||
"LEAF", "BRANCH_LEQ", "BRANCH_LEQ", "LEAF", "LEAF", "LEAF",
|
||||
"BRANCH_LEQ", "BRANCH_LEQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF", "LEAF"};
|
||||
//std::vector<int64_t> classes = {0, 1, 2, 3};
|
||||
std::vector<int64_t> class_treeids = {0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2};
|
||||
//std::vector<int64_t> classes = {0, 1};
|
||||
std::vector<int64_t> class_treeids = {0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2};
|
||||
std::vector<int64_t> class_nodeids = {1, 3, 4, 1, 4, 5, 6, 2, 4, 5, 6};
|
||||
std::vector<int64_t> class_classids = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1};
|
||||
std::vector<float> class_weights = {-1.f, 4.f, -1.f, 2.f, -1.f, +1.f, -2.f, 1.f, -1.f, 2.f, -3.f};
|
||||
|
|
@ -335,13 +335,13 @@ TEST(MLOpTest, TreeEnsembleClassifierBinaryProbabilities) {
|
|||
std::vector<float> probs = {};
|
||||
std::vector<float> log_probs = {};
|
||||
std::vector<float> scores{
|
||||
0.2689414f, 0.73105859f,
|
||||
0.00669282f, 0.99330717f,
|
||||
0.04742586f, 0.88079702f,
|
||||
0.73105859f, 0.26894140f,
|
||||
0.73105859f, 0.26894140f,
|
||||
0.73105859f, 0.26894140f,
|
||||
0.73105859f, 0.26894142f,
|
||||
0.73105859f, 0.26894142f,
|
||||
0.73105859f, 0.26894142f,
|
||||
0.26894140f, 0.73105859f,
|
||||
0.73105859f, 0.26894140f,
|
||||
0.73105859f, 0.26894142f,
|
||||
0.5f, 0.04742586f};
|
||||
|
||||
//define the context of the operator call
|
||||
|
|
|
|||
Loading…
Reference in a new issue