1 Star 0 Fork 0

xielinjiang/machine-learning-course

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
decisiontrees.py 3.07 KB
一键复制 编辑 原始数据 按行查看 历史
mergu 提交于 2019-04-19 08:03 +08:00 . Add code and splitting calculations
import graphviz
import itertools
import random
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.preprocessing import OneHotEncoder
# The possible values for each class
classes = {
'supplies': ['low', 'med', 'high'],
'weather': ['raining', 'cloudy', 'sunny'],
'worked?': ['yes', 'no']
}
# Our example data from the documentation
data = [
['low', 'sunny', 'yes'],
['high', 'sunny', 'yes'],
['med', 'cloudy', 'yes'],
['low', 'raining', 'yes'],
['low', 'cloudy', 'no' ],
['high', 'sunny', 'no' ],
['high', 'raining', 'no' ],
['med', 'cloudy', 'yes'],
['low', 'raining', 'yes'],
['low', 'raining', 'no' ],
['med', 'sunny', 'no' ],
['high', 'sunny', 'yes']
]
# Our target variable, whether someone went shopping
target = ['yes', 'no', 'no', 'no', 'yes', 'no', 'no', 'no', 'no', 'yes', 'yes', 'no']
# Scikit learn can't handle categorical data, so form numeric representations of the above data
# Categorical data support may be added in the future: https://github.com/scikit-learn/scikit-learn/pull/4899
categories = [classes['supplies'], classes['weather'], classes['worked?']]
encoder = OneHotEncoder(categories=categories)
x_data = encoder.fit_transform(data)
# Form and fit our decision tree to the now-encoded data
classifier = DecisionTreeClassifier()
tree = classifier.fit(x_data, target)
# Now that we have our decision tree, let's predict some outcomes from random data
# This goes through each class and builds a random set of 5 data points
prediction_data = []
for _ in itertools.repeat(None, 5):
prediction_data.append([
random.choice(classes['supplies']),
random.choice(classes['weather']),
random.choice(classes['worked?'])
])
# Use our tree to predict the outcome of the random values
prediction_results = tree.predict(encoder.transform(prediction_data))
# =============================================================================
# Output code
def format_array(arr):
return "".join(["| {:<10}".format(item) for item in arr])
def print_table(data, results):
line = "day " + format_array(list(classes.keys()) + ["went shopping?"])
print("-" * len(line))
print(line)
print("-" * len(line))
for day, row in enumerate(data):
print("{:<5}".format(day + 1) + format_array(row + [results[day]]))
print("")
feature_names = (
['supplies-' + x for x in classes["supplies"]] +
['weather-' + x for x in classes["weather"]] +
['worked-' + x for x in classes["worked?"]]
)
# Shows a visualization of the decision tree using graphviz
# Note that sklearn is unable to generate non-binary trees, so these are based on individual options in each class
dot_data = export_graphviz(tree, filled=True, proportion=True, feature_names=feature_names)
graph = graphviz.Source(dot_data)
graph.render(filename='decision_tree', cleanup=True, view=True)
# Display out training and prediction data and results
print("Training Data:")
print_table(data, target)
print("Predicted Random Results:")
print_table(prediction_data, prediction_results)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xielinjiangs/machine-learning-course.git
git@gitee.com:xielinjiangs/machine-learning-course.git
xielinjiangs
machine-learning-course
machine-learning-course
master

搜索帮助