mirror of
				https://github.com/princeton-nlp/tree-of-thought-llm
				synced 2025-06-26 18:26:00 +00:00 
			
		
		
		
	adding all files on benchmarks, data, new model tot scripts, new prompts - pt1
This commit is contained in:
		
							parent
							
								
									ff04bbcd41
								
							
						
					
					
						commit
						c69681e3a0
					
				
							
								
								
									
										39
									
								
								=0.26.0
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								=0.26.0
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,39 @@
 | 
				
			|||||||
 | 
					Collecting accelerate
 | 
				
			||||||
 | 
					  Downloading accelerate-1.2.0-py3-none-any.whl.metadata (19 kB)
 | 
				
			||||||
 | 
					Requirement already satisfied: numpy<3.0.0,>=1.17 in /opt/conda/lib/python3.10/site-packages (from accelerate) (1.25.2)
 | 
				
			||||||
 | 
					Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (24.1)
 | 
				
			||||||
 | 
					Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate) (5.9.3)
 | 
				
			||||||
 | 
					Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from accelerate) (6.0.2)
 | 
				
			||||||
 | 
					Requirement already satisfied: torch>=1.10.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (2.5.1)
 | 
				
			||||||
 | 
					Requirement already satisfied: huggingface-hub>=0.21.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (0.26.3)
 | 
				
			||||||
 | 
					Requirement already satisfied: safetensors>=0.4.3 in /opt/conda/lib/python3.10/site-packages (from accelerate) (0.4.5)
 | 
				
			||||||
 | 
					Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.1)
 | 
				
			||||||
 | 
					Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2024.10.0)
 | 
				
			||||||
 | 
					Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)
 | 
				
			||||||
 | 
					Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.67.0)
 | 
				
			||||||
 | 
					Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)
 | 
				
			||||||
 | 
					Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.4.2)
 | 
				
			||||||
 | 
					Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.4)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (9.1.0.70)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.5.8)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.2.1.3)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (10.3.5.147)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.6.1.9)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.3.1.170)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (2.21.5)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
 | 
				
			||||||
 | 
					Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
 | 
				
			||||||
 | 
					Requirement already satisfied: triton==3.1.0 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.0)
 | 
				
			||||||
 | 
					Requirement already satisfied: sympy==1.13.1 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (1.13.1)
 | 
				
			||||||
 | 
					Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy==1.13.1->torch>=1.10.0->accelerate) (1.3.0)
 | 
				
			||||||
 | 
					Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.10.0->accelerate) (3.0.2)
 | 
				
			||||||
 | 
					Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.4.0)
 | 
				
			||||||
 | 
					Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.10)
 | 
				
			||||||
 | 
					Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (1.26.20)
 | 
				
			||||||
 | 
					Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.8.30)
 | 
				
			||||||
 | 
					Downloading accelerate-1.2.0-py3-none-any.whl (336 kB)
 | 
				
			||||||
 | 
					Installing collected packages: accelerate
 | 
				
			||||||
 | 
					Successfully installed accelerate-1.2.0
 | 
				
			||||||
							
								
								
									
										75
									
								
								run.py
									
									
									
									
									
								
							
							
						
						
									
										75
									
								
								run.py
									
									
									
									
									
								
							@ -4,7 +4,8 @@ import argparse
 | 
				
			|||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from src.tot.tasks import get_task
 | 
					from src.tot.tasks import get_task
 | 
				
			||||||
from src.tot.methods.bfs import solve, naive_solve
 | 
					from src.tot.methods.bfs import solve, naive_solve, solve_bench
 | 
				
			||||||
 | 
					# from src.tot.methods.a_star import solve, naive_solve, solve_bench
 | 
				
			||||||
from src.tot.models import gpt_usage
 | 
					from src.tot.models import gpt_usage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
 | 
					from transformers import AutoTokenizer, AutoModelForCausalLM
 | 
				
			||||||
