Fix TunableOp signature generation (#14709)

Currently, all generated op sigs are `TunableOp<ParamsT, TimerT>` which
is causing signature collision when using rocblas solution api. This was
not a problem before `TuningContext`, because the `KernelMap`s were not
shared.

The root cause is the signature is initialized on Op consturction,
specificially, only in base class ctor, which is causing the type info
only caputre base class type info. That is, only the ParamsT + and base
class. After this change, the we will encode the derived class type in
the op sig.
This commit is contained in:
cloudhan 2023-02-25 12:05:07 +08:00 committed by GitHub
parent bc5d0c83d1
commit 3bcdb0a83a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 11 deletions

View file

@ -130,6 +130,7 @@ class TunableOp {
public:
TunableOp() = default;
TunableOp(TunableOp&&) = default;
virtual ~TunableOp() = default;
Status operator()(const ParamsT* params) {
int id = default_id_;
@ -168,7 +169,15 @@ class TunableOp {
// Do nothing if we are not playing around with params
}
virtual ~TunableOp() = default;
std::string Signature() {
// According to C++17 standard https://wg21.link/n4659 section 15.7.4
// > if the operand of typeid refers to the
// > object under construction or destruction, typeid yields the std::type_info object representing the constructor
// > or destructors class.
// So delay the op signature generation. See https://github.com/microsoft/onnxruntime/pull/14709
std::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
return signature_;
}
protected:
// set the default op to be used in non-tuning scenario
@ -190,10 +199,6 @@ class TunableOp {
});
}
std::string Signature() const {
return signature_;
}
private:
static void WarmUp(Op<ParamsT>& op, const ParamsT* param) {
constexpr const int num_iter = 4;
@ -272,7 +277,8 @@ class TunableOp {
#endif
}
std::string signature_{CreateSignature()};
mutable std::once_flag signature_init_once_;
std::string signature_;
// the default impl to use when tuning is disabled
int default_id_{0};

View file

@ -370,11 +370,6 @@ class TunableVecAddSelectFast : public TunableOp<VecAddParamsRecordLastRun> {
this->RegisterOp(FastFull);
}
// Re export for testing purpose
std::string Signature() {
return onnxruntime::test::TunableOp<VecAddParamsRecordLastRun>::Signature();
}
constexpr static int kSlowFullId = 0;
constexpr static int kFastFullId = 1;
};
@ -579,6 +574,43 @@ TEST(TunableOp, HandleInplaceUpdate) {
#endif
}
TEST(TunableOp, OpSignatureMustNotChange) {
#ifdef ORT_NO_RTTI
GTEST_SKIP() << "TunableOp needs RTTI to work correctly";
#else
std::vector<std::string> signatures1;
std::vector<std::string> signatures2;
signatures1.emplace_back(TunableVecAddSelectFast{}.Signature());
signatures1.emplace_back(TunableVecAddSelectSupported{}.Signature());
signatures1.emplace_back(TunableVecAddSelectFastestIfSupported{}.Signature());
signatures1.emplace_back(TunableVecAddNotHandleInplaceUpdate{}.Signature());
signatures1.emplace_back(TunableVecAddHandleInplaceUpdate{}.Signature());
signatures2.emplace_back(TunableVecAddSelectFast{}.Signature());
signatures2.emplace_back(TunableVecAddSelectSupported{}.Signature());
signatures2.emplace_back(TunableVecAddSelectFastestIfSupported{}.Signature());
signatures2.emplace_back(TunableVecAddNotHandleInplaceUpdate{}.Signature());
signatures2.emplace_back(TunableVecAddHandleInplaceUpdate{}.Signature());
ASSERT_EQ(signatures1, signatures2);
#endif
}
TEST(TunableOp, OpSignatureMustNotCollide) {
#ifdef ORT_NO_RTTI
GTEST_SKIP() << "TunableOp needs RTTI to work correctly";
#else
std::unordered_set<std::string> signatures;
signatures.insert(TunableVecAddSelectFast{}.Signature());
signatures.insert(TunableVecAddSelectSupported{}.Signature());
signatures.insert(TunableVecAddSelectFastestIfSupported{}.Signature());
signatures.insert(TunableVecAddNotHandleInplaceUpdate{}.Signature());
signatures.insert(TunableVecAddHandleInplaceUpdate{}.Signature());
ASSERT_THAT(signatures, ::testing::SizeIs(5));
#endif
}
} // namespace tuning
namespace tuning_context {