From ad95ecfd4bbe1d328f8c54fb8ed495abcbe5a1c4 Mon Sep 17 00:00:00 2001 From: emilyworks Date: Fri, 6 Dec 2024 19:59:26 +0000 Subject: [PATCH] adding quant model setup in scripts --- run.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/run.py b/run.py index f9dcad1..a720a20 100644 --- a/run.py +++ b/run.py @@ -16,28 +16,22 @@ def run(args): ''' main run function ''' - #load in non-gpt model in this driver function for now to avoid repeated loading later on + #bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading if args.backend == 'llama': - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") - model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") - - if args.quantize and args.quantize=='ptq': - model.train() - model.qconfig = torch.quantization.get_default_qconfig('x86') - torch.quantization.prepare(model, inplace=True) - for _, mod in model.named_modules(): - if isinstance(mod, torch.nn.Embedding): - mod.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig - model = torch.quantization.convert(model, inplace=True) - model.load_state_dict(torch.load('quant_experiments/quantized_model.pth')) - model.eval() + if args.quantize and args.quantize=='ptq_int4': + model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda") + model = torch.compile(model, mode="max-autotune") + if args.quantize and args.quantize=='ptq_int8': + model = AutoModelForCausalLM.from_pretrained("src/tot/ptq_int8", device_map="cuda") + model = torch.compile(model, mode="max-autotune") elif args.backend == 'qat': - pass - # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") - # model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8") + model = AutoModelForCausalLM.from_pretrained("src/tot/qat_int8", device_map="cuda") + model = torch.compile(model, mode="max-autotune") else: - pass - else: + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") + model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") + + else: #gpt4 will be used later in this case model = None tokenizer = None