@ -16,19 +17,21 @@ def run(args):
 | 
				
			|||||||
    '''
 | 
					    '''
 | 
				
			||||||
    main run function
 | 
					    main run function
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    #bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading 
 | 
					    #load in specific llama model, if applicable
 | 
				
			||||||
 | 
					    #bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading in models.py
 | 
				
			||||||
    if args.backend == 'llama':
 | 
					    if args.backend == 'llama':
 | 
				
			||||||
 | 
					        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if args.quantize and args.quantize=='ptq_int4':
 | 
					        if args.quantize and args.quantize=='ptq_int4':
 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda")
 | 
					            model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda", weights_only=False)
 | 
				
			||||||
            model = torch.compile(model, mode="max-autotune")
 | 
					            model = torch.compile(model, mode="max-autotune")
 | 
				
			||||||
        if args.quantize and args.quantize=='ptq_int8':
 | 
					        elif args.quantize and args.quantize=='ptq_int8':
 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained("src/tot/ptq_int8", device_map="cuda")
 | 
					            model = AutoModelForCausalLM.from_pretrained("src/tot/quant/ptq_int8", device_map="cuda")
 | 
				
			||||||
            model = torch.compile(model, mode="max-autotune")
 | 
					            model = torch.compile(model, mode="max-autotune")
 | 
				
			||||||
        elif args.backend == 'qat':
 | 
					        elif args.quantize and args.quantize == 'qat':
 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained("src/tot/qat_int8", device_map="cuda")
 | 
					            model = AutoModelForCausalLM.from_pretrained("src/tot/quant/qat_int8", device_map="cuda", weights_only=False)
 | 
				
			||||||
            model = torch.compile(model, mode="max-autotune")
 | 
					            model = torch.compile(model, mode="max-autotune")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
 | 
					 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
 | 
					            model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    else: #gpt4 will be used later in this case
 | 
					    else: #gpt4 will be used later in this case
 | 
				
			||||||
