mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[ROCm] TunableOp: Update rocBLAS get_solutions API (since ROCm5.6) (#16657)
### Description - Update existing rocBLAS get_solutions API using `*_get_solutions_by_type` (supported from ROCm5.6); remove the original nested TunableOp logic. - Update kernel_explorer.
This commit is contained in:
parent
ebc311365b
commit
b7fd5af48b
11 changed files with 322 additions and 396 deletions
|
|
@ -75,6 +75,11 @@ elseif (onnxruntime_USE_ROCM)
|
|||
if (onnxruntime_USE_HIPBLASLT)
|
||||
target_compile_definitions(kernel_explorer PRIVATE USE_HIPBLASLT)
|
||||
endif()
|
||||
if (onnxruntime_USE_ROCBLAS_EXTENSION_API)
|
||||
target_compile_definitions(kernel_explorer PRIVATE USE_ROCBLAS_EXTENSION_API)
|
||||
target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS)
|
||||
target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_BETA_FEATURES_API)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_dependencies(kernel_explorer onnxruntime_pybind11_state)
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
|
|||
params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size,
|
||||
"Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ").");
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->hw % hw_size != 0, "Arg hw_size (", hw_size ") is not a divisor of hw (", params->hw, ").");
|
||||
params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ").");
|
||||
if constexpr (WithSwish) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish.");
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -36,50 +36,38 @@ class RocblasHandleStreamGuard {
|
|||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
|
||||
template <typename T>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor(const T*) {
|
||||
static_assert(sizeof(T) == -1, "Unsupported type for rocBLAS operation.");
|
||||
// The code below should be unreachable due to the static_assert above.
|
||||
// But the compiler doesn't like not having a return statement, so we
|
||||
// return something sensible.
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor();
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
|
||||
return rocblas_datatype_f32_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<float>(const float*) {
|
||||
return rocblas_datatype_f32_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<half>(const half*) {
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<half>() {
|
||||
return rocblas_datatype_f16_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<double>(const double*) {
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
|
||||
return rocblas_datatype_f64_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>(const BFloat16*) {
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
|
||||
return rocblas_datatype_bf16_r;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor(const T*) {
|
||||
static_assert(sizeof(T) == -1, "Unsupported type for rocBLAS operation.");
|
||||
// The code below should be unreachable due to the static_assert above.
|
||||
// But the compiler doesn't like not having a return statement, so we
|
||||
// return something sensible.
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor();
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
|
||||
return rocblas_datatype_f32_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<float>(const float*) {
|
||||
return rocblas_datatype_f32_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<half>(const half*) {
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<half>() {
|
||||
// Note that we're returning the _compute_ type for a given datatype.
|
||||
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
||||
// slower than using compute type FP32. So we use FP32 compute even for
|
||||
|
|
@ -89,12 +77,12 @@ constexpr rocblas_datatype RocBlasComputeTypeFor<half>(const half*) {
|
|||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<double>(const double*) {
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
|
||||
return rocblas_datatype_f64_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>(const BFloat16*) {
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
|
||||
// Note that we're returning the _compute_ type for a given datatype.
|
||||
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
||||
// slower than using compute type FP32. So we use FP32 compute even for
|
||||
|
|
@ -109,359 +97,219 @@ auto DoCastForHalfOrBfloat16(const T fp) {
|
|||
}
|
||||
|
||||
template <>
|
||||
auto DoCastForHalfOrBfloat16<half>(const half fp) {
|
||||
inline auto DoCastForHalfOrBfloat16<half>(const half fp) {
|
||||
// alpha and beta should be the same as compute_type, in half case it is float.
|
||||
float h = onnxruntime::math::halfToFloat(*reinterpret_cast<const uint16_t*>(&fp));
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
|
||||
inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
|
||||
// alpha and beta should be the same as compute_type, in bfloat16 case it is float.
|
||||
float h = fp.ToFloat();
|
||||
return h;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class IndexedRocBlasGemmOp {
|
||||
public:
|
||||
IndexedRocBlasGemmOp()
|
||||
: index_(0) {}
|
||||
IndexedRocBlasGemmOp(int index)
|
||||
: index_(index) {}
|
||||
auto GetRocBlasGemmTypeStringAndOps() {
|
||||
rocblas_handle handle;
|
||||
ROCBLAS_CALL_THROW(rocblas_create_handle(&handle));
|
||||
|
||||
Status operator()(const GemmParams<T>* params) {
|
||||
RocblasHandleStreamGuard guard(params->handle, params->stream);
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
return ROCBLAS_CALL(
|
||||
rocblas_gemm_ex(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->b, RocBlasDataTypeFor(params->b), params->ldb,
|
||||
params->a, RocBlasDataTypeFor(params->a), params->lda,
|
||||
&h_b,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
RocBlasComputeTypeFor(params->a),
|
||||
rocblas_gemm_algo_solution_index,
|
||||
index_,
|
||||
rocblas_gemm_flags_none));
|
||||
int solution_size;
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
|
||||
// Get the number of available solutions
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
nullptr,
|
||||
&solution_size));
|
||||
|
||||
std::vector<int> solutions(solution_size);
|
||||
|
||||
// Get the list of available solutions
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
solutions.data(),
|
||||
&solution_size));
|
||||
|
||||
ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle));
|
||||
|
||||
std::vector<std::pair<std::string, Op<GemmParams<T>>>> ret;
|
||||
for (auto solution : solutions) {
|
||||
auto rocblas_gemm_op = [=](const GemmParams<T>* params) -> Status {
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
auto status = rocblas_gemm_ex(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->b, input_output_type, params->ldb,
|
||||
params->a, input_output_type, params->lda,
|
||||
&h_b,
|
||||
params->c, input_output_type, params->ldc,
|
||||
params->c, input_output_type, params->ldc,
|
||||
compute_type,
|
||||
rocblas_gemm_algo_solution_index,
|
||||
solution,
|
||||
rocblas_gemm_flags_none);
|
||||
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");
|
||||
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
status != rocblas_status_success, "Solution ", solution, " failed.");
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
ret.emplace_back(std::make_pair(onnxruntime::MakeString("RocBlasGemm_", solution), std::move(rocblas_gemm_op)));
|
||||
}
|
||||
|
||||
Status IsSupported(const GemmParams<T>*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int index_;
|
||||
};
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class RocBlasGemmTunableOp : public TunableOp<GemmParams<T>> {
|
||||
public:
|
||||
RocBlasGemmTunableOp() {
|
||||
// Ensure that the default implementation is always present
|
||||
this->RegisterOp(IndexedRocBlasGemmOp<T>{0});
|
||||
auto GetRocBlasBatchedGemmTypeStringAndOps() {
|
||||
rocblas_handle handle;
|
||||
ROCBLAS_CALL_THROW(rocblas_create_handle(&handle));
|
||||
|
||||
int solution_size;
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
|
||||
// Get the number of available solutions
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
nullptr,
|
||||
&solution_size));
|
||||
|
||||
std::vector<int> solutions(solution_size);
|
||||
|
||||
// Get the list of available solutions
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
solutions.data(),
|
||||
&solution_size));
|
||||
|
||||
ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle));
|
||||
|
||||
std::vector<std::pair<std::string, Op<BatchedGemmParams<T>>>> ret;
|
||||
for (auto solution : solutions) {
|
||||
auto rocblas_gemm_op = [=](const BatchedGemmParams<T>* params) -> Status {
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
auto status = rocblas_gemm_batched_ex(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->bs, input_output_type, params->ldb,
|
||||
params->as, input_output_type, params->lda,
|
||||
&h_b,
|
||||
params->cs, input_output_type, params->ldc,
|
||||
params->cs, input_output_type, params->ldc,
|
||||
params->batch,
|
||||
compute_type,
|
||||
rocblas_gemm_algo_solution_index,
|
||||
solution,
|
||||
rocblas_gemm_flags_none);
|
||||
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");
|
||||
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
status != rocblas_status_success, "Solution ", solution, " failed.");
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
ret.emplace_back(std::make_pair(
|
||||
onnxruntime::MakeString("RocBlasBatchedGemm_", solution), std::move(rocblas_gemm_op)));
|
||||
}
|
||||
|
||||
Status IsSupported(const GemmParams<T>* params) {
|
||||
ORT_UNUSED_PARAMETER(params);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual int FindFastest(const GemmParams<T>* params) override {
|
||||
auto solution_indices = this->GetSolutions(params);
|
||||
std::vector<Op<GemmParams<T>>> candidates;
|
||||
for (int solution_idx : solution_indices) {
|
||||
candidates.emplace_back(IndexedRocBlasGemmOp<T>{solution_idx});
|
||||
}
|
||||
|
||||
auto id = this->FindFastestImpl(params, candidates);
|
||||
// memoize the result
|
||||
this->RegisterOp(std::move(candidates[id]));
|
||||
return this->NumberOfOps() - 1;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> GetSolutions(const GemmParams<T>* params) {
|
||||
int num_solutions = 0;
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
// Get the number of candidate solutions
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->b, RocBlasDataTypeFor(params->b), params->ldb,
|
||||
params->a, RocBlasDataTypeFor(params->a), params->lda,
|
||||
&h_b,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
RocBlasComputeTypeFor(params->a),
|
||||
rocblas_gemm_algo_solution_index,
|
||||
rocblas_gemm_flags_none,
|
||||
NULL,
|
||||
&num_solutions));
|
||||
|
||||
// Get the actual candidate solutions
|
||||
std::vector<int> solutions(num_solutions);
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->b, RocBlasDataTypeFor(params->b), params->ldb,
|
||||
params->a, RocBlasDataTypeFor(params->a), params->lda,
|
||||
&h_b,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
RocBlasComputeTypeFor(params->a),
|
||||
rocblas_gemm_algo_solution_index,
|
||||
rocblas_gemm_flags_none,
|
||||
solutions.data(),
|
||||
&num_solutions));
|
||||
|
||||
return solutions;
|
||||
}
|
||||
};
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class IndexedRocBlasBatchedGemmOp {
|
||||
public:
|
||||
IndexedRocBlasBatchedGemmOp()
|
||||
: index_(0) {}
|
||||
IndexedRocBlasBatchedGemmOp(int index)
|
||||
: index_(index) {}
|
||||
auto GetRocBlasStridedBatchedGemmTypeStringAndOps() {
|
||||
rocblas_handle handle;
|
||||
ROCBLAS_CALL_THROW(rocblas_create_handle(&handle));
|
||||
|
||||
Status operator()(const BatchedGemmParams<T>* params) {
|
||||
RocblasHandleStreamGuard guard(params->handle, params->stream);
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
return ROCBLAS_CALL(
|
||||
rocblas_gemm_batched_ex(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->bs, RocBlasDataTypeFor(*(params->bs)), params->ldb,
|
||||
params->as, RocBlasDataTypeFor(*(params->as)), params->lda,
|
||||
&h_b,
|
||||
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
|
||||
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
|
||||
params->batch,
|
||||
RocBlasComputeTypeFor(*(params->as)),
|
||||
rocblas_gemm_algo_solution_index,
|
||||
index_,
|
||||
rocblas_gemm_flags_none));
|
||||
int solution_size;
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
|
||||
// Get the number of available solutions
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
nullptr,
|
||||
&solution_size));
|
||||
|
||||
std::vector<int> solutions(solution_size);
|
||||
|
||||
// Get the list of available solutions
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
solutions.data(),
|
||||
&solution_size));
|
||||
|
||||
ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle));
|
||||
|
||||
std::vector<std::pair<std::string, Op<StridedBatchedGemmParams<T>>>> ret;
|
||||
for (auto solution : solutions) {
|
||||
auto rocblas_gemm_op = [=](const StridedBatchedGemmParams<T>* params) -> Status {
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
auto status = rocblas_gemm_strided_batched_ex(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->b, input_output_type, params->ldb, params->stride_b,
|
||||
params->a, input_output_type, params->lda, params->stride_a,
|
||||
&h_b,
|
||||
params->c, input_output_type, params->ldc, params->stride_c,
|
||||
params->c, input_output_type, params->ldc, params->stride_c,
|
||||
params->batch,
|
||||
compute_type,
|
||||
rocblas_gemm_algo_solution_index,
|
||||
solution,
|
||||
rocblas_gemm_flags_none);
|
||||
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");
|
||||
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
status != rocblas_status_success, "Solution ", solution, " failed.");
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
ret.emplace_back(std::make_pair(
|
||||
onnxruntime::MakeString("RocBlasStridedBatchedGemm_", solution), std::move(rocblas_gemm_op)));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status IsSupported(const BatchedGemmParams<T>*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int index_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class RocBlasBatchedGemmTunableOp : public TunableOp<BatchedGemmParams<T>> {
|
||||
public:
|
||||
RocBlasBatchedGemmTunableOp() {
|
||||
// Ensure that the default implementation is always present
|
||||
this->RegisterOp(IndexedRocBlasBatchedGemmOp<T>{0});
|
||||
}
|
||||
|
||||
Status IsSupported(const BatchedGemmParams<T>* params) {
|
||||
ORT_UNUSED_PARAMETER(params);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual int FindFastest(const BatchedGemmParams<T>* params) override {
|
||||
auto solution_indices = this->GetSolutions(params);
|
||||
std::vector<Op<BatchedGemmParams<T>>> candidates;
|
||||
for (int solution_idx : solution_indices) {
|
||||
candidates.emplace_back(IndexedRocBlasBatchedGemmOp<T>{solution_idx});
|
||||
}
|
||||
|
||||
auto id = this->FindFastestImpl(params, candidates);
|
||||
// memoize the result
|
||||
this->RegisterOp(std::move(candidates[id]));
|
||||
return this->NumberOfOps() - 1;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> GetSolutions(const BatchedGemmParams<T>* params) {
|
||||
int num_solutions = 0;
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
// Get the number of candidate solutions
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->bs, RocBlasDataTypeFor(*(params->bs)), params->ldb,
|
||||
params->as, RocBlasDataTypeFor(*(params->as)), params->lda,
|
||||
&h_b,
|
||||
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
|
||||
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
|
||||
params->batch,
|
||||
RocBlasComputeTypeFor(*(params->as)),
|
||||
rocblas_gemm_algo_solution_index,
|
||||
rocblas_gemm_flags_none,
|
||||
NULL,
|
||||
&num_solutions));
|
||||
|
||||
// Get the actual candidate solutions
|
||||
std::vector<int> solutions(num_solutions);
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->bs, RocBlasDataTypeFor(*(params->bs)), params->ldb,
|
||||
params->as, RocBlasDataTypeFor(*(params->as)), params->lda,
|
||||
&h_b,
|
||||
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
|
||||
params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc,
|
||||
params->batch,
|
||||
RocBlasComputeTypeFor(*(params->as)),
|
||||
rocblas_gemm_algo_solution_index,
|
||||
rocblas_gemm_flags_none,
|
||||
solutions.data(),
|
||||
&num_solutions));
|
||||
|
||||
return solutions;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class IndexedRocBlasStridedBatchedGemmOp {
|
||||
public:
|
||||
IndexedRocBlasStridedBatchedGemmOp()
|
||||
: index_(0) {}
|
||||
IndexedRocBlasStridedBatchedGemmOp(int index)
|
||||
: index_(index) {}
|
||||
|
||||
Status operator()(const StridedBatchedGemmParams<T>* params) {
|
||||
RocblasHandleStreamGuard guard(params->handle, params->stream);
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
return ROCBLAS_CALL(
|
||||
rocblas_gemm_strided_batched_ex(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->b, RocBlasDataTypeFor(params->b), params->ldb, params->stride_b,
|
||||
params->a, RocBlasDataTypeFor(params->a), params->lda, params->stride_a,
|
||||
&h_b,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
|
||||
params->batch,
|
||||
RocBlasComputeTypeFor(params->a),
|
||||
rocblas_gemm_algo_solution_index,
|
||||
index_,
|
||||
rocblas_gemm_flags_none));
|
||||
}
|
||||
|
||||
Status IsSupported(const StridedBatchedGemmParams<T>*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int index_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class RocBlasStridedBatchedGemmTunableOp : public TunableOp<StridedBatchedGemmParams<T>> {
|
||||
public:
|
||||
RocBlasStridedBatchedGemmTunableOp() {
|
||||
// Ensure that the default implementation is always present
|
||||
this->RegisterOp(IndexedRocBlasStridedBatchedGemmOp<T>{0});
|
||||
}
|
||||
|
||||
Status IsSupported(const StridedBatchedGemmParams<T>* params) {
|
||||
ORT_UNUSED_PARAMETER(params);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual int FindFastest(const StridedBatchedGemmParams<T>* params) override {
|
||||
auto solution_indices = this->GetSolutions(params);
|
||||
std::vector<Op<StridedBatchedGemmParams<T>>> candidates;
|
||||
for (int solution_idx : solution_indices) {
|
||||
candidates.emplace_back(IndexedRocBlasStridedBatchedGemmOp<T>{solution_idx});
|
||||
}
|
||||
|
||||
auto id = this->FindFastestImpl(params, candidates);
|
||||
// memoize the result
|
||||
this->RegisterOp(std::move(candidates[id]));
|
||||
return this->NumberOfOps() - 1;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> GetSolutions(const StridedBatchedGemmParams<T>* params) {
|
||||
int num_solutions = 0;
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
// Get the number of candidate solutions
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_strided_batched_ex_get_solutions(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->b, RocBlasDataTypeFor(params->b), params->ldb, params->stride_b,
|
||||
params->a, RocBlasDataTypeFor(params->a), params->lda, params->stride_a,
|
||||
&h_b,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
|
||||
params->batch,
|
||||
RocBlasComputeTypeFor(params->a),
|
||||
rocblas_gemm_algo_solution_index,
|
||||
rocblas_gemm_flags_none,
|
||||
NULL,
|
||||
&num_solutions));
|
||||
|
||||
// Get the actual candidate solutions
|
||||
std::vector<int> solutions(num_solutions);
|
||||
ROCBLAS_CALL_THROW(rocblas_gemm_strided_batched_ex_get_solutions(
|
||||
params->handle,
|
||||
params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose,
|
||||
params->n, params->m, params->k,
|
||||
&h_a,
|
||||
params->b, RocBlasDataTypeFor(params->b), params->ldb, params->stride_b,
|
||||
params->a, RocBlasDataTypeFor(params->a), params->lda, params->stride_a,
|
||||
&h_b,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c,
|
||||
params->batch,
|
||||
RocBlasComputeTypeFor(params->a),
|
||||
rocblas_gemm_algo_solution_index,
|
||||
rocblas_gemm_flags_none,
|
||||
solutions.data(),
|
||||
&num_solutions));
|
||||
|
||||
return solutions;
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
|
||||
#endif // USE_ROCBLAS_EXTENSION_API
|
||||
|
||||
template <typename T>
|
||||
Status RocBlasGemmOp(const GemmParams<T>* params) {
|
||||
|
|
|
|||
|
|
@ -44,8 +44,11 @@ class GemmTunableOp : public TunableOp<GemmParams<T>> {
|
|||
#endif
|
||||
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
this->RegisterNestedTunableOp(&rocblas_gemm_tunable_op_);
|
||||
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
|
||||
for (auto&& [_, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
for (auto&& [_, op] : GetCKGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
||||
|
|
@ -79,11 +82,6 @@ class GemmTunableOp : public TunableOp<GemmParams<T>> {
|
|||
delete params;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
RocBlasGemmTunableOp<T> rocblas_gemm_tunable_op_;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T, typename ALayout, typename BLayout>
|
||||
|
|
@ -93,8 +91,11 @@ class BatchedGemmTunableOp : public TunableOp<BatchedGemmParams<T>> {
|
|||
this->RegisterOp(RocBlasBatchedGemmOp<T>);
|
||||
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
this->RegisterNestedTunableOp(&rocblas_batched_gemm_tunable_op_);
|
||||
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
|
||||
for (auto&& [_, op] : GetRocBlasBatchedGemmTypeStringAndOps<T>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
const BatchedGemmParams<T>* PreTuning(const BatchedGemmParams<T>* params) override {
|
||||
|
|
@ -131,11 +132,6 @@ class BatchedGemmTunableOp : public TunableOp<BatchedGemmParams<T>> {
|
|||
delete params;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
RocBlasBatchedGemmTunableOp<T> rocblas_batched_gemm_tunable_op_;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T, typename ALayout, typename BLayout>
|
||||
|
|
@ -149,8 +145,11 @@ class StridedBatchedGemmTunableOp : public TunableOp<StridedBatchedGemmParams<T>
|
|||
#endif
|
||||
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
this->RegisterNestedTunableOp(&rocblas_strided_batched_gemm_tunable_op_);
|
||||
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
|
||||
for (auto&& [_, op] : GetRocBlasStridedBatchedGemmTypeStringAndOps<T>()) {
|
||||
ORT_UNUSED_PARAMETER(_);
|
||||
this->RegisterOp(std::move(op));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
for (auto&& [_, op] : GetCKStridedBatchedGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
||||
|
|
@ -178,11 +177,6 @@ class StridedBatchedGemmTunableOp : public TunableOp<StridedBatchedGemmParams<T>
|
|||
delete params;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
RocBlasStridedBatchedGemmTunableOp<T> rocblas_strided_batched_gemm_tunable_op_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
|
|
|||
|
|
@ -51,6 +51,10 @@ class DeviceArray {
|
|||
CALL_THROW(MEMCPY(host_, device_.get(), size_ * itemsize_, MEMCPY_DEVICE_TO_HOST));
|
||||
}
|
||||
|
||||
void UpdateDeviceArray() {
|
||||
CALL_THROW(MEMCPY(device_.get(), host_, size_ * itemsize_, MEMCPY_HOST_TO_DEVICE));
|
||||
}
|
||||
|
||||
void* ptr() const {
|
||||
return device_.get();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ PYBIND11_PLUGIN_IMPL(_kernel_explorer) {
|
|||
KE_REGISTER(m) {
|
||||
py::class_<DeviceArray>(m, "DeviceArray")
|
||||
.def(py::init<py::array>())
|
||||
.def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray);
|
||||
.def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray)
|
||||
.def("UpdateDeviceArray", &DeviceArray::UpdateDeviceArray);
|
||||
|
||||
m.def("is_composable_kernel_available", []() {
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
class DeviceArray:
|
||||
def __init__(self, ndarray) -> None: ...
|
||||
def UpdateHostNumpyArray(self) -> None: ... # noqa: N802
|
||||
def UpdateDeviceArray(self) -> None: ... # noqa: N802
|
||||
|
||||
class blas_op: # noqa: N801
|
||||
T: int
|
||||
|
|
|
|||
|
|
@ -66,6 +66,11 @@ def _test_batched_gemm(
|
|||
if not my_gemm.SelectOp(impl):
|
||||
continue
|
||||
|
||||
# Restore C Arrays
|
||||
for my_c in my_cs:
|
||||
my_c.fill(1.0)
|
||||
for dev_c in dev_cs:
|
||||
dev_c.UpdateDeviceArray()
|
||||
my_gemm.Run()
|
||||
for dev_c in dev_cs:
|
||||
dev_c.UpdateHostNumpyArray()
|
||||
|
|
|
|||
|
|
@ -50,7 +50,9 @@ def _test_gemm(func, dtype: str, transa: bool, transb: bool, m: int, n: int, k:
|
|||
for impl in my_gemm.ListOps():
|
||||
if not my_gemm.SelectOp(impl):
|
||||
continue
|
||||
|
||||
# Restore C Array
|
||||
my_c.fill(1.0)
|
||||
dev_c.UpdateDeviceArray()
|
||||
my_gemm.Run()
|
||||
dev_c.UpdateHostNumpyArray()
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ class RocBlasGemm : public IKernelExplorer {
|
|||
DeviceArray& a, int64_t lda,
|
||||
DeviceArray& b, int64_t ldb,
|
||||
double beta,
|
||||
DeviceArray& c, int64_t ldc) {
|
||||
DeviceArray& c, int64_t ldc)
|
||||
: params_{} {
|
||||
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
|
||||
params_.tuning_ctx = TuningContext();
|
||||
params_.stream = Stream();
|
||||
|
|
@ -48,6 +49,16 @@ class RocBlasGemm : public IKernelExplorer {
|
|||
params_.beta = beta;
|
||||
params_.c = static_cast<T*>(c.ptr());
|
||||
params_.ldc = ldc;
|
||||
|
||||
type_strings_.emplace_back("RocBlasGemmDefault");
|
||||
ops_.emplace_back([](auto* params) { return RocBlasGemmOp<T>(params); });
|
||||
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
for (auto&& [type_string, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
|
||||
type_strings_.emplace_back(std::move(type_string));
|
||||
ops_.emplace_back(std::move(op));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
~RocBlasGemm() {
|
||||
|
|
@ -56,15 +67,23 @@ class RocBlasGemm : public IKernelExplorer {
|
|||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR(op_(¶ms_));
|
||||
ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_));
|
||||
}
|
||||
|
||||
std::vector<std::string> ListOps() const {
|
||||
return {"Rocblas"};
|
||||
return type_strings_;
|
||||
}
|
||||
|
||||
bool SelectOp(const std::string& name) {
|
||||
return name == "Rocblas";
|
||||
for (size_t i = 0; i < ops_.size(); i++) {
|
||||
if (type_strings_[i] == name) {
|
||||
selected_op_ = i;
|
||||
Status status = ops_[i](¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
}
|
||||
|
||||
ORT_THROW("Cannot find implementation ", name);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -74,7 +93,9 @@ class RocBlasGemm : public IKernelExplorer {
|
|||
using OpT = Op<ParamsT>;
|
||||
|
||||
ParamsT params_{};
|
||||
OpT op_{RocBlasGemmOp<T>};
|
||||
std::vector<OpT> ops_;
|
||||
std::vector<std::string> type_strings_;
|
||||
size_t selected_op_{};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -87,7 +108,8 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer<T> {
|
|||
std::vector<DeviceArray>& bs, int64_t ldb,
|
||||
double beta,
|
||||
std::vector<DeviceArray>& cs, int64_t ldc,
|
||||
int64_t batch) {
|
||||
int64_t batch)
|
||||
: params_{} {
|
||||
this->CopyAsBsCsPointersToDevice(as, bs, cs, batch);
|
||||
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
|
||||
params_.tuning_ctx = this->TuningContext();
|
||||
|
|
@ -107,6 +129,16 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer<T> {
|
|||
params_.cs = this->dev_cs_.get();
|
||||
params_.ldc = ldc;
|
||||
params_.batch = batch;
|
||||
|
||||
type_strings_.emplace_back("RocBlasBatchedGemmDefault");
|
||||
ops_.emplace_back([](auto* params) { return RocBlasBatchedGemmOp<T>(params); });
|
||||
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
for (auto&& [type_string, op] : GetRocBlasBatchedGemmTypeStringAndOps<T>()) {
|
||||
type_strings_.emplace_back(std::move(type_string));
|
||||
ops_.emplace_back(std::move(op));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
~RocBlasBatchedGemm() {
|
||||
|
|
@ -115,15 +147,23 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer<T> {
|
|||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR(op_(¶ms_));
|
||||
ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_));
|
||||
}
|
||||
|
||||
std::vector<std::string> ListOps() const {
|
||||
return {"Rocblas"};
|
||||
return type_strings_;
|
||||
}
|
||||
|
||||
bool SelectOp(const std::string& name) {
|
||||
return name == "Rocblas";
|
||||
for (size_t i = 0; i < ops_.size(); i++) {
|
||||
if (type_strings_[i] == name) {
|
||||
selected_op_ = i;
|
||||
Status status = ops_[i](¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
}
|
||||
|
||||
ORT_THROW("Cannot find implementation ", name);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -133,7 +173,9 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer<T> {
|
|||
using OpT = Op<ParamsT>;
|
||||
|
||||
ParamsT params_{};
|
||||
OpT op_{RocBlasBatchedGemmOp<T>};
|
||||
std::vector<OpT> ops_;
|
||||
std::vector<std::string> type_strings_;
|
||||
size_t selected_op_{};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -146,7 +188,8 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
|
|||
DeviceArray& b, int64_t ldb, int64_t stride_b,
|
||||
double beta,
|
||||
DeviceArray& c, int64_t ldc, int64_t stride_c,
|
||||
int64_t batch) {
|
||||
int64_t batch)
|
||||
: params_{} {
|
||||
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
|
||||
params_.tuning_ctx = TuningContext();
|
||||
params_.stream = Stream();
|
||||
|
|
@ -168,6 +211,16 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
|
|||
params_.ldc = ldc;
|
||||
params_.stride_c = stride_c;
|
||||
params_.batch = batch;
|
||||
|
||||
type_strings_.emplace_back("RocBlasStridedBatchedGemmDefault");
|
||||
ops_.emplace_back([](auto* params) { return RocBlasStridedBatchedGemmOp<T>(params); });
|
||||
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
for (auto&& [type_string, op] : GetRocBlasStridedBatchedGemmTypeStringAndOps<T>()) {
|
||||
type_strings_.emplace_back(std::move(type_string));
|
||||
ops_.emplace_back(std::move(op));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
~RocBlasStridedBatchedGemm() {
|
||||
|
|
@ -176,15 +229,23 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
|
|||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR(op_(¶ms_));
|
||||
ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_));
|
||||
}
|
||||
|
||||
std::vector<std::string> ListOps() const {
|
||||
return {"Rocblas"};
|
||||
return type_strings_;
|
||||
}
|
||||
|
||||
bool SelectOp(const std::string& name) {
|
||||
return name == "Rocblas";
|
||||
for (size_t i = 0; i < ops_.size(); i++) {
|
||||
if (type_strings_[i] == name) {
|
||||
selected_op_ = i;
|
||||
Status status = ops_[i](¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
}
|
||||
|
||||
ORT_THROW("Cannot find implementation ", name);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -194,7 +255,9 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer {
|
|||
using OpT = Op<ParamsT>;
|
||||
|
||||
ParamsT params_{};
|
||||
OpT op_{RocBlasStridedBatchedGemmOp<T>};
|
||||
std::vector<OpT> ops_;
|
||||
std::vector<std::string> type_strings_;
|
||||
size_t selected_op_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP_COMMON(type, dtype) \
|
||||
|
|
|
|||
|
|
@ -73,6 +73,9 @@ def _test_strided_batched_gemm(
|
|||
if not my_gemm.SelectOp(impl):
|
||||
continue
|
||||
|
||||
# Restore C Array
|
||||
my_c.fill(1.0)
|
||||
dev_c.UpdateDeviceArray()
|
||||
my_gemm.Run()
|
||||
dev_c.UpdateHostNumpyArray()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue