diff --git a/deepseek_vl2/models/modeling_deepseek_vl_v2.py b/deepseek_vl2/models/modeling_deepseek_vl_v2.py index fa6e182..957464f 100644 --- a/deepseek_vl2/models/modeling_deepseek_vl_v2.py +++ b/deepseek_vl2/models/modeling_deepseek_vl_v2.py @@ -618,8 +618,6 @@ class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel): cache_position=cache_position ) - self._clear_cuda_cache() - return outputs def _clear_cuda_cache(self): diff --git a/inference.py b/inference.py index 4722f33..f1b3ff8 100644 --- a/inference.py +++ b/inference.py @@ -126,18 +126,20 @@ def main(args): # print(key, value.shape, type(value)) with torch.no_grad(): - # run image encoder to get the image embeddings - # inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) - # incremental_prefilling when using 40G GPU for vl2-small - inputs_embeds, past_key_values = vl_gpt.incremental_prefilling( - input_ids=prepare_inputs.input_ids, - images=prepare_inputs.images, - images_seq_mask=prepare_inputs.images_seq_mask, - images_spatial_crop=prepare_inputs.images_spatial_crop, - attention_mask=prepare_inputs.attention_mask, - chunk_size=args.chunk_size - ) + inputs_embeds = None + past_key_values = None + + if args.chunk_size > 0: + # incremental_prefilling when using 40G GPU for vl2-small + inputs_embeds, past_key_values = vl_gpt.incremental_prefilling( + input_ids=prepare_inputs.input_ids, + images=prepare_inputs.images, + images_seq_mask=prepare_inputs.images_seq_mask, + images_spatial_crop=prepare_inputs.images_spatial_crop, + attention_mask=prepare_inputs.attention_mask, + chunk_size=args.chunk_size + ) # run the model to get the response outputs = vl_gpt.generate( @@ -180,6 +182,9 @@ if __name__ == "__main__": parser.add_argument("--model_path", type=str, required=True, default="deepseek-ai/deepseek-vl2", help="model name or local path to the model") - parser.add_argument("--chunk_size", type=int, default=512, help="chunk size for the model for prefiiling") + parser.add_argument("--chunk_size", type=int, default=-1, + help="chunk size for the model for prefiiling. " + "When using 40G gpu for vl2-small, set a chunk_size for incremental_prefilling." + "Otherwise, default value is -1, which means we do not use incremental_prefilling.") args = parser.parse_args() main(args)