mirror of
https://github.com/deepseek-ai/ESFT
synced 2025-06-26 18:15:50 +00:00
add training code
This commit is contained in:
58
utils.py
58
utils.py
@@ -1,3 +1,7 @@
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
# given a message object, convert to prompt and response
|
||||
|
||||
PROMPT_USER: str = 'User: {input}\n\n'
|
||||
@@ -38,3 +42,57 @@ def get_formatted_input_and_target(messages, tokenizer, IGNORE_TOKEN_ID=-100, ma
|
||||
assert False, f"Unknown role: {message['role']}"
|
||||
|
||||
return [input_ids, target_ids]
|
||||
|
||||
|
||||
def get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio, IGNORE_TOKEN_ID=-100):
|
||||
all_input_ids_list, all_target_ids_list = [], []
|
||||
all_input_ids, all_target_ids = [], []
|
||||
|
||||
for input_ids, target_ids in buffer:
|
||||
if len(input_ids) > seq_length - len(all_input_ids):
|
||||
input_ids = input_ids[-(seq_length - len(all_input_ids)):]
|
||||
target_ids = target_ids[-(seq_length - len(all_target_ids)):]
|
||||
if len(all_input_ids) > 0 and random.random() < random_concat_ratio:
|
||||
input_ids = input_ids[1:]
|
||||
target_ids = target_ids[1:]
|
||||
all_input_ids.extend(input_ids)
|
||||
all_target_ids.extend(target_ids)
|
||||
if len(all_input_ids) >= seq_length:
|
||||
assert len(all_input_ids) == seq_length, f"{len(all_input_ids)=}, {seq_length=}, {len(buffer)=}"
|
||||
all_input_ids_list.append(all_input_ids)
|
||||
all_target_ids_list.append(all_target_ids)
|
||||
all_input_ids, all_target_ids = [], []
|
||||
|
||||
all_input_ids = all_input_ids + [tokenizer.pad_token_id for i in range(seq_length - len(all_input_ids))]
|
||||
all_target_ids = all_target_ids + [IGNORE_TOKEN_ID for i in range(seq_length - len(all_target_ids))]
|
||||
all_input_ids_list.append(all_input_ids)
|
||||
all_target_ids_list.append(all_target_ids)
|
||||
|
||||
if len(all_input_ids) <= 0:
|
||||
return None
|
||||
return {
|
||||
"input_ids": torch.tensor(all_input_ids_list, dtype=torch.long),
|
||||
"labels": torch.tensor(all_target_ids_list, dtype=torch.long)
|
||||
}
|
||||
|
||||
|
||||
def init_parallel_groups(ep_size=1):
|
||||
dist.init_process_group("nccl")
|
||||
world_size = int(os.getenv("WORLD_SIZE", "0"))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
torch.cuda.set_device(local_rank)
|
||||
ep_group = edp_group = None
|
||||
for i in range(0, world_size, ep_size):
|
||||
ranks = list(range(i, i + ep_size))
|
||||
group = dist.new_group(ranks)
|
||||
if local_rank in ranks:
|
||||
ep_group = group
|
||||
edp_group = None
|
||||
for i in range(ep_size):
|
||||
ranks = list(range(i, world_size, ep_size))
|
||||
group = dist.new_group(ranks)
|
||||
if local_rank in ranks:
|
||||
edp_group = group
|
||||
dist.all_reduce(torch.zeros(1, device="cuda"), group=ep_group)
|
||||
dist.all_reduce(torch.zeros(1, device="cuda"), group=edp_group)
|
||||
return world_size, local_rank, ep_group, edp_group
|
||||
|
||||
Reference in New Issue
Block a user