1 Star 0 Fork 2

xiaxiaxia110/Intrusion-Detection

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Bayesnet.py 7.23 KB
一键复制 编辑 原始数据 按行查看 历史
Renovamen 提交于 2019-06-17 02:38 +08:00 . [Feat] save models and load models
'''
--------------- Bayesian Network ---------------
NOTE:
A --> B: B 为 A 的父节点
F:
key: 某节点
values:
children: 儿子节点列表
parents: 父节点列表
values: key 的所有可能取值
cpt: 条件概率表
V:
节点列表
E:
key: 某节点
values: 儿子节点列表
'''
from copy import copy, deepcopy
import numpy as np
import json
from collections import OrderedDict
import matplotlib.pyplot as plt
import networkx as nx
class BayesNet(object):
def __init__(self, E=None, value_dict=None, file=None):
if file is not None:
bn = read_bn(file)
self.V = bn.V
self.E = bn.E
self.F = bn.F
else:
if E is not None:
#assert (value_dict is not None), 'Must set values if E is set.'
self.set_structure(E, value_dict)
else:
self.V = []
self.E = {}
self.F = {}
def copy(self):
V = deepcopy(self.V)
E = deepcopy(self.E)
F = {}
for v in V:
F[v] = {}
F[v]['cpt'] = deepcopy(self.F[v]['cpt'])
F[v]['parents'] = deepcopy(self.F[v]['parents'])
F[v]['values'] = deepcopy(self.F[v]['values'])
bn = BayesNet()
bn.V = V
bn.E = E
bn.F = F
return bn
def nodes(self):
for v in self.V:
yield v
def has_node(self, rv):
return rv in self.V
def has_edge(self, u, v):
return v in self.E[u]
def edges(self):
for u in self.nodes():
for v in self.E[u]:
yield (u,v)
def cpt(self, rv):
return self.F[rv]['cpt']
def card(self, rv):
return len(self.F[rv]['values'])
def scope(self, rv):
scope = [rv]
scope.extend(self.F[rv]['parents'])
return scope
def parents(self, rv):
return self.F[rv]['parents']
def children(self, rv):
return self.E[rv]
def values(self, rv):
return self.F[rv]['values']
def value_idx(self, rv, val):
try:
return self.F[rv]['values'].index(val)
except ValueError:
# print("Value Index Error")
return -1
def stride(self, rv, n):
if n==rv:
return 1
else:
card_list = [self.card(rv)]
card_list.extend([self.card(p) for p in self.parents(rv)])
n_idx = self.parents(rv).index(n) + 1
return int(np.prod(card_list[0:n_idx]))
def cpt_indices(self, target, val_dict):
stride = dict([(n,self.stride(target,n)) for n in self.scope(target)])
card = dict([(n, self.card(n)) for n in self.scope(target)])
idx = set(range(len(self.cpt(target))))
for rv, val in val_dict.items():
val_idx = self.value_idx(rv,val)
rv_idx = []
s_idx = val_idx*stride[rv]
while s_idx < len(self.cpt(target)):
rv_idx.extend(range(s_idx,(s_idx+stride[rv])))
s_idx += stride[rv]*card[rv]
idx = idx.intersection(set(rv_idx))
return list(idx)
'''
set_structure():
建立结构学习后的贝叶斯网络结构
input:
edge_dict: dict
key: 节点
value: key 的儿子节点列表
value_dict: dict
key: 节点
value: key 的所有可能取值
'''
def set_structure(self, edge_dict, value_dict=None):
self.V = topsort(edge_dict)
self.E = edge_dict
self.F = dict([(rv,{}) for rv in self.nodes()])
for rv in self.nodes():
self.F[rv] = {
'parents':[p for p in self.nodes() if rv in self.children(p)],
'cpt': [],
'values': []
}
if value_dict is not None:
self.F[rv]['values'] = value_dict[rv]
'''
adj_list(): 相邻节点列表
'''
def adj_list(self):
adj_list = [[] for _ in self.V]
vi_map = dict((self.V[i],i) for i in range(len(self.V)))
for u,v in self.edges():
adj_list[vi_map[u]].append(vi_map[v])
return adj_list
'''
write_bn():
把贝叶斯网络(结构和参数)存入 .bn 文件
input:
bn: 贝叶斯网络对象
path: 存储路径
'''
def write_bn(self, path):
bn_dict = OrderedDict([('V', self.V), ('E', self.E), ('F', self.F)])
with open(path, 'w') as outfile:
json.dump(bn_dict, outfile, indent = 2, cls = NpEncoder)
'''
plot_bn():
画出贝叶斯网络结构图
input:
labels: 节点标签(可选,如果不输入则直接用节点编号当标签)
'''
def plot(self, labels = None):
# 创建空有向图
G = nx.DiGraph()
for i in self.E:
if labels is not None:
G.add_node(labels[i])
else:
G.add_node(i)
for j in self.E[i]:
if labels is not None:
G.add_edges_from([(labels[j], labels[i])])
else:
G.add_edges_from([(j, i)])
nx.draw(G, with_labels = True)
plt.savefig("result.png")
plt.show()
'''
read_bn():
从 .bn 文件中读取贝叶斯网络(结构和参数)
input:
path: 文件路径
output:
贝叶斯网络对象
'''
def read_bn(path):
def byteify(input):
if isinstance(input, dict):
return {byteify(key):byteify(value) for key,value in input.items()}
if isinstance(input, list):
return [byteify(element) for element in input]
if isinstance(input, str):
if(input >= '0' and input <= '9'):
return int(input)
return input
bn = BayesNet()
f = open(path, 'r')
ftxt = f.read()
success = False
try:
data = byteify(json.loads(ftxt))
bn.V = data['V']
bn.E = data['E']
bn.F = data['F']
success = True
except ValueError:
print("Could not read file.")
bn.V = topsort(bn.E)
return bn
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NpEncoder, self).default(obj)
'''
topsort():
拓扑排序的节点列表
input:
edge_dict: dict
key: 节点
value: key 的儿子节点列表
output:
拓扑排序后的节点列表
'''
def topsort(edge_dict, root=None):
queue = []
if root is not None:
queue = [root]
else:
for rv in edge_dict.keys():
prior = True
for p in edge_dict.keys():
if rv in edge_dict[p]:
prior = False
if prior == True:
queue.append(rv)
print(edge_dict)
visited = []
while queue:
vertex = queue.pop(0)
if vertex not in visited:
visited.append(vertex)
for nbr in edge_dict[vertex]:
queue.append(nbr)
return visited
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xiaxiaxia110/Intrusion-Detection.git
git@gitee.com:xiaxiaxia110/Intrusion-Detection.git
xiaxiaxia110
Intrusion-Detection
Intrusion-Detection
master

搜索帮助