diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/README.osc.md b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/README.osc.md index bf33a116ef8b3bbee7c7a2e7d8a64afc058293fc..6dd1923d8bc7fe00fd184500be5db4ef538f4aac 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/README.osc.md +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/README.osc.md @@ -219,9 +219,9 @@ bash test/train_grpo_performance_16p.sh --model_path=./models/xxx --dataset_path ```shell # 8卡训练 -bash test/train_kto_full_8p.sh --model_path=./models/Llama-3-8b-sft-mixture --dataset_path=./data/ultrafeedback-unpaired-preferences +bash test/train_kto_full_8p.sh --pretrain_path=./models/Llama-3-8b-sft-mixture --dataset_path=./data/ultrafeedback-unpaired-preferences # 8卡性能 -bash test/train_kto_performance_8p.sh --model_path=./models/Llama-3-8b-sft-mixture --dataset_path=./data/ultrafeedback-unpaired-preferences +bash test/train_kto_performance_8p.sh --pretrain_path=./models/Llama-3-8b-sft-mixture --dataset_path=./data/ultrafeedback-unpaired-preferences ``` #### RM算法 @@ -230,9 +230,9 @@ bash test/train_kto_performance_8p.sh --model_path=./models/Llama-3-8b-sft-mixtu ```shell # 8卡训练 -bash test/train_rm_full_8p.sh --model_path=./models/Llama-3-8b-sft-mixture --dataset_path=./data/preference_dataset_mixture2_and_safe_pku +bash test/train_rm_full_8p.sh --pretrain_path=./models/Llama-3-8b-sft-mixture --dataset_path=./data/preference_dataset_mixture2_and_safe_pku # 8卡性能 -bash test/train_rm_performance_8p.sh --model_path=./models/Llama-3-8b-sft-mixture --dataset_path=./data/preference_dataset_mixture2_and_safe_pku +bash test/train_rm_performance_8p.sh --pretrain_path=./models/Llama-3-8b-sft-mixture --dataset_path=./data/preference_dataset_mixture2_and_safe_pku ``` #### PRM算法 @@ -241,9 +241,9 @@ bash test/train_rm_performance_8p.sh --model_path=./models/Llama-3-8b-sft-mixtur ```shell # 8卡训练 -bash test/train_prm_full_8p.sh --model_path=./models/Mistral-7B-v0.1 --dataset_path=./data/Math-Shepherd/data +bash test/train_prm_full_8p.sh --pretrain_path=./models/Mistral-7B-v0.1 --dataset_path=./data/Math-Shepherd/data # 8卡性能 -bash test/train_prm_performance_8p.sh --model_path=./models/Mistral-7B-v0.1 --dataset_path=./data/Math-Shepherd/data +bash test/train_prm_performance_8p.sh --pretrain_path=./models/Mistral-7B-v0.1 --dataset_path=./data/Math-Shepherd/data ``` #### 训练结果展示 diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/test/train_prm_performance_8p.sh b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/test/train_prm_performance_8p.sh index 6ffa7c1dd410dae8d1ce47f2854eca9bea57a705..45ea8fa81a22ed3f6c311e46d43b649adf8e35f1 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/test/train_prm_performance_8p.sh +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/test/train_prm_performance_8p.sh @@ -68,7 +68,7 @@ openrlhf.cli.train_prm \ --eval_steps 100 \ --train_batch_size 64 \ --micro_train_batch_size 8 \ - --max_samples 64000 \ + --max_samples 64000 \ --pretrain $pretrain_path \ --bf16 \ --max_epochs $max_epochs \