mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Fix TunableOp signature generation (#14709)
Currently, all generated op sigs are `TunableOp<ParamsT, TimerT>` which is causing signature collision when using rocblas solution api. This was not a problem before `TuningContext`, because the `KernelMap`s were not shared. The root cause is the signature is initialized on Op consturction, specificially, only in base class ctor, which is causing the type info only caputre base class type info. That is, only the ParamsT + and base class. After this change, the we will encode the derived class type in the op sig.
This commit is contained in:
parent
bc5d0c83d1
commit
3bcdb0a83a
2 changed files with 49 additions and 11 deletions
|
|
@ -130,6 +130,7 @@ class TunableOp {
|
|||
public:
|
||||
TunableOp() = default;
|
||||
TunableOp(TunableOp&&) = default;
|
||||
virtual ~TunableOp() = default;
|
||||
|
||||
Status operator()(const ParamsT* params) {
|
||||
int id = default_id_;
|
||||
|
|
@ -168,7 +169,15 @@ class TunableOp {
|
|||
// Do nothing if we are not playing around with params
|
||||
}
|
||||
|
||||
virtual ~TunableOp() = default;
|
||||
std::string Signature() {
|
||||
// According to C++17 standard https://wg21.link/n4659 section 15.7.4
|
||||
// > if the operand of typeid refers to the
|
||||
// > object under construction or destruction, typeid yields the std::type_info object representing the constructor
|
||||
// > or destructor’s class.
|
||||
// So delay the op signature generation. See https://github.com/microsoft/onnxruntime/pull/14709
|
||||
std::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
|
||||
return signature_;
|
||||
}
|
||||
|
||||
protected:
|
||||
// set the default op to be used in non-tuning scenario
|
||||
|
|
@ -190,10 +199,6 @@ class TunableOp {
|
|||
});
|
||||
}
|
||||
|
||||
std::string Signature() const {
|
||||
return signature_;
|
||||
}
|
||||
|
||||
private:
|
||||
static void WarmUp(Op<ParamsT>& op, const ParamsT* param) {
|
||||
constexpr const int num_iter = 4;
|
||||
|
|
@ -272,7 +277,8 @@ class TunableOp {
|
|||
#endif
|
||||
}
|
||||
|
||||
std::string signature_{CreateSignature()};
|
||||
mutable std::once_flag signature_init_once_;
|
||||
std::string signature_;
|
||||
|
||||
// the default impl to use when tuning is disabled
|
||||
int default_id_{0};
|
||||
|
|
|
|||
|
|
@ -370,11 +370,6 @@ class TunableVecAddSelectFast : public TunableOp<VecAddParamsRecordLastRun> {
|
|||
this->RegisterOp(FastFull);
|
||||
}
|
||||
|
||||
// Re export for testing purpose
|
||||
std::string Signature() {
|
||||
return onnxruntime::test::TunableOp<VecAddParamsRecordLastRun>::Signature();
|
||||
}
|
||||
|
||||
constexpr static int kSlowFullId = 0;
|
||||
constexpr static int kFastFullId = 1;
|
||||
};
|
||||
|
|
@ -579,6 +574,43 @@ TEST(TunableOp, HandleInplaceUpdate) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(TunableOp, OpSignatureMustNotChange) {
|
||||
#ifdef ORT_NO_RTTI
|
||||
GTEST_SKIP() << "TunableOp needs RTTI to work correctly";
|
||||
#else
|
||||
std::vector<std::string> signatures1;
|
||||
std::vector<std::string> signatures2;
|
||||
signatures1.emplace_back(TunableVecAddSelectFast{}.Signature());
|
||||
signatures1.emplace_back(TunableVecAddSelectSupported{}.Signature());
|
||||
signatures1.emplace_back(TunableVecAddSelectFastestIfSupported{}.Signature());
|
||||
signatures1.emplace_back(TunableVecAddNotHandleInplaceUpdate{}.Signature());
|
||||
signatures1.emplace_back(TunableVecAddHandleInplaceUpdate{}.Signature());
|
||||
|
||||
signatures2.emplace_back(TunableVecAddSelectFast{}.Signature());
|
||||
signatures2.emplace_back(TunableVecAddSelectSupported{}.Signature());
|
||||
signatures2.emplace_back(TunableVecAddSelectFastestIfSupported{}.Signature());
|
||||
signatures2.emplace_back(TunableVecAddNotHandleInplaceUpdate{}.Signature());
|
||||
signatures2.emplace_back(TunableVecAddHandleInplaceUpdate{}.Signature());
|
||||
|
||||
ASSERT_EQ(signatures1, signatures2);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TunableOp, OpSignatureMustNotCollide) {
|
||||
#ifdef ORT_NO_RTTI
|
||||
GTEST_SKIP() << "TunableOp needs RTTI to work correctly";
|
||||
#else
|
||||
std::unordered_set<std::string> signatures;
|
||||
signatures.insert(TunableVecAddSelectFast{}.Signature());
|
||||
signatures.insert(TunableVecAddSelectSupported{}.Signature());
|
||||
signatures.insert(TunableVecAddSelectFastestIfSupported{}.Signature());
|
||||
signatures.insert(TunableVecAddNotHandleInplaceUpdate{}.Signature());
|
||||
signatures.insert(TunableVecAddHandleInplaceUpdate{}.Signature());
|
||||
|
||||
ASSERT_THAT(signatures, ::testing::SizeIs(5));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tuning
|
||||
|
||||
namespace tuning_context {
|
||||
|
|
|
|||
Loading…
Reference in a new issue