diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index fd79d6f5b85..cd36952819a 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4782,6 +4782,33 @@ class CPUReproTests(TestCase): code ) + @config.patch(freezing=True) + def test_add_layernorm(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.dense = torch.nn.Linear(768, 768) + self.layernorm = torch.nn.LayerNorm(768, eps=1e-12) + + def forward(self, context_layer, hidden_states): + attention_output = self.dense(context_layer) + hidden_states = attention_output + hidden_states + layer_output = self.layernorm(hidden_states) + return layer_output + + model = Model() + example_batch = (torch.rand(1, 197, 768), torch.rand(1, 197, 768)) + from torch.testing._internal.common_quantization import ( + _generate_qdq_quantized_model, + ) + + with torch.no_grad(): + converted_model = _generate_qdq_quantized_model(model, example_batch) + torch.ao.quantization.move_exported_model_to_eval(converted_model) + metrics.reset() + torch.compile(converted_model)(*example_batch) + check_metrics_vec_kernel_count(3) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index c742c1c86a4..2acf40ec12c 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -812,6 +812,16 @@ class GraphLowering(torch.fx.Interpreter): def get_dtype(self, buffer_name: str) -> torch.dtype: if buffer_name in self.constants: return self.constants[buffer_name].dtype + # For a mutation op we should return the dtype of the buffer being mutated + if ( + hasattr(self.scheduler, "mutation_real_name") + and buffer_name in self.scheduler.mutation_real_name + ): + mutated_buf = self.scheduler.mutation_real_name[buffer_name] + if mutated_buf in self.name_to_buffer: + return self.name_to_buffer[mutated_buf].get_dtype() + if mutated_buf in self.graph_inputs: + return self.graph_inputs[mutated_buf].get_dtype() if buffer_name in self.name_to_buffer: return self.name_to_buffer[buffer_name].get_dtype() if buffer_name in self.graph_inputs: