diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 097cc5b28..34d8321af 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -256,7 +256,7 @@ class ContinuousBatch: self.cache_index[k] += [new_block * cache.block_size] else: self.cache_index[k] += [self.cache_index[k][-1] +1] - position_ids += [self.position_ids[k][-1]] + position_ids += torch.tensor([[self.position_ids[k][-1]]]) next_full_cache_position += self.cache_index[k] # how to efficiently select the next block? -> we probably just take the next longest sequence for now! @@ -280,7 +280,7 @@ class ContinuousBatch: self.cumulative_seqlens_q.append(sample_length) self.cumulative_seqlens_k.append(sample_length) - position_ids += torch.arange(sample_length) + position_ids += [list(range(sample_length))] if sample_length < cache.block_size: current_cache_index = list(range(blocks_to_use[0] * cache.block_size, (blocks_to_use[0] * cache.block_size) + sample_length)) else: @@ -303,11 +303,9 @@ class ContinuousBatch: self.cumulative_seqlens_k = torch.tensor(self.cumulative_seqlens_k) self.cumulative_seqlens_q = torch.tensor(self.cumulative_seqlens_q) self.position_ids = position_ids - assert len(new_ids) == len(next_full_cache_position) == len(position_ids), "Some preprocessing went wrong" - position_ids = torch.tensor([position_ids]) - new_ids = torch.cat((self.next_ids, torch.tensor(new_ids))).reshape(1, -1) # new sequence placed at the end - # position_ids[0, self.cumulative_seqlens_k -1] -> the last index? + position_ids = torch.cat([torch.tensor(k) for k in position_ids])[ None,:] + new_ids = torch.cat((self.next_ids, torch.tensor(new_ids))).long().reshape(1, -1) # new sequence placed at the end return new_ids, position_ids, torch.tensor(next_full_cache_position) def update(self, generated_ids):