diff --git a/inference/convert.py b/inference/convert.py index 9a7ea90..f6fb5e2 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -78,7 +78,7 @@ if __name__ == "__main__": parser.add_argument("--hf-ckpt-path", type=str, required=True) parser.add_argument("--save-path", type=str, required=True) parser.add_argument("--n-experts", type=int, required=True) - parser.add_argument("--model-parallel", type=int, default=1) + parser.add_argument("--model-parallel", type=int, required=True) args = parser.parse_args() assert args.n_experts % args.model_parallel == 0 main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)