@ -36,6 +39,44 @@ def run(args):
 | 
				
			|||||||
        tokenizer = None
 | 
					        tokenizer = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #set up
 | 
					    #set up
 | 
				
			||||||
 | 
					    if args.task == 'comp_bench':
 | 
				
			||||||
 | 
					        test_data = torch.load('src/tot/data/agg_dl_test.pt')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for sample in test_data:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            start_timer = perf.counter()
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            WHILE WE GO THROUGH THE TREE
 | 
				
			||||||
 | 
					            #solve
 | 
				
			||||||
 | 
					            ys, info = solve_bench(args, model, tokenizer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            runtime = time.perf_counter()-start_timer
 | 
				
			||||||
 | 
					            print(f"""For iteration {i} -- 
 | 
				
			||||||
 | 
					            Total Time to Solve: {runtime} seconds""")
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            #evaluate/rate the proposals
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            #compare the proposed final answer vs the ground truth
 | 
				
			||||||
 | 
					            gt = tokenizer.decode(sample['label'][0], return_tensors='pt').replace("<|eot_id|>", "").replace("<|begin_of_text|>", "").strip()
 | 
				
			||||||
 | 
					            # log
 | 
				
			||||||
 | 
					            # infos = [task.test_output(i, y) for y in ys]
 | 
				
			||||||
 | 
					            # info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far (gpt only)': gpt_usage(args.backend), 'total_runtime': runtime})
 | 
				
			||||||
 | 
					            # logs.append(info)
 | 
				
			||||||
 | 
					            # with open(file, 'w') as f:
 | 
				
			||||||
 | 
					            #     json.dump(logs, f, indent=4)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # log main metric
 | 
				
			||||||
 | 
					            # accs = [info['r'] for info in infos]
 | 
				
			||||||
 | 
					            # cnt_avg += sum(accs) / len(accs)
 | 
				
			||||||
 | 
					            # cnt_any += any(accs)
 | 
				
			||||||
 | 
					            # print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # n = args.task_end_index - args.task_start_index
 | 
				
			||||||
 | 
					        # print(cnt_avg / n, cnt_any / n)
 | 
				
			||||||
 | 
					        # print('usage_so_far', gpt_usage(args.backend))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    else: #the original tot tasks from the paper
 | 
				
			||||||
        task = get_task(args.task)
 | 
					        task = get_task(args.task)
 | 
				
			||||||
        logs, cnt_avg, cnt_any = [], 0, 0
 | 
					        logs, cnt_avg, cnt_any = [], 0, 0
 | 
				
			||||||
        if args.naive_run:
 | 
					        if args.naive_run:
 | 
				
			||||||
@ -82,25 +123,23 @@ def parse_args():
 | 
				
			|||||||
    '''
 | 
					    '''
 | 
				
			||||||
    args = argparse.ArgumentParser()
 | 
					    args = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #what model to use
 | 
					    #the arguments to use for our purposes
 | 
				
			||||||
    args.add_argument('--backend', type=str, choices=['gpt-4o', 'llama'], default='gpt-4o')
 | 
					    args.add_argument('--backend', type=str, choices=['gpt-4o', 'llama'], default='gpt-4o')
 | 
				
			||||||
    args.add_argument('--quantize', type=str, choices=['qat', 'ptq', 'spinquant'])
 | 
					    args.add_argument('--quantize', type=str, choices=['qat', 'ptq_int4', 'ptq_int8'])
 | 
				
			||||||
 | 
					 | 
				
			||||||
    #what temperature to use
 | 
					 | 
				
			||||||
    args.add_argument('--temperature', type=float, default=0.0)
 | 
					    args.add_argument('--temperature', type=float, default=0.0)
 | 
				
			||||||
 | 
					    args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
 | 
				
			||||||
 | 
					    args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
 | 
				
			||||||
 | 
					    args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #the problem task
 | 
					    #other args from the original repo 
 | 
				
			||||||
    args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
 | 
					    args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords', 'comp_bench'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #which tasks from the data file to solve
 | 
					 | 
				
			||||||
    args.add_argument('--task_start_index', type=int, default=900)
 | 
					    args.add_argument('--task_start_index', type=int, default=900)
 | 
				
			||||||
    args.add_argument('--task_end_index', type=int, default=1000)
 | 
					    args.add_argument('--task_end_index', type=int, default=1000)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    args.add_argument('--naive_run', action='store_true')
 | 
					    args.add_argument('--naive_run', action='store_true')
 | 
				
			||||||
    args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot'])  # only used when method_generate = sample, or naive_run
 | 
					    args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot'])  # only used when method_generate = sample, or naive_run
 | 
				
			||||||
    args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
 | 
					
 | 
				
			||||||
    args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
 | 
					 | 
				
			||||||
    args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
 | 
					 | 
				
			||||||
    args.add_argument('--n_generate_sample', type=int, default=1)  # only thing needed if naive_run
 | 
					    args.add_argument('--n_generate_sample', type=int, default=1)  # only thing needed if naive_run
 | 
				
			||||||
    args.add_argument('--n_evaluate_sample', type=int, default=1)
 | 
					    args.add_argument('--n_evaluate_sample', type=int, default=1)
 | 
				
			||||||
    args.add_argument('--n_select_sample', type=int, default=1)
 | 
					    args.add_argument('--n_select_sample', type=int, default=1)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										43
									
								
								run_bench.py
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								run_bench.py
									
									
									
									
									
								
							@ -27,9 +27,6 @@ def load_llama(quant=None):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return model, tokenizer
 | 
					    return model, tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def propose(problem, current_state, tokenizer, model, device):
 | 
					 | 
				
			||||||
    pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def value_proposals(problem, current_state, proposals, tokenizer, model, device):
 | 
					def value_proposals(problem, current_state, proposals, tokenizer, model, device):
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    Takes in string values of problem, current state, and proposals. 
 | 
					    Takes in string values of problem, current state, and proposals. 
 | 
				
			||||||
@ -38,7 +35,7 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device)
 | 
				
			|||||||
    valuations = []
 | 
					    valuations = []
 | 
				
			||||||
    prompts = []
 | 
					    prompts = []
 | 
				
			||||||
    for p in proposals:
 | 
					    for p in proposals:
 | 
				
			||||||
        prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=proposal))
 | 
					        prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p))
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    values = tokenizer(prompts, return_tensors='pt')
 | 
					    values = tokenizer(prompts, return_tensors='pt')
 | 
				
			||||||
    value_inputs = values['input_ids'].to(device)
 | 
					    value_inputs = values['input_ids'].to(device)
 | 
				
			||||||
