adding all files on benchmarks, data, new model tot scripts, new prompts - pt1

This commit is contained in:
emilyworks 2024-12-07 03:23:12 +00:00
parent ff04bbcd41
commit c69681e3a0
16 changed files with 430 additions and 71 deletions

39
=0.26.0 Normal file
View File

@ -0,0 +1,39 @@
Collecting accelerate
Downloading accelerate-1.2.0-py3-none-any.whl.metadata (19 kB)
Requirement already satisfied: numpy<3.0.0,>=1.17 in /opt/conda/lib/python3.10/site-packages (from accelerate) (1.25.2)
Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (24.1)
Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate) (5.9.3)
Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from accelerate) (6.0.2)
Requirement already satisfied: torch>=1.10.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (2.5.1)
Requirement already satisfied: huggingface-hub>=0.21.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (0.26.3)
Requirement already satisfied: safetensors>=0.4.3 in /opt/conda/lib/python3.10/site-packages (from accelerate) (0.4.5)
Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.1)
Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2024.10.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)
Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.67.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)
Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.4.2)
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.4)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (9.1.0.70)
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.5.8)
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.2.1.3)
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (10.3.5.147)
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.6.1.9)
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.3.1.170)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (2.21.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
Requirement already satisfied: triton==3.1.0 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.0)
Requirement already satisfied: sympy==1.13.1 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy==1.13.1->torch>=1.10.0->accelerate) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.10.0->accelerate) (3.0.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (1.26.20)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.8.30)
Downloading accelerate-1.2.0-py3-none-any.whl (336 kB)
Installing collected packages: accelerate
Successfully installed accelerate-1.2.0

141
run.py
View File

