Support MatMulBnb4 in PaddingElimination (#18646)

Also support Cast pattern between input and embedding node for sparsity
inspecting
This commit is contained in:
guyang3532 2023-12-01 19:27:50 +08:00 committed by GitHub
parent ccfea55942
commit 182c525416
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 10 deletions

View file

@ -282,7 +282,8 @@ void IterateSubgraphFromNode(Graph& graph,
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end());
subgraph.insert(cur->MutableOutputDefs()[0]);
PushAllOutputNode(graph, to_visit, cur, visited);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMulBnb4", {1}, kMSDomain)) {
if (subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()) {
// If shape of [batch_size, seqlen, ...] is propagated from the first argument of MatMul.
// The dim size of the first argument must be larger than 2 to propagate the first two dims to the output.

View file

@ -157,12 +157,7 @@ class InputDensityObserver:
self._embedding_graph_input_to_padding_idx_map.clear()
for node in model.graph.node:
if not (
node.domain == "org.pytorch.aten"
and node.op_type == "ATen"
and node.input[1] in user_input_names
and len(node.input) >= 3
):
if not (node.domain == "org.pytorch.aten" and node.op_type == "ATen" and len(node.input) >= 3):
continue
found = [attr for attr in node.attribute if attr.name == "operator"]
@ -194,10 +189,29 @@ class InputDensityObserver:
if padding_idx < 0:
continue
if node.input[1] not in self._embedding_graph_input_to_padding_idx_map:
self._embedding_graph_input_to_padding_idx_map[node.input[1]] = set()
# Given the input arg of embedding node, find the corresponding user input that feeds into the data.
# Will iterate the args recursively if some subgraph pattern is found between the input and the embedding,
# such as Input -> Cast -> Cast -> Embedding.
# TODO: This is a workaround for the case that the input of embedding is a list of Cast nodes which is found
# in Llama-2. We need to find a general way to handle all types of subgraph parttern between input and embedding.
def _get_embedding_graph_input(node_arg):
if node_arg in user_input_names:
return node_arg
input_node = self._try_get_node_from_its_output(node_arg)
if input_node.op_type == "Cast":
return _get_embedding_graph_input(input_node.input[0])
else:
self._logger.warning(f"Cannot find embedding input {node_arg}")
return None
self._embedding_graph_input_to_padding_idx_map[node.input[1]].add(padding_idx)
embedding_graph_input = _get_embedding_graph_input(node.input[1])
if embedding_graph_input is None:
continue
if embedding_graph_input not in self._embedding_graph_input_to_padding_idx_map:
self._embedding_graph_input_to_padding_idx_map[embedding_graph_input] = set()
self._embedding_graph_input_to_padding_idx_map[embedding_graph_input].add(padding_idx)
def _initialize_loss_label_padding_inspector(self, model, user_input_names):
"""Register loss label input padding inspector.