mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Fix ordered dict loading with LibTorch (#100743)"
This reverts commit d371a890a2.
Reverted https://github.com/pytorch/pytorch/pull/100743 on behalf of https://github.com/jeanschmidt due to New test introduced SerializationTest.SaveStateDict is adding regressions ([comment](https://github.com/pytorch/pytorch/pull/100743#issuecomment-1542400538))
This commit is contained in:
parent
cb668b1291
commit
9ff547a57f
3 changed files with 4 additions and 48 deletions
|
|
@ -152,27 +152,6 @@ TEST(SerializationTest, TypeTags) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(SerializationTest, SaveStateDict) {
|
||||
// Requires the state_dict that should have been written in tests_setup.py
|
||||
// Refer: SaveStateDict in test/cpp/jit/tests_setup.py
|
||||
std::ifstream file("state_dict.pt", std::ios::binary);
|
||||
std::vector<char> data(
|
||||
(std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
auto dict = torch::pickle_load(data).toGenericDict();
|
||||
|
||||
for (auto& el : dict) {
|
||||
auto key = el.key().toStringRef();
|
||||
auto ten = el.value().toTensor();
|
||||
if (key == "weight") {
|
||||
ASSERT_TRUE(ten.eq(2.0).all().item().toBool());
|
||||
} else if (key == "bias") {
|
||||
ASSERT_TRUE(ten.eq(3.0).all().item().toBool());
|
||||
} else {
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SerializationTest, TestJitStream_CUDA) {
|
||||
torch::jit::Module model;
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
|
|
|
|||
|
|
@ -50,15 +50,6 @@ class SerializationInterop(FileSetup):
|
|||
|
||||
torch.save(value, self.path, _use_new_zipfile_serialization=True)
|
||||
|
||||
class SaveStateDict(FileSetup):
|
||||
path = 'state_dict.pt'
|
||||
|
||||
def setup(self):
|
||||
model = torch.nn.Linear(10, 10)
|
||||
torch.nn.init.constant_(model.weight, 2.0)
|
||||
torch.nn.init.constant_(model.bias, 3.0)
|
||||
|
||||
torch.save(model.state_dict(), self.path, _use_new_zipfile_serialization=True)
|
||||
|
||||
# See testTorchSaveError in test/cpp/jit/tests.h for usage
|
||||
class TorchSaveError(FileSetup):
|
||||
|
|
@ -102,7 +93,6 @@ tests = [
|
|||
EvalModeForLoadedModule(),
|
||||
SerializationInterop(),
|
||||
TorchSaveError(),
|
||||
SaveStateDict(),
|
||||
TorchSaveJitStream_CUDA()
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -511,17 +511,6 @@ PickleOpCode Unpickler::readInstruction() {
|
|||
"Parsing error: stack_ contains ",
|
||||
stack_.size(),
|
||||
" elements, at least 2 expected");
|
||||
|
||||
// In the OrderedDict case, the id has already been materialized
|
||||
// and added to the stack, thus there's no <functor_idx> but a Dict
|
||||
// there, in this case we can just pop the functor args and break.
|
||||
// The functor args in this case contain some other metadata like
|
||||
// '{_metadata: {: {version: 1}}}' which seem to be safe to ignore.
|
||||
if (stack_.at(stack_.size() - 2).isGenericDict()) {
|
||||
stack_.pop_back();
|
||||
break;
|
||||
}
|
||||
|
||||
std::swap(*(stack_.end() - 2), *(stack_.end() - 1));
|
||||
size_t idx = stack_.back().toInt();
|
||||
stack_.pop_back();
|
||||
|
|
@ -658,7 +647,6 @@ void Unpickler::readGlobal(
|
|||
TORCH_CHECK(false, "INVALID VALUES")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this
|
||||
// is only here for bc-compatibility reasons
|
||||
if (module_name == "__main__") {
|
||||
|
|
@ -766,12 +754,11 @@ void Unpickler::readGlobal(
|
|||
} else if (module_name == "collections" && class_name == "OrderedDict") {
|
||||
// collections.OrderedDict is used in tensor serialization for a tensor's
|
||||
// backward hooks (but they are not actually saved with this Pickler)
|
||||
// Python's model.state_dict() is an OrderedDict, but this is not used
|
||||
// for model loading.
|
||||
globals_.emplace_back([this] {
|
||||
// The OrderedDict becomes a GenericDict. The inputs which are in
|
||||
// stack.back() are fully ignored, but they are empty anyways.
|
||||
stack_.back() = c10::impl::GenericDict(AnyType::get(), AnyType::get());
|
||||
// drop the Tuple that was argument to OrderedDict, and replace it
|
||||
// with None OrderedDicts only appear in tensor deserialization and
|
||||
// their value is never used
|
||||
stack_.back() = IValue();
|
||||
});
|
||||
} else if (module_name == "torch" && class_name == "device") {
|
||||
globals_.emplace_back([this] {
|
||||
|
|
|
|||
Loading…
Reference in a new issue