mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
up
This commit is contained in:
parent
c800a2c913
commit
517cae97bb
1 changed files with 4 additions and 6 deletions
|
|
@ -256,7 +256,7 @@ class ContinuousBatch:
|
|||
self.cache_index[k] += [new_block * cache.block_size]
|
||||
else:
|
||||
self.cache_index[k] += [self.cache_index[k][-1] +1]
|
||||
position_ids += [self.position_ids[k][-1]]
|
||||
position_ids += torch.tensor([[self.position_ids[k][-1]]])
|
||||
next_full_cache_position += self.cache_index[k]
|
||||
|
||||
# how to efficiently select the next block? -> we probably just take the next longest sequence for now!
|
||||
|
|
@ -280,7 +280,7 @@ class ContinuousBatch:
|
|||
self.cumulative_seqlens_q.append(sample_length)
|
||||
self.cumulative_seqlens_k.append(sample_length)
|
||||
|
||||
position_ids += torch.arange(sample_length)
|
||||
position_ids += [list(range(sample_length))]
|
||||
if sample_length < cache.block_size:
|
||||
current_cache_index = list(range(blocks_to_use[0] * cache.block_size, (blocks_to_use[0] * cache.block_size) + sample_length))
|
||||
else:
|
||||
|
|
@ -303,11 +303,9 @@ class ContinuousBatch:
|
|||
self.cumulative_seqlens_k = torch.tensor(self.cumulative_seqlens_k)
|
||||
self.cumulative_seqlens_q = torch.tensor(self.cumulative_seqlens_q)
|
||||
self.position_ids = position_ids
|
||||
assert len(new_ids) == len(next_full_cache_position) == len(position_ids), "Some preprocessing went wrong"
|
||||
position_ids = torch.tensor([position_ids])
|
||||
new_ids = torch.cat((self.next_ids, torch.tensor(new_ids))).reshape(1, -1) # new sequence placed at the end
|
||||
|
||||
# position_ids[0, self.cumulative_seqlens_k -1] -> the last index?
|
||||
position_ids = torch.cat([torch.tensor(k) for k in position_ids])[ None,:]
|
||||
new_ids = torch.cat((self.next_ids, torch.tensor(new_ids))).long().reshape(1, -1) # new sequence placed at the end
|
||||
return new_ids, position_ids, torch.tensor(next_full_cache_position)
|
||||
|
||||
def update(self, generated_ids):
|
||||
|
|
|
|||
Loading…
Reference in a new issue