Update typing hints to support python 3.8 for training apis (#14649)

This commit is contained in:
Baiju Meswani 2023-02-13 09:52:05 -08:00 committed by GitHub
parent 326cf2f5e9
commit 22de2798f2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -8,6 +8,8 @@ from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue, get_ort_device_type
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
from typing import List
class Module:
"""
@ -125,9 +127,12 @@ class Module:
"""
self._model.copy_buffer_to_parameters(buffer)
def export_model_for_inferencing(self, inference_model_uri: str, graph_output_names: list[str]) -> None:
"""
Exports the model for inferencing.
def export_model_for_inferencing(self, inference_model_uri: str, graph_output_names: List[str]) -> None:
"""Exports the model for inferencing.
Once training is complete, this function can be used to drop the training specific nodes in the onnx model.
In particular, this function does the following:
- Parse over the training graph and identify nodes that generate the given output names.
- Drop all subsequent nodes in the graph since they are not relevant to the inference graph.
"""
self._model.export_model_for_inferencing(inference_model_uri, graph_output_names)
self._model.export_model_for_inferencing(inference_model_uri, graph_output_names)