[ROCm] TunableOp: Update rocBLAS get_solutions API (since ROCm5.6) (#16657)

### Description
- Update existing rocBLAS get_solutions API using
`*_get_solutions_by_type` (supported from ROCm5.6); remove the original
nested TunableOp logic.
- Update kernel_explorer.
This commit is contained in:
mindest 2023-07-13 11:20:26 +08:00 committed by GitHub
parent ebc311365b
commit b7fd5af48b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 322 additions and 396 deletions

View file

@ -75,6 +75,11 @@ elseif (onnxruntime_USE_ROCM)
if (onnxruntime_USE_HIPBLASLT)
target_compile_definitions(kernel_explorer PRIVATE USE_HIPBLASLT)
endif()
if (onnxruntime_USE_ROCBLAS_EXTENSION_API)
target_compile_definitions(kernel_explorer PRIVATE USE_ROCBLAS_EXTENSION_API)
target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS)
target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_BETA_FEATURES_API)
endif()
endif()
add_dependencies(kernel_explorer onnxruntime_pybind11_state)

View file

@ -50,7 +50,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ").");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->hw % hw_size != 0, "Arg hw_size (", hw_size ") is not a divisor of hw (", params->hw, ").");
params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ").");
if constexpr (WithSwish) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish.");
} else {

View file

@ -36,50 +36,38 @@ class RocblasHandleStreamGuard {
#ifdef USE_ROCBLAS_EXTENSION_API
template <typename T>
constexpr rocblas_datatype RocBlasDataTypeFor(const T*) {
static_assert(sizeof(T) == -1, "Unsupported type for rocBLAS operation.");
// The code below should be unreachable due to the static_assert above.
// But the compiler doesn't like not having a return statement, so we
// return something sensible.
constexpr rocblas_datatype RocBlasDataTypeFor();
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
return rocblas_datatype_f32_r;
}
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<float>(const float*) {
return rocblas_datatype_f32_r;
}
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<half>(const half*) {
constexpr rocblas_datatype RocBlasDataTypeFor<half>() {
return rocblas_datatype_f16_r;
}
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<double>(const double*) {
constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
return rocblas_datatype_f64_r;
}
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>(const BFloat16*) {
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
return rocblas_datatype_bf16_r;
}
template <typename T>
constexpr rocblas_datatype RocBlasComputeTypeFor(const T*) {
static_assert(sizeof(T) == -1, "Unsupported type for rocBLAS operation.");
// The code below should be unreachable due to the static_assert above.
// But the compiler doesn't like not having a return statement, so we
// return something sensible.
constexpr rocblas_datatype RocBlasComputeTypeFor();
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
return rocblas_datatype_f32_r;
}
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<float>(const float*) {
return rocblas_datatype_f32_r;
}
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<half>(const half*) {
constexpr rocblas_datatype RocBlasComputeTypeFor<half>() {
// Note that we're returning the _compute_ type for a given datatype.
// As of 12/2022, using compute type FP16 for 16-bit floats was much
// slower than using compute type FP32. So we use FP32 compute even for
@ -89,12 +77,12 @@ constexpr rocblas_datatype RocBlasComputeTypeFor<half>(const half*) {
}
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<double>(const double*) {
constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
return rocblas_datatype_f64_r;
}
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>(const BFloat16*) {
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
// Note that we're returning the _compute_ type for a given datatype.
// As of 12/2022, using compute type FP16 for 16-bit floats was much
// slower than using compute type FP32. So we use FP32 compute even for
@ -109,359 +97,219 @@ auto DoCastForHalfOrBfloat16(const T fp) {
}
template <>
auto DoCastForHalfOrBfloat16<half>(const half fp) {
inline auto DoCastForHalfOrBfloat16<half>(const half fp) {
// alpha and beta should be the same as compute_type, in half case it is float.
float h = onnxruntime::math::halfToFloat(*reinterpret_cast<const uint16_t*>(&fp));
return h;
}
template <>
auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
// alpha and beta should be the same as compute_type, in bfloat16 case it is float.
float h = fp.ToFloat();
return h;
}
template <typename T>
class IndexedRocBlasGemmOp {
public:
IndexedRocBlasGemmOp()
: index_(0) {}
IndexedRocBlasGemmOp(int index)
: index_(index) {}
auto GetRocBlasGemmTypeStringAndOps() {
rocblas_handle handle;
ROCBLAS_CALL_THROW(rocblas_create_handle(&handle));
Status operator()(const GemmParams<T>* params) {
RocblasHandleStreamGuard guard(params->handle, params->stream);
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
return ROCBLAS_CALL(
rocblas_gemm_ex(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->b, RocBlasDataTypeFor(params->b), params->ldb,
params->a, RocBlasDataTypeFor(params->a), params->lda,
&h_b,
params->c, RocBlasDataTypeFor(params->c), params->ldc,
params->c, RocBlasDataTypeFor(params->c), params->ldc,
RocBlasComputeTypeFor(params->a),
rocblas_gemm_algo_solution_index,
index_,
rocblas_gemm_flags_none));
int solution_size;
auto input_output_type = RocBlasDataTypeFor<T>();
auto compute_type = RocBlasComputeTypeFor<T>();
// Get the number of available solutions
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
nullptr,
&solution_size));
std::vector<int> solutions(solution_size);
// Get the list of available solutions
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
solutions.data(),
&solution_size));
ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle));
std::vector<std::pair<std::string, Op<GemmParams<T>>>> ret;
for (auto solution : solutions) {
auto rocblas_gemm_op = [=](const GemmParams<T>* params) -> Status {
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
auto status = rocblas_gemm_ex(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->b, input_output_type, params->ldb,
params->a, input_output_type, params->lda,
&h_b,
params->c, input_output_type, params->ldc,
params->c, input_output_type, params->ldc,
compute_type,
rocblas_gemm_algo_solution_index,
solution,
rocblas_gemm_flags_none);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success, "Solution ", solution, " failed.");
return Status::OK();
};
ret.emplace_back(std::make_pair(onnxruntime::MakeString("RocBlasGemm_", solution), std::move(rocblas_gemm_op)));
}
Status IsSupported(const GemmParams<T>*) {
return Status::OK();
}
private:
int index_;
};
return ret;
}
template <typename T>
class RocBlasGemmTunableOp : public TunableOp<GemmParams<T>> {
public:
RocBlasGemmTunableOp() {
// Ensure that the default implementation is always present
this->RegisterOp(IndexedRocBlasGemmOp<T>{0});
auto GetRocBlasBatchedGemmTypeStringAndOps() {
rocblas_handle handle;
ROCBLAS_CALL_THROW(rocblas_create_handle(&handle));
int solution_size;
auto input_output_type = RocBlasDataTypeFor<T>();
auto compute_type = RocBlasComputeTypeFor<T>();
// Get the number of available solutions
ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
nullptr,
&solution_size));
std::vector<int> solutions(solution_size);
// Get the list of available solutions
ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
solutions.data(),
&solution_size));
ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle));
std::vector<std::pair<std::string, Op<BatchedGemmParams<T>>>> ret;
for (auto solution : solutions) {
auto rocblas_gemm_op = [=](const BatchedGemmParams<T>* params) -> Status {
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
auto status = rocblas_gemm_batched_ex(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->bs, input_output_type, params->ldb,
params->as, input_output_type, params->lda,
&h_b,
params->cs, input_output_type, params->ldc,
params->cs, input_output_type, params->ldc,
params->batch,
compute_type,
rocblas_gemm_algo_solution_index,
solution,
rocblas_gemm_flags_none);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success, "Solution ", solution, " failed.");
return Status::OK();
};
ret.emplace_back(std::make_pair(
onnxruntime::MakeString("RocBlasBatchedGemm_", solution), std::move(rocblas_gemm_op)));
}
Status IsSupported(const GemmParams<T>* params) {
ORT_UNUSED_PARAMETER(params);
return Status::OK();
}
protected:
virtual int FindFastest(const GemmParams<T>* params) override {
auto solution_indices = this->GetSolutions(params);
std::vector<Op<GemmParams<T>>> candidates;
for (int solution_idx : solution_indices) {
candidates.emplace_back(IndexedRocBlasGemmOp<T>{solution_idx});
}
auto id = this->FindFastestImpl(params, candidates);
// memoize the result
this->RegisterOp(std::move(candidates[id]));
return this->NumberOfOps() - 1;
}
private:
std::vector<int> GetSolutions(const GemmParams<T>* params) {
int num_solutions = 0;
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
// Get the number of candidate solutions
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->b, RocBlasDataTypeFor(params->b), params->ldb,
params->a, RocBlasDataTypeFor(params->a), params->lda,
&h_b,
params->c, RocBlasDataTypeFor(params->c), params->ldc,
params->c, RocBlasDataTypeFor(params->c), params->ldc,
RocBlasComputeTypeFor(params->a),
rocblas_gemm_algo_solution_index,
rocblas_gemm_flags_none,
NULL,
&num_solutions));
// Get the actual candidate solutions
std::vector<int> solutions(num_solutions);
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->b, RocBlasDataTypeFor(params->b), params->ldb,
params->a, RocBlasDataTypeFor(params->a), params->lda,
&h_b,
params->c, RocBlasDataTypeFor(params->c), params->ldc,
params->c, RocBlasDataTypeFor(params->c), params->ldc,
RocBlasComputeTypeFor(params->a),
rocblas_gemm_algo_solution_index,
rocblas_gemm_flags_none,
solutions.data(),
&num_solutions));
return solutions;
}
};
return ret;
}
template <typename T>
class IndexedRocBlasBatchedGemmOp {
public:
IndexedRocBlasBatchedGemmOp()
: index_(0) {}
IndexedRocBlasBatchedGemmOp(int index)
: index_(index) {}
auto GetRocBlasStridedBatchedGemmTypeStringAndOps() {
rocblas_handle handle;
ROCBLAS_CALL_THROW(rocblas_create_handle(&handle));
Status operator()(const BatchedGemmParams<T>* params) {
RocblasHandleStreamGuard guard(params->handle, params->stream);
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
return ROCBLAS_CALL(
rocblas_gemm_batched_ex(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->bs, RocBlasDataTypeFor(*(params->bs)), params->ldb,
params->as, RocBlasDataTypeFor(*(params->as)), params->lda,
&h_b,
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
params->batch,
RocBlasComputeTypeFor(*(params->as)),
rocblas_gemm_algo_solution_index,
index_,
rocblas_gemm_flags_none));
int solution_size;
auto input_output_type = RocBlasDataTypeFor<T>();
auto compute_type = RocBlasComputeTypeFor<T>();
// Get the number of available solutions
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
nullptr,
&solution_size));
std::vector<int> solutions(solution_size);
// Get the list of available solutions
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
solutions.data(),
&solution_size));
ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle));
std::vector<std::pair<std::string, Op<StridedBatchedGemmParams<T>>>> ret;
for (auto solution : solutions) {
auto rocblas_gemm_op = [=](const StridedBatchedGemmParams<T>* params) -> Status {
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
auto status = rocblas_gemm_strided_batched_ex(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->b, input_output_type, params->ldb, params->stride_b,
params->a, input_output_type, params->lda, params->stride_a,
&h_b,
params->c, input_output_type, params->ldc, params->stride_c,
params->c, input_output_type, params->ldc, params->stride_c,
params->batch,
compute_type,
rocblas_gemm_algo_solution_index,
solution,
rocblas_gemm_flags_none);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success, "Solution ", solution, " failed.");
return Status::OK();
};
ret.emplace_back(std::make_pair(
onnxruntime::MakeString("RocBlasStridedBatchedGemm_", solution), std::move(rocblas_gemm_op)));
}
return ret;
}
Status IsSupported(const BatchedGemmParams<T>*) {
return Status::OK();
}
private:
int index_;
};
template <typename T>
class RocBlasBatchedGemmTunableOp : public TunableOp<BatchedGemmParams<T>> {
public:
RocBlasBatchedGemmTunableOp() {
// Ensure that the default implementation is always present
this->RegisterOp(IndexedRocBlasBatchedGemmOp<T>{0});
}
Status IsSupported(const BatchedGemmParams<T>* params) {
ORT_UNUSED_PARAMETER(params);
return Status::OK();
}
protected:
virtual int FindFastest(const BatchedGemmParams<T>* params) override {
auto solution_indices = this->GetSolutions(params);
std::vector<Op<BatchedGemmParams<T>>> candidates;
for (int solution_idx : solution_indices) {
candidates.emplace_back(IndexedRocBlasBatchedGemmOp<T>{solution_idx});
}
auto id = this->FindFastestImpl(params, candidates);
// memoize the result
this->RegisterOp(std::move(candidates[id]));
return this->NumberOfOps() - 1;
}
private:
std::vector<int> GetSolutions(const BatchedGemmParams<T>* params) {
int num_solutions = 0;
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
// Get the number of candidate solutions
ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->bs, RocBlasDataTypeFor(*(params->bs)), params->ldb,
params->as, RocBlasDataTypeFor(*(params->as)), params->lda,
&h_b,
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
params->batch,
RocBlasComputeTypeFor(*(params->as)),
rocblas_gemm_algo_solution_index,
rocblas_gemm_flags_none,
NULL,
&num_solutions));
// Get the actual candidate solutions
std::vector<int> solutions(num_solutions);
ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->bs, RocBlasDataTypeFor(*(params->bs)), params->ldb,
params->as, RocBlasDataTypeFor(*(params->as)), params->lda,
&h_b,
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
params->batch,
RocBlasComputeTypeFor(*(params->as)),
rocblas_gemm_algo_solution_index,
rocblas_gemm_flags_none,
solutions.data(),
&num_solutions));
return solutions;
}
};
template <typename T>
class IndexedRocBlasStridedBatchedGemmOp {
public:
IndexedRocBlasStridedBatchedGemmOp()
: index_(0) {}
IndexedRocBlasStridedBatchedGemmOp(int index)
: index_(index) {}
Status operator()(const StridedBatchedGemmParams<T>* params) {
RocblasHandleStreamGuard guard(params->handle, params->stream);
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
return ROCBLAS_CALL(
rocblas_gemm_strided_batched_ex(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->b, RocBlasDataTypeFor(params->b), params->ldb, params->stride_b,
params->a, RocBlasDataTypeFor(params->a), params->lda, params->stride_a,
&h_b,
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
params->batch,
RocBlasComputeTypeFor(params->a),
rocblas_gemm_algo_solution_index,
index_,
rocblas_gemm_flags_none));
}
Status IsSupported(const StridedBatchedGemmParams<T>*) {
return Status::OK();
}
private:
int index_;
};
template <typename T>
class RocBlasStridedBatchedGemmTunableOp : public TunableOp<StridedBatchedGemmParams<T>> {
public:
RocBlasStridedBatchedGemmTunableOp() {
// Ensure that the default implementation is always present
this->RegisterOp(IndexedRocBlasStridedBatchedGemmOp<T>{0});
}
Status IsSupported(const StridedBatchedGemmParams<T>* params) {
ORT_UNUSED_PARAMETER(params);
return Status::OK();
}
protected:
virtual int FindFastest(const StridedBatchedGemmParams<T>* params) override {
auto solution_indices = this->GetSolutions(params);
std::vector<Op<StridedBatchedGemmParams<T>>> candidates;
for (int solution_idx : solution_indices) {
candidates.emplace_back(IndexedRocBlasStridedBatchedGemmOp<T>{solution_idx});
}
auto id = this->FindFastestImpl(params, candidates);
// memoize the result
this->RegisterOp(std::move(candidates[id]));
return this->NumberOfOps() - 1;
}
private:
std::vector<int> GetSolutions(const StridedBatchedGemmParams<T>* params) {
int num_solutions = 0;
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
// Get the number of candidate solutions
ROCBLAS_CALL_THROW(rocblas_gemm_strided_batched_ex_get_solutions(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->b, RocBlasDataTypeFor(params->b), params->ldb, params->stride_b,
params->a, RocBlasDataTypeFor(params->a), params->lda, params->stride_a,
&h_b,
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
params->batch,
RocBlasComputeTypeFor(params->a),
rocblas_gemm_algo_solution_index,
rocblas_gemm_flags_none,
NULL,
&num_solutions));
// Get the actual candidate solutions
std::vector<int> solutions(num_solutions);
ROCBLAS_CALL_THROW(rocblas_gemm_strided_batched_ex_get_solutions(
params->handle,
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
params->n, params->m, params->k,
&h_a,
params->b, RocBlasDataTypeFor(params->b), params->ldb, params->stride_b,
params->a, RocBlasDataTypeFor(params->a), params->lda, params->stride_a,
&h_b,
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
params->batch,
RocBlasComputeTypeFor(params->a),
rocblas_gemm_algo_solution_index,
rocblas_gemm_flags_none,
solutions.data(),
&num_solutions));
return solutions;
}
};
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
#endif // USE_ROCBLAS_EXTENSION_API
template <typename T>
Status RocBlasGemmOp(const GemmParams<T>* params) {

View file

@ -44,8 +44,11 @@ class GemmTunableOp : public TunableOp<GemmParams<T>> {
#endif
#ifdef USE_ROCBLAS_EXTENSION_API
this->RegisterNestedTunableOp(&rocblas_gemm_tunable_op_);
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
for (auto&& [_, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif
#ifdef USE_COMPOSABLE_KERNEL
for (auto&& [_, op] : GetCKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
@ -79,11 +82,6 @@ class GemmTunableOp : public TunableOp<GemmParams<T>> {
delete params;
}
}
private:
#ifdef USE_ROCBLAS_EXTENSION_API
RocBlasGemmTunableOp<T> rocblas_gemm_tunable_op_;
#endif
};
template <typename T, typename ALayout, typename BLayout>
@ -93,8 +91,11 @@ class BatchedGemmTunableOp : public TunableOp<BatchedGemmParams<T>> {
this->RegisterOp(RocBlasBatchedGemmOp<T>);
#ifdef USE_ROCBLAS_EXTENSION_API
this->RegisterNestedTunableOp(&rocblas_batched_gemm_tunable_op_);
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
for (auto&& [_, op] : GetRocBlasBatchedGemmTypeStringAndOps<T>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif
}
const BatchedGemmParams<T>* PreTuning(const BatchedGemmParams<T>* params) override {
@ -131,11 +132,6 @@ class BatchedGemmTunableOp : public TunableOp<BatchedGemmParams<T>> {
delete params;
}
}
private:
#ifdef USE_ROCBLAS_EXTENSION_API
RocBlasBatchedGemmTunableOp<T> rocblas_batched_gemm_tunable_op_;
#endif
};
template <typename T, typename ALayout, typename BLayout>
@ -149,8 +145,11 @@ class StridedBatchedGemmTunableOp : public TunableOp<StridedBatchedGemmParams<T>
#endif
#ifdef USE_ROCBLAS_EXTENSION_API
this->RegisterNestedTunableOp(&rocblas_strided_batched_gemm_tunable_op_);
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
for (auto&& [_, op] : GetRocBlasStridedBatchedGemmTypeStringAndOps<T>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif
#ifdef USE_COMPOSABLE_KERNEL
for (auto&& [_, op] : GetCKStridedBatchedGemmTypeStringAndOps<T, ALayout, BLayout>()) {
@ -178,11 +177,6 @@ class StridedBatchedGemmTunableOp : public TunableOp<StridedBatchedGemmParams<T>
delete params;
}
}
private:
#ifdef USE_ROCBLAS_EXTENSION_API
RocBlasStridedBatchedGemmTunableOp<T> rocblas_strided_batched_gemm_tunable_op_;
#endif
};
} // namespace internal

View file

@ -51,6 +51,10 @@ class DeviceArray {
CALL_THROW(MEMCPY(host_, device_.get(), size_ * itemsize_, MEMCPY_DEVICE_TO_HOST));
}
void UpdateDeviceArray() {
CALL_THROW(MEMCPY(device_.get(), host_, size_ * itemsize_, MEMCPY_HOST_TO_DEVICE));
}
void* ptr() const {
return device_.get();
}

View file

@ -32,7 +32,8 @@ PYBIND11_PLUGIN_IMPL(_kernel_explorer) {
KE_REGISTER(m) {
py::class_<DeviceArray>(m, "DeviceArray")
.def(py::init<py::array>())
.def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray);
.def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray)
.def("UpdateDeviceArray", &DeviceArray::UpdateDeviceArray);
m.def("is_composable_kernel_available", []() {
#ifdef USE_COMPOSABLE_KERNEL

View file

@ -1,6 +1,7 @@
class DeviceArray:
def __init__(self, ndarray) -> None: ...
def UpdateHostNumpyArray(self) -> None: ... # noqa: N802
def UpdateDeviceArray(self) -> None: ... # noqa: N802
class blas_op: # noqa: N801
T: int

View file

@ -66,6 +66,11 @@ def _test_batched_gemm(
if not my_gemm.SelectOp(impl):
continue
# Restore C Arrays
for my_c in my_cs:
my_c.fill(1.0)
for dev_c in dev_cs:
dev_c.UpdateDeviceArray()
my_gemm.Run()
for dev_c in dev_cs:
dev_c.UpdateHostNumpyArray()

View file

@ -50,7 +50,9 @@ def _test_gemm(func, dtype: str, transa: bool, transb: bool, m: int, n: int, k:
for impl in my_gemm.ListOps():
if not my_gemm.SelectOp(impl):
continue
# Restore C Array
my_c.fill(1.0)
dev_c.UpdateDeviceArray()
my_gemm.Run()
dev_c.UpdateHostNumpyArray()

View file

@ -30,7 +30,8 @@ class RocBlasGemm : public IKernelExplorer {
DeviceArray& a, int64_t lda,
DeviceArray& b, int64_t ldb,
double beta,
DeviceArray& c, int64_t ldc) {
DeviceArray& c, int64_t ldc)
: params_{} {
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
params_.tuning_ctx = TuningContext();
params_.stream = Stream();
@ -48,6 +49,16 @@ class RocBlasGemm : public IKernelExplorer {
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
type_strings_.emplace_back("RocBlasGemmDefault");
ops_.emplace_back([](auto* params) { return RocBlasGemmOp<T>(params); });
#ifdef USE_ROCBLAS_EXTENSION_API
for (auto&& [type_string, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
#endif
}
~RocBlasGemm() {
@ -56,15 +67,23 @@ class RocBlasGemm : public IKernelExplorer {
}
void Run() override {
ORT_THROW_IF_ERROR(op_(&params_));
ORT_THROW_IF_ERROR(ops_[selected_op_](&params_));
}
std::vector<std::string> ListOps() const {
return {"Rocblas"};
return type_strings_;
}
bool SelectOp(const std::string& name) {
return name == "Rocblas";
for (size_t i = 0; i < ops_.size(); i++) {
if (type_strings_[i] == name) {
selected_op_ = i;
Status status = ops_[i](&params_);
return status.IsOK();
}
}
ORT_THROW("Cannot find implementation ", name);
}
private:
@ -74,7 +93,9 @@ class RocBlasGemm : public IKernelExplorer {
using OpT = Op<ParamsT>;
ParamsT params_{};
OpT op_{RocBlasGemmOp<T>};
std::vector<OpT> ops_;
std::vector<std::string> type_strings_;
size_t selected_op_{};
};
template <typename T>
@ -87,7 +108,8 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer<T> {
std::vector<DeviceArray>& bs, int64_t ldb,
double beta,
std::vector<DeviceArray>& cs, int64_t ldc,
int64_t batch) {
int64_t batch)
: params_{} {
this->CopyAsBsCsPointersToDevice(as, bs, cs, batch);
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
params_.tuning_ctx = this->TuningContext();
@ -107,6 +129,16 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer<T> {
params_.cs = this->dev_cs_.get();
params_.ldc = ldc;
params_.batch = batch;
type_strings_.emplace_back("RocBlasBatchedGemmDefault");
ops_.emplace_back([](auto* params) { return RocBlasBatchedGemmOp<T>(params); });
#ifdef USE_ROCBLAS_EXTENSION_API
for (auto&& [type_string, op] : GetRocBlasBatchedGemmTypeStringAndOps<T>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
#endif
}
~RocBlasBatchedGemm() {
@ -115,15 +147,23 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer<T> {
}
void Run() override {
ORT_THROW_IF_ERROR(op_(&params_));
ORT_THROW_IF_ERROR(ops_[selected_op_](&params_));
}
std::vector<std::string> ListOps() const {
return {"Rocblas"};
return type_strings_;
}
bool SelectOp(const std::string& name) {
return name == "Rocblas";
for (size_t i = 0; i < ops_.size(); i++) {
if (type_strings_[i] == name) {
selected_op_ = i;
Status status = ops_[i](&params_);
return status.IsOK();
}
}
ORT_THROW("Cannot find implementation ", name);
}
private:
@ -133,7 +173,9 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer<T> {
using OpT = Op<ParamsT>;
ParamsT params_{};
OpT op_{RocBlasBatchedGemmOp<T>};
std::vector<OpT> ops_;
std::vector<std::string> type_strings_;
size_t selected_op_{};
};
template <typename T>
@ -146,7 +188,8 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
DeviceArray& b, int64_t ldb, int64_t stride_b,
double beta,
DeviceArray& c, int64_t ldc, int64_t stride_c,
int64_t batch) {
int64_t batch)
: params_{} {
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
params_.tuning_ctx = TuningContext();
params_.stream = Stream();
@ -168,6 +211,16 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
params_.ldc = ldc;
params_.stride_c = stride_c;
params_.batch = batch;
type_strings_.emplace_back("RocBlasStridedBatchedGemmDefault");
ops_.emplace_back([](auto* params) { return RocBlasStridedBatchedGemmOp<T>(params); });
#ifdef USE_ROCBLAS_EXTENSION_API
for (auto&& [type_string, op] : GetRocBlasStridedBatchedGemmTypeStringAndOps<T>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
#endif
}
~RocBlasStridedBatchedGemm() {
@ -176,15 +229,23 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
}
void Run() override {
ORT_THROW_IF_ERROR(op_(&params_));
ORT_THROW_IF_ERROR(ops_[selected_op_](&params_));
}
std::vector<std::string> ListOps() const {
return {"Rocblas"};
return type_strings_;
}
bool SelectOp(const std::string& name) {
return name == "Rocblas";
for (size_t i = 0; i < ops_.size(); i++) {
if (type_strings_[i] == name) {
selected_op_ = i;
Status status = ops_[i](&params_);
return status.IsOK();
}
}
ORT_THROW("Cannot find implementation ", name);
}
private:
@ -194,7 +255,9 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
using OpT = Op<ParamsT>;
ParamsT params_{};
OpT op_{RocBlasStridedBatchedGemmOp<T>};
std::vector<OpT> ops_;
std::vector<std::string> type_strings_;
size_t selected_op_{};
};
#define REGISTER_OP_COMMON(type, dtype) \

View file

@ -73,6 +73,9 @@ def _test_strided_batched_gemm(
if not my_gemm.SelectOp(impl):
continue
# Restore C Array
my_c.fill(1.0)
dev_c.UpdateDeviceArray()
my_gemm.Run()
dev_c.UpdateHostNumpyArray()