mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Support MatMulBnb4 in PaddingElimination (#18646)
Also support Cast pattern between input and embedding node for sparsity inspecting
This commit is contained in:
parent
ccfea55942
commit
182c525416
2 changed files with 25 additions and 10 deletions
|
|
@ -282,7 +282,8 @@ void IterateSubgraphFromNode(Graph& graph,
|
|||
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end());
|
||||
subgraph.insert(cur->MutableOutputDefs()[0]);
|
||||
PushAllOutputNode(graph, to_visit, cur, visited);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13})) {
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMulBnb4", {1}, kMSDomain)) {
|
||||
if (subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()) {
|
||||
// If shape of [batch_size, seqlen, ...] is propagated from the first argument of MatMul.
|
||||
// The dim size of the first argument must be larger than 2 to propagate the first two dims to the output.
|
||||
|
|
|
|||
|
|
@ -157,12 +157,7 @@ class InputDensityObserver:
|
|||
self._embedding_graph_input_to_padding_idx_map.clear()
|
||||
|
||||
for node in model.graph.node:
|
||||
if not (
|
||||
node.domain == "org.pytorch.aten"
|
||||
and node.op_type == "ATen"
|
||||
and node.input[1] in user_input_names
|
||||
and len(node.input) >= 3
|
||||
):
|
||||
if not (node.domain == "org.pytorch.aten" and node.op_type == "ATen" and len(node.input) >= 3):
|
||||
continue
|
||||
|
||||
found = [attr for attr in node.attribute if attr.name == "operator"]
|
||||
|
|
@ -194,10 +189,29 @@ class InputDensityObserver:
|
|||
if padding_idx < 0:
|
||||
continue
|
||||
|
||||
if node.input[1] not in self._embedding_graph_input_to_padding_idx_map:
|
||||
self._embedding_graph_input_to_padding_idx_map[node.input[1]] = set()
|
||||
# Given the input arg of embedding node, find the corresponding user input that feeds into the data.
|
||||
# Will iterate the args recursively if some subgraph pattern is found between the input and the embedding,
|
||||
# such as Input -> Cast -> Cast -> Embedding.
|
||||
# TODO: This is a workaround for the case that the input of embedding is a list of Cast nodes which is found
|
||||
# in Llama-2. We need to find a general way to handle all types of subgraph parttern between input and embedding.
|
||||
def _get_embedding_graph_input(node_arg):
|
||||
if node_arg in user_input_names:
|
||||
return node_arg
|
||||
input_node = self._try_get_node_from_its_output(node_arg)
|
||||
if input_node.op_type == "Cast":
|
||||
return _get_embedding_graph_input(input_node.input[0])
|
||||
else:
|
||||
self._logger.warning(f"Cannot find embedding input {node_arg}")
|
||||
return None
|
||||
|
||||
self._embedding_graph_input_to_padding_idx_map[node.input[1]].add(padding_idx)
|
||||
embedding_graph_input = _get_embedding_graph_input(node.input[1])
|
||||
if embedding_graph_input is None:
|
||||
continue
|
||||
|
||||
if embedding_graph_input not in self._embedding_graph_input_to_padding_idx_map:
|
||||
self._embedding_graph_input_to_padding_idx_map[embedding_graph_input] = set()
|
||||
|
||||
self._embedding_graph_input_to_padding_idx_map[embedding_graph_input].add(padding_idx)
|
||||
|
||||
def _initialize_loss_label_padding_inspector(self, model, user_input_names):
|
||||
"""Register loss label input padding inspector.
|
||||
|
|
|
|||
Loading…
Reference in a new issue