remove clear_cuda_cache in forward

This commit is contained in:
StevenLiuWen 2024-12-30 14:19:34 +08:00
parent a8341f36dd
commit 9789f97283
2 changed files with 17 additions and 14 deletions

View File

@ -618,8 +618,6 @@ class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
cache_position=cache_position
)
self._clear_cuda_cache()
return outputs
def _clear_cuda_cache(self):

View File

@ -126,18 +126,20 @@ def main(args):
# print(key, value.shape, type(value))
with torch.no_grad():
# run image encoder to get the image embeddings
# inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# incremental_prefilling when using 40G GPU for vl2-small
inputs_embeds, past_key_values = vl_gpt.incremental_prefilling(
input_ids=prepare_inputs.input_ids,
images=prepare_inputs.images,
images_seq_mask=prepare_inputs.images_seq_mask,
images_spatial_crop=prepare_inputs.images_spatial_crop,
attention_mask=prepare_inputs.attention_mask,
chunk_size=args.chunk_size
)
inputs_embeds = None
past_key_values = None
if args.chunk_size > 0:
# incremental_prefilling when using 40G GPU for vl2-small
inputs_embeds, past_key_values = vl_gpt.incremental_prefilling(
input_ids=prepare_inputs.input_ids,
images=prepare_inputs.images,
images_seq_mask=prepare_inputs.images_seq_mask,
images_spatial_crop=prepare_inputs.images_spatial_crop,
attention_mask=prepare_inputs.attention_mask,
chunk_size=args.chunk_size
)
# run the model to get the response
outputs = vl_gpt.generate(
@ -180,6 +182,9 @@ if __name__ == "__main__":
parser.add_argument("--model_path", type=str, required=True,
default="deepseek-ai/deepseek-vl2",
help="model name or local path to the model")
parser.add_argument("--chunk_size", type=int, default=512, help="chunk size for the model for prefiiling")
parser.add_argument("--chunk_size", type=int, default=-1,
help="chunk size for the model for prefiiling. "
"When using 40G gpu for vl2-small, set a chunk_size for incremental_prefilling."
"Otherwise, default value is -1, which means we do not use incremental_prefilling.")
args = parser.parse_args()
main(args)