mirror of
https://github.com/princeton-nlp/tree-of-thought-llm
synced 2025-05-22 20:03:58 +00:00
adding quant model setup in scripts
This commit is contained in:
parent
17c8f9b8dc
commit
ad95ecfd4b
32
run.py
32
run.py
@ -16,28 +16,22 @@ def run(args):
|
|||||||
'''
|
'''
|
||||||
main run function
|
main run function
|
||||||
'''
|
'''
|
||||||
#load in non-gpt model in this driver function for now to avoid repeated loading later on
|
#bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading
|
||||||
if args.backend == 'llama':
|
if args.backend == 'llama':
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
if args.quantize and args.quantize=='ptq_int4':
|
||||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda")
|
||||||
|
model = torch.compile(model, mode="max-autotune")
|
||||||
if args.quantize and args.quantize=='ptq':
|
if args.quantize and args.quantize=='ptq_int8':
|
||||||
model.train()
|
model = AutoModelForCausalLM.from_pretrained("src/tot/ptq_int8", device_map="cuda")
|
||||||
model.qconfig = torch.quantization.get_default_qconfig('x86')
|
model = torch.compile(model, mode="max-autotune")
|
||||||
torch.quantization.prepare(model, inplace=True)
|
|
||||||
for _, mod in model.named_modules():
|
|
||||||
if isinstance(mod, torch.nn.Embedding):
|
|
||||||
mod.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
|
|
||||||
model = torch.quantization.convert(model, inplace=True)
|
|
||||||
model.load_state_dict(torch.load('quant_experiments/quantized_model.pth'))
|
|
||||||
model.eval()
|
|
||||||
elif args.backend == 'qat':
|
elif args.backend == 'qat':
|
||||||
pass
|
model = AutoModelForCausalLM.from_pretrained("src/tot/qat_int8", device_map="cuda")
|
||||||
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
|
model = torch.compile(model, mode="max-autotune")
|
||||||
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8")
|
|
||||||
else:
|
else:
|
||||||
pass
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||||
else:
|
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||||
|
|
||||||
|
else: #gpt4 will be used later in this case
|
||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user