mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
This PR aims for parity+ compared to the old testing for the simplest foreach test case. Test coverage increase: we now test foreach optimizers with CPU as well as on GPU. Before: ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (19136605)]$ python test/test_optim.py -v -k test_multi_tensor_optimizers /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" test_multi_tensor_optimizers (optim.test_optim.TestOptim) ... ok ---------------------------------------------------------------------- Ran 1 test in 7.253s OK (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (19136605)]$ ``` Now, we get granular test cases at the cost of overhead! ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (19136605)]$ python test/test_optim.py -v -k test_foreach /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" test_foreach_ASGD_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_Adadelta_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_Adagrad_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_AdamW_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_Adam_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_Adamax_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_NAdam_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_RAdam_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_RMSprop_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_Rprop_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_SGD_cpu_float64 (__main__.TestOptimRenewedCPU) ... ok test_foreach_ASGD_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_Adadelta_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_Adagrad_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_AdamW_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_Adam_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_Adamax_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_NAdam_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_RAdam_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_RMSprop_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_Rprop_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok test_foreach_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA) ... ok ---------------------------------------------------------------------- Ran 22 tests in 30.954s OK (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (19136605)]$ ``` Why the increase in time? Two reasons: 1. overhead. Any _CUDA_ *Info test (OpInfo, ModuleInfo, OptimizerInfo) will wrap itself with the `CudaNonDefaultStream` policy, and `CudaNonDefaultStream.__enter__` when called for the first time will go through all visible CUDA devices and synchronize each of them, thus forcing the CUDAContext to be init'd. Doing this for all 8 devices takes ~10-15s. Also, test parametrization costs a little overhead too, but not to the level init'ing CUDA context does. 2. We test more! Now, we have 72 configs (in the foreach optimizer world) whereas we only had 59 before. Next steps for the future: - consider adding more Tensor LR configs (like a Tensor LR without capturable in the single tensor case) - this is likely the next PR or 2: migrate all uses of _test_derived_optimizers in test_optim to TestOptimRenewed Pull Request resolved: https://github.com/pytorch/pytorch/pull/114797 Approved by: https://github.com/albanD |
||
|---|---|---|
| .. | ||
| test_lrscheduler.py | ||
| test_optim.py | ||
| test_swa_utils.py | ||