決定木(分類)

# Google Colaboratory で実行する場合はインストールする
if str(get_ipython()).startswith("<google.colab."):
    !pip install dtreeviz
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.tree import DecisionTreeClassifier, plot_tree
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Input In [2], in <cell line: 2>()
      1 import numpy as np
----> 2 import matplotlib.pyplot as plt
      3 from sklearn.datasets import make_classification
      4 from sklearn.tree import DecisionTreeClassifier, plot_tree

ModuleNotFoundError: No module named 'matplotlib'
# 表示する文字サイズを調整
plt.rc("font", size=20)
plt.rc("legend", fontsize=16)
plt.rc("xtick", labelsize=14)
plt.rc("ytick", labelsize=14)

# 乱数
np.random.seed(777)

サンプルデータを作成

2クラス分類をするためのサンプルデータを作成します。

n_classes = 2
X, y = make_classification(
    n_samples=100,
    n_features=2,
    n_redundant=0,
    n_informative=2,
    random_state=2,
    n_classes=n_classes,
    n_clusters_per_class=1,
)

決定木を作成

DecisionTreeClassifier(criterion="gini").fit(X, y)でモデルを訓練し、作成した木の決定境界を可視化します。 criterion="gini"は分岐を決めるための指標を指定するオプションです。

# モデルの訓練
clf = DecisionTreeClassifier(criterion="gini").fit(X, y)

# 決定境界のカラーマップ用データ
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# 決定境界を可視化
plt.figure(figsize=(8, 8))
plt.tight_layout()
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Pastel1)
plt.xlabel("x1")
plt.ylabel("x2")

# ラベルごとに色を分けてプロット
for i, color, label_name in zip(range(n_classes), ["r", "b"], ["A", "B"]):
    idx = np.where(y == i)
    plt.scatter(X[idx, 0], X[idx, 1], c=color, label=label_name, cmap=plt.cm.Pastel1)

plt.legend()
plt.show()
../../../_images/決定木(分類)_7_0.png

決定木の構造を画像で出

plt.figure()
clf = DecisionTreeClassifier(criterion="gini").fit(X, y)
plt.figure(figsize=(12, 12))
plot_tree(clf, filled=True)
plt.show()
<Figure size 432x288 with 0 Axes>
../../../_images/決定木(分類)_9_1.png