From 36c6194349e42969bfb782fe7583529e4ba86783 Mon Sep 17 00:00:00 2001 From: ZihanWang314 <510642032wzh@gmail.com> Date: Fri, 9 Aug 2024 18:06:57 +0800 Subject: [PATCH] update eval and readme --- README.md | 59 +++++++---- eval_multigpu.py | 92 ++++++++++++++++++ scripts/eval.sh | 18 +++- ...nerate_expert_config.sh => eval_expert.sh} | 12 ++- scripts/expert/generate_expert_config.py | 97 +++++++++++++++++++ scripts/expert/get_expert_scores.py | 78 +++++++++++++++ scripts/generate_expert_config.py | 50 ---------- scripts/get_expert_scores.py | 75 -------------- scripts/train.sh | 2 +- scripts/train_ep.sh | 2 +- 10 files changed, 333 insertions(+), 152 deletions(-) create mode 100644 eval_multigpu.py rename scripts/{generate_expert_config.sh => eval_expert.sh} (58%) create mode 100644 scripts/expert/generate_expert_config.py create mode 100644 scripts/expert/get_expert_scores.py delete mode 100644 scripts/generate_expert_config.py delete mode 100644 scripts/get_expert_scores.py diff --git a/README.md b/README.md index f2ce48e..e001aab 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ Y. Wu. **ESFT** aims to efficiently customize Large Language Models (LLMs) with Mixture-of-Experts (MoE) architecture by adjusting only task-relevant parts, improving efficiency and performance while using fewer resources and storage. +## πŸ“° News + +πŸ“… **2024.8.11:** We now release the **ESFT training code**! ✨ You can now try it with your own models and dataset! ## πŸš€ Quick Start @@ -19,9 +22,9 @@ git clone https://github.com/deepseek-ai/ESFT.git cd esft ``` -### Install dependencies +### Install required dependencies ```bash -pip install transformers torch safetensors +pip install transformers torch safetensors accelerate ``` ### Download necessary adapters @@ -32,35 +35,38 @@ bash scripts/download_adapters.sh ## πŸ”§Key Scripts -1. **eval.py** -This script evaluates the performance of the model on various datasets. **Usage:** +1. **eval_multigpu.py** +This script evaluates the performance of the model on various datasets. See **scripts/eval.sh** for detailed configs and explanations. + +**Usage:** ```bash -python scripts/eval.py \ - --eval_datasets=translation \ +python eval_multigpu.py \ + --eval_dataset=translation \ --base_model_path=deepseek-ai/ESFT-vanilla-lite \ - --adapter_dir=all_models/adapters/token \ - --output_dir=results/completions/token \ - --max_new_tokens=512 \ - --openai_api_key=REPLACE_WITH_YOUR_KEY \ - --eval_batch_size=2 + --adapter_dir=all_models/adapters/token/translation \ + --output_path=results/completions/token/translation.jsonl \ + --openai_api_key=YOUR_OPENAI_API_KEY ``` + 2. **get_expert_scores.py** This script calculates the scores for each expert based on the evaluation datasets. **Usage:** ```bash -python scripts/get_expert_scores.py \ - --eval_datasets=intent,summary,law,translation \ +python scripts/expert/get_expert_scores.py \ + --eval_dataset=translation \ --base_model_path=deepseek-ai/ESFT-vanilla-lite \ - --output_dir=results/expert_scores \ - --n_sample_tokens=8192 # the sample size hyperparameter + --output_dir=results/expert_scores/translation \ + --n_sample_tokens=131072 \ + --world_size=4 \ + --gpus_per_rank=2 ``` 3. **generate_expert_config.py** This script generates the configuration to convert a MoE model with only task-relevant tasks trained based on evaluation scores. **Usage:** ```bash -python scripts/generate_expert_config.py \ +python scripts/expert/generate_expert_config.py \ --eval_datasets=intent,summary,law,translation \ --expert_scores_dir=results/expert_scores \ --output_dir=results/expert_configs \ @@ -68,13 +74,32 @@ python scripts/generate_expert_config.py \ --top_p=0.2 # the scoring function and top_p are hyperparameters ``` +4. **train.py** and **train_ep.py** +This script trains the model with the expert configuration generated by the previous script. The train_ep.py file uses expert parallel and has been optimized for multi-GPU training. +**Usage:** +```bash +python train.py \ + --base_model_path=deepseek-ai/ESFT-vanilla-lite \ + --expert_config=results/expert_configs/intent.json \ + --train_dataset=intent \ + --train_config=configs/base.yaml \ + --output_dir=results/checkpoints/intent + +torchrun --nproc-per-node=8 train_ep.py \ + --base_model_path=deepseek-ai/ESFT-vanilla-lite \ + --expert_config=results/expert_configs/translation.json \ + --train_dataset=translation \ + --train_config=configs/base.yaml \ + --output_dir=results/checkpoints/translation + +``` ## Contact and Support For bug reports, feature requests, and general inquiries, please open an issue on our GitHub Issues page. Make sure to include as much detail as possible to help us address your issue quickly. ## 🌟Todo list - β˜‘οΈ πŸ“ Update models, evaluation scripts, and expert selection scripts -- πŸ”² πŸ”§ Update training scripts +- β˜‘οΈ πŸ”§ Update training scripts - πŸ”² πŸš€ More... diff --git a/eval_multigpu.py b/eval_multigpu.py new file mode 100644 index 0000000..5516166 --- /dev/null +++ b/eval_multigpu.py @@ -0,0 +1,92 @@ +import json +import argparse + +from torch import device +from benchmarks import * +import os +from esft import load_base_model, add_adapter +import torch.multiprocessing as mp +from itertools import accumulate +from accelerate import dispatch_model +from transformers import AutoModelForCausalLM, AutoTokenizer + +def infer_auto_device_map(model, pp_splits, visible_devices): + assert len(pp_splits) == len(visible_devices) + device_map = { + "model.embed_tokens": 0, + "model.norm": len(pp_splits) - 1, + "lm_head": len(pp_splits) - 1 + } + assert len(model.model.layers) == sum(pp_splits) + pp_splits = [0, *list(accumulate(pp_splits))] + for idx, (start, end) in enumerate(zip(pp_splits[:-1], pp_splits[1:])): + for i in range(start, end): + device_map.update({f"model.layers.{i}": idx}) + for k, v in device_map.items(): + device_map[k] = visible_devices[v] + return device_map + + +def eval_model(rank, args, model, dataset): + config = { + "max_new_tokens": args.max_new_tokens, + "eval_batch_size": args.eval_batch_size, + "openai_api_key": args.openai_api_key + } + evaluator_map = { + "intent": IntentEvaluator, + "summary": SummaryEvaluator, + "law": LawEvaluator, + "translation": TranslationEvaluator + } + try: + evaluator_cls = evaluator_map[args.eval_dataset] + print(f"Rank {rank} starting evaluation...", flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.base_model_path) + visible_devices = list(range(rank * args.gpus_per_rank, (rank + 1) * args.gpus_per_rank)) + device_map = infer_auto_device_map(model, [14, 13], visible_devices) + model = dispatch_model(model, device_map) + cur_dataset = dataset[rank::args.world_size] + evaluator = evaluator_cls(cur_dataset, config) + with torch.no_grad(): + results, metrics = evaluator.evaluate(model, tokenizer) + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + with open(args.output_path + f".rank_{rank}", "w") as f: + for res, m in zip(results, metrics): + obj = { + "example": res, + "score": m + } + f.write(json.dumps(obj, ensure_ascii=False) + "\n") + + except Exception as e: + print(f"Error in process {rank}: {e}", flush=True) + raise + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a model with adapters on a specified dataset.") + parser.add_argument("--eval_dataset", type=str, required=True, help="Name of the evaluation dataset") + parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model") + parser.add_argument("--adapter_dir", type=str, required=True, help="Directory containing the adapter") + parser.add_argument("--output_path", type=str, required=True, help="Path to save the evaluation results") + parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of new tokens") + parser.add_argument("--openai_api_key", type=str, required=True, help="API key for OpenAI") + parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for evaluation") + parser.add_argument("--world_size", type=int, default=4, help="Number of processes to use for evaluation") + parser.add_argument("--gpus_per_rank", type=int, default=2, help="Number of GPUs per process") + + args = parser.parse_args() + + + + print("Loading base model...") + model = AutoModelForCausalLM.from_pretrained(args.base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) # not using tokenizer here to aviod deadlock + + print(f"Running evaluation on {args.eval_dataset}...") + dataset = [json.loads(i) for i in open(f"datasets/eval/{args.eval_dataset}.jsonl").readlines()] + + print("Adding adapter...") + model = add_adapter(model, args.adapter_dir, return_original_states=False) + + print("Start Evaluating...") + mp.spawn(eval_model, args=(args, model, dataset), nprocs=args.world_size, join=True) diff --git a/scripts/eval.sh b/scripts/eval.sh index dc1f69a..b86d88c 100644 --- a/scripts/eval.sh +++ b/scripts/eval.sh @@ -1,12 +1,24 @@ -# first, download adapter models and put them to the corresponding directories +# first: download adapter models and put them to the corresponding directories -python scripts/eval.py \ +python eval_multigpu.py \ --eval_datasets=translation \ --base_model_path=deepseek-ai/ESFT-vanilla-lite \ --adapter_dir=all_models/adapters/token \ --output_dir=results/completions/token \ --max_new_tokens=512 \ --openai_api_key=REPLACE_WITH_YOUR_KEY \ - --eval_batch_size=2 + --eval_batch_size=2 \ + --world_size=4 \ + --gpus_per_rank=2 +# this script is used for single-gpu training and has been deprecated. If you have no multiple gpus, you can set above world_size=1 and gpus_per_rank=1 + +# python scripts/eval.py \ +# --eval_datasets=translation \ +# --base_model_path=deepseek-ai/ESFT-vanilla-lite \ +# --adapter_dir=all_models/adapters/token \ +# --output_dir=results/completions/token \ +# --max_new_tokens=512 \ +# --openai_api_key=REPLACE_WITH_YOUR_KEY \ +# --eval_batch_size=2 diff --git a/scripts/generate_expert_config.sh b/scripts/eval_expert.sh similarity index 58% rename from scripts/generate_expert_config.sh rename to scripts/eval_expert.sh index 13be3f4..2c8e8aa 100644 --- a/scripts/generate_expert_config.sh +++ b/scripts/eval_expert.sh @@ -1,10 +1,12 @@ -python scripts/get_expert_scores.py \ - --eval_datasets=intent,summary,law,translation \ +python scripts/expert/get_expert_scores.py \ + --eval_dataset=translation \ --base_model_path=deepseek-ai/ESFT-vanilla-lite \ - --output_dir=results/expert_scores \ - --n_sample_tokens=8192 # this sample size is a hyperparameter + --output_dir=results/expert_scores/translation \ + --n_sample_tokens=131072 \ + --world_size=4 \ + --gpus_per_rank=2 -python scripts/generate_expert_config.py \ +python scripts/expert/generate_expert_config.py \ --eval_datasets=intent,summary,law,translation \ --expert_scores_dir=results/expert_scores \ --output_dir=results/expert_configs \ diff --git a/scripts/expert/generate_expert_config.py b/scripts/expert/generate_expert_config.py new file mode 100644 index 0000000..8ed4a8e --- /dev/null +++ b/scripts/expert/generate_expert_config.py @@ -0,0 +1,97 @@ +import argparse +import json +import os +from multiprocessing import Pool +import numpy as np + + +def parse_line(line): + expert_ids, expert_weights = line.split("\t\t") + expert_ids = [int(i) for i in expert_ids.split("\t")] + expert_weights = [float(i) for i in expert_weights.split("\t")] + return expert_ids, expert_weights + + +def get_summary(files): + TOP_K=6 + N_EXPERTS=64 + N_LAYERS=26 # 27 layers totally, the first layer is not MoE + + gate_scores = np.zeros((N_LAYERS, N_EXPERTS)) + token_scores = np.zeros((N_LAYERS, N_EXPERTS)) + + print("loading files") + for rank, file in files: + layer_id = int(file.split(".")[0].split("_")[2]) - 1 + + with open(os.path.join(args.expert_scores_dir, rank, file)) as f: + data = f.readlines() + for line in data: + expert_ids, expert_weights = parse_line(line) + np.add.at(gate_scores[layer_id], expert_ids, expert_weights) + np.add.at(token_scores[layer_id], expert_ids, np.ones_like(expert_weights) / TOP_K) + + gate_scores = gate_scores / np.sum(gate_scores, axis=0) + token_scores = token_scores / np.sum(token_scores, axis=0) + + summary = {"token_scores": token_scores, "gate_scores": gate_scores} + summary = {k: {str(i+1): {str(j): round(v, 4) for j, v in enumerate(l)} for i, l in enumerate(v)} for k, v in summary.items()} + + return summary + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--eval_dataset", type=str, required=True) + parser.add_argument("--expert_scores_dir", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--score_function", type=str, required=True) + parser.add_argument("--top_p", type=float, required=True) + parser.add_argument("--train_shared_experts", action="store_true") + parser.add_argument("--train_non_expert_modules", action="store_true") + + args = parser.parse_args() + + expert_cfg = { # initialize expert config + "experts": {}, + "shared_experts": args.train_shared_experts, + "non_expert_modules": args.train_non_expert_modules + } + + # let's walk inside args.expert_scores_dir and get abs file names + file_names = [] + for rank in [i for i in os.listdir(args.expert_scores_dir) if 'rank' in i]: + for file in os.listdir(os.path.join(args.expert_scores_dir, rank)): + file_names.append([rank, file]) + + + summary_file = os.path.join(args.expert_scores_dir, "summary.json") + summary = get_summary(file_names) + + with open(summary_file, "w") as f: + f.write(json.dumps(summary)) + + + scores = summary[f"{args.score_function}_scores"] + for layer, l_score in scores.items(): + l_score = [(int(k), v) for k,v in l_score.items()] + l_score = sorted(l_score, key=lambda x: x[1], reverse=True) + selected_experts = [] + current_score = 0 + for expert, score in l_score: + if current_score >= args.top_p: + break + selected_experts.append(expert) + current_score += score + expert_cfg["experts"][layer] = selected_experts + + top_p = args.top_p + train_shared_experts = args.train_shared_experts + train_non_expert_modules = args.train_non_expert_modules + + + + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + with open(args.output_path, "w") as f: + json.dump(expert_cfg, f) diff --git a/scripts/expert/get_expert_scores.py b/scripts/expert/get_expert_scores.py new file mode 100644 index 0000000..bd94e88 --- /dev/null +++ b/scripts/expert/get_expert_scores.py @@ -0,0 +1,78 @@ +import json +import os +import torch +import argparse +import random +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import get_formatted_input_and_target +import torch.multiprocessing as mp +from itertools import accumulate +from accelerate import dispatch_model + + +def infer_auto_device_map(model, pp_splits, visible_devices): + assert len(pp_splits) == len(visible_devices) + device_map = { + "model.embed_tokens": 0, + "model.norm": len(pp_splits) - 1, + "lm_head": len(pp_splits) - 1 + } + assert len(model.model.layers) == sum(pp_splits) + pp_splits = [0, *list(accumulate(pp_splits))] + for idx, (start, end) in enumerate(zip(pp_splits[:-1], pp_splits[1:])): + for i in range(start, end): + device_map.update({f"model.layers.{i}": idx}) + for k, v in device_map.items(): + device_map[k] = visible_devices[v] + return device_map + + +def eval_expert(rank, args, model, dataset): + try: + print(f"Rank {rank} starting expert evaluation...", flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.base_model_path) + visible_devices = list(range(rank * args.gpus_per_rank, (rank + 1) * args.gpus_per_rank)) + device_map = infer_auto_device_map(model, [14, 13], visible_devices) + model = dispatch_model(model, device_map) + model.config.expert_log_dir = os.path.join(args.output_dir, f"rank_{rank}") + n_sample_tokens = args.n_sample_tokens // args.world_size + os.makedirs(os.path.join(args.output_dir, f"rank_{rank}"), exist_ok=True) + done_tokens = 0 + cur_dataset = dataset[rank::args.world_size] + for instance in cur_dataset: + input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100) + model(input_ids=torch.tensor(input_ids).unsqueeze(0), labels=torch.tensor(target_ids).unsqueeze(0)) + done_tokens += len(input_ids) + if done_tokens >= n_sample_tokens: + break + + + except Exception as e: + print(f"Error in process {rank}: {e}", flush=True) + raise + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a model with adapters on a specified dataset.") + parser.add_argument("--eval_dataset", type=str, required=True, help="Name of the evaluation dataset") + parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model") + parser.add_argument("--output_dir", type=str, required=True, help="Path to save the evaluation results") + parser.add_argument("--world_size", type=int, default=4, help="Number of processes to use for evaluation") + parser.add_argument("--gpus_per_rank", type=int, default=2, help="Number of GPUs per process") + parser.add_argument("--n_sample_tokens", type=int, required=True, help="Token to sample for expert evaluation") + args = parser.parse_args() + random.seed(5934875) + + + print("Loading base model...") + model = AutoModelForCausalLM.from_pretrained(args.base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) # not using tokenizer here to aviod deadlock + model.config.log_expert_weights = True + + + print(f"Running expert evaluation on {args.eval_dataset}...") + dataset = [json.loads(i) for i in open(f"datasets/train/{args.eval_dataset}.jsonl").readlines()] + random.shuffle(dataset) + + + print("Start Evaluating...") + mp.spawn(eval_expert, args=(args, model, dataset), nprocs=args.world_size, join=True) diff --git a/scripts/generate_expert_config.py b/scripts/generate_expert_config.py deleted file mode 100644 index d641cff..0000000 --- a/scripts/generate_expert_config.py +++ /dev/null @@ -1,50 +0,0 @@ -import argparse -import json -import os - -parser = argparse.ArgumentParser() -parser.add_argument("--eval_datasets", type=str, required=True) -parser.add_argument("--expert_scores_dir", type=str, required=True) -parser.add_argument("--output_dir", type=str, required=True) -parser.add_argument("--score_function", type=str, required=True) -parser.add_argument("--top_p", type=float, required=True) -parser.add_argument("--train_shared_experts", action="store_true") -parser.add_argument("--train_non_expert_modules", action="store_true") - -args = parser.parse_args() - -eval_datasets = args.eval_datasets.split(",") -expert_scores_dir = args.expert_scores_dir -output_dir = args.output_dir -score_function = args.score_function -top_p = args.top_p -train_shared_experts = args.train_shared_experts -train_non_expert_modules = args.train_non_expert_modules - -for dataset_name in eval_datasets: - summary_file = f"{expert_scores_dir}/{dataset_name}/summary.json" - expert_cfg = {"experts": {}, "shared_experts": train_shared_experts, "non_expert_modules": train_non_expert_modules} - - with open(summary_file) as f: - data = json.load(f) - assert score_function in ["gate", "token"], f"Unknown score function: {score_function}" - scores = data[f"{score_function}_scores"] - - for layer, l_score in scores.items(): - l_score = [(int(k), v) for k,v in l_score.items()] - l_score = sorted(l_score, key=lambda x: x[1], reverse=True) - # get the top experts that make the threshold exceed top_p - selected_experts = [] - current_score = 0 - for expert, score in l_score: - if current_score >= top_p: - break - selected_experts.append(expert) - current_score += score - expert_cfg["experts"][layer] = selected_experts - - os.makedirs(output_dir, exist_ok=True) - with open(f"{output_dir}/{dataset_name}.json", "w") as f: - json.dump(expert_cfg, f) - - diff --git a/scripts/get_expert_scores.py b/scripts/get_expert_scores.py deleted file mode 100644 index 34dfd90..0000000 --- a/scripts/get_expert_scores.py +++ /dev/null @@ -1,75 +0,0 @@ -import json -from benchmarks import * -import os -import torch -from torch import nn -import argparse -from random import shuffle -from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import get_formatted_input_and_target - -# constants for deepseek-v2-lite -TOP_K=6 -N_EXPERTS=64 - -parser = argparse.ArgumentParser() -parser.add_argument("--base_model_path", type=str, required=True) -parser.add_argument("--eval_datasets", type=str, required=True) -parser.add_argument("--output_dir", type=str, required=True) -parser.add_argument("--n_sample_tokens", type=int, required=True) -args = parser.parse_args() - -eval_datasets = args.eval_datasets.split(",") -output_dir = args.output_dir -base_model_path = args.base_model_path -n_sample_tokens = args.n_sample_tokens - -model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path) -model.config.log_expert_weights = True - -for dataset_name in eval_datasets: - dataset = [json.loads(i) for i in open(f"datasets/train/{dataset_name}.jsonl").readlines()] - shuffle(dataset) - model.config.expert_log_dir = os.path.join(args.output_dir, dataset_name) - # make dir -p this - os.makedirs(os.path.join(args.output_dir, dataset_name), exist_ok=True) - done_tokens = 0 - for instance in dataset: - input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100) - model(input_ids=torch.tensor(input_ids).unsqueeze(0), labels=torch.tensor(target_ids).unsqueeze(0)) - done_tokens += len(input_ids) - if done_tokens >= n_sample_tokens: - break - - # open all files under os.path.join(args.output_dir, dataset_name). For each file, generate a summary of it - # and write it to a file in the same directory - files = os.listdir(os.path.join(args.output_dir, dataset_name)) - summary_file = os.path.join(args.output_dir, dataset_name, "summary.json") - token_scores = {} - gate_scores = {} - - for file in files: - if not file.endswith(".txt"): - continue - layer_idx = file.split("_")[2].split(".")[0] - token_scores[layer_idx] = {expert:0 for expert in range(N_EXPERTS)} - gate_scores[layer_idx] = {expert:0 for expert in range(N_EXPERTS)} - - with open(os.path.join(args.output_dir, dataset_name, file)) as f: - data = f.readlines() - for line in data: - expert_ids, expert_weights = line.split("\t\t") - expert_ids = [int(i) for i in expert_ids.split("\t")] - expert_weights = [float(i) for i in expert_weights.split("\t")] - for expert_id, expert_weight in zip(expert_ids, expert_weights): - gate_scores[layer_idx][expert_id] += expert_weight - token_scores[layer_idx][expert_id] += 1. / TOP_K - total = sum(token_scores[layer_idx].values()) - gate_scores[layer_idx] = {expert: round(gate_scores[layer_idx][expert] / total, 4) for expert in gate_scores[layer_idx]} - token_scores[layer_idx] = {expert: round(token_scores[layer_idx][expert] / total, 4) for expert in token_scores[layer_idx]} - - - with open(summary_file, "w") as f: - f.write(json.dumps({"token_scores": token_scores, "gate_scores": gate_scores})) - - diff --git a/scripts/train.sh b/scripts/train.sh index 5671421..2a33603 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -2,7 +2,7 @@ export TOKENIZERS_PARALLELISM=false exp_name="test/eval_translation" -base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model" +base_model_path="deepseek-ai/esft-vanilla-lite" # turn above to for loop python train.py \ --base_model_path=${base_model_path} \ diff --git a/scripts/train_ep.sh b/scripts/train_ep.sh index 0785e9a..e6f702d 100644 --- a/scripts/train_ep.sh +++ b/scripts/train_ep.sh @@ -2,7 +2,7 @@ export TOKENIZERS_PARALLELISM=false exp_name="test/eval_translation" -base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model" +base_model_path="deepseek-ai/esft-vanilla-lite" torchrun --nproc-per-node=8 train_ep.py \ --base_model_path=${base_model_path} \ --expert_config=results/expert_configs/translation.json \