[fx] move DCE rand check to import time (#145118)

Mitigates the deterministic benchmark regression: https://github.com/pytorch/pytorch/issues/144775#issuecomment-2593411844. and maybe the dashboard issue.

fx.Node.is_impure is unexpectedly a hot spot. It gets called for every node in the graph whenever we invoke DCE, which should be okay, EXCEPT we invoke DCE on the full graph ~10 times at various stages of torch.compile, and an insane number of times (>O(parameters)) for the subgraphs traced by the pattern matcher.

I considered addressing this problem by reducing the amount of times DCE is called, but I think we can only trim the ones from the pattern matcher, which will require some refactor/caching solution that I leave out of this PR.

torch.Tag.nondeterministic_seeded is provided by native_functions.yml and is implemented as a list. Most of the time, it has <=2 elements, so it's not really worth it to turn it into a set for fast lookup.

Using the deterministic instruction count benchmarks
```python
# before
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8914894946
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8866669058
# after
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8770562314
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8779547794
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145118
Approved by: https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
Simon Fan 2025-01-21 14:03:04 -08:00 committed by PyTorch MergeBot
parent f2cfe8b59f
commit 27598cd154
3 changed files with 8 additions and 7 deletions

View file

@ -6,7 +6,7 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
add_loop_inductor,compile_time_instruction_count,32440000000,0.015
add_loop_inductor,compile_time_instruction_count,32120000000,0.015
@ -14,7 +14,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,45210000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,27740000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,27360000000,0.015
@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,928600000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,21760000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,21310000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17810000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17600000000,0.015
@ -54,7 +54,7 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5764000000,0
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9203000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9103000000,0.015

1 add_loop_eager compile_time_instruction_count 3066000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 928600000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 21760000000 21310000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17810000000 17600000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10885050825 0.2
10 update_hint_regression compile_time_instruction_count 1686000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1031000000 0.015
12 symint_sum compile_time_instruction_count 3324000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5764000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 9203000000 9103000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3863000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10340000000 0.015
18
19
20
22
23
24
25
26
27
28
29
30
31
32
54
55
56
57
58
59
60

View file

@ -688,6 +688,8 @@ class OpOverload(OperatorBase):
self._overloadname = (
"default" if schema.overload_name == "" else schema.overload_name
)
if tags:
self._nondeterministic_seeded = torch.Tag.nondeterministic_seeded in tags
self._name = self._schema.name
if schema.overload_name:
self._name += "." + schema.overload_name

View file

@ -764,8 +764,7 @@ class Node(_NodeBase):
# impure since it mutates inputs
return True
tags: Optional[list[torch.Tag]] = getattr(self.target, "_tags", None)
if tags is not None and torch.Tag.nondeterministic_seeded in tags:
if getattr(self.target, "_nondeterministic_seeded", False):
# impure since it mutates RNG state
return True