@ -4,7 +4,8 @@ import argparse
import time
from src.tot.tasks import get_task
from src.tot.methods.bfs import solve, naive_solve
from src.tot.methods.bfs import solve, naive_solve, solve_bench
# from src.tot.methods.a_star import solve, naive_solve, solve_bench
from src.tot.models import gpt_usage
from transformers import AutoTokenizer, AutoModelForCausalLM
@ -16,19 +17,21 @@ def run(args):
'''
main run function
'''
#bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading
#load in specific llama model, if applicable
#bc of the way the original repo is structured, will need to load in llama models in run.py to avoid repeated loading in models.py
if args.backend == 'llama':
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
if args.quantize and args.quantize=='ptq_int4':
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda")
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/hf_quant_int4", device_map="cuda", weights_only=False)
model = torch.compile(model, mode="max-autotune")
if args.quantize and args.quantize=='ptq_int8':
model = AutoModelForCausalLM.from_pretrained("src/tot/ptq_int8", device_map="cuda")
elif args.quantize and args.quantize=='ptq_int8':
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/ptq_int8", device_map="cuda")
model = torch.compile(model, mode="max-autotune")
elif args.backend == 'qat':
model = AutoModelForCausalLM.from_pretrained("src/tot/qat_int8", device_map="cuda")
elif args.quantize and args.quantize == 'qat':
model = AutoModelForCausalLM.from_pretrained("src/tot/quant/qat_int8", device_map="cuda", weights_only=False)
model = torch.compile(model, mode="max-autotune")
else:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
else: #gpt4 will be used later in this case
@ -36,44 +39,82 @@ def run(args):
tokenizer = None
#set up
task = get_task(args.task)
logs, cnt_avg, cnt_any = [], 0, 0
if args.naive_run:
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
else:
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
os.makedirs(os.path.dirname(file), exist_ok=True)
if args.task == 'comp_bench':
test_data = torch.load('src/tot/data/agg_dl_test.pt')
#run the specified range of tasks
for i in range(args.task_start_index, args.task_end_index):
for sample in test_data:
# solve
start_timer = time.perf_counter()
if args.naive_run:
ys, info = naive_solve(args, task, i, model, tokenizer)
else:
ys, info = solve(args, task, i, model, tokenizer)
start_timer = perf.counter()
WHILE WE GO THROUGH THE TREE
#solve
ys, info = solve_bench(args, model, tokenizer)
runtime = time.perf_counter()-start_timer
print(f"""For iteration {i} --
Total Time to Solve: {runtime} seconds""")
runtime = time.perf_counter()-start_timer
print(f"""For iteration {i} --
Total Time to Solve: {runtime} seconds""")
#evaluate/rate the proposals
# log
infos = [task.test_output(i, y) for y in ys]
info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far (gpt only)': gpt_usage(args.backend), 'total_runtime': runtime})
logs.append(info)
with open(file, 'w') as f:
json.dump(logs, f, indent=4)
#compare the proposed final answer vs the ground truth
gt = tokenizer.decode(sample['label'][0], return_tensors='pt').replace("<|eot_id|>", "").replace("<|begin_of_text|>", "").strip()
# log
# infos = [task.test_output(i, y) for y in ys]
# info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far (gpt only)': gpt_usage(args.backend), 'total_runtime': runtime})
# logs.append(info)
# with open(file, 'w') as f:
# json.dump(logs, f, indent=4)
# log main metric
# accs = [info['r'] for info in infos]
# cnt_avg += sum(accs) / len(accs)
# cnt_any += any(accs)
# print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
# log main metric
accs = [info['r'] for info in infos]
cnt_avg += sum(accs) / len(accs)
cnt_any += any(accs)
print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
n = args.task_end_index - args.task_start_index
print(cnt_avg / n, cnt_any / n)
print('usage_so_far', gpt_usage(args.backend))
# n = args.task_end_index - args.task_start_index
# print(cnt_avg / n, cnt_any / n)
# print('usage_so_far', gpt_usage(args.backend))
else: #the original tot tasks from the paper
task = get_task(args.task)
logs, cnt_avg, cnt_any = [], 0, 0
if args.naive_run:
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
else:
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
os.makedirs(os.path.dirname(file), exist_ok=True)
#run the specified range of tasks
for i in range(args.task_start_index, args.task_end_index):
# solve
start_timer = time.perf_counter()
if args.naive_run:
ys, info = naive_solve(args, task, i, model, tokenizer)
else:
ys, info = solve(args, task, i, model, tokenizer)
runtime = time.perf_counter()-start_timer
print(f"""For iteration {i} --
Total Time to Solve: {runtime} seconds""")
# log
infos = [task.test_output(i, y) for y in ys]
info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far (gpt only)': gpt_usage(args.backend), 'total_runtime': runtime})
logs.append(info)
with open(file, 'w') as f:
json.dump(logs, f, indent=4)
# log main metric
accs = [info['r'] for info in infos]
cnt_avg += sum(accs) / len(accs)
cnt_any += any(accs)
print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
n = args.task_end_index - args.task_start_index
print(cnt_avg / n, cnt_any / n)
print('usage_so_far', gpt_usage(args.backend))
def parse_args():
@ -82,25 +123,23 @@ def parse_args():
'''
args = argparse.ArgumentParser()
#what model to use
#the arguments to use for our purposes
args.add_argument('--backend', type=str, choices=['gpt-4o', 'llama'], default='gpt-4o')
args.add_argument('--quantize', type=str, choices=['qat', 'ptq', 'spinquant'])
#what temperature to use
args.add_argument('--quantize', type=str, choices=['qat', 'ptq_int4', 'ptq_int8'])
args.add_argument('--temperature', type=float, default=0.0)
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
#the problem task
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
#other args from the original repo
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords', 'comp_bench'])
#which tasks from the data file to solve
args.add_argument('--task_start_index', type=int, default=900)
args.add_argument('--task_end_index', type=int, default=1000)
args.add_argument('--naive_run', action='store_true')
args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot']) # only used when method_generate = sample, or naive_run
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
args.add_argument('--n_evaluate_sample', type=int, default=1)
args.add_argument('--n_select_sample', type=int, default=1)

View File

@ -27,9 +27,6 @@ def load_llama(quant=None):
return model, tokenizer
def propose(problem, current_state, tokenizer, model, device):
pass
def value_proposals(problem, current_state, proposals, tokenizer, model, device):
'''
Takes in string values of problem, current state, and proposals.
@ -38,7 +35,7 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device)
valuations = []
prompts = []
for p in proposals:
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=proposal))
prompts.append(value_prompt.format(problem=problem, current_state=current_state, proposal=p))
values = tokenizer(prompts, return_tensors='pt')
value_inputs = values['input_ids'].to(device)
@ -58,8 +55,19 @@ def value_proposals(problem, current_state, proposals, tokenizer, model, device)
return valuations
def final_eval():
pass
def final_eval(gt, final_prop):
'''
compare the ground truth and final proposed solution by the model
'''
print("THIS IS THE FINAL PROP")
print(final_prop)
print("THIS IS THE GT")
print(gt)
if gt in final_prop:
return 1.0
else:
return 0.0
def run(args):
@ -83,6 +91,9 @@ def run(args):
#set up
test_data = torch.load('src/tot/data/agg_dl_test.pt')
total = 0
right = 0
for samples in test_data:
for sample in samples: #going to do this one problem at a time for now.
@ -97,6 +108,7 @@ def run(args):
#start solving via tot
start_timer = perf.counter()
selected = ""
for i in range(args.depth): #args.depth number of attempts to reach the solution
#propose next step/solutions per node/prompt
@ -105,7 +117,7 @@ def run(args):
input_ids,
attention_mask=mask,
temperature=args.temperature,
max_new_tokens=max_tokens,
max_new_tokens=args.max_tokens,
num_return_sequences=args.breadth)
@ -123,21 +135,28 @@ def run(args):
if 100.0 in valuations:
break
#select the proposals that act as input to the next iteration
#select the best proposal
val_props = list(zip(proposals, valuations))
valuations.sort(key = lambda ele: ele[1], descending=True)
selected = valuations[:args.greedy_n]
val_props.sort(key = lambda ele: ele[1], descending=True)
selected = val_props[:args.greedy_n]
inputs = tokenizer(valuations, return_tensors='pt')
#format the chosen proposal for the next iteration
next_prompt = propose_prompt.format(problem=problem, current_state=selected)
inputs = tokenizer(next_prompt, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)
mask = inputs['attention_mask'].to(device)
#compare the proposed final answer vs the ground truth
gt = tokenizer.decode(label, skip_special_token=True)
judgement = final_eval(gt, final_proposal)
judgement = final_eval(gt, selected)
#keep track of the running totals
total += 1.0
right += judgement
print("Accuracy so far: ", total/right)
total_accuracy = right/total
def parse_args():
'''

View File

0
src/tot/data/test.ipynb Normal file
View File

View File

@ -36,7 +36,7 @@ def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
return values, times
def get_votes(task, x, ys, n_evaluate_sample):
vote_prompt = task.vote_prompt_wrap(x, ys)
vote_prompt = task.cot_prompt_wrap(x, ys)
vote_outputs = inference_model(vote_prompt, n=n_evaluate_sample, stop=None)
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values

View File

@ -1,7 +1,7 @@
import os
import openai
from openai import OpenAI
import backoff
# import backoff
import torch
import transformers
import time
@ -41,7 +41,8 @@ def hf_model(model, tokenizer, prompt, temperature=0.7, max_tokens=1000, n=5, st
"""
outputs = []
all_times = []
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
#tokenize inputs
inputs = tokenizer(prompt, return_tensors="pt")

View File

@ -0,0 +1,48 @@
{
"_name_or_path": "meta-llama/Llama-3.2-3B-Instruct",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": [
128001,
128008,
128009
],
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 24,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"quantization_config": {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "int4_weight_only",
"quant_type_kwargs": {
"group_size": 128
}
},
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.46.2",
"use_cache": true,
"vocab_size": 128256
}

View File

@ -0,0 +1,12 @@
{
"bos_token_id": 128000,
"do_sample": true,
"eos_token_id": [
128001,
128008,
128009
],
"temperature": 0.6,
"top_p": 0.9,
"transformers_version": "4.46.2"
}

View File

@ -0,0 +1,40 @@
{
"_name_or_path": "meta-llama/Llama-3.2-3B-Instruct",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": [
128001,
128008,
128009
],
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 24,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": true,
"torch_dtype": "float32",
"transformers_version": "4.46.2",
"use_cache": true,
"vocab_size": 128256
}

View File

@ -0,0 +1,12 @@
{
"bos_token_id": 128000,
"do_sample": true,
"eos_token_id": [
128001,
128008,
128009
],
"temperature": 0.6,
"top_p": 0.9,
"transformers_version": "4.46.2"
}

View File

@ -0,0 +1,40 @@
{
"_name_or_path": "meta-llama/Llama-3.2-3B-Instruct",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": [
128001,
128008,
128009
],
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 24,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": true,
"torch_dtype": "float32",
"transformers_version": "4.46.2",
"use_cache": true,
"vocab_size": 128256
}

View File

@ -0,0 +1,12 @@
{
"bos_token_id": 128000,
"do_sample": true,
"eos_token_id": [
128001,
128008,
128009
],
"temperature": 0.6,
"top_p": 0.9,
"transformers_version": "4.46.2"
}

View File

@ -8,5 +8,8 @@ def get_task(name):
elif name == 'crosswords':
from src.tot.tasks.crosswords import MiniCrosswordsTask
return MiniCrosswordsTask()
elif name == 'comp_bench':
from src.tot.tasks.bench import Bench
return Bench()
else:
raise NotImplementedError

View File

@ -5,11 +5,11 @@ class Task:
def __init__(self):
pass
def __len__(self) -> int:
pass
# def __len__(self) -> int:
# pass
def get_input(self, idx: int) -> str:
pass
# def get_input(self, idx: int) -> str:
# pass
def test_output(self, idx: int, output: str):
pass

94
src/tot/tasks/bench.py Normal file
View File

@ -0,0 +1,94 @@
import re
import os
import sympy
import pandas as pd
from src.tot.tasks.base import Task, DATA_PATH
from src.tot.prompts.game24 import *
def get_current_numbers(y: str) -> str:
last_line = y.strip().split('\n')[-1]
return last_line.split('left: ')[-1].split(')')[0]
class Game24Task(Task):
"""
Input (x) : a string of
Output (y) : a trajectory of 3 steps to reach 24
Reward (r) : 0 or 1, depending on whether the trajectory is correct
Input Example:
1 2 3 4
Output Example:
1 + 2 = 3 (left: 3 3 4)
3 + 3 = 6 (left: 4 6)
6 * 4 = 24 (left: 24)
(1 + 2 + 3) * 4 = 24
"""
def __init__(self, file='24.csv', depth=5):
"""
file: a csv file (fixed)
"""
super().__init__()
self.data = torch.load('../data/agg_dl_test.pt')
self.value_cache = {}
self.steps = depth
self.stops = None
def __len__(self) -> int:
return len(self.data)
def get_input(self, idx: int) -> str:
return self.data[idx]
def test_output(self, idx: int, output: str):
expression = output.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0]
numbers = re.findall(r'\d+', expression)
problem_numbers = re.findall(r'\d+', self.data[idx])
if sorted(numbers) != sorted(problem_numbers):
print("sorted numbers length not equal to sorted problem numbers length")
return {'r': 0}
try:
# print(sympy.simplify(expression))
return {'r': int(sympy.simplify(expression) == 24)}
except Exception as e:
# print(e)
print("entered exception!!")
return {'r': 0}
@staticmethod
def standard_prompt_wrap(x: str, y:str='') -> str:
return standard_prompt.format(input=x) + y
@staticmethod
def cot_prompt_wrap(x: str, y:str='') -> str:
return cot_prompt.format(input=x) + y
@staticmethod
def propose_prompt_wrap(x: str, y: str='') -> str:
current_numbers = get_current_numbers(y if y else x)
if current_numbers == '24':
prompt = cot_prompt.format(input=x) + 'Steps:' + y
# print([prompt])
else:
prompt = propose_prompt.format(input=current_numbers)
return prompt
@staticmethod
def value_prompt_wrap(x: str, y: str) -> str:
last_line = y.strip().split('\n')[-1]
if 'left: ' not in last_line: # last step
ans = last_line.lower().replace('answer: ', '')
# print([value_last_step_prompt.format(input=x, answer=ans)])
return value_last_step_prompt.format(input=x, answer=ans)
current_numbers = get_current_numbers(y)
return value_prompt.format(input=current_numbers)
@staticmethod
def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
if len(y.strip().split('\n')) == 4 and 'answer' not in y.lower():
return 0
value_names = [_.split('\n')[-1] for _ in value_outputs]
value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} # TODO: ad hoc
value = sum(value * value_names.count(name) for name, value in value_map.items())
normalized_value = value / (len(value_names) * value_map['sure']) # scale to [0, 1]
return normalized_value