[JIT] Use signed integers in CalculatedNecessaryArgs

x was underflowing:
```
size_t x = ...
while (x >= 0) {
  x--;
}
```

Changed the variables to ssize_t.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79331

Approved by: https://github.com/yuhc, https://github.com/tugsbayasgalan
This commit is contained in:
David Berard 2022-06-10 19:25:23 -07:00 committed by PyTorch MergeBot
parent c727ec6129
commit 91a2e953e5
2 changed files with 19 additions and 3 deletions

View file

@ -4,6 +4,7 @@
#include <sstream>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/import.h>
@ -259,5 +260,17 @@ TEST(SerializationTest, ParentDirNotExist) {
"Parent directory ./doesnotexist does not exist.");
}
TEST(SerializationTest, CalculateNecessaryArgsTest) {
auto schema = torch::schema(
"sync_stream(int stream_id = -1) -> ()",
c10::AliasAnalysisKind::CONSERVATIVE);
auto graph = std::make_shared<Graph>();
auto one_val = graph->insertConstant(-1);
auto necessary = CalculateNecessaryArgs(schema.arguments(), {one_val}, true);
EXPECT_EQ(0, necessary.first);
EXPECT_EQ(0, necessary.second);
}
} // namespace jit
} // namespace torch

View file

@ -7,7 +7,10 @@
namespace torch {
namespace jit {
inline std::pair<size_t, size_t> CalculateNecessaryArgs(
// Calculates the number of args that need to be passed in.
// Less args may be needed if defaults are provided.
// Returns: {number args needed, number of out args}
inline std::pair<int64_t, int64_t> CalculateNecessaryArgs(
const std::vector<Argument>& schema_args,
at::ArrayRef<Value*> actual_inputs,
bool allow_trailing_out_args) {
@ -16,7 +19,7 @@ inline std::pair<size_t, size_t> CalculateNecessaryArgs(
}
// count number of out arguments
auto schema_idx = schema_args.size() - 1;
int64_t schema_idx = static_cast<int64_t>(schema_args.size()) - 1;
if (allow_trailing_out_args) {
// skip over out arguments in the end.
while (schema_idx >= 0) {
@ -28,7 +31,7 @@ inline std::pair<size_t, size_t> CalculateNecessaryArgs(
}
}
auto num_out = schema_args.size() - schema_idx - 1;
int64_t num_out = static_cast<int64_t>(schema_args.size()) - schema_idx - 1;
if (schema_args.size() < actual_inputs.size()) {
return std::make_pair(actual_inputs.size(), num_out);