mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Reduce GPU memory for Whisper models converted to ONNX (#17378)
### Description This PR changes the Whisper export scripts to further optimize the process of removing duplicate initializers from two subgraphs. The current Greedy approach is quicker by a large factor, but results in some duplicate initializers not being caught and removed. This not only results in a slightly larger Whisper model, but also a model that uses more GPU memory. The approach in this PR uses data hashes and caches to keep a quick export but no longer rely on a greedy approach. --------- Co-authored-by: Peter McAughan <petermca@microsoft.com>
This commit is contained in:
parent
dbcc60bed5
commit
fa28359beb
3 changed files with 70 additions and 21 deletions
|
|
@ -883,7 +883,8 @@ def remove_shared_initializers(
|
|||
graph2: GraphProto,
|
||||
shared_prefix: str = "shared_",
|
||||
min_elements: int = 1024,
|
||||
require_raw_data: bool = False,
|
||||
signature_cache1: Optional[dict] = None,
|
||||
signature_cache2: Optional[dict] = None,
|
||||
):
|
||||
"""Remove initializers with same value from two graphs.
|
||||
|
||||
|
|
@ -892,7 +893,8 @@ def remove_shared_initializers(
|
|||
graph2 (GraphProto): the second graph to process
|
||||
shared_prefix (str): add prefix to the shared initializers among two graphs
|
||||
min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024.
|
||||
require_raw_data (bool, optional): Only remove tensors with raw_data field to speed up method
|
||||
signature_cache1 (dict): Optional dictionary to store data signatures of tensors in graph1 in order to speed up comparison
|
||||
signature_cache2 (dict): Optional dictionary to store data signatures of tensors in graph2 in order to speed up comparison
|
||||
"""
|
||||
|
||||
mapping_initializers_1 = {}
|
||||
|
|
@ -909,7 +911,7 @@ def remove_shared_initializers(
|
|||
if not (initializer2.dims and sum(initializer2.dims) >= min_elements):
|
||||
continue
|
||||
|
||||
if OnnxModel.has_same_value(initializer1, initializer2, require_raw_data=True):
|
||||
if OnnxModel.has_same_value(initializer1, initializer2, signature_cache1, signature_cache2):
|
||||
mapping_initializers_1[initializer1.name] = shared_prefix + initializer2.name
|
||||
shared_initializers_1.append(initializer1)
|
||||
|
||||
|
|
@ -982,14 +984,17 @@ def remove_shared_initializers(
|
|||
return shared_initializers_2
|
||||
|
||||
|
||||
def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto, require_raw_data: bool = False):
|
||||
def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto):
|
||||
encoder = OnnxModel(encoder_model)
|
||||
decoder = OnnxModel(decoder_model)
|
||||
encoder.add_prefix_to_names("e_")
|
||||
decoder.add_prefix_to_names("d_")
|
||||
encoder.remove_duplicated_initializer(require_raw_data)
|
||||
decoder.remove_duplicated_initializer(require_raw_data)
|
||||
initializers = remove_shared_initializers(decoder.model.graph, encoder.model.graph, "s_", require_raw_data)
|
||||
signature_cache1, signature_cache2 = {}, {}
|
||||
encoder.remove_duplicated_initializer(signature_cache1)
|
||||
decoder.remove_duplicated_initializer(signature_cache2)
|
||||
initializers = remove_shared_initializers(
|
||||
decoder.model.graph, encoder.model.graph, "s_", signature_cache1, signature_cache2
|
||||
)
|
||||
return initializers
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ def chain_model(args):
|
|||
|
||||
# Initializers/opsets
|
||||
# Delete shared data between decoder/encoder and move to larger graph initializers
|
||||
initializers = get_shared_initializers(encoder_model, decoder_model, require_raw_data=True)
|
||||
initializers = get_shared_initializers(encoder_model, decoder_model)
|
||||
node.attribute.extend(
|
||||
[
|
||||
helper.make_attribute("decoder", decoder_model.graph),
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from onnx import (
|
|||
numpy_helper,
|
||||
save_model,
|
||||
)
|
||||
from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
|
||||
from shape_infer_helper import SymbolicShapeInferenceHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -1091,29 +1092,72 @@ class OnnxModel:
|
|||
return op_count
|
||||
|
||||
@staticmethod
|
||||
def has_same_value(tensor1: TensorProto, tensor2: TensorProto, require_raw_data: bool = False) -> bool:
|
||||
def to_data_hash(tensor: TensorProto, base_dir: str = "") -> int:
|
||||
"""Converts a tensor def object to a hash for data comparison purposes.
|
||||
Args:
|
||||
tensor: a TensorProto object.
|
||||
base_dir: if external tensor exists, base_dir can help to find the path to it
|
||||
Returns:
|
||||
hash: a hash of the data.
|
||||
"""
|
||||
if tensor.HasField("segment"):
|
||||
raise ValueError("Currently not supporting loading segments.")
|
||||
if tensor.data_type == TensorProto.UNDEFINED:
|
||||
raise TypeError("The element type in the input tensor is not defined.")
|
||||
tensor_dtype = tensor.data_type
|
||||
storage_field = helper.tensor_dtype_to_field(tensor_dtype)
|
||||
|
||||
if tensor.data_type == TensorProto.STRING:
|
||||
utf8_strings = getattr(tensor, storage_field)
|
||||
return hash(tuple(s.decode("utf-8") for s in utf8_strings))
|
||||
# Load raw data from external tensor if it exists
|
||||
if uses_external_data(tensor):
|
||||
load_external_data_for_tensor(tensor, base_dir)
|
||||
if tensor.HasField("raw_data"):
|
||||
return hash(tensor.raw_data)
|
||||
else:
|
||||
np_data = numpy_helper.to_array(tensor)
|
||||
return hash(np_data.tobytes())
|
||||
|
||||
@staticmethod
|
||||
def has_same_value(
|
||||
tensor1: TensorProto,
|
||||
tensor2: TensorProto,
|
||||
signature_cache1: Optional[dict] = None,
|
||||
signature_cache2: Optional[dict] = None,
|
||||
) -> bool:
|
||||
"""Returns True when two tensors have same value.
|
||||
Note that name can be different.
|
||||
|
||||
Args:
|
||||
tensor1 (TensorProto): initializer 1
|
||||
tensor2 (TensorProto): initializer 2
|
||||
require_raw_data (bool): ignore tensors without raw_data
|
||||
Note: Flag can speed up runtime significantly
|
||||
|
||||
signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison.
|
||||
signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison.
|
||||
Returns:
|
||||
bool: True when two intializers has same value.
|
||||
"""
|
||||
if tensor1.data_type != tensor2.data_type or tensor1.dims != tensor2.dims:
|
||||
return False
|
||||
if tensor1.HasField("raw_data") and tensor2.HasField("raw_data"):
|
||||
return tensor1.raw_data == tensor2.raw_data
|
||||
if require_raw_data:
|
||||
return False
|
||||
sig1 = (
|
||||
signature_cache1[tensor1.name]
|
||||
if signature_cache1 and tensor1.name in signature_cache1
|
||||
else OnnxModel.to_data_hash(tensor1)
|
||||
)
|
||||
sig2 = (
|
||||
signature_cache2[tensor2.name]
|
||||
if signature_cache2 and tensor2.name in signature_cache2
|
||||
else OnnxModel.to_data_hash(tensor2)
|
||||
)
|
||||
if signature_cache1 is not None:
|
||||
signature_cache1[tensor1.name] = sig1
|
||||
if signature_cache2 is not None:
|
||||
signature_cache2[tensor2.name] = sig2
|
||||
if sig1 == sig2 and tensor1.data_type == tensor2.data_type and tensor1.dims == tensor2.dims:
|
||||
# Same signature, now do the expensive check to confirm the data is the same
|
||||
return (numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)).all()
|
||||
|
||||
return (numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)).all()
|
||||
return False
|
||||
|
||||
def remove_duplicated_initializer(self, require_raw_data: bool = False):
|
||||
def remove_duplicated_initializer(self, cache: Optional[dict] = None):
|
||||
"""Remove initializers with duplicated values, and only keep the first one.
|
||||
It could help reduce size of models (like ALBert) with shared weights.
|
||||
If require_raw_data passed, method will only compare raw_data initializers to speed runtime
|
||||
|
|
@ -1130,7 +1174,7 @@ class OnnxModel:
|
|||
continue
|
||||
for j in range(i + 1, initializer_count):
|
||||
if OnnxModel.has_same_value(
|
||||
self.model.graph.initializer[i], self.model.graph.initializer[j], require_raw_data
|
||||
self.model.graph.initializer[i], self.model.graph.initializer[j], cache, cache
|
||||
):
|
||||
same[j] = i
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue