From 91a2e953e52cf3c48b322c373587f609f59a12b5 Mon Sep 17 00:00:00 2001 From: David Berard Date: Fri, 10 Jun 2022 19:25:23 -0700 Subject: [PATCH] [JIT] Use signed integers in CalculatedNecessaryArgs x was underflowing: ``` size_t x = ... while (x >= 0) { x--; } ``` Changed the variables to ssize_t. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79331 Approved by: https://github.com/yuhc, https://github.com/tugsbayasgalan --- test/cpp/jit/test_save_load.cpp | 13 +++++++++++++ torch/csrc/jit/runtime/calculate_necessary_args.h | 9 ++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/test/cpp/jit/test_save_load.cpp b/test/cpp/jit/test_save_load.cpp index 6ecf67917ec..6a98e23a673 100644 --- a/test/cpp/jit/test_save_load.cpp +++ b/test/cpp/jit/test_save_load.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -259,5 +260,17 @@ TEST(SerializationTest, ParentDirNotExist) { "Parent directory ./doesnotexist does not exist."); } +TEST(SerializationTest, CalculateNecessaryArgsTest) { + auto schema = torch::schema( + "sync_stream(int stream_id = -1) -> ()", + c10::AliasAnalysisKind::CONSERVATIVE); + + auto graph = std::make_shared(); + auto one_val = graph->insertConstant(-1); + auto necessary = CalculateNecessaryArgs(schema.arguments(), {one_val}, true); + EXPECT_EQ(0, necessary.first); + EXPECT_EQ(0, necessary.second); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/calculate_necessary_args.h b/torch/csrc/jit/runtime/calculate_necessary_args.h index 401353f7411..f0149f27663 100644 --- a/torch/csrc/jit/runtime/calculate_necessary_args.h +++ b/torch/csrc/jit/runtime/calculate_necessary_args.h @@ -7,7 +7,10 @@ namespace torch { namespace jit { -inline std::pair CalculateNecessaryArgs( +// Calculates the number of args that need to be passed in. +// Less args may be needed if defaults are provided. +// Returns: {number args needed, number of out args} +inline std::pair CalculateNecessaryArgs( const std::vector& schema_args, at::ArrayRef actual_inputs, bool allow_trailing_out_args) { @@ -16,7 +19,7 @@ inline std::pair CalculateNecessaryArgs( } // count number of out arguments - auto schema_idx = schema_args.size() - 1; + int64_t schema_idx = static_cast(schema_args.size()) - 1; if (allow_trailing_out_args) { // skip over out arguments in the end. while (schema_idx >= 0) { @@ -28,7 +31,7 @@ inline std::pair CalculateNecessaryArgs( } } - auto num_out = schema_args.size() - schema_idx - 1; + int64_t num_out = static_cast(schema_args.size()) - schema_idx - 1; if (schema_args.size() < actual_inputs.size()) { return std::make_pair(actual_inputs.size(), num_out);