mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Always unspecialize float in OSS (#138922)"
This reverts commit ba5253da9b.
Reverted https://github.com/pytorch/pytorch/pull/138922 on behalf of https://github.com/yf225 due to perf regression on torchbench ([comment](https://github.com/pytorch/pytorch/pull/138922#issuecomment-2499277511))
This commit is contained in:
parent
964655bf0c
commit
ad37afd590
14 changed files with 75 additions and 156 deletions
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
39e9d3084686b291546cbfdbfc3e34f53659783d
|
||||
2ec22641e390cda25ec7c61fcbce07507727d584
|
||||
|
|
|
|||
|
|
@ -360,8 +360,8 @@ class GraphModule(torch.nn.Module):
|
|||
actual_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x, y)),
|
||||
ifdynstaticdefault(3, 4),
|
||||
expected_opcount=3,
|
||||
ifdynstaticdefault(2, 3),
|
||||
expected_opcount=2,
|
||||
return_graph=True,
|
||||
)
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
|
|
@ -369,20 +369,18 @@ class GraphModule(torch.nn.Module):
|
|||
actual_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[3, 1]", L_y_: "f64[]"):
|
||||
def forward(self, L_x_: "f32[3, 1]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
item: "Sym(zf0)" = l_y_.item(); l_y_ = None
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, item); wrap_body_0 = l_x_ = item = None
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3, 1]", item: "Sym(zf0)"):
|
||||
def forward(self, l_x_: "f32[3, 1]"):
|
||||
view: "f32[3]" = l_x_.view(3); l_x_ = None
|
||||
add: "f32[3]" = view + item; view = item = None
|
||||
add: "f32[3]" = view + 0.5; view = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
|
@ -391,20 +389,18 @@ class GraphModule(torch.nn.Module):
|
|||
actual_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]", L_y_: "f64[]"):
|
||||
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
item: "Sym(zf1)" = l_y_.item(); l_y_ = None
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None
|
||||
getitem: "f32[s0]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]", item: "Sym(zf1)"):
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"):
|
||||
view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None
|
||||
add: "f32[s0]" = view + item; view = item = None
|
||||
add: "f32[s0]" = view + 0.5; view = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
|
@ -2409,7 +2405,7 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
x = torch.zeros([])
|
||||
# Numbers don't get lifted, so args is still 2.
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), 3)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), 2)
|
||||
|
||||
def test_capture_global_num_adds_guard(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
|
|
@ -2432,7 +2428,7 @@ class GraphModule(torch.nn.Module):
|
|||
x = torch.zeros([])
|
||||
y = 3.14
|
||||
# Numbers don't get lifted, so args is still 2.
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 3, expected_opcount=3)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 2)
|
||||
|
||||
def test_side_effect_in_body(self):
|
||||
counters.clear()
|
||||
|
|
|
|||
|
|
@ -628,11 +628,11 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
|||
|
||||
def inner(x, ys, zs):
|
||||
for y, z in zip(ys, zs):
|
||||
x += y * (3.0 if z else 3.2)
|
||||
x += y * z
|
||||
return x
|
||||
|
||||
ys = [1.0, 2.0]
|
||||
zs = [True]
|
||||
zs = [3.0]
|
||||
x = torch.tensor([1.0])
|
||||
|
||||
fn_opt = torch._dynamo.optimize("eager")(fn)
|
||||
|
|
@ -641,9 +641,8 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
|||
|
||||
record_str = "\n".join(r.getMessage() for r in records)
|
||||
|
||||
# TODO: this is a very sensitive test
|
||||
self.assertIn(
|
||||
f"___check_obj_id(L['zs'][0], {id(True)})",
|
||||
"""L['zs'][0] == 3.0""",
|
||||
record_str,
|
||||
)
|
||||
self.assertIn(
|
||||
|
|
|
|||
|
|
@ -1396,7 +1396,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
cfg2.val = 2.0
|
||||
v = opt_fn(v, cfg2) # 7
|
||||
self.assertEqual(v[0], 7)
|
||||
self.assertEqual(cnts.op_count, 6)
|
||||
self.assertEqual(cnts.op_count, 8)
|
||||
|
||||
def test_config_getattr_default(self):
|
||||
class Cfg:
|
||||
|
|
@ -1491,7 +1491,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
self.assertEqual(opt_fn_ret(1.5)[0], -459)
|
||||
self.assertEqual(out[0], 2100)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 9)
|
||||
self.assertEqual(cnts.op_count, 7)
|
||||
|
||||
def test_tensor_dict1(self):
|
||||
def fn(inputs):
|
||||
|
|
@ -3717,7 +3717,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
self.assertAlmostEqual(cell1 + 1, result1)
|
||||
self.assertTrue(torch.allclose(cell2 + 3, result2))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
self.assertEqual(cnts.op_count, 1)
|
||||
|
||||
def test_closure_out_of_scope_cell_with_mutation(self):
|
||||
cell1 = torch.rand(1).item()
|
||||
|
|
@ -3745,12 +3745,8 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
result1, result2, _ = opt_fn()
|
||||
self.assertAlmostEqual(orig1 + 1 * i, result1)
|
||||
self.assertTrue(torch.allclose(orig2 + 10 * i, result2))
|
||||
if i == 1:
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 6)
|
||||
else:
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
self.assertEqual(cnts.op_count, 0)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 3)
|
||||
cnts.clear()
|
||||
|
||||
def test_closure_with_mutation_and_graph_break(self):
|
||||
|
|
|
|||
|
|
@ -1242,13 +1242,13 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
with torch.no_grad():
|
||||
cnt = self._reformer(nopython=True)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 13)
|
||||
self.assertEqual(cnt.op_count, 11)
|
||||
|
||||
def test_reformer_train(self):
|
||||
with torch.enable_grad():
|
||||
cnt = self._reformer(nopython=False)
|
||||
expected_op_count = (
|
||||
"""13""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5"""
|
||||
"""11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5"""
|
||||
)
|
||||
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
|
|
@ -1725,7 +1725,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
opt_model(inp)
|
||||
opt_model(inp)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(18, cnt.op_count)
|
||||
self.assertEqual(12, cnt.op_count)
|
||||
|
||||
def test_exec_import(self):
|
||||
def fn1():
|
||||
|
|
|
|||
|
|
@ -674,13 +674,7 @@ class StructuredTraceTest(TestCase):
|
|||
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_storage": {"id": 1, "describer_id": "ID", "size": 8}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_tensor": {"id": 1, "ndim": 0, "dtype": "torch.float64", "device": "device(type='cpu')", "size": [], "is_leaf": true, "stride": [], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 1, "source": "___as_tensor(L['ys'][0])"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_storage": {"id": 2, "describer_id": "ID", "size": 8}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_tensor": {"id": 2, "ndim": 0, "dtype": "torch.float64", "device": "device(type='cpu')", "size": [], "is_leaf": true, "stride": [], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 2, "source": "___as_tensor(L['zs'][0])"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "l_ys_0_": [], "l_zs_0_": [], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
@ -688,13 +682,7 @@ class StructuredTraceTest(TestCase):
|
|||
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_storage": {"id": 1, "describer_id": "ID", "size": 8}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_tensor": {"id": 1, "ndim": 0, "dtype": "torch.float64", "device": "device(type='cpu')", "size": [], "is_leaf": true, "stride": [], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 1, "source": "___as_tensor(L['ys'][0])"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_storage": {"id": 2, "describer_id": "ID", "size": 8}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_tensor": {"id": 2, "ndim": 0, "dtype": "torch.float64", "device": "device(type='cpu')", "size": [], "is_leaf": true, "stride": [], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 2, "source": "___as_tensor(L['zs'][0])"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "l_ys_0_": [], "l_zs_0_": [], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
|
||||
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
|
||||
""", # noqa: B950
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ class TestDynamoTimed(TestCase):
|
|||
'runtime_cudagraphify_time_us': None,
|
||||
'runtime_triton_autotune_time_us': None,
|
||||
'shape_env_guard_count': 0,
|
||||
'specialize_float': False,
|
||||
'specialize_float': True,
|
||||
'start_time': 0.0001,
|
||||
'start_time_us': 100,
|
||||
'structured_logging_overhead_s': 0.0,
|
||||
|
|
|
|||
|
|
@ -114,115 +114,67 @@ class KernelCounts(NamedTuple):
|
|||
# This maps the test name to the
|
||||
# expected kernel count
|
||||
KERNEL_COUNT_OVERRIDES = {
|
||||
"test_adadelta_cpu": 6,
|
||||
"test_adadelta_foreach_rho_weight_decay_cpu": 12,
|
||||
"test_adadelta_foreach_weight_decay_cpu": 12,
|
||||
"test_adadelta_foreach_weight_decay_maximize_cpu": 12,
|
||||
"test_adadelta_maximize_cpu": 6,
|
||||
"test_adadelta_rho_weight_decay_cpu": 6,
|
||||
"test_adadelta_tensor_lr_capturable_cuda": 6,
|
||||
"test_adadelta_tensor_lr_capturable_xpu": 6,
|
||||
"test_adadelta_weight_decay_cpu": 6,
|
||||
"test_adadelta_weight_decay_maximize_cpu": 6,
|
||||
"test_adagrad_cpu": 6,
|
||||
"test_adagrad_cuda": 6,
|
||||
"test_adagrad_initial_accumulator_value_weight_decay_cpu": 6,
|
||||
"test_adagrad_initial_accumulator_value_weight_decay_cuda": 6,
|
||||
"test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2,
|
||||
"test_adagrad_lr_decay_weight_decay_cpu": 6,
|
||||
"test_adagrad_lr_decay_weight_decay_cuda": 6,
|
||||
"test_adagrad_lr_decay_weight_decay_foreach_xpu": 2,
|
||||
"test_adagrad_tensor_lr_cpu": 6,
|
||||
"test_adagrad_tensor_lr_cuda": 6,
|
||||
"test_adagrad_tensor_lr_xpu": 6,
|
||||
"test_adagrad_weight_decay_cpu": 6,
|
||||
"test_adagrad_weight_decay_cuda": 6,
|
||||
"test_adagrad_weight_decay_foreach_xpu": 2,
|
||||
"test_adagrad_weight_decay_maximize_cpu": 6,
|
||||
"test_adagrad_weight_decay_maximize_cuda": 6,
|
||||
"test_adagrad_weight_decay_maximize_foreach_xpu": 2,
|
||||
"test_adam_amsgrad_capturable_cuda": 6,
|
||||
"test_adam_amsgrad_capturable_xpu": 6,
|
||||
"test_adam_cpu": 6,
|
||||
"test_rmsprop_foreach_weight_decay_cpu": 12,
|
||||
"test_nadam_foreach_weight_decay_momentum_decay_cpu": 20,
|
||||
"test_adamw_amsgrad_capturable_foreach_cuda": 3,
|
||||
"test_adamw_amsgrad_capturable_foreach_xpu": 3,
|
||||
"test_adamw_amsgrad_capturable_cuda": 6,
|
||||
"test_adamw_amsgrad_capturable_xpu": 6,
|
||||
"test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6,
|
||||
"test_adamw_tensor_lr_tensor_betas_capturable_cuda": 6,
|
||||
"test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_xpu": 6,
|
||||
"test_adamw_tensor_lr_amsgrad_capturable_cuda": 6,
|
||||
"test_adamw_tensor_lr_amsgrad_capturable_xpu": 6,
|
||||
"test_adam_tensor_lr_amsgrad_capturable_cuda": 6,
|
||||
"test_adam_tensor_lr_amsgrad_capturable_xpu": 6,
|
||||
"test_adam_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6,
|
||||
"test_adam_tensor_lr_tensor_betas_capturable_cuda": 6,
|
||||
"test_adam_weight_decay_cpu": 6,
|
||||
"test_adam_weight_decay_maximize_cpu": 6,
|
||||
"test_adamax_cpu": 6,
|
||||
"test_adamax_maximize_cpu": 6,
|
||||
"test_adamax_tensor_lr_weight_decay_capturable_cuda": 6,
|
||||
"test_adamax_tensor_lr_weight_decay_capturable_xpu": 6,
|
||||
"test_adamax_weight_decay_cpu": 6,
|
||||
"test_adamax_weight_decay_maximize_cpu": 6,
|
||||
"test_adamw_amsgrad_capturable_cuda": 6,
|
||||
"test_adamw_amsgrad_capturable_foreach_cuda": 3,
|
||||
"test_adamw_amsgrad_capturable_foreach_xpu": 3,
|
||||
"test_adamw_amsgrad_capturable_xpu": 6,
|
||||
"test_adamw_cpu": 6,
|
||||
"test_adamw_tensor_lr_amsgrad_capturable_cuda": 6,
|
||||
"test_adamw_tensor_lr_amsgrad_capturable_xpu": 6,
|
||||
"test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6,
|
||||
"test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_xpu": 6,
|
||||
"test_adamw_tensor_lr_tensor_betas_capturable_cuda": 6,
|
||||
"test_adamw_weight_decay_cpu": 6,
|
||||
"test_adamw_weight_decay_maximize_cpu": 6,
|
||||
"test_asgd_cpu": 3,
|
||||
"test_asgd_lambd_cpu": 3,
|
||||
"test_asgd_maximize_cpu": 3,
|
||||
"test_asgd_recompile_single": 16,
|
||||
"test_asgd_t0_cpu": 3,
|
||||
"test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5,
|
||||
"test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8,
|
||||
"test_asgd_weight_decay_cpu": 3,
|
||||
"test_asgd_weight_decay_maximize_cpu": 3,
|
||||
"test_nadam_cpu": 3,
|
||||
"test_nadam_foreach_weight_decay_momentum_decay_cpu": 20,
|
||||
"test_nadam_momentum_decay_cpu": 3,
|
||||
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6,
|
||||
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9,
|
||||
"test_nadam_weight_decay_cpu": 5,
|
||||
"test_nadam_weight_decay_maximize_cpu": 5,
|
||||
"test_nadam_weight_decay_momentum_decay_cpu": 5,
|
||||
"test_nadam_weight_decay_momentum_decay_decoupled_weight_decay_cpu": 5,
|
||||
"test_radam_cpu": 7,
|
||||
"test_radam_eps_cpu": 7,
|
||||
"test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6,
|
||||
"test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6,
|
||||
"test_radam_weight_decay_cpu": 7,
|
||||
"test_radam_weight_decay_decoupled_weight_decay_cpu": 7,
|
||||
"test_radam_weight_decay_maximize_cpu": 7,
|
||||
"test_rmsprop_cpu": 6,
|
||||
"test_rmsprop_foreach_weight_decay_cpu": 12,
|
||||
"test_rmsprop_maximize_cpu": 6,
|
||||
"test_rmsprop_maximize_weight_decay_cpu": 6,
|
||||
"test_adam_amsgrad_capturable_cuda": 6,
|
||||
"test_adam_amsgrad_capturable_xpu": 6,
|
||||
"test_adadelta_tensor_lr_capturable_cuda": 6,
|
||||
"test_adadelta_tensor_lr_capturable_xpu": 6,
|
||||
"test_rmsprop_tensor_lr_capturable_cuda": 6,
|
||||
"test_rmsprop_tensor_lr_capturable_xpu": 6,
|
||||
"test_rmsprop_weight_decay_centered_cpu": 6,
|
||||
"test_rmsprop_weight_decay_cpu": 6,
|
||||
"test_sgd_cpu": 4,
|
||||
"test_sgd_cuda": 4,
|
||||
"test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16,
|
||||
"test_adadelta_foreach_weight_decay_maximize_cpu": 12,
|
||||
"test_adadelta_foreach_rho_weight_decay_cpu": 12,
|
||||
"test_adadelta_foreach_weight_decay_cpu": 12,
|
||||
"test_sgd_foreach_momentum_weight_decay_cpu": 16,
|
||||
"test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16,
|
||||
"test_sgd_momentum_dampening_foreach_cuda": 5,
|
||||
"test_sgd_momentum_dampening_foreach_xpu": 5,
|
||||
"test_sgd_momentum_foreach_cuda": 5,
|
||||
"test_sgd_momentum_foreach_xpu": 5,
|
||||
"test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
|
||||
"test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2,
|
||||
"test_sgd_weight_decay_maximize_cuda": 4,
|
||||
"test_sgd_weight_decay_maximize_xpu": 4,
|
||||
"test_sgd_weight_decay_maximize_cpu": 4,
|
||||
"test_sgd_weight_decay_cpu": 4,
|
||||
"test_sgd_weight_decay_cuda": 4,
|
||||
"test_sgd_weight_decay_xpu": 4,
|
||||
"test_sgd_momentum_weight_decay_foreach_cuda": 2,
|
||||
"test_sgd_momentum_weight_decay_foreach_xpu": 2,
|
||||
"test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
|
||||
"test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2,
|
||||
"test_sgd_cuda": 4,
|
||||
"test_sgd_cpu": 4,
|
||||
"test_sgd_xpu": 4,
|
||||
"test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2,
|
||||
"test_adagrad_lr_decay_weight_decay_foreach_xpu": 2,
|
||||
"test_adagrad_weight_decay_foreach_xpu": 2,
|
||||
"test_adagrad_weight_decay_maximize_foreach_xpu": 2,
|
||||
"test_adagrad_tensor_lr_cpu": 6,
|
||||
"test_adagrad_tensor_lr_cuda": 6,
|
||||
"test_adagrad_tensor_lr_xpu": 6,
|
||||
"test_adamax_tensor_lr_weight_decay_capturable_cuda": 6,
|
||||
"test_adamax_tensor_lr_weight_decay_capturable_xpu": 6,
|
||||
"test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5,
|
||||
"test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8,
|
||||
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6,
|
||||
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9,
|
||||
"test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6,
|
||||
"test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6,
|
||||
"test_sgd_tensor_lr_cpu": 2,
|
||||
"test_sgd_tensor_lr_cuda": 2,
|
||||
"test_sgd_tensor_lr_xpu": 2,
|
||||
"test_sgd_weight_decay_cpu": 4,
|
||||
"test_sgd_weight_decay_cuda": 4,
|
||||
"test_sgd_weight_decay_maximize_cpu": 4,
|
||||
"test_sgd_weight_decay_maximize_cuda": 4,
|
||||
"test_sgd_weight_decay_maximize_xpu": 4,
|
||||
"test_sgd_weight_decay_xpu": 4,
|
||||
"test_sgd_xpu": 4,
|
||||
}
|
||||
|
||||
# also tracks currently supported optimizers
|
||||
|
|
@ -679,7 +631,7 @@ class CompiledOptimizerTests(TestCase):
|
|||
test_adagrad_recompile = make_recompile_test(Adagrad, lr=0.01)
|
||||
test_asgd_recompile_default = make_recompile_test(ASGD, lr=0.01)
|
||||
test_asgd_recompile_single = make_recompile_test(
|
||||
ASGD, kernel_count=3, lr=0.01, foreach=False
|
||||
ASGD, kernel_count=8, lr=0.01, foreach=False
|
||||
)
|
||||
test_asgd_recompile_foreach = make_recompile_test(ASGD, lr=0.01, foreach=True)
|
||||
test_sgd_recompile_single = make_recompile_test(
|
||||
|
|
|
|||
|
|
@ -296,7 +296,6 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
)
|
||||
and epilogue != "mul"
|
||||
and epilogue != "div"
|
||||
and epilogue != "leaky_relu"
|
||||
or (
|
||||
dtype in (torch.float16, torch.bfloat16)
|
||||
and epilogue == "add"
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from torch import nn
|
|||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.test_case import TestCase
|
||||
|
|
@ -95,9 +94,6 @@ class MultiUserConvOp(nn.Module):
|
|||
|
||||
|
||||
class EfficientConvBNEvalTemplate(TestCase):
|
||||
# With specialize_float = False, momentum becomes an input and so
|
||||
# the number of bytes accessed wobbles
|
||||
@dynamo_config.patch(specialize_float=True)
|
||||
@inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True})
|
||||
def test_basic(self):
|
||||
def test_conv_bn_eval(
|
||||
|
|
|
|||
|
|
@ -114,9 +114,6 @@ def cal_conv_generated_kernel_number(mod, input, dtype):
|
|||
return input_kernel + output_kernel
|
||||
|
||||
|
||||
# The pattern match for this is kind of broken. I'll cc the
|
||||
# person who wrote this test/match on the diff to see if they can help me fix it.
|
||||
@torch._dynamo.config.patch(specialize_float=True)
|
||||
@config.patch({"freezing": True})
|
||||
class TestPatternMatcherBase(TestCase):
|
||||
def _check_unary_is_decomposed(self, unary_fn):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from unittest.mock import patch
|
|||
|
||||
import functorch
|
||||
import torch
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.config as config
|
||||
import torch.autograd
|
||||
from torch._inductor import metrics
|
||||
|
|
@ -481,9 +480,6 @@ class FusionTests(TestCase):
|
|||
inp = (T(10, 10), T(10, 10), T(10, 10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """500""")
|
||||
|
||||
# With specialize_float = False, epsilon becomes an input and so
|
||||
# the number of bytes accessed wobbles
|
||||
@dynamo_config.patch(specialize_float=True)
|
||||
def test_reduction_pointwise_multi_level_reduction(self):
|
||||
hidden_size = 4096
|
||||
layer_norm = torch.nn.LayerNorm(hidden_size).cuda().float()
|
||||
|
|
|
|||
|
|
@ -3939,7 +3939,7 @@ class TestAttnBias(NNTestCase):
|
|||
SDPBackend.MATH,
|
||||
SDPBackend.CUDNN_ATTENTION]):
|
||||
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
|
||||
self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
|
||||
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
|
||||
|
||||
@skipIfRocm
|
||||
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ specialize_int = False
|
|||
# Whether or not to specialize on float inputs. Dynamo will always promote
|
||||
# float inputs into Tensor inputs, but at the moment, backends inconsistently
|
||||
# support codegen on float (this is to be fixed).
|
||||
specialize_float = True if is_fbcode() else False
|
||||
specialize_float = True
|
||||
|
||||
# legacy config, does nothing now!
|
||||
dynamic_shapes = True
|
||||
|
|
|
|||
Loading…
Reference in a new issue