diff --git a/inference.py b/inference.py index f1b3ff8..9f94aa6 100644 --- a/inference.py +++ b/inference.py @@ -127,10 +127,10 @@ def main(args): with torch.no_grad(): - inputs_embeds = None - past_key_values = None - - if args.chunk_size > 0: + if args.chunk_size == -1: + inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + past_key_values = None + else: # incremental_prefilling when using 40G GPU for vl2-small inputs_embeds, past_key_values = vl_gpt.incremental_prefilling( input_ids=prepare_inputs.input_ids,