mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Optimize is_symbolic test and some refactor (#86230)
Our SymInt rep can be represented more efficiently as just a greater than test, but the compiler doesn't seem to figure it out. Help it out. There is also some refactoring to simplify the code and add more debugging. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/86230 Approved by: https://github.com/albanD
This commit is contained in:
parent
8c6d352bcf
commit
3b1ec7511e
2 changed files with 33 additions and 56 deletions
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
namespace c10 {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
static std::array<SymIntNode, 2> normalize_symints(SymInt a_, SymInt b_) {
|
||||
SymIntNode a, b;
|
||||
if (a_.is_symbolic())
|
||||
|
|
@ -37,21 +36,6 @@ c10::SymInt SymInt::toSymInt(SymIntNode sin_sp) {
|
|||
auto rep = (ptr & ~MASK) | IS_SYM;
|
||||
return c10::SymInt(UNCHECKED, static_cast<int64_t>(rep));
|
||||
}
|
||||
#else
|
||||
// this code should never be executed on mobile due to inlining of `is_symbolic`
|
||||
// which always returns `false` on mobile.
|
||||
// However, if we decide to strip off `SymIntNode` completely from mobile builds
|
||||
// We would need to stub these methods anyways
|
||||
c10::SymInt SymInt::toSymInt(SymIntNode sin_sp) {
|
||||
TORCH_INTERNAL_ASSERT(false, "SymInts aren't available on mobile");
|
||||
}
|
||||
SymIntNode SymInt::toSymIntNodeImpl() const {
|
||||
TORCH_INTERNAL_ASSERT(false, "SymInts aren't available on mobile");
|
||||
}
|
||||
static std::array<SymIntNode, 2> normalize_symints(SymInt a_, SymInt b_) {
|
||||
TORCH_INTERNAL_ASSERT(false, "SymInts aren't available on mobile");
|
||||
}
|
||||
#endif
|
||||
|
||||
int64_t SymInt::guard_int(const char* file, int64_t line) const {
|
||||
if (!is_symbolic()) {
|
||||
|
|
|
|||
|
|
@ -31,14 +31,6 @@ class SymFloat;
|
|||
// SymIntNodeImpl*] which will be implemented as a single packed int64_t field
|
||||
// named data_.
|
||||
|
||||
#ifdef C10_MOBILE
|
||||
#define SKIP_IS_SYMBOLIC_ON_MOBILE(_) \
|
||||
do { \
|
||||
} while (0)
|
||||
#else
|
||||
#define SKIP_IS_SYMBOLIC_ON_MOBILE(X) TORCH_CHECK(X)
|
||||
#endif
|
||||
|
||||
class C10_API SymInt {
|
||||
public:
|
||||
enum Unchecked {
|
||||
|
|
@ -46,7 +38,10 @@ class C10_API SymInt {
|
|||
};
|
||||
|
||||
/*implicit*/ SymInt(int64_t d) : data_(d) {
|
||||
SKIP_IS_SYMBOLIC_ON_MOBILE(!is_symbolic());
|
||||
// NB: this relies on exception in constructor inhibiting
|
||||
// destructor; otherwise we would attempt to deallocate
|
||||
// the garbage data!
|
||||
TORCH_CHECK(!is_symbolic());
|
||||
};
|
||||
SymInt() : data_(0) {}
|
||||
|
||||
|
|
@ -90,18 +85,14 @@ class C10_API SymInt {
|
|||
}
|
||||
|
||||
SymInt clone() const {
|
||||
#ifndef C10_MOBILE
|
||||
if (is_symbolic()) {
|
||||
return toSymIntNodeImplUnowned()->clone()->toSymInt();
|
||||
}
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT(!is_symbolic());
|
||||
#endif
|
||||
return *this;
|
||||
}
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
SymIntNodeImpl* toSymIntNodeImplUnowned() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic());
|
||||
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
|
||||
uint64_t sign_bit_mask = 1ULL << (62 - 1);
|
||||
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
|
||||
|
|
@ -117,18 +108,15 @@ class C10_API SymInt {
|
|||
}
|
||||
|
||||
SymIntNodeImpl* release() && {
|
||||
#ifndef C10_MOBILE
|
||||
TORCH_INTERNAL_ASSERT(is_symbolic());
|
||||
auto* r = toSymIntNodeImplUnowned();
|
||||
data_ = 0; // transfer ownership
|
||||
return r;
|
||||
}
|
||||
#else
|
||||
void release_() {}
|
||||
|
||||
SymIntNodeImpl* release() && {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
SymIntNode toSymIntNodeImpl() const;
|
||||
static c10::SymInt toSymInt(SymIntNode sin);
|
||||
|
|
@ -142,7 +130,7 @@ class C10_API SymInt {
|
|||
// shapes, and you don't have time to fix it immediately, as if we
|
||||
// try to trigger the path in C++ you'll appropriately get an error
|
||||
int64_t expect_int() const {
|
||||
SKIP_IS_SYMBOLIC_ON_MOBILE(!is_symbolic());
|
||||
TORCH_CHECK(!is_symbolic());
|
||||
return data_;
|
||||
}
|
||||
|
||||
|
|
@ -159,12 +147,12 @@ class C10_API SymInt {
|
|||
|
||||
// N.B. It's important to keep this definition in the header
|
||||
// as we expect if checks to be folded for mobile builds
|
||||
// where `is_symbolic` is always false
|
||||
// where `is_symbolic` is always false and optimize dead code paths
|
||||
C10_ALWAYS_INLINE bool is_symbolic() const {
|
||||
#ifdef C10_MOBILE
|
||||
return false;
|
||||
#else
|
||||
return (MASK & static_cast<uint64_t>(this->data_)) == IS_SYM;
|
||||
return !check_range(data_);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -193,41 +181,46 @@ class C10_API SymInt {
|
|||
operator SymFloat() const;
|
||||
|
||||
int64_t as_int_unchecked() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_symbolic());
|
||||
return data_;
|
||||
}
|
||||
|
||||
// Return whether the integer is representable as a SymInt.
|
||||
static bool check_range(int64_t i) {
|
||||
return i > MIN_INT;
|
||||
return i > MAX_UNREPRESENTABLE_INT;
|
||||
}
|
||||
|
||||
private:
|
||||
// Constraints on the internal representation:
|
||||
// - Should represent positive and small negative ints
|
||||
// - No conversion necessary for operations on ints.
|
||||
// - Must represent valid 64-bit pointers
|
||||
//
|
||||
// So, the scheme is to reserve large negative numbers:
|
||||
// - 0b0.... means we are a positive int (following two's complement)
|
||||
// - 0b11... means we are a negative int (following two's complement)
|
||||
// - Should represent positive and small negative ints
|
||||
// - No conversion necessary for operations on ints
|
||||
// - Must represent valid 64-bit pointers
|
||||
// - Is symbolic test should be FAST (two arithmetic instructions is too
|
||||
// much).
|
||||
// This code being a hotpath is based on Strobelight profiles of
|
||||
// is_symbolic(). FB only: https://fburl.com/strobelight/5l50ncxd
|
||||
// (you will need to change the time window).
|
||||
//
|
||||
// So, the scheme is to reserve large negative numbers (asssuming
|
||||
// two's complement):
|
||||
//
|
||||
// - 0b0.... means we are a positive int
|
||||
// - 0b11... means we are a small negative int
|
||||
// - 0b10... means we are are a pointer. This means that
|
||||
// [-2^63, -2^62-1] are not representable as ints.
|
||||
// We don't actually need all of this space as on x86_64
|
||||
// as the top 16bits aren't used for anything
|
||||
static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62;
|
||||
static constexpr uint64_t IS_SYM = 1ULL << 63;
|
||||
// Since we use the top two bits to determine whether something is symbolic,
|
||||
// we cannot represent symbolic indices that are large enough to use those
|
||||
// bits. This will probably never happen.
|
||||
static constexpr uint64_t MAX_SYM_IDX = 1ULL << 62;
|
||||
// Since 0b10... is reserved for symbolic indices, any integers lower than
|
||||
// this value would collide with our representation.
|
||||
static constexpr int64_t MIN_INT = -1LL & static_cast<int64_t>(~(1ULL << 62));
|
||||
static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61;
|
||||
static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61;
|
||||
// We must manually translate the bit pattern test into a greater
|
||||
// than test because compiler doesn't figure it out:
|
||||
// https://godbolt.org/z/356aferaW
|
||||
static constexpr int64_t MAX_UNREPRESENTABLE_INT =
|
||||
-1LL & static_cast<int64_t>(~(1ULL << 62));
|
||||
int64_t data_;
|
||||
};
|
||||
|
||||
#undef SKIP_IS_SYMBOLIC_ON_MOBILE
|
||||
|
||||
/// Sum of a list of SymInt; accumulates into the c10::SymInt expression
|
||||
template <
|
||||
typename C,
|
||||
|
|
|
|||
Loading…
Reference in a new issue