diff --git a/policy/enemy_policy_selection.py b/policy/enemy_policy_selection.py index 3c8c2390c30348576cffd18dc70f45787990dcb9..3350340f7e22048e5069f1558f663a6180d6b638 100644 --- a/policy/enemy_policy_selection.py +++ b/policy/enemy_policy_selection.py @@ -41,14 +41,16 @@ def move_new_to_old(): # 读入目前新版本文件夹里的所有模型 for root, dirs, files in os.walk(new_model_files): if "saved_model.pb" in files: - tmp.append((os.stat(root).st_ctime, root)) + if os.path.exists(root): + tmp.append((os.stat(root).st_ctime, root)) sorted_list = sorted(tmp, key=lambda x: (x[0], x[1]), reverse=True) # 根据时间将应该被划分为旧版本的文件筛选出来,转移到旧版本文件夹里面 for model in sorted_list: if (sorted_list[0][0] - model[0]) / 3600 > time_from_newest_to_old: create_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime(model[0])) - shutil.move(model[1], old_model_files + '/' + create_time + '/eval_policy') + if os.path.exists(model[1]): + shutil.move(model[1], old_model_files + '/' + create_time + '/eval_policy') # # for model in existed_models_path: # if(sorted_list[0][0] - model[0]) / 3600 <= time_from_newest_to_old: @@ -70,30 +72,39 @@ def generate_enemy_policy(): for path in existed_version_models: for root, dirs, files in os.walk(path): if "saved_model.pb" in files: - existed_models_path.append(root) + if os.path.exists(root): + existed_models_path.append(root) # 读入新版本模型的路径 for root, dirs, files in os.walk(new_model_files): if "saved_model.pb" in files: - new_models_path.append((os.stat(root).st_ctime, root)) + if os.path.exists(root): + new_models_path.append((os.stat(root).st_ctime, root)) new_models_path = sorted(new_models_path, key=lambda x: (x[0], x[1]), reverse=True) # 读入旧版本模型的路径 for root, dirs, files in os.walk(old_model_files): if "saved_model.pb" in files: - old_models_path.append(root) + if os.path.exists(root): + old_models_path.append(root) rand = random.random() if rand < new_version_rate: - print("env_constructed with new_version") - return new_models_path[0][1] + if os.path.exists(new_models_path[0][1]): + print("env_constructed with new_version") + return new_models_path[0][1] if len(old_models_path) > 0 and new_version_rate <= rand < new_version_rate + old_version_rate: + idx = random.randint(0, len(old_models_path) - 1) print("env_constructed with old_version") - return old_models_path[random.randint(0, len(old_models_path) - 1)] - - print("env_constructed with existed_version") - return existed_models_path[random.randint(0, len(existed_models_path) - 1)] + if os.path.exists(old_models_path[idx]): + return old_models_path[idx] + + if len(existed_models_path) > 0: + print("env_constructed with existed_version") + idx = random.randint(0, len(existed_models_path) - 1) + if os.path.exists(existed_models_path[idx]): + return existed_models_path[idx] if __name__ == '__main__':