diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index f50b18b736..84631bd1f6 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -246,7 +246,7 @@ to standard outputs. #### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input +- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the embedding input data sparsity based performance optimizations. ```bash diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index fda6e345da..e189ffff9c 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -681,11 +681,15 @@ class GraphExecutionManager(GraphExecutionInterface): ) if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0: - graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) - self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) - self._runtime_options.embed_sparsity_ratio = ",".join( - [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] - ) + if detected_device.type == "cuda": + # Embedding sparsity optimization is only supported on CUDA devices. + graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) + self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) + self._runtime_options.embed_sparsity_ratio = ",".join( + [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] + ) + else: + self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.") # If users don't want to print input density, disable the input density observer to avoid overhead # when looping through inputs during training. diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 539859a0d5..93d24a34df 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -271,7 +271,7 @@ class _RuntimeOptions: self.enable_sparse_optimizer = True self.label_sparsity_ratio = "" self.embed_sparsity_ratio = "" - self.enable_embedding_sparse_optimizer = False # TODO(pengwa): remove once validation on more models are done. + self.enable_embedding_sparse_optimizer = True # Configuration for memory optimization. self.memory_optimization_level = (