diff --git a/README.md b/README.md index 7268012..6a6b4d9 100644 --- a/README.md +++ b/README.md @@ -317,10 +317,11 @@ messages_list = [ [{"role": "user", "content": "What can you do?"}], [{"role": "user", "content": "Explain Transformer briefly."}], ] -prompts = [tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) for messages in messages_list] +# Avoid adding bos_token repeatedly +prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True) for messages in messages_list] sampling_params.stop = [tokenizer.eos_token] -outputs = llm.generate(prompts, sampling_params) +outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params) generated_text = [output.outputs[0].text for output in outputs] print(generated_text)