Update inference.py

This commit is contained in:
StevenLiuWen 2024-12-30 14:48:51 +08:00
parent 9789f97283
commit 66ec91081c

View File

@ -127,10 +127,10 @@ def main(args):
with torch.no_grad():
inputs_embeds = None
if args.chunk_size == -1:
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
past_key_values = None
if args.chunk_size > 0:
else:
# incremental_prefilling when using 40G GPU for vl2-small
inputs_embeds, past_key_values = vl_gpt.incremental_prefilling(
input_ids=prepare_inputs.input_ids,