diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 3029570c71a..d72f10405f3 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -74,6 +74,24 @@ using DtypeAccum = float; using DtypeEpilogue = float; using DtypeOutput = cutlass::bfloat16_t; +using Multiply = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + DtypeEpilogue, + DtypeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + +using Add = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + DtypeEpilogue, + DtypeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + +using Cast = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, + DtypeOutput, + DtypeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + template struct Schedule; @@ -156,54 +174,29 @@ void f8f8bf16_rowwise_impl( // the Collective Builder // Implement rowwise scaling epilogue. - using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0, - TileShape, - DtypeScale, - cute::Stride, cute::Int<0>, cute::Int<0>>>; + constexpr int ColBroadcastStages = 0; + constexpr int RowBroadcastStages = PONG ? 2 : 1; - using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< - PONG ? 2 : 1, - TileShape, - DtypeScale, - cute::Stride, cute::Int<1>, cute::Int<0>>>; + using XScale = cutlass::epilogue::fusion:: + Sm90ColBroadcast; - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< - PONG ? 2 : 1, - TileShape, - DtypeBias, - cute::Stride, cute::Int<1>, cute::Int<0>>>; + using WScale = cutlass::epilogue::fusion:: + Sm90RowBroadcast; + + using Bias = cutlass::epilogue::fusion:: + Sm90RowBroadcast; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, - DtypeEpilogue, // First stage output type. - DtypeEpilogue, // First stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, - DtypeEpilogue, // Second stage output type. - DtypeEpilogue, // Second stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute1 = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< - cutlass::plus, - DtypeOutput, // Final (optional) stage output type. - DtypeEpilogue, // Final stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeBias = - cutlass::epilogue::fusion::Sm90EVT; - - using EpilogueEVT = EVTComputeBias; + using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< + Cast, + cutlass::epilogue::fusion::Sm90EVT< + Add, + Bias, + cutlass::epilogue::fusion::Sm90EVT< + Multiply, + XScale, + cutlass::epilogue::fusion::Sm90EVT>>>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -265,10 +258,10 @@ void f8f8bf16_rowwise_impl( stride_a, reinterpret_cast(WQ.data_ptr()), stride_b}, - {{{bias.has_value() ? reinterpret_cast(bias->data_ptr()) - : nullptr}, - {{reinterpret_cast(x_scale.data_ptr())}, - {{reinterpret_cast(w_scale.data_ptr())}}}}, + {{{{bias.has_value() ? reinterpret_cast(bias->data_ptr()) + : nullptr}, + {{reinterpret_cast(x_scale.data_ptr())}, + {{reinterpret_cast(w_scale.data_ptr())}}}}}, reinterpret_cast(out.data_ptr()), stride_output, reinterpret_cast(out.data_ptr()),