diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 21ab195b61..883135b2b6 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -52,6 +52,17 @@ class Timer { StreamT stream_; }; +template +struct HasIsSupportedMethod { + constexpr static bool value = false; +}; + +template +struct HasIsSupportedMethod< + T, Arg, std::enable_if_t().IsSupported(std::declval())), Status>>> { + constexpr static bool value = true; +}; + // A type erased Callable wrapper. We could have used std::function> here. However, std::function // requires the callable object to be CopyConstructible and CopyAssignable. This is not suitable for move only functor // or move captured lambda. So we create a simple wrapper for our purpose here. @@ -64,11 +75,13 @@ class Op { template explicit Op(T&& c) : callable_{std::make_unique>(std::forward(c))} {} Status operator()(const ParamsT* param) { return (*callable_)(param); } + Status IsSupported(const ParamsT* param) { return (*callable_).IsSupported(param); } private: struct ICallable { virtual ~ICallable() = default; virtual Status operator()(const ParamsT*) = 0; + virtual Status IsSupported(const ParamsT*) = 0; }; template @@ -76,6 +89,14 @@ class Op { explicit CallableImpl(T&& c) : c_{std::move(c)} {} Status operator()(const ParamsT* param) override { return c_(param); } + Status IsSupported(const ParamsT* param) override { + if constexpr (HasIsSupportedMethod::value) { + return c_.IsSupported(param); + } else { + return c_(param); + } + } + private: T c_; }; @@ -86,7 +107,7 @@ class Op { // NOTE: onnxruntime's Status currently does not have a StatusCode::UNSUPPORTED. Currently, we do not want to extend the // enum. So we reuse StatusCode::INVALID_ARGUMENT for this purpose. It can be interpreted as "The input argument is not // valid for this specialized kernel implementation.". This semantic is crucial for the tuning mechanism. -#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \ +#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \ do { \ if (condition) { \ return ORT_MAKE_STATUS(NONE, INVALID_ARGUMENT, __VA_ARGS__); \ @@ -163,7 +184,7 @@ class TunableOp { } static bool IsSupported(Op& op, const ParamsT* param) { - Status status = op(param); + Status status = op.IsSupported(param); if (status.Category() == common::StatusCategory::NONE && status.Code() == common::StatusCode::INVALID_ARGUMENT) { return false; } diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 8869e4ffbd..128a1b3965 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -187,6 +187,7 @@ class VecAddMoveOnlyFunctor { ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddMoveOnlyFunctor); Status operator()(const VecAddParams* params) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->c == nullptr, "output buffer cannot be nullptr"); LaunchVecAddKernel(params->a, params->b, params->c, params->num_elem, params->beta); return Status::OK(); } @@ -205,6 +206,61 @@ TEST(TunableOp, OpWrapsMoveOnlyFunctor) { ASSERT_EQ(c, 7500042); } + +class VecAddWithIsSupportedMethod { + public: + VecAddWithIsSupportedMethod(VecAddWithIsSupportedMethod&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddWithIsSupportedMethod); + + Status operator()(const VecAddParams* params) { + LaunchVecAddKernel(params->a, params->b, params->c, params->num_elem, params->beta); + return Status::OK(); + } + + Status IsSupported(const VecAddParams* params) { + // Purely for testing purpose. In real world, this methods must be crafted with excessive carefulness. + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->num_elem != 4, "only support num_elem == 4"); + return Status::OK(); + } +}; + +TEST(TunableOp, OpWrapsFunctorWithExtendedIsSupported) { + constexpr const int a[] = {0, 1, 2, 3}; + constexpr const int b[] = {42, 42, 42, 42}; + int c[4] = {}; + + Status status; + + // Test Op::IsSupported will have correct fallback if user does not implement it in its functor. + { + tunable::Op vec_add(VecAddMoveOnlyFunctor{}); + VecAddParams params(a, b, nullptr, 1, 0); + status = vec_add.IsSupported(¶ms); + ASSERT_EQ(status.Category(), common::StatusCategory::NONE); + ASSERT_EQ(status.Code(), common::StatusCode::INVALID_ARGUMENT); + ASSERT_THAT(status.ErrorMessage(), testing::HasSubstr("output buffer cannot be nullptr")); + + params.c = c; + status = vec_add.IsSupported(¶ms); + ASSERT_TRUE(status.IsOK()); + } + + // Test Op::IsSupported will use user provided one if they implemented it. + { + tunable::Op vec_add(VecAddWithIsSupportedMethod{}); + + VecAddParams params(a, b, c, 4, 0); + status = vec_add.IsSupported(¶ms); + ASSERT_TRUE(status.IsOK()); + + params.num_elem = 1; + status = vec_add.IsSupported(¶ms); + ASSERT_EQ(status.Category(), common::StatusCategory::NONE); + ASSERT_EQ(status.Code(), common::StatusCode::INVALID_ARGUMENT); + ASSERT_THAT(status.ErrorMessage(), testing::HasSubstr("only support num_elem == 4")); + } +} + } // namespace wrapper namespace tuning {