mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Update typing hints to support python 3.8 for training apis (#14649)
This commit is contained in:
parent
326cf2f5e9
commit
22de2798f2
1 changed files with 9 additions and 4 deletions
|
|
@ -8,6 +8,8 @@ from onnxruntime.capi import _pybind_state as C
|
|||
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue, get_ort_device_type
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
class Module:
|
||||
"""
|
||||
|
|
@ -125,9 +127,12 @@ class Module:
|
|||
"""
|
||||
self._model.copy_buffer_to_parameters(buffer)
|
||||
|
||||
def export_model_for_inferencing(self, inference_model_uri: str, graph_output_names: list[str]) -> None:
|
||||
"""
|
||||
Exports the model for inferencing.
|
||||
def export_model_for_inferencing(self, inference_model_uri: str, graph_output_names: List[str]) -> None:
|
||||
"""Exports the model for inferencing.
|
||||
|
||||
Once training is complete, this function can be used to drop the training specific nodes in the onnx model.
|
||||
In particular, this function does the following:
|
||||
- Parse over the training graph and identify nodes that generate the given output names.
|
||||
- Drop all subsequent nodes in the graph since they are not relevant to the inference graph.
|
||||
"""
|
||||
self._model.export_model_for_inferencing(inference_model_uri, graph_output_names)
|
||||
self._model.export_model_for_inferencing(inference_model_uri, graph_output_names)
|
||||
|
|
|
|||
Loading…
Reference in a new issue