Revert "Always unspecialize float in OSS (#138922)"

This reverts commit ba5253da9b.

Reverted https://github.com/pytorch/pytorch/pull/138922 on behalf of https://github.com/yf225 due to perf regression on torchbench ([comment](https://github.com/pytorch/pytorch/pull/138922#issuecomment-2499277511))
This commit is contained in:
PyTorch MergeBot 2024-11-26 00:03:03 +00:00
parent 964655bf0c
commit ad37afd590
14 changed files with 75 additions and 156 deletions

View file

@ -1 +1 @@
39e9d3084686b291546cbfdbfc3e34f53659783d
2ec22641e390cda25ec7c61fcbce07507727d584

View file

@ -360,8 +360,8 @@ class GraphModule(torch.nn.Module):
actual_graph = self._test_wrap_simple(
f,
default_args_generator((x, y)),
ifdynstaticdefault(3, 4),
expected_opcount=3,
ifdynstaticdefault(2, 3),
expected_opcount=2,
return_graph=True,
)
if torch._dynamo.config.assume_static_by_default:
@ -369,20 +369,18 @@ class GraphModule(torch.nn.Module):
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 1]", L_y_: "f64[]"):
def forward(self, L_x_: "f32[3, 1]"):
l_x_ = L_x_
l_y_ = L_y_
item: "Sym(zf0)" = l_y_.item(); l_y_ = None
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, item); wrap_body_0 = l_x_ = item = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3, 1]", item: "Sym(zf0)"):
def forward(self, l_x_: "f32[3, 1]"):
view: "f32[3]" = l_x_.view(3); l_x_ = None
add: "f32[3]" = view + item; view = item = None
add: "f32[3]" = view + 0.5; view = None
return (add,)
""",
)
@ -391,20 +389,18 @@ class GraphModule(torch.nn.Module):
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]", L_y_: "f64[]"):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"):
l_x_ = L_x_
l_y_ = L_y_
item: "Sym(zf1)" = l_y_.item(); l_y_ = None
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]", item: "Sym(zf1)"):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"):
view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None
add: "f32[s0]" = view + item; view = item = None
add: "f32[s0]" = view + 0.5; view = None
return (add,)
""",
)
@ -2409,7 +2405,7 @@ class GraphModule(torch.nn.Module):
x = torch.zeros([])
# Numbers don't get lifted, so args is still 2.
self._test_wrap_simple(f, default_args_generator((x,)), 3)
self._test_wrap_simple(f, default_args_generator((x,)), 2)
def test_capture_global_num_adds_guard(self):
@torch.compile(backend="eager", fullgraph=True)
@ -2432,7 +2428,7 @@ class GraphModule(torch.nn.Module):
x = torch.zeros([])
y = 3.14
# Numbers don't get lifted, so args is still 2.
self._test_wrap_simple(f, default_args_generator((x, y)), 3, expected_opcount=3)
self._test_wrap_simple(f, default_args_generator((x, y)), 2)
def test_side_effect_in_body(self):
counters.clear()

View file

@ -628,11 +628,11 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
def inner(x, ys, zs):
for y, z in zip(ys, zs):
x += y * (3.0 if z else 3.2)
x += y * z
return x
ys = [1.0, 2.0]
zs = [True]
zs = [3.0]
x = torch.tensor([1.0])
fn_opt = torch._dynamo.optimize("eager")(fn)
@ -641,9 +641,8 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
record_str = "\n".join(r.getMessage() for r in records)
# TODO: this is a very sensitive test
self.assertIn(
f"___check_obj_id(L['zs'][0], {id(True)})",
"""L['zs'][0] == 3.0""",
record_str,
)
self.assertIn(

View file

@ -1396,7 +1396,7 @@ utils_device.CURRENT_DEVICE == None""".split(
cfg2.val = 2.0
v = opt_fn(v, cfg2) # 7
self.assertEqual(v[0], 7)
self.assertEqual(cnts.op_count, 6)
self.assertEqual(cnts.op_count, 8)
def test_config_getattr_default(self):
class Cfg:
@ -1491,7 +1491,7 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(opt_fn_ret(1.5)[0], -459)
self.assertEqual(out[0], 2100)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 9)
self.assertEqual(cnts.op_count, 7)
def test_tensor_dict1(self):
def fn(inputs):
@ -3717,7 +3717,7 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertAlmostEqual(cell1 + 1, result1)
self.assertTrue(torch.allclose(cell2 + 3, result2))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 4)
self.assertEqual(cnts.op_count, 1)
def test_closure_out_of_scope_cell_with_mutation(self):
cell1 = torch.rand(1).item()
@ -3745,12 +3745,8 @@ utils_device.CURRENT_DEVICE == None""".split(
result1, result2, _ = opt_fn()
self.assertAlmostEqual(orig1 + 1 * i, result1)
self.assertTrue(torch.allclose(orig2 + 10 * i, result2))
if i == 1:
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 6)
else:
self.assertEqual(cnts.frame_count, 0)
self.assertEqual(cnts.op_count, 0)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 3)
cnts.clear()
def test_closure_with_mutation_and_graph_break(self):

