Add IsSupported support to Op functor (#13692)

Sometime it is a bit risky to call the Op directly to check whether the
impl supports consuming the param. This gives the user a way to actually
implement `IsSupported` for checking in non-compact way.
This commit is contained in:
cloudhan 2022-11-21 19:22:00 +08:00 committed by GitHub
parent 4a2a857030
commit 8de5381e84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 2 deletions

View file

@ -52,6 +52,17 @@ class Timer {
StreamT stream_;
};
template <typename T, typename Arg, typename E = void>
struct HasIsSupportedMethod {
constexpr static bool value = false;
};
template <typename T, typename Arg>
struct HasIsSupportedMethod<
T, Arg, std::enable_if_t<std::is_same_v<decltype(std::declval<T>().IsSupported(std::declval<Arg>())), Status>>> {
constexpr static bool value = true;
};
// A type erased Callable wrapper. We could have used std::function<Status<const ParamsT*>> here. However, std::function
// requires the callable object to be CopyConstructible and CopyAssignable. This is not suitable for move only functor
// or move captured lambda. So we create a simple wrapper for our purpose here.
@ -64,11 +75,13 @@ class Op {
template <typename T>
explicit Op(T&& c) : callable_{std::make_unique<CallableImpl<T>>(std::forward<T>(c))} {}
Status operator()(const ParamsT* param) { return (*callable_)(param); }
Status IsSupported(const ParamsT* param) { return (*callable_).IsSupported(param); }
private:
struct ICallable {
virtual ~ICallable() = default;
virtual Status operator()(const ParamsT*) = 0;
virtual Status IsSupported(const ParamsT*) = 0;
};
template <typename T>
@ -76,6 +89,14 @@ class Op {
explicit CallableImpl(T&& c) : c_{std::move(c)} {}
Status operator()(const ParamsT* param) override { return c_(param); }
Status IsSupported(const ParamsT* param) override {
if constexpr (HasIsSupportedMethod<T, const ParamsT*>::value) {
return c_.IsSupported(param);
} else {
return c_(param);
}
}
private:
T c_;
};
@ -86,7 +107,7 @@ class Op {
// NOTE: onnxruntime's Status currently does not have a StatusCode::UNSUPPORTED. Currently, we do not want to extend the
// enum. So we reuse StatusCode::INVALID_ARGUMENT for this purpose. It can be interpreted as "The input argument is not
// valid for this specialized kernel implementation.". This semantic is crucial for the tuning mechanism.
#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \
#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \
do { \
if (condition) { \
return ORT_MAKE_STATUS(NONE, INVALID_ARGUMENT, __VA_ARGS__); \
@ -163,7 +184,7 @@ class TunableOp {
}
static bool IsSupported(Op<ParamsT>& op, const ParamsT* param) {
Status status = op(param);
Status status = op.IsSupported(param);
if (status.Category() == common::StatusCategory::NONE && status.Code() == common::StatusCode::INVALID_ARGUMENT) {
return false;
}

View file

@ -187,6 +187,7 @@ class VecAddMoveOnlyFunctor {
ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddMoveOnlyFunctor);
Status operator()(const VecAddParams* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->c == nullptr, "output buffer cannot be nullptr");
LaunchVecAddKernel(params->a, params->b, params->c, params->num_elem, params->beta);
return Status::OK();
}
@ -205,6 +206,61 @@ TEST(TunableOp, OpWrapsMoveOnlyFunctor) {
ASSERT_EQ(c, 7500042);
}
class VecAddWithIsSupportedMethod {
public:
VecAddWithIsSupportedMethod(VecAddWithIsSupportedMethod&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddWithIsSupportedMethod);
Status operator()(const VecAddParams* params) {
LaunchVecAddKernel(params->a, params->b, params->c, params->num_elem, params->beta);
return Status::OK();
}
Status IsSupported(const VecAddParams* params) {
// Purely for testing purpose. In real world, this methods must be crafted with excessive carefulness.
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->num_elem != 4, "only support num_elem == 4");
return Status::OK();
}
};
TEST(TunableOp, OpWrapsFunctorWithExtendedIsSupported) {
constexpr const int a[] = {0, 1, 2, 3};
constexpr const int b[] = {42, 42, 42, 42};
int c[4] = {};
Status status;
// Test Op::IsSupported will have correct fallback if user does not implement it in its functor.
{
tunable::Op<VecAddParams> vec_add(VecAddMoveOnlyFunctor{});
VecAddParams params(a, b, nullptr, 1, 0);
status = vec_add.IsSupported(&params);
ASSERT_EQ(status.Category(), common::StatusCategory::NONE);
ASSERT_EQ(status.Code(), common::StatusCode::INVALID_ARGUMENT);
ASSERT_THAT(status.ErrorMessage(), testing::HasSubstr("output buffer cannot be nullptr"));
params.c = c;
status = vec_add.IsSupported(&params);
ASSERT_TRUE(status.IsOK());
}
// Test Op::IsSupported will use user provided one if they implemented it.
{
tunable::Op<VecAddParams> vec_add(VecAddWithIsSupportedMethod{});
VecAddParams params(a, b, c, 4, 0);
status = vec_add.IsSupported(&params);
ASSERT_TRUE(status.IsOK());
params.num_elem = 1;
status = vec_add.IsSupported(&params);
ASSERT_EQ(status.Category(), common::StatusCategory::NONE);
ASSERT_EQ(status.Code(), common::StatusCode::INVALID_ARGUMENT);
ASSERT_THAT(status.ErrorMessage(), testing::HasSubstr("only support num_elem == 4"));
}
}
} // namespace wrapper
namespace tuning {