mirror of
https://github.com/deepseek-ai/ESFT
synced 2025-06-26 18:15:50 +00:00
add training code
This commit is contained in:
33
esft.py
33
esft.py
@@ -7,6 +7,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
def to_buffer(module, mark_param=True):
|
||||
"""Turns all parameters of a module into buffers."""
|
||||
if module is None:
|
||||
return
|
||||
modules = module.modules()
|
||||
module = next(modules)
|
||||
delattrs = []
|
||||
@@ -25,6 +27,8 @@ def to_buffer(module, mark_param=True):
|
||||
|
||||
def to_param(module):
|
||||
"""Turns all buffers of a module into parameterss."""
|
||||
if module is None:
|
||||
return
|
||||
modules = module.modules()
|
||||
module = next(modules)
|
||||
param_list = getattr(module, 'param_list', [])
|
||||
@@ -57,7 +61,7 @@ def to_esft(model, adapter_config):
|
||||
to_buffer(model)
|
||||
else:
|
||||
to_param(model)
|
||||
for idx, layer in enumerate(model.layers):
|
||||
for idx, layer in enumerate(model.model.layers):
|
||||
if type(layer.mlp).__name__ != "DeepseekV2MoE":
|
||||
continue
|
||||
if adapter_config.get('shared_experts', False):
|
||||
@@ -72,15 +76,25 @@ def to_esft(model, adapter_config):
|
||||
to_buffer(layer.mlp.experts[expert_id])
|
||||
return model
|
||||
|
||||
|
||||
def load_state_dict(folder_path):
|
||||
# 初始化空的 state_dict
|
||||
combined_state_dict = {}
|
||||
|
||||
# 遍历文件夹中的所有文件
|
||||
for file_name in os.listdir(folder_path):
|
||||
if file_name.endswith('.safetensors'):
|
||||
file_path = os.path.join(folder_path, file_name)
|
||||
state_dict = load_file(file_path)
|
||||
combined_state_dict.update(state_dict)
|
||||
|
||||
|
||||
# legacy for loading v1 checkpoints: add prefix "model." for parameters
|
||||
for k in list(combined_state_dict.keys()):
|
||||
if k.startswith("layers"):
|
||||
k_new = "model." + k
|
||||
combined_state_dict[k_new] = combined_state_dict[k]
|
||||
del combined_state_dict[k]
|
||||
|
||||
return combined_state_dict
|
||||
|
||||
|
||||
@@ -89,21 +103,24 @@ def load_esft_model(base_model_path, adapter_dir):
|
||||
adapter_state_dict = load_state_dict(adapter_dir)
|
||||
|
||||
# load pretrained model:
|
||||
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path)
|
||||
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path)
|
||||
|
||||
to_esft(model.model, adapter_config)
|
||||
model.model.load_state_dict(adapter_state_dict)
|
||||
to_esft(model, adapter_config)
|
||||
model.load_state_dict(adapter_state_dict)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def load_base_model(base_model_path):
|
||||
# load pretrained model:
|
||||
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path)
|
||||
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def add_adapter(base_model, adapter_dir, return_original_states=False):
|
||||
adapter_config = json.load(open(adapter_dir + "/expert_cfg.json"))
|
||||
def add_adapter(base_model, adapter_dir, return_original_states=False, expert_config=None):
|
||||
if expert_config is not None:
|
||||
adapter_config = json.load(open(expert_config))
|
||||
else:
|
||||
adapter_config = json.load(open(adapter_dir + "/expert_cfg.json"))
|
||||
adapter_state_dict = load_state_dict(adapter_dir)
|
||||
|
||||
to_esft(base_model, adapter_config)
|
||||
|
||||
Reference in New Issue
Block a user