mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Replace expect_int with guard_int (#113921)
The idea is that instead of erroring, we will just specialize at these sites. Fixes https://github.com/pytorch/pytorch/issues/113142 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/113921 Approved by: https://github.com/zou3519
This commit is contained in:
parent
59ad51e10a
commit
8c4812be80
10 changed files with 16 additions and 15 deletions
|
|
@ -59,7 +59,7 @@ inline typename remove_symint<T>::type unpackSymInt(T x) { return x; }
|
|||
|
||||
template <>
|
||||
inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
|
||||
return x.expect_int();
|
||||
return x.guard_int(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
@ -69,7 +69,7 @@ inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIn
|
|||
|
||||
template <>
|
||||
inline typename remove_symint<c10::optional<c10::SymInt>>::type unpackSymInt(c10::optional<c10::SymInt> x) {
|
||||
return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt;
|
||||
return x.has_value() ? c10::make_optional(x->guard_int(__FILE__, __LINE__)) : c10::nullopt;
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ template <>
|
|||
inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
|
||||
|
||||
template <>
|
||||
inline int64_t maybe_convert_symint(c10::SymInt x) { return x.expect_int(); }
|
||||
inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
|
||||
|
||||
template <typename T>
|
||||
static inline void checkInBoundsForStorage(
|
||||
|
|
|
|||
|
|
@ -214,8 +214,8 @@ Tensor narrow_nested_symint(const at::Tensor& self, int64_t dim, SymInt start, S
|
|||
auto storage_offsets = nt_impl->get_storage_offsets();
|
||||
auto storage_offsets_ptr = storage_offsets.data_ptr<int64_t>();
|
||||
|
||||
auto start_int = start.expect_int();
|
||||
auto length_int = length.expect_int();
|
||||
auto start_int = start.guard_int(__FILE__, __LINE__);
|
||||
auto length_int = length.guard_int(__FILE__, __LINE__);
|
||||
auto buffer_offset = storage_offsets_ptr[start_int];
|
||||
|
||||
nested_sizes = nested_sizes.narrow(0, start_int, length_int);
|
||||
|
|
|
|||
|
|
@ -414,7 +414,7 @@ int64_t TensorImpl::storage_offset_custom() const {
|
|||
// TODO: fix this
|
||||
return pyobj_slot_.load_pyobj_interpreter()
|
||||
->sym_storage_offset(this)
|
||||
.expect_int();
|
||||
.guard_int(__FILE__, __LINE__);
|
||||
}
|
||||
return storage_offset_default();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1737,7 +1737,6 @@ symbolic_tensor_failures = {
|
|||
xfail('renorm', ''),
|
||||
|
||||
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
|
||||
xfail('randint_like', ''), # when unpacking SymInt, expected int but got s0
|
||||
|
||||
# many complex operators incorrect striding, metadata
|
||||
xfail('fft.fft', ''),
|
||||
|
|
|
|||
|
|
@ -2106,7 +2106,7 @@ Tensor _nested_split_with_sizes_backward(
|
|||
if (grads[i].defined()) {
|
||||
grads_all_defined.push_back(static_cast<Tensor>(grads[i]));
|
||||
} else {
|
||||
const auto& length = split_sizes[i].expect_int();
|
||||
const auto& length = split_sizes[i].guard_int(__FILE__, __LINE__);
|
||||
auto nt_split_size = nt_sizes.clone();
|
||||
auto nt_split_size_ptr = nt_split_size.data_ptr<int64_t>();
|
||||
for (int64_t j : c10::irange(static_cast<int64_t>(nt_sizes.size(0)))) {
|
||||
|
|
|
|||
|
|
@ -161,7 +161,8 @@ struct TensorArgs {
|
|||
|
||||
struct AutogradCompilerCall {
|
||||
void add_size_input(const c10::SymInt& s) {
|
||||
all_size_inputs.emplace_back(SizeInput(default_dyn_type, s.expect_int()));
|
||||
all_size_inputs.emplace_back(
|
||||
SizeInput(default_dyn_type, s.guard_int(__FILE__, __LINE__)));
|
||||
}
|
||||
|
||||
int emplace_hook(c10::SafePyObject&& fn) {
|
||||
|
|
|
|||
|
|
@ -614,7 +614,7 @@ void addInputs(Node* n, const char* name, int64_t value) {
|
|||
}
|
||||
|
||||
void addInputs(Node* n, const char* name, c10::SymInt value) {
|
||||
addInputs(n, name, value.expect_int());
|
||||
addInputs(n, name, value.guard_int(__FILE__, __LINE__));
|
||||
}
|
||||
|
||||
void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
|
||||
|
|
@ -815,8 +815,9 @@ void addInputs(Node* n, const char* name, c10::optional<c10::SymInt> value) {
|
|||
addInputs(
|
||||
n,
|
||||
name,
|
||||
value.has_value() ? c10::make_optional(value->expect_int())
|
||||
: c10::nullopt);
|
||||
value.has_value()
|
||||
? c10::make_optional(value->guard_int(__FILE__, __LINE__))
|
||||
: c10::nullopt);
|
||||
}
|
||||
|
||||
void addInputs(
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ inline Value GetSymIntValue(c10::SymInt a) {
|
|||
inline std::vector<int64_t> GetSymIntArrayRefValue(c10::SymIntArrayRef arr) {
|
||||
std::vector<int64_t> r;
|
||||
for (const auto& a : arr) {
|
||||
r.emplace_back(a.expect_int());
|
||||
r.emplace_back(a.guard_int(__FILE__, __LINE__));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -350,12 +350,12 @@ Check this module for more information.
|
|||
return f"{argname}.has_value() ? c10::make_optional(c10::SymInt(*{argname})) : c10::nullopt"
|
||||
elif goal.type == BaseCType(longT):
|
||||
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
|
||||
return f"{symInt_type}.expect_int()"
|
||||
return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
|
||||
elif goal.type == OptionalCType(BaseCType(longT)):
|
||||
argname = direct_solve(
|
||||
NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
|
||||
)
|
||||
return f"{argname}.has_value() ? c10::make_optional({argname}->expect_int()) : c10::nullopt"
|
||||
return f"{argname}.has_value() ? c10::make_optional({argname}->guard_int(__FILE__, __LINE__)) : c10::nullopt"
|
||||
elif goal.type == BaseCType(optionalIntArrayRefT):
|
||||
try:
|
||||
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
|
||||
|
|
|
|||
Loading…
Reference in a new issue