mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Support fused_sgd_kernel support for CPU. ## Bench result: 32 core/sockets ICX Test Scripts: https://gist.github.com/zhuhaozhe/79e842e0a6e25d6d7fa1e4598807272c https://gist.github.com/zhuhaozhe/b4c6998a509dcea1796dd05b3005c969 ``` Tensor Size: 262144, Num Tensor 4, Num Threads: 1 _single_tensor_adagrad time: 0.2500 seconds _fused_adagrad time: 0.0933 seconds Tensor Size: 4194304, Num Tensor 32, Num Threads: 32 _single_tensor_adagrad time: 2.8819 seconds _fused_adagrad time: 1.7591 seconds ``` ## Test Plan: ``` python test_optim.py -k test_fused_matches_forloop python test_optim.py -k test_fused_large_tensor python test_optim.py -k test_can_load_older_state_dict python test_optim.py -k test_grad_scaling_autocast_fused_optimizers python test_torch.py -k test_grad_scaling_autocast_fused python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step ``` Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/124905 Approved by: https://github.com/jgong5, https://github.com/janeyx99 |
||
|---|---|---|
| .. | ||
| codegen | ||
| data | ||
| distributed | ||
| generated | ||
| opinfo | ||
| optests | ||
| test_module | ||
| __init__.py | ||
| autocast_test_lists.py | ||
| autograd_function_db.py | ||
| check_kernel_launches.py | ||
| common_cuda.py | ||
| common_device_type.py | ||
| common_dist_composable.py | ||
| common_distributed.py | ||
| common_dtype.py | ||
| common_fsdp.py | ||
| common_jit.py | ||
| common_methods_invocations.py | ||
| common_mkldnn.py | ||
| common_modules.py | ||
| common_nn.py | ||
| common_optimizers.py | ||
| common_pruning.py | ||
| common_quantization.py | ||
| common_quantized.py | ||
| common_subclass.py | ||
| common_utils.py | ||
| composite_compliance.py | ||
| custom_op_db.py | ||
| dist_utils.py | ||
| dynamo_test_failures.py | ||
| hop_db.py | ||
| hypothesis_utils.py | ||
| inductor_utils.py | ||
| jit_metaprogramming_utils.py | ||
| jit_utils.py | ||
| logging_tensor.py | ||
| logging_utils.py | ||
| quantization_torch_package_models.py | ||
| static_module.py | ||
| torchbind_impls.py | ||
| triton_utils.py | ||
| two_tensor.py | ||