diff --git a/src/tot/methods/bfs.py b/src/tot/methods/bfs.py index 2976a32..029a6be 100644 --- a/src/tot/methods/bfs.py +++ b/src/tot/methods/bfs.py @@ -35,7 +35,7 @@ def get_votes(task, x, ys, n_evaluate_sample): def get_proposals(task, x, y): propose_prompt = task.propose_prompt_wrap(x, y) - proposals = inference_model(propose_prompt, n=1, stop=None)[0].split('\n') + proposals = inference_model(propose_prompt, n=5, stop=None)[0].split('\n') return [y + _ + '\n' for _ in proposals] diff --git a/src/tot/models.py b/src/tot/models.py index 722b3d3..a6e6ca4 100644 --- a/src/tot/models.py +++ b/src/tot/models.py @@ -69,13 +69,13 @@ def chatgpt(model, messages, temperature=0.7, max_tokens=1000, n=5, stop=None) - global completion_tokens, prompt_tokens outputs = [] client = OpenAI() + while n > 0: cnt = min(n, 20) n -= cnt res = client.chat.completions.create(model=model, messages=messages, temperature=temperature, n=cnt, stop=stop) - res_answer = res.choices[0].message.content #answers get returned in a single message.content string even when n > 1; need to double check on why later - outputs.extend([res_answer]) + outputs.extend([choice.message.content for choice in res.choices]) # log completion tokens completion_tokens += res.usage.completion_tokens