Improve TreeNodeElementId hash function (#16459)

### Description
This PR improves `TreeNodeElementId` hash function by employing [Elegant
Pairing function](http://szudzik.com/ElegantPairing.pdf). In few works,
Elegant Pairing function maps two non−negative integers to a
non−negative integer that is uniquely associated with that pair. This
drastically reduces the collision and therefore reduces the time
required to create a session in order to use a large tree ensemble
model.

### Motivation and Context
We use ONNX runtime to serve our models as part of Triton backend. We
noticed that it was taking around 2 minutes to load a model which is a
large tree ensemble model (around 5k trees with around 3 millions nodes
in total). After investigating the issue, it was clear that the
`TreeNodeElementId` hash function wasn't being able to map keys to
buckets of C++ `unordered_map` without a significant amount of
collisions (in same cases 700 items per bucket).

The following picture shows graphically the improvement obtained by the
proposed change. We used the `onnx_test_runner` command.

![flamegraph](https://github.com/microsoft/onnxruntime/assets/3594678/2588e87c-125b-4a4b-8f03-55e00ae25e08)

#### Before
```
$> time ./onnx_test_runner -v ~/folder_with_model
result:
	Models: 1
	Total test cases: 0
		Succeeded: 0
		Not implemented: 0
		Failed: 0
	Stats by Operator type:
		Not implemented(0):
		Failed:
Failed Test Cases:

real	0m55.695s
user	0m52.919s
sys	0m0.760s
```

#### After
```
$> time ./onnx_test_runner -v ~/folder_with_model
result:
	Models: 1
	Total test cases: 0
		Succeeded: 0
		Not implemented: 0
		Failed: 0
	Stats by Operator type:
		Not implemented(0):
		Failed:
Failed Test Cases:

real	0m17.152s
user	0m14.318s
sys	0m0.619s
```
This commit is contained in:
Luis Rios 2023-07-25 09:25:50 -03:00 committed by GitHub
parent daef133982
commit feeb0b50f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -23,9 +23,7 @@ struct TreeNodeElementId {
}
struct hash_fn {
std::size_t operator()(const TreeNodeElementId& key) const {
std::size_t h1 = std::hash<int64_t>()(key.tree_id);
std::size_t h2 = std::hash<int64_t>()(key.node_id);
return h1 ^ h2;
return static_cast<std::size_t>(static_cast<uint64_t>(key.tree_id) << 32 | static_cast<uint64_t>(key.node_id));
}
};
};