mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-06-26 18:26:00 +00:00
minor vllm add
This commit is contained in:
parent
e968d35da4
commit
6fe1750b33
@ -2,7 +2,8 @@ import os
|
||||
import json
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import vllm
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
import torch.quantization
|
||||
@ -28,6 +29,9 @@ def load_llama(quant=None):
|
||||
elif args.quantize and args.quantize == 'qat':
|
||||
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/qat_int8", device_map="cuda", weights_only=False)
|
||||
model = torch.compile(model, mode="max-autotune")
|
||||
elif args.vllm:
|
||||
sampling_params = SamplingParams(n=1, max_tokens=100)
|
||||
model = LLM(model="meta-llama/Llama-3.2-3B-Instruct", trust_remote_code=True, gpu_memory_utilization=0.9, max_model_len=2048) # Name or path of your model
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||
|
||||
@ -219,6 +223,7 @@ def parse_args():
|
||||
args.add_argument('--breadth', type=int, default=3)
|
||||
args.add_argument('--greedy_n', type=int, default=1)
|
||||
args.add_argument('--concurrent', type=int, default=4)
|
||||
args.add_argument('--vllm', type=bool, default=False)
|
||||
|
||||
args = args.parse_args()
|
||||
return args
|
||||
|
Loading…
Reference in New Issue
Block a user