This commit is contained in:
Arthur Zucker 2025-01-15 18:07:25 +01:00
parent c800a2c913
commit 517cae97bb

View file

@ -256,7 +256,7 @@ class ContinuousBatch:
self.cache_index[k] += [new_block * cache.block_size]
else:
self.cache_index[k] += [self.cache_index[k][-1] +1]
position_ids += [self.position_ids[k][-1]]
position_ids += torch.tensor([[self.position_ids[k][-1]]])
next_full_cache_position += self.cache_index[k]
# how to efficiently select the next block? -> we probably just take the next longest sequence for now!
@ -280,7 +280,7 @@ class ContinuousBatch:
self.cumulative_seqlens_q.append(sample_length)
self.cumulative_seqlens_k.append(sample_length)
position_ids += torch.arange(sample_length)
position_ids += [list(range(sample_length))]
if sample_length < cache.block_size:
current_cache_index = list(range(blocks_to_use[0] * cache.block_size, (blocks_to_use[0] * cache.block_size) + sample_length))
else:
@ -303,11 +303,9 @@ class ContinuousBatch:
self.cumulative_seqlens_k = torch.tensor(self.cumulative_seqlens_k)
self.cumulative_seqlens_q = torch.tensor(self.cumulative_seqlens_q)
self.position_ids = position_ids
assert len(new_ids) == len(next_full_cache_position) == len(position_ids), "Some preprocessing went wrong"
position_ids = torch.tensor([position_ids])
new_ids = torch.cat((self.next_ids, torch.tensor(new_ids))).reshape(1, -1) # new sequence placed at the end
# position_ids[0, self.cumulative_seqlens_k -1] -> the last index?
position_ids = torch.cat([torch.tensor(k) for k in position_ids])[ None,:]
new_ids = torch.cat((self.next_ids, torch.tensor(new_ids))).long().reshape(1, -1) # new sequence placed at the end
return new_ids, position_ids, torch.tensor(next_full_cache_position)
def update(self, generated_ids):