Revert "Fix ordered dict loading with LibTorch (#100743)"

This reverts commit d371a890a2.

Reverted https://github.com/pytorch/pytorch/pull/100743 on behalf of https://github.com/jeanschmidt due to New test introduced SerializationTest.SaveStateDict is adding regressions ([comment](https://github.com/pytorch/pytorch/pull/100743#issuecomment-1542400538))
This commit is contained in:
PyTorch MergeBot 2023-05-10 15:29:13 +00:00
parent cb668b1291
commit 9ff547a57f
3 changed files with 4 additions and 48 deletions

View file

@ -152,27 +152,6 @@ TEST(SerializationTest, TypeTags) {
}
}
TEST(SerializationTest, SaveStateDict) {
// Requires the state_dict that should have been written in tests_setup.py
// Refer: SaveStateDict in test/cpp/jit/tests_setup.py
std::ifstream file("state_dict.pt", std::ios::binary);
std::vector<char> data(
(std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
auto dict = torch::pickle_load(data).toGenericDict();
for (auto& el : dict) {
auto key = el.key().toStringRef();
auto ten = el.value().toTensor();
if (key == "weight") {
ASSERT_TRUE(ten.eq(2.0).all().item().toBool());
} else if (key == "bias") {
ASSERT_TRUE(ten.eq(3.0).all().item().toBool());
} else {
ASSERT_TRUE(false);
}
}
}
TEST(SerializationTest, TestJitStream_CUDA) {
torch::jit::Module model;
std::vector<torch::jit::IValue> inputs;

View file

@ -50,15 +50,6 @@ class SerializationInterop(FileSetup):
torch.save(value, self.path, _use_new_zipfile_serialization=True)
class SaveStateDict(FileSetup):
path = 'state_dict.pt'
def setup(self):
model = torch.nn.Linear(10, 10)
torch.nn.init.constant_(model.weight, 2.0)
torch.nn.init.constant_(model.bias, 3.0)
torch.save(model.state_dict(), self.path, _use_new_zipfile_serialization=True)
# See testTorchSaveError in test/cpp/jit/tests.h for usage
class TorchSaveError(FileSetup):
@ -102,7 +93,6 @@ tests = [
EvalModeForLoadedModule(),
SerializationInterop(),
TorchSaveError(),
SaveStateDict(),
TorchSaveJitStream_CUDA()
]

View file

@ -511,17 +511,6 @@ PickleOpCode Unpickler::readInstruction() {
"Parsing error: stack_ contains ",
stack_.size(),
" elements, at least 2 expected");
// In the OrderedDict case, the id has already been materialized
// and added to the stack, thus there's no <functor_idx> but a Dict
// there, in this case we can just pop the functor args and break.
// The functor args in this case contain some other metadata like
// '{_metadata: {: {version: 1}}}' which seem to be safe to ignore.
if (stack_.at(stack_.size() - 2).isGenericDict()) {
stack_.pop_back();
break;
}
std::swap(*(stack_.end() - 2), *(stack_.end() - 1));
size_t idx = stack_.back().toInt();
stack_.pop_back();
@ -658,7 +647,6 @@ void Unpickler::readGlobal(
TORCH_CHECK(false, "INVALID VALUES")
}
}
// TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this
// is only here for bc-compatibility reasons
if (module_name == "__main__") {
@ -766,12 +754,11 @@ void Unpickler::readGlobal(
} else if (module_name == "collections" && class_name == "OrderedDict") {
// collections.OrderedDict is used in tensor serialization for a tensor's
// backward hooks (but they are not actually saved with this Pickler)
// Python's model.state_dict() is an OrderedDict, but this is not used
// for model loading.
globals_.emplace_back([this] {
// The OrderedDict becomes a GenericDict. The inputs which are in
// stack.back() are fully ignored, but they are empty anyways.
stack_.back() = c10::impl::GenericDict(AnyType::get(), AnyType::get());
// drop the Tuple that was argument to OrderedDict, and replace it
// with None OrderedDicts only appear in tensor deserialization and
// their value is never used
stack_.back() = IValue();
});
} else if (module_name == "torch" && class_name == "device") {
globals_.emplace_back([this] {