mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[JIT] Use signed integers in CalculatedNecessaryArgs
x was underflowing:
```
size_t x = ...
while (x >= 0) {
x--;
}
```
Changed the variables to ssize_t.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79331
Approved by: https://github.com/yuhc, https://github.com/tugsbayasgalan
This commit is contained in:
parent
c727ec6129
commit
91a2e953e5
2 changed files with 19 additions and 3 deletions
|
|
@ -4,6 +4,7 @@
|
|||
#include <sstream>
|
||||
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
|
|
@ -259,5 +260,17 @@ TEST(SerializationTest, ParentDirNotExist) {
|
|||
"Parent directory ./doesnotexist does not exist.");
|
||||
}
|
||||
|
||||
TEST(SerializationTest, CalculateNecessaryArgsTest) {
|
||||
auto schema = torch::schema(
|
||||
"sync_stream(int stream_id = -1) -> ()",
|
||||
c10::AliasAnalysisKind::CONSERVATIVE);
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
auto one_val = graph->insertConstant(-1);
|
||||
auto necessary = CalculateNecessaryArgs(schema.arguments(), {one_val}, true);
|
||||
EXPECT_EQ(0, necessary.first);
|
||||
EXPECT_EQ(0, necessary.second);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
inline std::pair<size_t, size_t> CalculateNecessaryArgs(
|
||||
// Calculates the number of args that need to be passed in.
|
||||
// Less args may be needed if defaults are provided.
|
||||
// Returns: {number args needed, number of out args}
|
||||
inline std::pair<int64_t, int64_t> CalculateNecessaryArgs(
|
||||
const std::vector<Argument>& schema_args,
|
||||
at::ArrayRef<Value*> actual_inputs,
|
||||
bool allow_trailing_out_args) {
|
||||
|
|
@ -16,7 +19,7 @@ inline std::pair<size_t, size_t> CalculateNecessaryArgs(
|
|||
}
|
||||
|
||||
// count number of out arguments
|
||||
auto schema_idx = schema_args.size() - 1;
|
||||
int64_t schema_idx = static_cast<int64_t>(schema_args.size()) - 1;
|
||||
if (allow_trailing_out_args) {
|
||||
// skip over out arguments in the end.
|
||||
while (schema_idx >= 0) {
|
||||
|
|
@ -28,7 +31,7 @@ inline std::pair<size_t, size_t> CalculateNecessaryArgs(
|
|||
}
|
||||
}
|
||||
|
||||
auto num_out = schema_args.size() - schema_idx - 1;
|
||||
int64_t num_out = static_cast<int64_t>(schema_args.size()) - schema_idx - 1;
|
||||
|
||||
if (schema_args.size() < actual_inputs.size()) {
|
||||
return std::make_pair(actual_inputs.size(), num_out);
|
||||
|
|
|
|||
Loading…
Reference in a new issue