diff --git a/orttraining/orttraining/python/training/api/module.py b/orttraining/orttraining/python/training/api/module.py index 433e190924..97194e4474 100644 --- a/orttraining/orttraining/python/training/api/module.py +++ b/orttraining/orttraining/python/training/api/module.py @@ -8,6 +8,8 @@ from onnxruntime.capi import _pybind_state as C from onnxruntime.capi.onnxruntime_inference_collection import OrtValue, get_ort_device_type from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector +from typing import List + class Module: """ @@ -125,9 +127,12 @@ class Module: """ self._model.copy_buffer_to_parameters(buffer) - def export_model_for_inferencing(self, inference_model_uri: str, graph_output_names: list[str]) -> None: - """ - Exports the model for inferencing. + def export_model_for_inferencing(self, inference_model_uri: str, graph_output_names: List[str]) -> None: + """Exports the model for inferencing. + + Once training is complete, this function can be used to drop the training specific nodes in the onnx model. + In particular, this function does the following: + - Parse over the training graph and identify nodes that generate the given output names. + - Drop all subsequent nodes in the graph since they are not relevant to the inference graph. """ self._model.export_model_for_inferencing(inference_model_uri, graph_output_names) - self._model.export_model_for_inferencing(inference_model_uri, graph_output_names)