mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update
[ghstack-poisoned]
This commit is contained in:
parent
9a85cc109e
commit
0cbb12ac14
1 changed files with 64 additions and 8 deletions
|
|
@ -37,7 +37,7 @@ GEMM_TEMPLATE_CUTLASS_3X = r"""
|
|||
extern "C" {
|
||||
PT_EXPORT {{kernel_call_signature}} {
|
||||
try {
|
||||
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
|
||||
int B = {{kernel.size(Y, 0, -3, default_value=1)}};
|
||||
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
|
||||
using coord_t = cutlass::gemm::GemmCoord::Index;
|
||||
static cutlass::KernelHardwareInfo hw_info;
|
||||
|
|
@ -152,7 +152,7 @@ GEMM_TEMPLATE_CUTLASS_2X = r"""
|
|||
extern "C" {
|
||||
PT_EXPORT {{kernel_call_signature}} {
|
||||
try {
|
||||
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
|
||||
int B = {{kernel.size(Y, 0, -3, default_value=1)}};
|
||||
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
|
||||
using coord_t = cutlass::gemm::GemmCoord::Index;
|
||||
static cutlass::KernelHardwareInfo hw_info;
|
||||
|
|
@ -264,8 +264,8 @@ GEMM_ARGS_SPARSE_CUTLASS_2X = r"""
|
|||
// Initialize GemmSparse arguments.
|
||||
arguments = {
|
||||
{
|
||||
static_cast<coord_t>({{M}}),
|
||||
static_cast<coord_t>({{N}}),
|
||||
static_cast<coord_t>(M),
|
||||
static_cast<coord_t>(N),
|
||||
static_cast<coord_t>(2 * K),
|
||||
}, // GemmCoord problem_size
|
||||
X_ref, // TensorRef<ElementA const, LayoutA> ref_A
|
||||
|
|
@ -302,20 +302,43 @@ bool initialize_block(
|
|||
if (block.size()<=0) return false;
|
||||
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min));
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
(Element*)block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
{% if kernel.cutlass_dtype(Meta, "void") != "void" %}
|
||||
template <class Element>
|
||||
bool initialize_block_meta(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed) {
|
||||
if (block.size()<=0) return false;
|
||||
cutlass::reference::device::BlockFillRandomSparseMeta(
|
||||
(Element*)block.get(), block.size(), seed, {{instance_type}}::kMetaSizeInBits);
|
||||
return true;
|
||||
}
|
||||
{% endif %}
|
||||
|
||||
extern "C" int run_standalone(uint64_t seed, int repetitions) {
|
||||
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl;
|
||||
size_t workspace_size = 0;
|
||||
size_t* workspace_size_ptr = &workspace_size;
|
||||
|
||||
int M = {{kernel.get_layout_args()[0]}};
|
||||
int N = {{kernel.get_layout_args()[1]}};
|
||||
int K = {{kernel.get_layout_args()[2]}};
|
||||
int lda = {{kernel.get_layout_args()[3]}};
|
||||
int ldb = {{kernel.get_layout_args()[4]}};
|
||||
int ldc = {{kernel.get_layout_args()[5]}};
|
||||
int ldd = {{kernel.get_layout_args()[6]}};
|
||||
|
||||
using ElementA = {{kernel.cutlass_dtype(X)}};
|
||||
using ElementB = {{kernel.cutlass_dtype(W)}};
|
||||
using ElementC = {{kernel.cutlass_dtype(Bias, default_dtype='uint8_t')}}; // may not be void
|
||||
using ElementD = {{kernel.cutlass_dtype(Y)}};
|
||||
{% if kernel.cutlass_dtype(Meta, "void") != "void" %}
|
||||
using ElementE = {{kernel.cutlass_dtype(Meta)}};
|
||||
{% endif %}
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> X_data({{kernel.max_valid_index(X)+1}});
|
||||
initialize_block(X_data, seed++);
|
||||
|
|
@ -324,6 +347,10 @@ extern "C" int run_standalone(uint64_t seed, int repetitions) {
|
|||
cutlass::DeviceAllocation<ElementC> Bias_data({{kernel.max_valid_index(Bias)+1}});
|
||||
initialize_block(Bias_data, seed++);
|
||||
cutlass::DeviceAllocation<ElementD> Y_data({{kernel.max_valid_index(Y)+1}});
|
||||
{% if kernel.cutlass_dtype(Meta, "void") != "void" %}
|
||||
cutlass::DeviceAllocation<ElementE> Meta_data({{kernel.max_valid_index(Meta)+1}});
|
||||
initialize_block_meta(Meta_data, seed++);
|
||||
{% endif %}
|
||||
|
||||
cutlass::DeviceAllocation<uint8_t> workspace_data;
|
||||
// Call once with workspace_size_ptr set to get workspace size
|
||||
|
|
@ -464,6 +491,14 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _update_arg_names_for_test_call_statement(
|
||||
self,
|
||||
arg_names: list[str],
|
||||
input_nodes: list[Buffer],
|
||||
) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _add_cutlass_gemm_choices(
|
||||
self,
|
||||
choices: list[ChoiceCaller],
|
||||
|
|
@ -977,13 +1012,14 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
"""
|
||||
_, __, arg_types = kernel.args.cpp_argdefs()
|
||||
arg_names = [name.strip() for name in names_str.strip().split(",")]
|
||||
if input_nodes[2] is None:
|
||||
del arg_names[2]
|
||||
arg_names = self._update_arg_names_for_test_call_statement(
|
||||
arg_names, input_nodes
|
||||
)
|
||||
arguments = [
|
||||
f"(({arg_type}){arg_name}_data.get())"
|
||||
for arg_type, arg_name in zip(arg_types, arg_names)
|
||||
]
|
||||
return f"{kernel.kernel_name}({', '.join(arguments)}, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);"
|
||||
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, lda, ldb, ldc, ldd, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
|
||||
|
||||
|
||||
class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
|
|
@ -1203,6 +1239,15 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
|||
names: list[str] = []
|
||||
return (Bias, inputs, names)
|
||||
|
||||
def _update_arg_names_for_test_call_statement(
|
||||
self,
|
||||
arg_names: list[str],
|
||||
input_nodes: list[Buffer],
|
||||
) -> list[str]:
|
||||
if input_nodes[2] is None:
|
||||
del arg_names[2]
|
||||
return arg_names
|
||||
|
||||
def render_gemm_arguments(
|
||||
self,
|
||||
argument_template: str,
|
||||
|
|
@ -1479,6 +1524,17 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
|
|||
names = ["Meta"]
|
||||
return (Bias, inputs, names)
|
||||
|
||||
def _update_arg_names_for_test_call_statement(
|
||||
self,
|
||||
arg_names: list[str],
|
||||
input_nodes: list[Buffer],
|
||||
) -> list[str]:
|
||||
if input_nodes[3] is None:
|
||||
del arg_names[3]
|
||||
if input_nodes[2] is None:
|
||||
del arg_names[2]
|
||||
return arg_names
|
||||
|
||||
def render_gemm_arguments(
|
||||
self,
|
||||
instance_type: str,
|
||||
|
|
|
|||
Loading…
Reference in a new issue