代码拉取完成,页面将自动刷新
同步操作将从 Renovamen/Intrusion-Detection 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import time
import numpy as np
from sklearn.linear_model import SGDClassifier, LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, mean_squared_error,mean_absolute_error
from sklearn.neural_network import MLPClassifier
import pickle
from sklearn.externals import joblib
from Config import Config
from Feature import get_data
class Model(object):
def __init__(self, model, params = None):
if params == None:
self.model = model()
else:
self.model = model(**params)
def train(self, x_train, y_train):
print("------------------ Start Training ------------------")
print(self.model)
start_time = time.time()
self.model.fit(x_train, y_train)
train_time = time.time() - start_time
print("train_time:%.3fs" %train_time)
print("------------------- End Training -------------------")
def predict(self, x):
return self.model.predict(x)
def evaluate(self, x_test, y_test):
print("----------------- Start Evaluating -----------------")
start_time = time.time()
predictions = self.model.predict(x_test)
test_time = time.time() - start_time
accuracy = accuracy_score(y_test, predictions)
if(Config.if_multi == False):
recall = recall_score(y_test, predictions, average="binary")
precision = precision_score(y_test, predictions, average="binary")
f1 = f1_score(y_test, predictions, average="binary")
else:
recall = recall_score(y_test, predictions, average="micro")
precision = precision_score(y_test, predictions, average="micro")
f1 = f1_score(y_test, predictions, average="weighted")
# 混淆矩阵
cm = confusion_matrix(y_test, predictions)
tpr = float(cm[0][0])/np.sum(cm[0])
fpr = float(cm[1][1])/np.sum(cm[1])
print(cm)
print("tpr:%.3f" %tpr)
print("fpr:%.3f" %fpr)
print("accuracy:%.3f" %accuracy)
print("precision:%.3f" %precision)
print("recall:%.3f" %recall)
print("f-score:%.3f" %f1)
print("test_time:%.3fs" %test_time)
print("------------------ End Evaluating ------------------")
def save_model(self, model_name):
save_path = 'Models/' + model_name + '.m'
pickle.dump(self.model, open(save_path, "wb"))
def load_model(model_name):
model_path = 'Models/' + model_name + '.m'
model = joblib.load(model_path)
return model
if __name__ == '__main__':
config = Config()
x_train, x_test, y_train, y_test = get_data()
# creat model
rf = Model(model = RandomForestClassifier, params = config.rf_params)
et = Model(model = ExtraTreesClassifier, params = config.et_params)
ada = Model(model = AdaBoostClassifier, params = config.ada_params)
gb = Model(model = GradientBoostingClassifier, params = config.gb_params)
svm = Model(model = SVC, params = config.svc_params)
dt = Model(model = DecisionTreeClassifier)
sgd = Model(model = SGDClassifier)
lr = Model(model = LogisticRegression)
gnb = Model(model = GaussianNB)
kn = Model(model = KNeighborsClassifier)
mlp = Model(model = MLPClassifier, params = config.mlp_params)
# train model
rf.train(x_train, y_train)
# evaluate model
rf.evaluate(x_test, y_test)
# save model
rf.save_model("rf")
# load model
rf_load = load_model("rf")
rf_load.evaluate(x_test, y_test)
et.train(x_train, y_train)
et.evaluate(x_test, y_test)
ada.train(x_train, y_train)
ada.evaluate(x_test, y_test)
gb.train(x_train, y_train)
gb.evaluate(x_test, y_test)
svm.train(x_train, y_train)
svm.evaluate(x_test, y_test)
dt.train(x_train, y_train)
dt.evaluate(x_test, y_test)
sgd.train(x_train, y_train)
sgd.evaluate(x_test, y_test)
lr.train(x_train, y_train)
lr.evaluate(x_test, y_test)
gnb.train(x_train, y_train)
gnb.evaluate(x_test, y_test)
kn.train(x_train, y_train)
kn.evaluate(x_test, y_test)
mlp.train(x_train, y_train)
mlp.evaluate(x_test, y_test)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。