代码拉取完成,页面将自动刷新
import json
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import time
import os
import re
from tqdm import tqdm
from colorama import Fore
from main import get_args
from gen_plot import DataUtils, STATIC
plt.rcParams["font.sans-serif"] = ["SimHei"] ###解决中文乱码
plt.rcParams["axes.unicode_minus"] = False
def read_json(file_path):
with open(file_path, "r") as f:
data = json.load(f)
return data
def read_save_info(args):
if args.task_name != "":
cur_path = f"output/Net-{args.net_idx}/{args.mode}/{args.evaluator}/{args.mutation}/{args.task_name}"
else:
cur_path = (
f"output/Net-{args.net_idx}/{args.mode}/{args.evaluator}/{args.mutation}"
)
label = f"{args.bee_num}_{args.max_itrs}"
return read_json(f"{cur_path}/{label}_save_info.json")
# a function to 比较不同的实验的cost下降结果
def compare_cost():
args = get_args()
args.bee_num = 20
args.max_itrs = 50
args.net_idx = 1
plt.figure(figsize=(20, 10))
for dim_num in ["D1", "D2", "D3", "D4", "D5", "D6"]:
# for loss in ["k2", "bic", "bde"]:
for loss in ["bic"]:
# for mutation in ["mutation", "cross_and_mutation", "cross"]:
for mutation in ["mutation"]:
args.mode = dim_num
args.evaluator = loss
args.mutation = mutation
data = read_save_info(args)
cost = data["convergence"]["best"]
plt.plot(cost, label=f"{dim_num}_{loss}_{mutation}")
plt.title(f"Net-{args.net_idx} cost comparison")
plt.legend()
cur_dir = f"output/analysis/Net-{args.net_idx}"
if not os.path.exists(cur_dir):
os.makedirs(cur_dir)
print(f"saving to {Fore.GREEN}{cur_dir}/cost_comparison.png{Fore.RESET}")
plt.savefig(f"{cur_dir}/cost_comparison.png")
plt.close()
def extract_score(data, score_name):
"""
输入文本是这样的 : "TP=1, TN=6, FP=69, FN=14, Precision=0.014285714285714285, Recall=0.06666666666666667, F1=0.023529411764705882",
要从中抽取TP,TN,FP,FN, Precision, Recall, F1等指标
"""
TP = re.findall(r"TP=(\d+)", data)
TN = re.findall(r"TN=(\d+)", data)
FP = re.findall(r"FP=(\d+)", data)
FN = re.findall(r"FN=(\d+)", data)
Precision = re.findall(r"Precision=(\d+.\d+)", data)
Recall = re.findall(r"Recall=(\d+.\d+)", data)
F1 = re.findall(r"F1=(\d+.\d+)", data)
if F1:
return {
"TP": int(TP[0]),
"TN": int(TN[0]),
"FP": int(FP[0]),
"FN": int(FN[0]),
"Precision": float(Precision[0]),
"Recall": float(Recall[0]),
"F1": float(F1[0]),
}
else:
return {
"TP": int(TP[0]),
"TN": int(TN[0]),
"FP": int(FP[0]),
"FN": int(FN[0]),
"Precision": float(Precision[0]),
"Recall": float(Recall[0]),
"F1": 0.0,
}
def compare_F1():
args = get_args()
args.net_idx = 1
args.bee_num = 20
args.max_itrs = 50
plt.figure(figsize=(20, 10))
f1_scores = [] # List to store F1 scores
for dim_num in ["D1", "D2", "D3", "D4", "D5", "D6"]:
# for loss in ["k2", "bic", "bde"]:
for loss in ["bic"]:
# for mutation in ["mutation", "cross_and_mutation", "cross"]:
for mutation in ["mutation"]:
args.mode = dim_num
args.evaluator = loss
args.mutation = mutation
data = read_save_info(args)
evaluation_data = data["union_eval_result"]
F1 = extract_score(evaluation_data, "F1")
f1_scores.append(
{
"label": f"{dim_num}_{loss}_{mutation}", # Append label to the list
"score": F1,
}
) # Append F1 score to the list
# 根据F1 score排序,绘制条形图
f1_scores = sorted(f1_scores, key=lambda x: x["score"]["F1"], reverse=True)
labels = [x["label"] for x in f1_scores]
scores = [x["score"]["F1"] for x in f1_scores]
plt.bar(labels, scores)
plt.title(f"Net-{args.net_idx} F1 score comparison")
plt.xticks(rotation=45, fontsize=16)
plt.tight_layout()
cur_dir = f"output/analysis/Net-{args.net_idx}"
if not os.path.exists(cur_dir):
os.makedirs(cur_dir)
print(f"saving to {Fore.GREEN}{cur_dir}/F1_comparison.png{Fore.RESET}")
plt.savefig(f"{cur_dir}/F1_comparison.png")
plt.close()
def steps_compare(size, normal_list, faster_list, show=True, save_dir=None):
color_1, color_2, color_3, color_4 = "#FFA07A", "#20B2AA", "#FF6347", "#4682B4"
# 绘制柱状图,比较size10的情况下加速和不加度的步数,每个net分别绘制一个柱子,并且加上平均值
plt.figure(figsize=(6, 5))
x = np.arange(1, 6)
plt.bar(x, normal_list, label="不使用DTW初始化", width=0.5, color=color_1)
plt.bar(x, faster_list, label="使用DTW初始化", width=0.5, color=color_2)
plt.axhline(
y=np.mean(normal_list), color=color_1, linestyle="--", label="不使用DTW的平均值"
)
plt.axhline(
y=np.mean(faster_list),
color=color_4,
linestyle="--",
label="使用DTW的平均值",
)
plt.xticks(x, [f"Net-{i}" for i in range(1, 6)])
plt.ylabel("迭代次数")
plt.title(f"收敛所需的迭代次数比较(Size-{size})")
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False)
plt.tight_layout()
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
size_path = f"{save_dir}/steps_compare_size_{size}.png"
plt.savefig(size_path)
print(f"{Fore.BLUE}save to {size_path}{Fore.RESET}")
if show:
plt.show()
pass
def compareDTW_steps(show=True, save_dir=None):
step_lists = {
"size_10": [],
"size_100": [],
"size_10_faster": [],
"size_100_faster": [],
}
for idx in range(1, 6):
best_list, mean_list, std_list = DataUtils.read_cost_data(
idx, 10, scale_factor=1, add_factor=0, task=STATIC.FASTER_TASK
)
# 找到best_list中最小值第一次出现的位置
min_idx = np.argmin(best_list)
step_lists["size_10_faster"].append(int(min_idx))
best_list, mean_list, std_list = DataUtils.read_cost_data(
idx, 100, scale_factor=1, add_factor=0, task=STATIC.FASTER_TASK
)
min_idx = np.argmin(best_list)
step_lists["size_100_faster"].append(int(min_idx))
best_list, mean_list, std_list = DataUtils.read_cost_data(
idx, 10, scale_factor=1, add_factor=0, task=STATIC.NORMAL_TASK
)
min_idx = np.argmin(best_list)
step_lists["size_10"].append(int(min_idx))
best_list, mean_list, std_list = DataUtils.read_cost_data(
idx, 100, scale_factor=1, add_factor=0, task=STATIC.NORMAL_TASK
)
min_idx = np.argmin(best_list)
step_lists["size_100"].append(int(min_idx))
result = {
"step_lists": step_lists,
"size_10": np.mean(step_lists["size_10"]),
"size_100": np.mean(step_lists["size_100"]),
"size_10_faster": np.mean(step_lists["size_10_faster"]),
"size_100_faster": np.mean(step_lists["size_100_faster"]),
}
# 计算速度提升百分比
result["size_10_speedup"] = round(
(result["size_10"] - result["size_10_faster"]) / result["size_10"], 2
)
result["size_100_speedup"] = round(
(result["size_100"] - result["size_100_faster"]) / result["size_100"], 2
)
print(result)
color_1, color_2, color_3, color_4 = "#FFA07A", "#20B2AA", "#FF6347", "#4682B4"
steps_compare(
10, step_lists["size_10"], step_lists["size_10_faster"], show, save_dir
)
steps_compare(
100, step_lists["size_100"], step_lists["size_100_faster"], show, save_dir
)
if save_dir is not None:
print(f"{Fore.BLUE}saving to {save_dir}/steps_compare.json{Fore.RESET}")
with open(f"{save_dir}/steps_compare.json", "w") as f:
json.dump(result, f, indent=4)
def compare_with_dyn():
light_color, green_color, bold_color = "#C4A5DE", "#20B2AA", "#B883D4"
cur_size = 10
# 从xlsx中读取数据
abc_metrics_path = "output/report/ours/metrics.xlsx"
print(f"{Fore.GREEN}Reading from {abc_metrics_path}{Fore.RESET}")
df = pd.read_excel(abc_metrics_path)
# 获取F1 score
f1_scores = df["F1"].values.tolist()
# 计算平均值,放到最后
f1_scores.append(np.mean(f1_scores))
# 从json中读取数据
dyn_path = "output/report/dyn/dyn_size10_info.json"
print(f"{Fore.GREEN}Reading from {dyn_path}{Fore.RESET}")
with open(dyn_path, "r") as f:
dyn_info = json.load(f)
# 获取F1 score
dyn_f1_scores = [_item["f1_score"] for _item in dyn_info["results"].values()]
# 绘制F1 score的对比图,bar
plt.figure(figsize=(10, 6))
x = np.arange(len(f1_scores))
width = 0.35
plt.bar(x - width / 2, f1_scores, width, label="ABC Method", color=green_color)
plt.bar(
x + width / 2, dyn_f1_scores, width, label="dynGENIE3 Method", color=light_color
)
plt.xlabel("Experiment")
plt.ylabel("F1 Score")
plt.grid()
plt.title("Comparison of F1 Scores")
plt.xticks(x, [f"Net-{i}" for i in range(1, 6)] + ["Average"])
plt.legend()
plt.tight_layout()
print(f"{Fore.BLUE}save to output/report/dyn/F1_score_compare.png{Fore.RESET}")
plt.savefig("output/report/dyn/F1_score_compare.png")
plt.close()
cur_size = 100
# 从xlsx中读取数据
abc_metrics_path = "output/report/ours/metrics.xlsx"
print(f"{Fore.GREEN}Reading from {abc_metrics_path}{Fore.RESET}")
df = pd.read_excel(abc_metrics_path)
# 获取F1 score
f1_scores = df["F1"].values.tolist()
# 计算平均值,放到最后
f1_scores.append(np.mean(f1_scores))
# 从json中读取数据
dyn_path = "output/report/dyn/dyn_size10_info.json"
print(f"{Fore.GREEN}Reading from {dyn_path}{Fore.RESET}")
with open(dyn_path, "r") as f:
dyn_info = json.load(f)
# 获取F1 score
dyn_f1_scores = [_item["f1_score"] for _item in dyn_info["results"].values()]
# 绘制F1 score的对比图,bar
plt.figure(figsize=(10, 6))
x = np.arange(len(f1_scores))
width = 0.35
plt.bar(x - width / 2, f1_scores, width, label="ABC Method", color=green_color)
plt.bar(
x + width / 2, dyn_f1_scores, width, label="dynGENIE3 Method", color=light_color
)
plt.xlabel("Experiment")
plt.ylabel("F1 Score")
plt.grid()
plt.title("Comparison of F1 Scores")
plt.xticks(x, [f"Net-{i}" for i in range(1, 6)] + ["Average"])
plt.legend()
plt.tight_layout()
print(f"{Fore.BLUE}save to output/report/dyn/F1_score_compare.png{Fore.RESET}")
plt.savefig("output/report/dyn/F1_score_compare.png")
plt.close()
if __name__ == "__main__":
# compare_cost()
# compare_F1()
# compareDTW_steps(
# show=False,
# save_dir="output/report/compare/DTW",
# )
# compare_with_dyn()
print("DONE~!")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。