mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
fix embedding sparsity log bug of -1% density (#20420)
### Description When not checked valid embedding sparsity, the log print a wrong info of "-1% density", this pr is to fix it.
This commit is contained in:
parent
ed6f1adcb8
commit
ffb9c8d598
2 changed files with 12 additions and 4 deletions
|
|
@ -706,11 +706,12 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
valid_token = torch.count_nonzero(ebd_input - module.padding_idx)
|
||||
total_token = ebd_input.numel()
|
||||
embed_density = float(valid_token) / float(total_token) * 100
|
||||
if module not in self._runtime_inspector._embedding_module_to_padding_density_map:
|
||||
self._logger.warning("Found Embedding module not in the map. %s", module)
|
||||
return None
|
||||
|
||||
if embed_density < 90:
|
||||
self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density)
|
||||
if module not in self._runtime_inspector._embedding_module_to_padding_density_map:
|
||||
self._logger.warning("Found Embedding module not in the map. %s", module)
|
||||
return None
|
||||
if self._runtime_inspector._embedding_module_to_padding_density_map[module][1] != -1:
|
||||
self._logger.warning(
|
||||
"Found duplicate Embedding module. %s",
|
||||
|
|
@ -794,6 +795,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
[
|
||||
f"{v[0]}:{v[1]:.0f}%"
|
||||
for v in self._runtime_inspector._embedding_module_to_padding_density_map.values()
|
||||
if v[1] != -1
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5776,11 +5776,15 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l
|
|||
_ = run_step(ort_model, input, label)
|
||||
|
||||
found_embed_is_sparse = False
|
||||
found_embed_is_dense = False
|
||||
found_label_is_sparse = False
|
||||
for record in caplog.records:
|
||||
if "Label sparsity-based optimization is ON for" in record.getMessage():
|
||||
found_label_is_sparse = True
|
||||
|
||||
if "Embedding sparsity-based optimization is OFF for" in record.getMessage():
|
||||
found_embed_is_dense = True
|
||||
|
||||
if "Embedding sparsity-based optimization is ON for" in record.getMessage():
|
||||
found_embed_is_sparse = True
|
||||
|
||||
|
|
@ -5788,7 +5792,9 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l
|
|||
assert found_label_is_sparse
|
||||
|
||||
if embed_is_sparse:
|
||||
assert found_embed_is_sparse
|
||||
assert found_embed_is_sparse and not found_embed_is_dense
|
||||
else:
|
||||
assert not found_embed_is_sparse and found_embed_is_dense
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
Loading…
Reference in a new issue