@ -58,8 +55,19 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return valuations
 | 
					    return valuations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def final_eval():
 | 
					def final_eval(gt, final_prop):
 | 
				
			||||||
    pass
 | 
					    '''
 | 
				
			||||||
 | 
					    compare the ground truth and final proposed solution by the model
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    print("THIS IS THE FINAL PROP")
 | 
				
			||||||
 | 
					    print(final_prop)
 | 
				
			||||||
 | 
					    print("THIS IS THE GT")
 | 
				
			||||||
 | 
					    print(gt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if gt in final_prop:
 | 
				
			||||||
 | 
					        return 1.0
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run(args):
 | 
					def run(args):
 | 
				
			||||||
@ -83,6 +91,9 @@ def run(args):
 | 
				
			|||||||
    #set up
 | 
					    #set up
 | 
				
			||||||
    test_data = torch.load('src/tot/data/agg_dl_test.pt')
 | 
					    test_data = torch.load('src/tot/data/agg_dl_test.pt')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    total = 0
 | 
				
			||||||
 | 
					    right = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for samples in test_data:
 | 
					    for samples in test_data:
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        for sample in samples: #going to do this one problem at a time for now.
 | 
					        for sample in samples: #going to do this one problem at a time for now.
 | 
				
			||||||
@ -97,6 +108,7 @@ def run(args):
 | 
				
			|||||||
            #start solving via tot
 | 
					            #start solving via tot
 | 
				
			||||||
            start_timer = perf.counter()
 | 
					            start_timer = perf.counter()
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
 | 
					            selected = ""
 | 
				
			||||||
            for i in range(args.depth): #args.depth number of attempts to reach the solution
 | 
					            for i in range(args.depth): #args.depth number of attempts to reach the solution
 | 
				
			||||||
                
 | 
					                
 | 
				
			||||||
                #propose next step/solutions per node/prompt
 | 
					                #propose next step/solutions per node/prompt
 | 
				
			||||||
@ -105,7 +117,7 @@ def run(args):
 | 
				
			|||||||
                    input_ids,
 | 
					                    input_ids,
 | 
				
			||||||
                    attention_mask=mask,
 | 
					                    attention_mask=mask,
 | 
				
			||||||
                    temperature=args.temperature, 
 | 
					                    temperature=args.temperature, 
 | 
				
			||||||
                    max_new_tokens=max_tokens, 
 | 
					                    max_new_tokens=args.max_tokens, 
 | 
				
			||||||
                    num_return_sequences=args.breadth)
 | 
					                    num_return_sequences=args.breadth)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                
 | 
					                
 | 
				
			||||||
@ -123,21 +135,28 @@ def run(args):
 | 
				
			|||||||
                if 100.0 in valuations:
 | 
					                if 100.0 in valuations:
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
                
 | 
					                
 | 
				
			||||||
                #select the proposals that act as input to the next iteration
 | 
					                #select the best proposal
 | 
				
			||||||
                val_props = list(zip(proposals, valuations))
 | 
					                val_props = list(zip(proposals, valuations))
 | 
				
			||||||
                valuations.sort(key = lambda ele: ele[1], descending=True)
 | 
					                val_props.sort(key = lambda ele: ele[1], descending=True)
 | 
				
			||||||
                selected = valuations[:args.greedy_n]
 | 
					                selected = val_props[:args.greedy_n]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                inputs = tokenizer(valuations, return_tensors='pt')
 | 
					                #format the chosen proposal for the next iteration
 | 
				
			||||||
 | 
					                next_prompt = propose_prompt.format(problem=problem, current_state=selected)
 | 
				
			||||||
 | 
					                inputs = tokenizer(next_prompt, return_tensors='pt')
 | 
				
			||||||
                input_ids = inputs['input_ids'].to(device)
 | 
					                input_ids = inputs['input_ids'].to(device)
 | 
				
			||||||
                mask = inputs['attention_mask'].to(device)
 | 
					                mask = inputs['attention_mask'].to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            #compare the proposed final answer vs the ground truth
 | 
					            #compare the proposed final answer vs the ground truth
 | 
				
			||||||
            gt = tokenizer.decode(label, skip_special_token=True)
 | 
					            gt = tokenizer.decode(label, skip_special_token=True)
 | 
				
			||||||
 | 
					            judgement = final_eval(gt, selected)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            judgement = final_eval(gt, final_proposal)
 | 
					            #keep track of the running totals
 | 
				
			||||||
 | 
					            total += 1.0
 | 
				
			||||||
 | 
					            right += judgement
 | 
				
			||||||
 | 
					            print("Accuracy so far: ", total/right)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    total_accuracy = right/total
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_args():
 | 
					def parse_args():
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										0
									
								
								src/tot/data/benchmark/bench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/tot/data/benchmark/bench.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								src/tot/data/test.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/tot/data/test.ipynb
									
									
									
									
									
										Normal file
									
								
							@ -36,7 +36,7 @@ def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
 | 
				
			|||||||
    return values, times
 | 
					    return values, times
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_votes(task, x, ys, n_evaluate_sample):
 | 
					def get_votes(task, x, ys, n_evaluate_sample):
 | 
				
			||||||
    vote_prompt = task.vote_prompt_wrap(x, ys)
 | 
					    vote_prompt = task.cot_prompt_wrap(x, ys)
 | 
				
			||||||
    vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None)
 | 
					    vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None)
 | 
				
			||||||
    values = task.vote_outputs_unwrap(vote_outputs, len(ys))
 | 
					    values = task.vote_outputs_unwrap(vote_outputs, len(ys))
 | 
				
			||||||
    return values
 | 
					    return values
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,7 @@
 | 
				
			|||||||
import os
 | 
					import os
 | 
				
			||||||
import openai
 | 
					import openai
 | 
				
			||||||
from openai import OpenAI
 | 
					from openai import OpenAI
 | 
				
			||||||
import backoff 
 | 
					# import backoff 
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
@ -41,7 +41,8 @@ def hf_model(model, tokenizer, prompt, temperature=0.7, max_tokens=1000, n=5, st
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
    outputs = []
 | 
					    outputs = []
 | 
				
			||||||
    all_times = []
 | 
					    all_times = []
 | 
				
			||||||
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | 
					    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | 
				
			||||||
 | 
					    device = 'cpu'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #tokenize inputs
 | 
					    #tokenize inputs
 | 
				
			||||||
    inputs = tokenizer(prompt, return_tensors="pt")
 | 
					    inputs = tokenizer(prompt, return_tensors="pt")
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										48
									
								
								src/tot/quant/hf_quant_int4/config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								src/tot/quant/hf_quant_int4/config.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,48 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "_name_or_path": "meta-llama/Llama-3.2-3B-Instruct",
 | 
				
			||||||
 | 
					  "architectures": [
 | 
				
			||||||
 | 
					    "LlamaForCausalLM"
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "attention_bias": false,
 | 
				
			||||||
 | 
					  "attention_dropout": 0.0,
 | 
				
			||||||
 | 
					  "bos_token_id": 128000,
 | 
				
			||||||
 | 
					  "eos_token_id": [
 | 
				
			||||||
 | 
					    128001,
 | 
				
			||||||
 | 
					    128008,
 | 
				
			||||||
 | 
					    128009
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "head_dim": 128,
 | 
				
			||||||
 | 
					  "hidden_act": "silu",
 | 
				
			||||||
 | 
					  "hidden_size": 3072,
 | 
				
			||||||
 | 
					  "initializer_range": 0.02,
 | 
				
			||||||
 | 
					  "intermediate_size": 8192,
 | 
				
			||||||
 | 
					  "max_position_embeddings": 131072,
 | 
				
			||||||
 | 
					  "mlp_bias": false,
 | 
				
			||||||
 | 
					  "model_type": "llama",
 | 
				
			||||||
 | 
					  "num_attention_heads": 24,
 | 
				
			||||||
 | 
					  "num_hidden_layers": 28,
 | 
				
			||||||
 | 
					  "num_key_value_heads": 8,
 | 
				
			||||||
 | 
					  "pretraining_tp": 1,
 | 
				
			||||||
 | 
					  "quantization_config": {
 | 
				
			||||||
 | 
					    "modules_to_not_convert": null,
 | 
				
			||||||
 | 
					    "quant_method": "torchao",
 | 
				
			||||||
 | 
					    "quant_type": "int4_weight_only",
 | 
				
			||||||
 | 
					    "quant_type_kwargs": {
 | 
				
			||||||
 | 
					      "group_size": 128
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "rms_norm_eps": 1e-05,
 | 
				
			||||||
 | 
					  "rope_scaling": {
 | 
				
			||||||
 | 
					    "factor": 32.0,
 | 
				
			||||||
 | 
					    "high_freq_factor": 4.0,
 | 
				
			||||||
 | 
					    "low_freq_factor": 1.0,
 | 
				
			||||||
 | 
					    "original_max_position_embeddings": 8192,
 | 
				
			||||||
 | 
					    "rope_type": "llama3"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "rope_theta": 500000.0,
 | 
				
			||||||
 | 
					  "tie_word_embeddings": true,
 | 
				
			||||||
 | 
					  "torch_dtype": "bfloat16",
 | 
				
			||||||
 | 
					  "transformers_version": "4.46.2",
 | 
				
			||||||
 | 
					  "use_cache": true,
 | 
				
			||||||
 | 
					  "vocab_size": 128256
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										12
									
								
								src/tot/quant/hf_quant_int4/generation_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								src/tot/quant/hf_quant_int4/generation_config.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "bos_token_id": 128000,
 | 
				
			||||||
 | 
					  "do_sample": true,
 | 
				
			||||||
 | 
					  "eos_token_id": [
 | 
				
			||||||
 | 
					    128001,
 | 
				
			||||||
 | 
					    128008,
 | 
				
			||||||
 | 
					    128009
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "temperature": 0.6,
 | 
				
			||||||
 | 
					  "top_p": 0.9,
 | 
				
			||||||
 | 
					  "transformers_version": "4.46.2"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										40
									
								
								src/tot/quant/ptq_int8/config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								src/tot/quant/ptq_int8/config.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,40 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "_name_or_path": "meta-llama/Llama-3.2-3B-Instruct",
 | 
				
			||||||
 | 
					  "architectures": [
 | 
				
			||||||
 | 
					    "LlamaForCausalLM"
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "attention_bias": false,
 | 
				
			||||||
 | 
					  "attention_dropout": 0.0,
 | 
				
			||||||
 | 
					  "bos_token_id": 128000,
 | 
				
			||||||
 | 
					  "eos_token_id": [
 | 
				
			||||||
 | 
					    128001,
 | 
				
			||||||
 | 
					    128008,
 | 
				
			||||||
 | 
					    128009
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "head_dim": 128,
 | 
				
			||||||
 | 
					  "hidden_act": "silu",
 | 
				
			||||||
 | 
					  "hidden_size": 3072,
 | 
				
			||||||
 | 
					  "initializer_range": 0.02,
 | 
				
			||||||
 | 
					  "intermediate_size": 8192,
 | 
				
			||||||
 | 
					  "max_position_embeddings": 131072,
 | 
				
			||||||
 | 
					  "mlp_bias": false,
 | 
				
			||||||
 | 
					  "model_type": "llama",
 | 
				
			||||||
 | 
					  "num_attention_heads": 24,
 | 
				
			||||||
 | 
					  "num_hidden_layers": 28,
 | 
				
			||||||
 | 
					  "num_key_value_heads": 8,
 | 
				
			||||||
 | 
					  "pretraining_tp": 1,
 | 
				
			||||||
 | 
					  "rms_norm_eps": 1e-05,
 | 
				
			||||||
 | 
					  "rope_scaling": {
 | 
				
			||||||
 | 
					    "factor": 32.0,
 | 
				
			||||||
 | 
					    "high_freq_factor": 4.0,
 | 
				
			||||||
 | 
					    "low_freq_factor": 1.0,
 | 
				
			||||||
 | 
					    "original_max_position_embeddings": 8192,
 | 
				
			||||||
 | 
					    "rope_type": "llama3"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "rope_theta": 500000.0,
 | 
				
			||||||
 | 
					  "tie_word_embeddings": true,
 | 
				
			||||||
 | 
					  "torch_dtype": "float32",
 | 
				
			||||||
 | 
					  "transformers_version": "4.46.2",
 | 
				
			||||||
 | 
					  "use_cache": true,
 | 
				
			||||||
 | 
					  "vocab_size": 128256
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										12
									
								
								src/tot/quant/ptq_int8/generation_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								src/tot/quant/ptq_int8/generation_config.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "bos_token_id": 128000,
 | 
				
			||||||
 | 
					  "do_sample": true,
 | 
				
			||||||
 | 
					  "eos_token_id": [
 | 
				
			||||||
 | 
					    128001,
 | 
				
			||||||
 | 
					    128008,
 | 
				
			||||||
 | 
					    128009
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "temperature": 0.6,
 | 
				
			||||||
 | 
					  "top_p": 0.9,
 | 
				
			||||||
 | 
					  "transformers_version": "4.46.2"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										40
									
								
								src/tot/quant/qat_int8/config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								src/tot/quant/qat_int8/config.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,40 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "_name_or_path": "meta-llama/Llama-3.2-3B-Instruct",
 | 
				
			||||||
 | 
					  "architectures": [
 | 
				
			||||||
 | 
					    "LlamaForCausalLM"
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "attention_bias": false,
 | 
				
			||||||
 | 
					  "attention_dropout": 0.0,
 | 
				
			||||||
 | 
					  "bos_token_id": 128000,
 | 
				
			||||||
 | 
					  "eos_token_id": [
 | 
				
			||||||
 | 
					    128001,
 | 
				
			||||||
 | 
					    128008,
 | 
				
			||||||
 | 
					    128009
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "head_dim": 128,
 | 
				
			||||||
 | 
					  "hidden_act": "silu",
 | 
				
			||||||
 | 
					  "hidden_size": 3072,
 | 
				
			||||||
 | 
					  "initializer_range": 0.02,
 | 
				
			||||||
 | 
					  "intermediate_size": 8192,
 | 
				
			||||||
 | 
					  "max_position_embeddings": 131072,
 | 
				
			||||||
 | 
					  "mlp_bias": false,
 | 
				
			||||||
 | 
					  "model_type": "llama",
 | 
				
			||||||
 | 
					  "num_attention_heads": 24,
 | 
				
			||||||
 | 
					  "num_hidden_layers": 28,
 | 
				
			||||||
 | 
					  "num_key_value_heads": 8,
 | 
				
			||||||
 | 
					  "pretraining_tp": 1,
 | 
				
			||||||
 | 
					  "rms_norm_eps": 1e-05,
 | 
				
			||||||
 | 
					  "rope_scaling": {
 | 
				
			||||||
 | 
					    "factor": 32.0,
 | 
				
			||||||
 | 
					    "high_freq_factor": 4.0,
 | 
				
			||||||
 | 
					    "low_freq_factor": 1.0,
 | 
				
			||||||
 | 
					    "original_max_position_embeddings": 8192,
 | 
				
			||||||
 | 
					    "rope_type": "llama3"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "rope_theta": 500000.0,
 | 
				
			||||||
 | 
					  "tie_word_embeddings": true,
 | 
				
			||||||
 | 
					  "torch_dtype": "float32",
 | 
				
			||||||
 | 
					  "transformers_version": "4.46.2",
 | 
				
			||||||
 | 
					  "use_cache": true,
 | 
				
			||||||
 | 
					  "vocab_size": 128256
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										12
									
								
								src/tot/quant/qat_int8/generation_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								src/tot/quant/qat_int8/generation_config.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "bos_token_id": 128000,
 | 
				
			||||||
 | 
					  "do_sample": true,
 | 
				
			||||||
 | 
					  "eos_token_id": [
 | 
				
			||||||
 | 
					    128001,
 | 
				
			||||||
 | 
					    128008,
 | 
				
			||||||
 | 
					    128009
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "temperature": 0.6,
 | 
				
			||||||
 | 
					  "top_p": 0.9,
 | 
				
			||||||
 | 
					  "transformers_version": "4.46.2"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -8,5 +8,8 @@ def get_task(name):
 | 
				
			|||||||
    elif name == 'crosswords':
 | 
					    elif name == 'crosswords':
 | 
				
			||||||
        from src.tot.tasks.crosswords import MiniCrosswordsTask
 | 
					        from src.tot.tasks.crosswords import MiniCrosswordsTask
 | 
				
			||||||
        return MiniCrosswordsTask()
 | 
					        return MiniCrosswordsTask()
 | 
				
			||||||
 | 
					    elif name == 'comp_bench':
 | 
				
			||||||
 | 
					        from src.tot.tasks.bench import Bench
 | 
				
			||||||
 | 
					        return Bench()
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
@ -5,11 +5,11 @@ class Task:
 | 
				
			|||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __len__(self) -> int:
 | 
					    # def __len__(self) -> int:
 | 
				
			||||||
        pass
 | 
					    #     pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_input(self, idx: int) -> str:
 | 
					    # def get_input(self, idx: int) -> str:
 | 
				
			||||||
        pass
 | 
					    #     pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_output(self, idx: int, output: str):
 | 
					    def test_output(self, idx: int, output: str):
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
							
								
								
									
										94
									
								
								src/tot/tasks/bench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								src/tot/tasks/bench.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,94 @@
 | 
				
			|||||||
 | 
					import re
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import sympy
 | 
				
			||||||
 | 
					import pandas as pd
 | 
				
			||||||
 | 
					from src.tot.tasks.base import Task, DATA_PATH
 | 
				
			||||||
 | 
					from src.tot.prompts.game24 import * 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_current_numbers(y: str) -> str:
 | 
				
			||||||
 | 
					    last_line = y.strip().split('\n')[-1]
 | 
				
			||||||
 | 
					    return last_line.split('left: ')[-1].split(')')[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Game24Task(Task):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Input (x)   : a string of 
 | 
				
			||||||
 | 
					    Output (y)  : a trajectory of 3 steps to reach 24
 | 
				
			||||||
 | 
					    Reward (r)  : 0 or 1, depending on whether the trajectory is correct
 | 
				
			||||||
 | 
					    Input Example: 
 | 
				
			||||||
 | 
					        1 2 3 4
 | 
				
			||||||
 | 
					    Output Example: 
 | 
				
			||||||
 | 
					        1 + 2 = 3 (left: 3 3 4)
 | 
				
			||||||
 | 
					        3 + 3 = 6 (left: 4 6)
 | 
				
			||||||
 | 
					        6 * 4 = 24 (left: 24)
 | 
				
			||||||
 | 
					        (1 + 2 + 3) * 4 = 24
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, file='24.csv', depth=5):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        file: a csv file (fixed)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.data = torch.load('../data/agg_dl_test.pt')
 | 
				
			||||||
 | 
					        self.value_cache = {}
 | 
				
			||||||
 | 
					        self.steps = depth
 | 
				
			||||||
 | 
					        self.stops = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self) -> int:
 | 
				
			||||||
 | 
					        return len(self.data)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def get_input(self, idx: int) -> str:
 | 
				
			||||||
 | 
					        return self.data[idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_output(self, idx: int, output: str):
 | 
				
			||||||
 | 
					        expression = output.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0]
 | 
				
			||||||
 | 
					        numbers = re.findall(r'\d+', expression)
 | 
				
			||||||
 | 
					        problem_numbers = re.findall(r'\d+', self.data[idx])
 | 
				
			||||||
 | 
					        if sorted(numbers) != sorted(problem_numbers):
 | 
				
			||||||
 | 
					            print("sorted numbers length not equal to sorted problem numbers length")
 | 
				
			||||||
 | 
					            return {'r': 0}
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            # print(sympy.simplify(expression))
 | 
				
			||||||
 | 
					            return {'r': int(sympy.simplify(expression) == 24)}
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            # print(e)
 | 
				
			||||||
 | 
					            print("entered exception!!")
 | 
				
			||||||
 | 
					            return {'r': 0}
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def standard_prompt_wrap(x: str, y:str='') -> str:
 | 
				
			||||||
 | 
					        return standard_prompt.format(input=x) + y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def cot_prompt_wrap(x: str, y:str='') -> str:
 | 
				
			||||||
 | 
					        return cot_prompt.format(input=x) + y
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def propose_prompt_wrap(x: str, y: str='') -> str:
 | 
				
			||||||
 | 
					        current_numbers = get_current_numbers(y if y else x)
 | 
				
			||||||
 | 
					        if current_numbers == '24':
 | 
				
			||||||
 | 
					            prompt = cot_prompt.format(input=x) + 'Steps:' + y
 | 
				
			||||||
 | 
					            # print([prompt])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            prompt = propose_prompt.format(input=current_numbers)
 | 
				
			||||||
 | 
					        return prompt
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def value_prompt_wrap(x: str, y: str) -> str:
 | 
				
			||||||
 | 
					        last_line = y.strip().split('\n')[-1]
 | 
				
			||||||
 | 
					        if 'left: ' not in last_line:  # last step
 | 
				
			||||||
 | 
					            ans = last_line.lower().replace('answer: ', '')
 | 
				
			||||||
 | 
					            # print([value_last_step_prompt.format(input=x, answer=ans)])
 | 
				
			||||||
 | 
					            return value_last_step_prompt.format(input=x, answer=ans)
 | 
				
			||||||
 | 
					        current_numbers = get_current_numbers(y)
 | 
				
			||||||
 | 
					        return value_prompt.format(input=current_numbers)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
 | 
				
			||||||
 | 
					        if len(y.strip().split('\n')) == 4 and 'answer' not in y.lower():
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        value_names = [_.split('\n')[-1] for _ in value_outputs]
 | 
				
			||||||
 | 
					        value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}  # TODO: ad hoc
 | 
				
			||||||
 | 
					        value = sum(value * value_names.count(name) for name, value in value_map.items())
 | 
				
			||||||
 | 
					        normalized_value = value / (len(value_names) * value_map['sure'])   # scale to [0, 1]
 | 
				
			||||||
 | 
					        return normalized_value
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user