View file

@ -1242,13 +1242,13 @@ class ReproTests(torch._dynamo.test_case.TestCase):
with torch.no_grad():
cnt = self._reformer(nopython=True)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 13)
self.assertEqual(cnt.op_count, 11)
def test_reformer_train(self):
with torch.enable_grad():
cnt = self._reformer(nopython=False)
expected_op_count = (
"""13""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5"""
"""11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5"""
)
self.assertExpectedInline(cnt.frame_count, """1""")
@ -1725,7 +1725,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
opt_model(inp)
opt_model(inp)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(18, cnt.op_count)
self.assertEqual(12, cnt.op_count)
def test_exec_import(self):
def fn1():

View file

@ -674,13 +674,7 @@ class StructuredTraceTest(TestCase):
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_storage": {"id": 1, "describer_id": "ID", "size": 8}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 1, "ndim": 0, "dtype": "torch.float64", "device": "device(type='cpu')", "size": [], "is_leaf": true, "stride": [], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 1, "source": "___as_tensor(L['ys'][0])"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_storage": {"id": 2, "describer_id": "ID", "size": 8}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 2, "ndim": 0, "dtype": "torch.float64", "device": "device(type='cpu')", "size": [], "is_leaf": true, "stride": [], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 2, "source": "___as_tensor(L['zs'][0])"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "l_ys_0_": [], "l_zs_0_": [], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
@ -688,13 +682,7 @@ class StructuredTraceTest(TestCase):
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_storage": {"id": 1, "describer_id": "ID", "size": 8}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 1, "ndim": 0, "dtype": "torch.float64", "device": "device(type='cpu')", "size": [], "is_leaf": true, "stride": [], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 1, "source": "___as_tensor(L['ys'][0])"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_storage": {"id": 2, "describer_id": "ID", "size": 8}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 2, "ndim": 0, "dtype": "torch.float64", "device": "device(type='cpu')", "size": [], "is_leaf": true, "stride": [], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 2, "source": "___as_tensor(L['zs'][0])"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "l_ys_0_": [], "l_zs_0_": [], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
""", # noqa: B950

View file

@ -270,7 +270,7 @@ class TestDynamoTimed(TestCase):
'runtime_cudagraphify_time_us': None,
'runtime_triton_autotune_time_us': None,
'shape_env_guard_count': 0,
'specialize_float': False,
'specialize_float': True,
'start_time': 0.0001,
'start_time_us': 100,
'structured_logging_overhead_s': 0.0,

View file

@ -114,115 +114,67 @@ class KernelCounts(NamedTuple):
# This maps the test name to the
# expected kernel count
KERNEL_COUNT_OVERRIDES = {
"test_adadelta_cpu": 6,
"test_adadelta_foreach_rho_weight_decay_cpu": 12,
"test_adadelta_foreach_weight_decay_cpu": 12,
"test_adadelta_foreach_weight_decay_maximize_cpu": 12,
"test_adadelta_maximize_cpu": 6,
"test_adadelta_rho_weight_decay_cpu": 6,
"test_adadelta_tensor_lr_capturable_cuda": 6,
"test_adadelta_tensor_lr_capturable_xpu": 6,
"test_adadelta_weight_decay_cpu": 6,
"test_adadelta_weight_decay_maximize_cpu": 6,
"test_adagrad_cpu": 6,
"test_adagrad_cuda": 6,
"test_adagrad_initial_accumulator_value_weight_decay_cpu": 6,
"test_adagrad_initial_accumulator_value_weight_decay_cuda": 6,
"test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2,
"test_adagrad_lr_decay_weight_decay_cpu": 6,
"test_adagrad_lr_decay_weight_decay_cuda": 6,
"test_adagrad_lr_decay_weight_decay_foreach_xpu": 2,
"test_adagrad_tensor_lr_cpu": 6,
"test_adagrad_tensor_lr_cuda": 6,
"test_adagrad_tensor_lr_xpu": 6,
"test_adagrad_weight_decay_cpu": 6,
"test_adagrad_weight_decay_cuda": 6,
"test_adagrad_weight_decay_foreach_xpu": 2,
"test_adagrad_weight_decay_maximize_cpu": 6,
"test_adagrad_weight_decay_maximize_cuda": 6,
"test_adagrad_weight_decay_maximize_foreach_xpu": 2,
"test_adam_amsgrad_capturable_cuda": 6,
"test_adam_amsgrad_capturable_xpu": 6,
"test_adam_cpu": 6,
"test_rmsprop_foreach_weight_decay_cpu": 12,
"test_nadam_foreach_weight_decay_momentum_decay_cpu": 20,
"test_adamw_amsgrad_capturable_foreach_cuda": 3,
"test_adamw_amsgrad_capturable_foreach_xpu": 3,
"test_adamw_amsgrad_capturable_cuda": 6,
"test_adamw_amsgrad_capturable_xpu": 6,
"test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6,
"test_adamw_tensor_lr_tensor_betas_capturable_cuda": 6,
"test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_xpu": 6,
"test_adamw_tensor_lr_amsgrad_capturable_cuda": 6,
"test_adamw_tensor_lr_amsgrad_capturable_xpu": 6,
"test_adam_tensor_lr_amsgrad_capturable_cuda": 6,
"test_adam_tensor_lr_amsgrad_capturable_xpu": 6,
"test_adam_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6,
"test_adam_tensor_lr_tensor_betas_capturable_cuda": 6,
"test_adam_weight_decay_cpu": 6,
"test_adam_weight_decay_maximize_cpu": 6,
"test_adamax_cpu": 6,
"test_adamax_maximize_cpu": 6,
"test_adamax_tensor_lr_weight_decay_capturable_cuda": 6,
"test_adamax_tensor_lr_weight_decay_capturable_xpu": 6,
"test_adamax_weight_decay_cpu": 6,
"test_adamax_weight_decay_maximize_cpu": 6,
"test_adamw_amsgrad_capturable_cuda": 6,
"test_adamw_amsgrad_capturable_foreach_cuda": 3,
"test_adamw_amsgrad_capturable_foreach_xpu": 3,
"test_adamw_amsgrad_capturable_xpu": 6,
"test_adamw_cpu": 6,
"test_adamw_tensor_lr_amsgrad_capturable_cuda": 6,
"test_adamw_tensor_lr_amsgrad_capturable_xpu": 6,
"test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6,
"test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_xpu": 6,
"test_adamw_tensor_lr_tensor_betas_capturable_cuda": 6,
"test_adamw_weight_decay_cpu": 6,
"test_adamw_weight_decay_maximize_cpu": 6,
"test_asgd_cpu": 3,
"test_asgd_lambd_cpu": 3,
"test_asgd_maximize_cpu": 3,
"test_asgd_recompile_single": 16,
"test_asgd_t0_cpu": 3,
"test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5,
"test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8,
"test_asgd_weight_decay_cpu": 3,
"test_asgd_weight_decay_maximize_cpu": 3,
"test_nadam_cpu": 3,
"test_nadam_foreach_weight_decay_momentum_decay_cpu": 20,
"test_nadam_momentum_decay_cpu": 3,
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6,
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9,
"test_nadam_weight_decay_cpu": 5,
"test_nadam_weight_decay_maximize_cpu": 5,
"test_nadam_weight_decay_momentum_decay_cpu": 5,
"test_nadam_weight_decay_momentum_decay_decoupled_weight_decay_cpu": 5,
"test_radam_cpu": 7,
"test_radam_eps_cpu": 7,
"test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6,
"test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6,
"test_radam_weight_decay_cpu": 7,
"test_radam_weight_decay_decoupled_weight_decay_cpu": 7,
"test_radam_weight_decay_maximize_cpu": 7,
"test_rmsprop_cpu": 6,
"test_rmsprop_foreach_weight_decay_cpu": 12,
"test_rmsprop_maximize_cpu": 6,
"test_rmsprop_maximize_weight_decay_cpu": 6,
"test_adam_amsgrad_capturable_cuda": 6,
"test_adam_amsgrad_capturable_xpu": 6,
"test_adadelta_tensor_lr_capturable_cuda": 6,
"test_adadelta_tensor_lr_capturable_xpu": 6,
"test_rmsprop_tensor_lr_capturable_cuda": 6,
"test_rmsprop_tensor_lr_capturable_xpu": 6,
"test_rmsprop_weight_decay_centered_cpu": 6,
"test_rmsprop_weight_decay_cpu": 6,
"test_sgd_cpu": 4,
"test_sgd_cuda": 4,
"test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16,
"test_adadelta_foreach_weight_decay_maximize_cpu": 12,
"test_adadelta_foreach_rho_weight_decay_cpu": 12,
"test_adadelta_foreach_weight_decay_cpu": 12,
"test_sgd_foreach_momentum_weight_decay_cpu": 16,
"test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16,
"test_sgd_momentum_dampening_foreach_cuda": 5,
"test_sgd_momentum_dampening_foreach_xpu": 5,
"test_sgd_momentum_foreach_cuda": 5,
"test_sgd_momentum_foreach_xpu": 5,
"test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
"test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2,
"test_sgd_weight_decay_maximize_cuda": 4,
"test_sgd_weight_decay_maximize_xpu": 4,
"test_sgd_weight_decay_maximize_cpu": 4,
"test_sgd_weight_decay_cpu": 4,
"test_sgd_weight_decay_cuda": 4,
"test_sgd_weight_decay_xpu": 4,
"test_sgd_momentum_weight_decay_foreach_cuda": 2,
"test_sgd_momentum_weight_decay_foreach_xpu": 2,
"test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
"test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2,
"test_sgd_cuda": 4,
"test_sgd_cpu": 4,
"test_sgd_xpu": 4,
"test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2,
"test_adagrad_lr_decay_weight_decay_foreach_xpu": 2,
"test_adagrad_weight_decay_foreach_xpu": 2,
"test_adagrad_weight_decay_maximize_foreach_xpu": 2,
"test_adagrad_tensor_lr_cpu": 6,
"test_adagrad_tensor_lr_cuda": 6,
"test_adagrad_tensor_lr_xpu": 6,
"test_adamax_tensor_lr_weight_decay_capturable_cuda": 6,
"test_adamax_tensor_lr_weight_decay_capturable_xpu": 6,
"test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5,
"test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8,
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6,
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9,
"test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6,
"test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6,
"test_sgd_tensor_lr_cpu": 2,
"test_sgd_tensor_lr_cuda": 2,
"test_sgd_tensor_lr_xpu": 2,
"test_sgd_weight_decay_cpu": 4,
"test_sgd_weight_decay_cuda": 4,
"test_sgd_weight_decay_maximize_cpu": 4,
"test_sgd_weight_decay_maximize_cuda": 4,
"test_sgd_weight_decay_maximize_xpu": 4,
"test_sgd_weight_decay_xpu": 4,
"test_sgd_xpu": 4,
}
# also tracks currently supported optimizers
@ -679,7 +631,7 @@ class CompiledOptimizerTests(TestCase):
test_adagrad_recompile = make_recompile_test(Adagrad, lr=0.01)
test_asgd_recompile_default = make_recompile_test(ASGD, lr=0.01)
test_asgd_recompile_single = make_recompile_test(
ASGD, kernel_count=3, lr=0.01, foreach=False
ASGD, kernel_count=8, lr=0.01, foreach=False
)
test_asgd_recompile_foreach = make_recompile_test(ASGD, lr=0.01, foreach=True)
test_sgd_recompile_single = make_recompile_test(

View file

@ -296,7 +296,6 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
)
and epilogue != "mul"
and epilogue != "div"
and epilogue != "leaky_relu"
or (
dtype in (torch.float16, torch.bfloat16)
and epilogue == "add"

View file

@ -13,7 +13,6 @@ from torch import nn
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch._dynamo import config as dynamo_config
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch._inductor.test_case import TestCase
@ -95,9 +94,6 @@ class MultiUserConvOp(nn.Module):
class EfficientConvBNEvalTemplate(TestCase):
# With specialize_float = False, momentum becomes an input and so
# the number of bytes accessed wobbles
@dynamo_config.patch(specialize_float=True)
@inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True})
def test_basic(self):
def test_conv_bn_eval(

View file

@ -114,9 +114,6 @@ def cal_conv_generated_kernel_number(mod, input, dtype):
return input_kernel + output_kernel
# The pattern match for this is kind of broken. I'll cc the
# person who wrote this test/match on the diff to see if they can help me fix it.
@torch._dynamo.config.patch(specialize_float=True)
@config.patch({"freezing": True})
class TestPatternMatcherBase(TestCase):
def _check_unary_is_decomposed(self, unary_fn):

View file

@ -5,7 +5,6 @@ from unittest.mock import patch
import functorch
import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.config as config
import torch.autograd
from torch._inductor import metrics
@ -481,9 +480,6 @@ class FusionTests(TestCase):
inp = (T(10, 10), T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """500""")
# With specialize_float = False, epsilon becomes an input and so
# the number of bytes accessed wobbles
@dynamo_config.patch(specialize_float=True)
def test_reduction_pointwise_multi_level_reduction(self):
hidden_size = 4096
layer_norm = torch.nn.LayerNorm(hidden_size).cuda().float()

View file

@ -3939,7 +3939,7 @@ class TestAttnBias(NNTestCase):
SDPBackend.MATH,
SDPBackend.CUDNN_ATTENTION]):
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
@skipIfRocm
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])

View file

@ -65,7 +65,7 @@ specialize_int = False
# Whether or not to specialize on float inputs. Dynamo will always promote
# float inputs into Tensor inputs, but at the moment, backends inconsistently
# support codegen on float (this is to be fixed).
specialize_float = True if is_fbcode() else False
specialize_float = True
# legacy config, does nothing now!
dynamic_shapes = True