Replace expect_int with guard_int (#113921)

The idea is that instead of erroring, we will just specialize at these sites.

Fixes https://github.com/pytorch/pytorch/issues/113142

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113921
Approved by: https://github.com/zou3519
This commit is contained in:
Edward Z. Yang 2023-11-20 09:24:41 -08:00 committed by PyTorch MergeBot
parent 59ad51e10a
commit 8c4812be80
10 changed files with 16 additions and 15 deletions

View file

@ -59,7 +59,7 @@ inline typename remove_symint<T>::type unpackSymInt(T x) { return x; }
template <>
inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
return x.expect_int();
return x.guard_int(__FILE__, __LINE__);
}
template <>
@ -69,7 +69,7 @@ inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIn
template <>
inline typename remove_symint<c10::optional<c10::SymInt>>::type unpackSymInt(c10::optional<c10::SymInt> x) {
return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt;
return x.has_value() ? c10::make_optional(x->guard_int(__FILE__, __LINE__)) : c10::nullopt;
}
template <>

View file

@ -75,7 +75,7 @@ template <>
inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
template <>
inline int64_t maybe_convert_symint(c10::SymInt x) { return x.expect_int(); }
inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
template <typename T>
static inline void checkInBoundsForStorage(

View file

@ -214,8 +214,8 @@ Tensor narrow_nested_symint(const at::Tensor& self, int64_t dim, SymInt start, S
auto storage_offsets = nt_impl->get_storage_offsets();
auto storage_offsets_ptr = storage_offsets.data_ptr<int64_t>();
auto start_int = start.expect_int();
auto length_int = length.expect_int();
auto start_int = start.guard_int(__FILE__, __LINE__);
auto length_int = length.guard_int(__FILE__, __LINE__);
auto buffer_offset = storage_offsets_ptr[start_int];
nested_sizes = nested_sizes.narrow(0, start_int, length_int);

View file

@ -414,7 +414,7 @@ int64_t TensorImpl::storage_offset_custom() const {
// TODO: fix this
return pyobj_slot_.load_pyobj_interpreter()
->sym_storage_offset(this)
.expect_int();
.guard_int(__FILE__, __LINE__);
}
return storage_offset_default();
}

View file

@ -1737,7 +1737,6 @@ symbolic_tensor_failures = {
xfail('renorm', ''),
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
xfail('randint_like', ''), # when unpacking SymInt, expected int but got s0
# many complex operators incorrect striding, metadata
xfail('fft.fft', ''),

View file

@ -2106,7 +2106,7 @@ Tensor _nested_split_with_sizes_backward(
if (grads[i].defined()) {
grads_all_defined.push_back(static_cast<Tensor>(grads[i]));
} else {
const auto& length = split_sizes[i].expect_int();
const auto& length = split_sizes[i].guard_int(__FILE__, __LINE__);
auto nt_split_size = nt_sizes.clone();
auto nt_split_size_ptr = nt_split_size.data_ptr<int64_t>();
for (int64_t j : c10::irange(static_cast<int64_t>(nt_sizes.size(0)))) {

View file

@ -161,7 +161,8 @@ struct TensorArgs {
struct AutogradCompilerCall {
void add_size_input(const c10::SymInt& s) {
all_size_inputs.emplace_back(SizeInput(default_dyn_type, s.expect_int()));
all_size_inputs.emplace_back(
SizeInput(default_dyn_type, s.guard_int(__FILE__, __LINE__)));
}
int emplace_hook(c10::SafePyObject&& fn) {

View file

@ -614,7 +614,7 @@ void addInputs(Node* n, const char* name, int64_t value) {
}
void addInputs(Node* n, const char* name, c10::SymInt value) {
addInputs(n, name, value.expect_int());
addInputs(n, name, value.guard_int(__FILE__, __LINE__));
}
void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
@ -815,8 +815,9 @@ void addInputs(Node* n, const char* name, c10::optional<c10::SymInt> value) {
addInputs(
n,
name,
value.has_value() ? c10::make_optional(value->expect_int())
: c10::nullopt);
value.has_value()
? c10::make_optional(value->guard_int(__FILE__, __LINE__))
: c10::nullopt);
}
void addInputs(

View file

@ -141,7 +141,7 @@ inline Value GetSymIntValue(c10::SymInt a) {
inline std::vector<int64_t> GetSymIntArrayRefValue(c10::SymIntArrayRef arr) {
std::vector<int64_t> r;
for (const auto& a : arr) {
r.emplace_back(a.expect_int());
r.emplace_back(a.guard_int(__FILE__, __LINE__));
}
return r;
}

View file

@ -350,12 +350,12 @@ Check this module for more information.
return f"{argname}.has_value() ? c10::make_optional(c10::SymInt(*{argname})) : c10::nullopt"
elif goal.type == BaseCType(longT):
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
return f"{symInt_type}.expect_int()"
return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
elif goal.type == OptionalCType(BaseCType(longT)):
argname = direct_solve(
NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
)
return f"{argname}.has_value() ? c10::make_optional({argname}->expect_int()) : c10::nullopt"
return f"{argname}.has_value() ? c10::make_optional({argname}->guard_int(__FILE__, __LINE__)) : c10::nullopt"
elif goal.type == BaseCType(optionalIntArrayRefT):
try:
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))