From 04e5acc08ed5b878225491bf62540ea10274fb29 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 6 Mar 2023 10:05:27 -0800 Subject: [PATCH] Fix a bug in 1D input shape (#5) --- cacheflow/models/attention.py | 11 ++++++++--- cacheflow/models/input_metadata.py | 2 +- server.py | 4 ++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 34edeec02cbc2..7c77db5a819b6 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -47,9 +47,8 @@ class OPTCacheFlowAttention(nn.Module): max_s=max_prompt_len, causal=True, )[0] - num_tokens = prefix_sum[-1] # FIXME(woosuk): Unnecessary copy. Optimize this. - output[:num_tokens].copy_(out, non_blocking=True) + output.copy_(out, non_blocking=True) def single_query_cached_kv_attention( self, @@ -108,8 +107,14 @@ class OPTCacheFlowAttention(nn.Module): # Compute the attention op for prompts. if input_metadata.num_prompts > 0: + num_prompt_tokens = sum(input_metadata.prompt_lens) self.multi_query_kv_attention( - output, query, key, value, input_metadata.prompt_lens) + output[:num_prompt_tokens], + query[:num_prompt_tokens], + key[:num_prompt_tokens], + value[:num_prompt_tokens], + input_metadata.prompt_lens, + ) # Wait until the cache op is done. if cache_event is not None: diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index 86cc2e8f1f5a3..77f25054e38a6 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -24,7 +24,7 @@ class InputMetadata: self.num_prompts = len(prompt_lens) self.num_generation_tokens = context_lens.shape[0] - self.num_valid_tokens = len(slot_mapping) + self.num_valid_tokens = slot_mapping.shape[0] if block_tables.numel() > 0: self.max_num_blocks_per_seq = block_tables.shape[1] else: diff --git a/server.py b/server.py index d70dab01abd45..04e2f6d726693 100644 --- a/server.py +++ b/server.py @@ -57,11 +57,11 @@ def main(): 'UC Berkeley is', 'The future of cloud computing is', ] - for prompt in test_inputs: - frontend.query(prompt) # FIXME while True: + if test_inputs: + frontend.query(test_inputs.pop()) scheduler.step() if not scheduler.pending and not scheduler.running: break