Update inference.py

This commit is contained in:
StevenLiuWen 2024-12-30 14:48:51 +08:00
parent 9789f97283
commit 66ec91081c

View File

@ -127,10 +127,10 @@ def main(args):
with torch.no_grad(): with torch.no_grad():
inputs_embeds = None if args.chunk_size == -1:
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
past_key_values = None past_key_values = None
else:
if args.chunk_size > 0:
# incremental_prefilling when using 40G GPU for vl2-small # incremental_prefilling when using 40G GPU for vl2-small
inputs_embeds, past_key_values = vl_gpt.incremental_prefilling( inputs_embeds, past_key_values = vl_gpt.incremental_prefilling(
input_ids=prepare_inputs.input_ids, input_ids=prepare_inputs.input_ids,