adding ptq model support

This commit is contained in:
emilyworks 2024-12-04 00:26:31 -05:00
parent a0015eb0f7
commit 975a6e8608

19
run.py
View File

@ -9,6 +9,7 @@ from src.tot.models import gpt_usage
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch import torch
import torch.quantization
def run(args): def run(args):
@ -17,14 +18,20 @@ def run(args):
''' '''
#load in non-gpt model in this driver function for now to avoid repeated loading later on #load in non-gpt model in this driver function for now to avoid repeated loading later on
if args.backend == 'llama': if args.backend == 'llama':
if not args.quantize:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
elif args.quantize == 'qat':
pass if args.quantize and args.quantize=='ptq':
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") model.train()
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8") model.qconfig = torch.quantization.get_default_qconfig('x86')
elif args.backend == 'spinquant': torch.quantization.prepare(model, inplace=True)
for _, mod in model.named_modules():
if isinstance(mod, torch.nn.Embedding):
mod.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
model = torch.quantization.convert(model, inplace=True)
model.load_state_dict(torch.load('quant_experiments/quantized_model.pth'))
model.eval()
elif args.backend == 'qat':
pass pass
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8") # model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8")