diff --git a/run.py b/run.py index dd21b7a..b52a15c 100644 --- a/run.py +++ b/run.py @@ -9,6 +9,7 @@ from src.tot.models import gpt_usage from transformers import AutoTokenizer, AutoModelForCausalLM import torch +import torch.quantization def run(args): @@ -17,14 +18,20 @@ def run(args): ''' #load in non-gpt model in this driver function for now to avoid repeated loading later on if args.backend == 'llama': - if not args.quantize: - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") - model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") - elif args.quantize == 'qat': - pass - # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") - # model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8") - elif args.backend == 'spinquant': + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") + model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") + + if args.quantize and args.quantize=='ptq': + model.train() + model.qconfig = torch.quantization.get_default_qconfig('x86') + torch.quantization.prepare(model, inplace=True) + for _, mod in model.named_modules(): + if isinstance(mod, torch.nn.Embedding): + mod.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + model = torch.quantization.convert(model, inplace=True) + model.load_state_dict(torch.load('quant_experiments/quantized_model.pth')) + model.eval() + elif args.backend == 'qat': pass # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") # model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8")