import os import random import torch import torch.distributed as dist # given a message object, convert to prompt and response PROMPT_USER: str = 'User: {input}\n\n' PROMPT_ASSISTANT: str = 'Assistant:' # should not have a space at the end ASSISTANT_RESPONSE: str = ' {input}' def get_formatted_question(line): return PROMPT_USER.format(input=str(line).strip()) + PROMPT_ASSISTANT def get_formatted_answer(line): return ASSISTANT_RESPONSE.format(input=str(line).strip()) def get_formatted_input_and_target(messages, tokenizer, IGNORE_TOKEN_ID=-100, mask_prompt=True): input_ids = [] target_ids = [] for idx, message in enumerate(messages): if idx == 0: input_ids.extend([tokenizer.bos_token_id]) target_ids.extend([tokenizer.bos_token_id]) if message['role'] == "user": formatted_question = get_formatted_question(message['content']) tokenized_line = tokenizer.encode(formatted_question, add_special_tokens=False) input_ids.extend(tokenized_line) if mask_prompt: target_ids.extend([IGNORE_TOKEN_ID] * len(tokenized_line)) else: target_ids.extend(tokenized_line) elif message['role'] == "assistant": formatted_answer = get_formatted_answer(message['content']) tokenized_line = tokenizer.encode(formatted_answer, add_special_tokens=False) + [tokenizer.eos_token_id] input_ids.extend(tokenized_line) if message.get('mask', 0) == 1: target_ids.extend([IGNORE_TOKEN_ID] * len(tokenized_line)) else: target_ids.extend(tokenized_line) else: assert False, f"Unknown role: {message['role']}" return [input_ids, target_ids] def get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio, IGNORE_TOKEN_ID=-100): all_input_ids_list, all_target_ids_list = [], [] all_input_ids, all_target_ids = [], [] for input_ids, target_ids in buffer: if len(input_ids) > seq_length - len(all_input_ids): input_ids = input_ids[-(seq_length - len(all_input_ids)):] target_ids = target_ids[-(seq_length - len(all_target_ids)):] if len(all_input_ids) > 0 and random.random() < random_concat_ratio: input_ids = input_ids[1:] target_ids = target_ids[1:] all_input_ids.extend(input_ids) all_target_ids.extend(target_ids) if len(all_input_ids) >= seq_length: assert len(all_input_ids) == seq_length, f"{len(all_input_ids)=}, {seq_length=}, {len(buffer)=}" all_input_ids_list.append(all_input_ids) all_target_ids_list.append(all_target_ids) all_input_ids, all_target_ids = [], [] all_input_ids = all_input_ids + [tokenizer.pad_token_id for i in range(seq_length - len(all_input_ids))] all_target_ids = all_target_ids + [IGNORE_TOKEN_ID for i in range(seq_length - len(all_target_ids))] all_input_ids_list.append(all_input_ids) all_target_ids_list.append(all_target_ids) if len(all_input_ids) <= 0: return None return { "input_ids": torch.tensor(all_input_ids_list, dtype=torch.long), "labels": torch.tensor(all_target_ids_list, dtype=torch.long) } def init_parallel_groups(ep_size=1): dist.init_process_group("nccl") world_size = int(os.getenv("WORLD_SIZE", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0")) torch.cuda.set_device(local_rank) ep_group = edp_group = None for i in range(0, world_size, ep_size): ranks = list(range(i, i + ep_size)) group = dist.new_group(ranks) if local_rank in ranks: ep_group = group edp_group = None for i in range(ep_size): ranks = list(range(i, world_size, ep_size)) group = dist.new_group(ranks) if local_rank in ranks: edp_group = group dist.all_reduce(torch.zeros(1, device="cuda"), group=ep_group) dist.all_reduce(torch.zeros(1, device="cuda"), group=edp_group) return world_size, local_rank, ep_group, edp_group