fix embedding sparsity log bug of -1% density (#20420)

### Description
When not checked valid embedding sparsity, the log print a wrong info of
"-1% density", this pr is to fix it.
This commit is contained in:
guyang3532 2024-04-23 20:37:50 +08:00 committed by GitHub
parent ed6f1adcb8
commit ffb9c8d598
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View file

@ -706,11 +706,12 @@ class GraphExecutionManager(GraphExecutionInterface):
valid_token = torch.count_nonzero(ebd_input - module.padding_idx)
total_token = ebd_input.numel()
embed_density = float(valid_token) / float(total_token) * 100
if module not in self._runtime_inspector._embedding_module_to_padding_density_map:
self._logger.warning("Found Embedding module not in the map. %s", module)
return None
if embed_density < 90:
self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density)
if module not in self._runtime_inspector._embedding_module_to_padding_density_map:
self._logger.warning("Found Embedding module not in the map. %s", module)
return None
if self._runtime_inspector._embedding_module_to_padding_density_map[module][1] != -1:
self._logger.warning(
"Found duplicate Embedding module. %s",
@ -794,6 +795,7 @@ class GraphExecutionManager(GraphExecutionInterface):
[
f"{v[0]}:{v[1]:.0f}%"
for v in self._runtime_inspector._embedding_module_to_padding_density_map.values()
if v[1] != -1
]
)

View file

@ -5776,11 +5776,15 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l
_ = run_step(ort_model, input, label)
found_embed_is_sparse = False
found_embed_is_dense = False
found_label_is_sparse = False
for record in caplog.records:
if "Label sparsity-based optimization is ON for" in record.getMessage():
found_label_is_sparse = True
if "Embedding sparsity-based optimization is OFF for" in record.getMessage():
found_embed_is_dense = True
if "Embedding sparsity-based optimization is ON for" in record.getMessage():
found_embed_is_sparse = True
@ -5788,7 +5792,9 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l
assert found_label_is_sparse
if embed_is_sparse:
assert found_embed_is_sparse
assert found_embed_is_sparse and not found_embed_is_dense
else:
assert not found_embed_is_sparse and found_embed_is_dense
@pytest.mark.parametrize(