mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Add IsSupported support to Op functor (#13692)
Sometime it is a bit risky to call the Op directly to check whether the impl supports consuming the param. This gives the user a way to actually implement `IsSupported` for checking in non-compact way.
This commit is contained in:
parent
4a2a857030
commit
8de5381e84
2 changed files with 79 additions and 2 deletions
|
|
@ -52,6 +52,17 @@ class Timer {
|
|||
StreamT stream_;
|
||||
};
|
||||
|
||||
template <typename T, typename Arg, typename E = void>
|
||||
struct HasIsSupportedMethod {
|
||||
constexpr static bool value = false;
|
||||
};
|
||||
|
||||
template <typename T, typename Arg>
|
||||
struct HasIsSupportedMethod<
|
||||
T, Arg, std::enable_if_t<std::is_same_v<decltype(std::declval<T>().IsSupported(std::declval<Arg>())), Status>>> {
|
||||
constexpr static bool value = true;
|
||||
};
|
||||
|
||||
// A type erased Callable wrapper. We could have used std::function<Status<const ParamsT*>> here. However, std::function
|
||||
// requires the callable object to be CopyConstructible and CopyAssignable. This is not suitable for move only functor
|
||||
// or move captured lambda. So we create a simple wrapper for our purpose here.
|
||||
|
|
@ -64,11 +75,13 @@ class Op {
|
|||
template <typename T>
|
||||
explicit Op(T&& c) : callable_{std::make_unique<CallableImpl<T>>(std::forward<T>(c))} {}
|
||||
Status operator()(const ParamsT* param) { return (*callable_)(param); }
|
||||
Status IsSupported(const ParamsT* param) { return (*callable_).IsSupported(param); }
|
||||
|
||||
private:
|
||||
struct ICallable {
|
||||
virtual ~ICallable() = default;
|
||||
virtual Status operator()(const ParamsT*) = 0;
|
||||
virtual Status IsSupported(const ParamsT*) = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -76,6 +89,14 @@ class Op {
|
|||
explicit CallableImpl(T&& c) : c_{std::move(c)} {}
|
||||
Status operator()(const ParamsT* param) override { return c_(param); }
|
||||
|
||||
Status IsSupported(const ParamsT* param) override {
|
||||
if constexpr (HasIsSupportedMethod<T, const ParamsT*>::value) {
|
||||
return c_.IsSupported(param);
|
||||
} else {
|
||||
return c_(param);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
T c_;
|
||||
};
|
||||
|
|
@ -86,7 +107,7 @@ class Op {
|
|||
// NOTE: onnxruntime's Status currently does not have a StatusCode::UNSUPPORTED. Currently, we do not want to extend the
|
||||
// enum. So we reuse StatusCode::INVALID_ARGUMENT for this purpose. It can be interpreted as "The input argument is not
|
||||
// valid for this specialized kernel implementation.". This semantic is crucial for the tuning mechanism.
|
||||
#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \
|
||||
#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \
|
||||
do { \
|
||||
if (condition) { \
|
||||
return ORT_MAKE_STATUS(NONE, INVALID_ARGUMENT, __VA_ARGS__); \
|
||||
|
|
@ -163,7 +184,7 @@ class TunableOp {
|
|||
}
|
||||
|
||||
static bool IsSupported(Op<ParamsT>& op, const ParamsT* param) {
|
||||
Status status = op(param);
|
||||
Status status = op.IsSupported(param);
|
||||
if (status.Category() == common::StatusCategory::NONE && status.Code() == common::StatusCode::INVALID_ARGUMENT) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ class VecAddMoveOnlyFunctor {
|
|||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddMoveOnlyFunctor);
|
||||
|
||||
Status operator()(const VecAddParams* params) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->c == nullptr, "output buffer cannot be nullptr");
|
||||
LaunchVecAddKernel(params->a, params->b, params->c, params->num_elem, params->beta);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -205,6 +206,61 @@ TEST(TunableOp, OpWrapsMoveOnlyFunctor) {
|
|||
ASSERT_EQ(c, 7500042);
|
||||
}
|
||||
|
||||
|
||||
class VecAddWithIsSupportedMethod {
|
||||
public:
|
||||
VecAddWithIsSupportedMethod(VecAddWithIsSupportedMethod&&) = default;
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddWithIsSupportedMethod);
|
||||
|
||||
Status operator()(const VecAddParams* params) {
|
||||
LaunchVecAddKernel(params->a, params->b, params->c, params->num_elem, params->beta);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IsSupported(const VecAddParams* params) {
|
||||
// Purely for testing purpose. In real world, this methods must be crafted with excessive carefulness.
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->num_elem != 4, "only support num_elem == 4");
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(TunableOp, OpWrapsFunctorWithExtendedIsSupported) {
|
||||
constexpr const int a[] = {0, 1, 2, 3};
|
||||
constexpr const int b[] = {42, 42, 42, 42};
|
||||
int c[4] = {};
|
||||
|
||||
Status status;
|
||||
|
||||
// Test Op::IsSupported will have correct fallback if user does not implement it in its functor.
|
||||
{
|
||||
tunable::Op<VecAddParams> vec_add(VecAddMoveOnlyFunctor{});
|
||||
VecAddParams params(a, b, nullptr, 1, 0);
|
||||
status = vec_add.IsSupported(¶ms);
|
||||
ASSERT_EQ(status.Category(), common::StatusCategory::NONE);
|
||||
ASSERT_EQ(status.Code(), common::StatusCode::INVALID_ARGUMENT);
|
||||
ASSERT_THAT(status.ErrorMessage(), testing::HasSubstr("output buffer cannot be nullptr"));
|
||||
|
||||
params.c = c;
|
||||
status = vec_add.IsSupported(¶ms);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
}
|
||||
|
||||
// Test Op::IsSupported will use user provided one if they implemented it.
|
||||
{
|
||||
tunable::Op<VecAddParams> vec_add(VecAddWithIsSupportedMethod{});
|
||||
|
||||
VecAddParams params(a, b, c, 4, 0);
|
||||
status = vec_add.IsSupported(¶ms);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
params.num_elem = 1;
|
||||
status = vec_add.IsSupported(¶ms);
|
||||
ASSERT_EQ(status.Category(), common::StatusCategory::NONE);
|
||||
ASSERT_EQ(status.Code(), common::StatusCode::INVALID_ARGUMENT);
|
||||
ASSERT_THAT(status.ErrorMessage(), testing::HasSubstr("only support num_elem == 4"));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
namespace tuning {
|
||||
|
|
|
|||
Loading…
Reference in a new issue