From 66ec91081c0eaa7db4212ba34358fa3b1d159f4f Mon Sep 17 00:00:00 2001 From: StevenLiuWen Date: Mon, 30 Dec 2024 14:48:51 +0800 Subject: [PATCH] Update inference.py --- inference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/inference.py b/inference.py index f1b3ff8..9f94aa6 100644 --- a/inference.py +++ b/inference.py @@ -127,10 +127,10 @@ def main(args): with torch.no_grad(): - inputs_embeds = None - past_key_values = None - - if args.chunk_size > 0: + if args.chunk_size == -1: + inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + past_key_values = None + else: # incremental_prefilling when using 40G GPU for vl2-small inputs_embeds, past_key_values = vl_gpt.incremental_prefilling( input_ids=prepare_inputs.input_ids,