diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 8a493ed87a..1b786d514f 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -706,11 +706,12 @@ class GraphExecutionManager(GraphExecutionInterface): valid_token = torch.count_nonzero(ebd_input - module.padding_idx) total_token = ebd_input.numel() embed_density = float(valid_token) / float(total_token) * 100 + if module not in self._runtime_inspector._embedding_module_to_padding_density_map: + self._logger.warning("Found Embedding module not in the map. %s", module) + return None + if embed_density < 90: self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density) - if module not in self._runtime_inspector._embedding_module_to_padding_density_map: - self._logger.warning("Found Embedding module not in the map. %s", module) - return None if self._runtime_inspector._embedding_module_to_padding_density_map[module][1] != -1: self._logger.warning( "Found duplicate Embedding module. %s", @@ -794,6 +795,7 @@ class GraphExecutionManager(GraphExecutionInterface): [ f"{v[0]}:{v[1]:.0f}%" for v in self._runtime_inspector._embedding_module_to_padding_density_map.values() + if v[1] != -1 ] ) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index e231579887..0839f957c2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5776,11 +5776,15 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l _ = run_step(ort_model, input, label) found_embed_is_sparse = False + found_embed_is_dense = False found_label_is_sparse = False for record in caplog.records: if "Label sparsity-based optimization is ON for" in record.getMessage(): found_label_is_sparse = True + if "Embedding sparsity-based optimization is OFF for" in record.getMessage(): + found_embed_is_dense = True + if "Embedding sparsity-based optimization is ON for" in record.getMessage(): found_embed_is_sparse = True @@ -5788,7 +5792,9 @@ def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, l assert found_label_is_sparse if embed_is_sparse: - assert found_embed_is_sparse + assert found_embed_is_sparse and not found_embed_is_dense + else: + assert not found_embed_is_sparse and found_embed_is_dense @pytest.mark.parametrize(