From c5f276cbd9500579e7e30f0a47829468e391beac Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Fri, 7 Feb 2025 14:00:17 -0800 Subject: [PATCH] Update base for Update on "[inductor] online softmax" Softmax need do some preparation work that access the input tensor in two passes - compute amax of each row - compute (x - amax).exp.sum for each row When the row size is large, cache can not hold all the active data and accessing the input multiple passes increases execution time since the kernel is membw bounded. Online softmax uses a customized reduction to compute max and sum at the same time by accessing the data in one pass. Check this paper for more details ( https://arxiv.org/abs/1805.02867 ). Also here is an online softmax kernel generated by inductor as a reference: https://gist.github.com/shunting314/67ae4fffd45d4f2753c781780332fa54 ## Microbenchmark - `TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_ONLINE_SOFTMAX=0 DO_PERF_TEST=1 python test/inductor/test_online_softmax.py -k test_softmax` : without online softmax - eager_ms=6.671296119689941 - opt_ms=8.06931209564209 - `TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_ONLINE_SOFTMAX=1 DO_PERF_TEST=1 python test/inductor/test_online_softmax.py -k test_softmax`: with online softmax - eager_ms=6.634047985076904 - opt_ms=6.230591773986816 Ideally, online softmax should save about 2ms here. We saves about 1.84ms in practice. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov peterbell10 ngimel [ghstack-poisoned]