mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix file path for shard_num 1 with mllama converter (#35053)
"#35049 fix path for num_shard 1"
This commit is contained in:
parent
0938b57770
commit
80f2b1610f
1 changed files with 5 additions and 1 deletions
|
|
@ -338,7 +338,11 @@ def write_model(
|
|||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}...")
|
||||
if num_shards == 1:
|
||||
loaded = [torch.load(os.path.join(input_base_path, "consolidated.pth"), map_location="cpu", mmap=True)]
|
||||
if os.path.exists(os.path.join(input_base_path, "consolidated.00.pth")):
|
||||
path = os.path.join(input_base_path, "consolidated.00.pth")
|
||||
else:
|
||||
path = os.path.join(input_base_path, "consolidated.pth")
|
||||
loaded = [torch.load(path, map_location="cpu", mmap=True)]
|
||||
else:
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu", mmap=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue