minor vllm add

This commit is contained in:
emilyworks 2024-12-08 23:39:36 +00:00
parent e968d35da4
commit 6fe1750b33

View File

@ -2,7 +2,8 @@ import os
import json
import argparse
import time
import vllm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.quantization
@ -28,6 +29,9 @@ def load_llama(quant=None):
elif args.quantize and args.quantize == 'qat':
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/qat_int8", device_map="cuda", weights_only=False)
model = torch.compile(model, mode="max-autotune")
elif args.vllm:
sampling_params = SamplingParams(n=1, max_tokens=100)
model = LLM(model="meta-llama/Llama-3.2-3B-Instruct", trust_remote_code=True, gpu_memory_utilization=0.9, max_model_len=2048) # Name or path of your model
else:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
@ -219,6 +223,7 @@ def parse_args():
args.add_argument('--breadth', type=int, default=3)
args.add_argument('--greedy_n', type=int, default=1)
args.add_argument('--concurrent', type=int, default=4)
args.add_argument('--vllm', type=bool, default=False)
args = args.parse_args()
return args