mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[fp8 rowwise] Simplify epilogue visitor tree via common blocks (#134223)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134223 Approved by: https://github.com/drisspg
This commit is contained in:
parent
25b2e46573
commit
2f198605ac
1 changed files with 40 additions and 47 deletions
|
|
@ -74,6 +74,24 @@ using DtypeAccum = float;
|
|||
using DtypeEpilogue = float;
|
||||
using DtypeOutput = cutlass::bfloat16_t;
|
||||
|
||||
using Multiply = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
DtypeEpilogue,
|
||||
DtypeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using Add = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::plus,
|
||||
DtypeEpilogue,
|
||||
DtypeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using Cast = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::Identity,
|
||||
DtypeOutput,
|
||||
DtypeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
template <bool PingPong, bool FastAccum>
|
||||
struct Schedule;
|
||||
|
||||
|
|
@ -156,54 +174,29 @@ void f8f8bf16_rowwise_impl(
|
|||
// the Collective Builder
|
||||
|
||||
// Implement rowwise scaling epilogue.
|
||||
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
DtypeScale,
|
||||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
|
||||
constexpr int ColBroadcastStages = 0;
|
||||
constexpr int RowBroadcastStages = PONG ? 2 : 1;
|
||||
|
||||
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
PONG ? 2 : 1,
|
||||
TileShape,
|
||||
DtypeScale,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
using XScale = cutlass::epilogue::fusion::
|
||||
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
PONG ? 2 : 1,
|
||||
TileShape,
|
||||
DtypeBias,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
using WScale = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::
|
||||
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
DtypeEpilogue, // First stage output type.
|
||||
DtypeEpilogue, // First stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
DtypeEpilogue, // Second stage output type.
|
||||
DtypeEpilogue, // Second stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||
|
||||
using ComputeBias = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::plus,
|
||||
DtypeOutput, // Final (optional) stage output type.
|
||||
DtypeEpilogue, // Final stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeBias =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeBias, Bias, EVTCompute1>;
|
||||
|
||||
using EpilogueEVT = EVTComputeBias;
|
||||
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Cast,
|
||||
cutlass::epilogue::fusion::Sm90EVT<
|
||||
Add,
|
||||
Bias,
|
||||
cutlass::epilogue::fusion::Sm90EVT<
|
||||
Multiply,
|
||||
XScale,
|
||||
cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>>>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
|
|
@ -265,10 +258,10 @@ void f8f8bf16_rowwise_impl(
|
|||
stride_a,
|
||||
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
|
||||
stride_b},
|
||||
{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
|
||||
: nullptr},
|
||||
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())},
|
||||
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}}}},
|
||||
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
|
||||
: nullptr},
|
||||
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())},
|
||||
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}}}}},
|
||||
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
|
||||
stride_output,
|
||||
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
|
||||
|
|
|
|||
Loading…
Reference in a new issue