[fp8 rowwise] Simplify epilogue visitor tree via common blocks (#134223)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134223
Approved by: https://github.com/drisspg
This commit is contained in:
Luca Wehrstedt 2024-08-22 17:38:31 +00:00 committed by PyTorch MergeBot
parent 25b2e46573
commit 2f198605ac

View file

@ -74,6 +74,24 @@ using DtypeAccum = float;
using DtypeEpilogue = float;
using DtypeOutput = cutlass::bfloat16_t;
using Multiply = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
DtypeEpilogue,
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using Add = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus,
DtypeEpilogue,
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using Cast = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity,
DtypeOutput,
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
template <bool PingPong, bool FastAccum>
struct Schedule;
@ -156,54 +174,29 @@ void f8f8bf16_rowwise_impl(
// the Collective Builder
// Implement rowwise scaling epilogue.
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
0,
TileShape,
DtypeScale,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
constexpr int ColBroadcastStages = 0;
constexpr int RowBroadcastStages = PONG ? 2 : 1;
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
PONG ? 2 : 1,
TileShape,
DtypeScale,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
PONG ? 2 : 1,
TileShape,
DtypeBias,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using WScale = cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
using Bias = cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
DtypeEpilogue, // First stage output type.
DtypeEpilogue, // First stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
DtypeEpilogue, // Second stage output type.
DtypeEpilogue, // Second stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 =
cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
using ComputeBias = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus,
DtypeOutput, // Final (optional) stage output type.
DtypeEpilogue, // Final stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeBias =
cutlass::epilogue::fusion::Sm90EVT<ComputeBias, Bias, EVTCompute1>;
using EpilogueEVT = EVTComputeBias;
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
cutlass::epilogue::fusion::Sm90EVT<
Add,
Bias,
cutlass::epilogue::fusion::Sm90EVT<
Multiply,
XScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>>>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
@ -265,10 +258,10 @@ void f8f8bf16_rowwise_impl(
stride_a,
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
stride_b},
{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
: nullptr},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}}}},
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
: nullptr},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}}}}},
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output,
reinterpret_cast<DtypeOutput*>(out.data_ptr()),