{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [],
"user_expressions": []
},
"source": [
"# 決定木から始める機械学習"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"このHands-onでは,機械学習手法のひとつである**決定木**を使って,あらかじめ与えられたデータから,未知データを分類する規則を抽出・適用する**教師あり学習**を体験する.\n",
"このHands-onで用いるデータは以下の通り:\n",
"\n",
"* アヤメ(花の種類)のデータ\n",
"* タイタニック号の乗船者データ"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"まず,必要なライブラリを準備しよう.\n",
"Google Colaboratory(もしくはJupyter)に\n",
"* graphviz\n",
"* category_encoders\n",
"\n",
"の2つのライブラリをインストールするために, 以下のコードを実行しよう."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": [
"remove-out"
]
},
"outputs": [],
"source": [
"try:\n",
" import category_encoders\n",
" import graphviz\n",
"except:\n",
" !pip install graphviz\n",
" !pip install category_encoders"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"続けて,必要なライブラリを読み込む.\n",
"以下のコードを実行しよう."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# 表形式のデータを操作するためのライブラリ\n",
"import pandas as pd\n",
"\n",
"# 機械学習用ライブラリsklearn\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.tree import export_graphviz\n",
"\n",
"# その他\n",
"import category_encoders\n",
"\n",
"# グラフ描画ライブラリ\n",
"from graphviz import Source\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"\n",
"---\n",
"## 例題1: アヤメ"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"データマイニングや機械学習を学ぶ際,例題データとしてアヤメ(英語名:Iris)データがよく用いられる([アヤメ](https://ja.wikipedia.org/wiki/%E3%82%A2%E3%83%A4%E3%83%A1)は植物の1つ). \n",
"決定木アルゴリズムを体験する題材として,このHands-onでもアヤメデータを使ってみよう.\n",
"\n",
"以下のコードを実行して,アヤメのデータを読み込みむ."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" sepal_length | \n",
" sepal_width | \n",
" petal_length | \n",
" petal_width | \n",
" species | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 5.1 | \n",
" 3.5 | \n",
" 1.4 | \n",
" 0.2 | \n",
" setosa | \n",
"
\n",
" \n",
" 1 | \n",
" 4.9 | \n",
" 3.0 | \n",
" 1.4 | \n",
" 0.2 | \n",
" setosa | \n",
"
\n",
" \n",
" 2 | \n",
" 4.7 | \n",
" 3.2 | \n",
" 1.3 | \n",
" 0.2 | \n",
" setosa | \n",
"
\n",
" \n",
" 3 | \n",
" 4.6 | \n",
" 3.1 | \n",
" 1.5 | \n",
" 0.2 | \n",
" setosa | \n",
"
\n",
" \n",
" 4 | \n",
" 5.0 | \n",
" 3.6 | \n",
" 1.4 | \n",
" 0.2 | \n",
" setosa | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sepal_length sepal_width petal_length petal_width species\n",
"0 5.1 3.5 1.4 0.2 setosa\n",
"1 4.9 3.0 1.4 0.2 setosa\n",
"2 4.7 3.2 1.3 0.2 setosa\n",
"3 4.6 3.1 1.5 0.2 setosa\n",
"4 5.0 3.6 1.4 0.2 setosa"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn import datasets\n",
"\n",
"# Iris(アヤメ)の大きさに関するデータをロード\n",
"iris = datasets.load_iris()\n",
"iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
"iris_df['species'] = iris.target_names[iris.target]\n",
"\n",
"# 簡単のために,カラム名を修正しておく\n",
"iris_df = iris_df.rename(\n",
" columns = {\n",
" 'sepal length (cm)': 'sepal_length',\n",
" 'sepal width (cm)': 'sepal_width',\n",
" 'petal length (cm)': 'petal_length',\n",
" 'petal width (cm)': 'petal_width'\n",
" }\n",
")\n",
"\n",
"# 最初の数件を表示\n",
"iris_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"このアヤメデータには,花弁(petal)の長さ・幅,がく(sepal)の長さ・幅,品種が記されている.\n",
"例題1の目標は,**花弁の長さ・幅,がくの長さ・幅から品種を推定する予測モデルの構築**である.\n",
"早速,決定木を用いて予測モデルを構築してみよう."
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"一般に教師あり学習で予測を行うモデルを構築する際には,データを**学習用(訓練)データ**と**評価用データ**に分割してデータ分析を行う.\n",
"以下のコードを実行して,先ほど用意したデータを学習用(70%)と評価用(30%)に分割する."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# データを学習用(70%)と評価用(30%)に分割する\n",
"iris_train_df, iris_test_df = train_test_split(\n",
" iris_df, test_size=0.3,\n",
" random_state=1,\n",
" stratify=iris_df.species)"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"変数``iris_test_df``には品種情報も含まれる.\n",
"予測モデルの性能評価の際には,品種情報が未知であるとして予測を行い,予測結果と(隠しておいた)品種情報を照らし合わせて評価することになる."
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"では,教師あり学習のひとつである決定木アルゴリズムを適用してみよう.\n",
"``iris_train_df``に決定木アルゴリズムを適用して,品種を見分けるルールを抽出(学習)しよう.\n",
"\n",
"決定木アルゴリズムは`sklearn`ライブラリの``DecisionTreeClassifier``クラスを使って実行できる.\n",
"下記コードを実行してみよう."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeClassifier(criterion='entropy', random_state=12345)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# X_trainは,品種(Species)以外のすべての指標\n",
"features = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']\n",
"X_train = iris_train_df[features]\n",
"\n",
"# y_trainは品種の指標\n",
"y_train = iris_train_df.species\n",
"\n",
"# 学習\n",
"model = DecisionTreeClassifier(criterion='entropy',\n",
" random_state=12345) # 初期値を固定\n",
"model.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"品種を予測するルールが学習された.\n",
"以下のコードを実行して,予測ルールをわかりやすく可視化してみよう."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Source(export_graphviz(model, out_file=None,\n",
" feature_names=features,\n",
" class_names=['setosa', 'versicolor', 'virginica'],\n",
" proportion=True,\n",
" filled=True, rounded=True # 見た目の調整\n",
" ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"分類ルールが木のように枝分かれした形で可視化された.\n",
"この可視化結果が,今回の教師あり学習アルゴリズムが **決定「木」** と呼ばれる所以である.\n",
"\n",
"各四角が分類ルールの分岐を表している.\n",
"四角の下に書かれた文字情報が分岐条件を示している.\n",
"四角中に書かれた文字は,四角に至るまでに適用された分岐条件を満たすと,\n",
"* その条件を満たすデータが全体の何パーセントあるか\n",
"* ラベルごとの分類結果の割合が何パーセントか\n",
"\n",
"を示している.\n",
"例えば,上図の上から3段目の左にある「class=versicolor, value=\\[0.0, 1.00, 0.0\\]」という四角は,\n",
"* 花弁(petal)の長さが2.6より大きい,かつ花弁(petal)の長さが4.75以下の場合,その個体は100%の確率でversicolorであること\n",
"* この条件にマッチする個体はデータセットに28.6%存在すること\n",
"\n",
"を示している."
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"さて,ここまでやったことは予測のためのルール(モデル)の構築であった.\n",
"構築した予測モデルを使って,未知のデータを予測してみよう.\n",
"この例題の冒頭で,変数``iris_test_df``に**予測モデルの構築に使われていないデータ**を別途用意していたことを思い出そう."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" sepal_length | \n",
" sepal_width | \n",
" petal_length | \n",
" petal_width | \n",
" species | \n",
"
\n",
" \n",
" \n",
" \n",
" 148 | \n",
" 6.2 | \n",
" 3.4 | \n",
" 5.4 | \n",
" 2.3 | \n",
" virginica | \n",
"
\n",
" \n",
" 5 | \n",
" 5.4 | \n",
" 3.9 | \n",
" 1.7 | \n",
" 0.4 | \n",
" setosa | \n",
"
\n",
" \n",
" 6 | \n",
" 4.6 | \n",
" 3.4 | \n",
" 1.4 | \n",
" 0.3 | \n",
" setosa | \n",
"
\n",
" \n",
" 106 | \n",
" 4.9 | \n",
" 2.5 | \n",
" 4.5 | \n",
" 1.7 | \n",
" virginica | \n",
"
\n",
" \n",
" 75 | \n",
" 6.6 | \n",
" 3.0 | \n",
" 4.4 | \n",
" 1.4 | \n",
" versicolor | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sepal_length sepal_width petal_length petal_width species\n",
"148 6.2 3.4 5.4 2.3 virginica\n",
"5 5.4 3.9 1.7 0.4 setosa\n",
"6 4.6 3.4 1.4 0.3 setosa\n",
"106 4.9 2.5 4.5 1.7 virginica\n",
"75 6.6 3.0 4.4 1.4 versicolor"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 最初の数件を表示\n",
"iris_test_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"先ほど構築した予測モデルをこの``iris_test_df``に適用して,未知データのアヤメの品種を予測してみよう.\n",
"構築した予測モデル``iris_model``を用いて未知データを予測するには``predict``関数を用いる."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['virginica', 'setosa', 'setosa', 'versicolor', 'versicolor',\n",
" 'versicolor', 'virginica', 'versicolor', 'virginica', 'setosa',\n",
" 'setosa', 'virginica', 'setosa', 'versicolor', 'setosa',\n",
" 'versicolor', 'virginica', 'versicolor', 'versicolor', 'virginica',\n",
" 'virginica', 'setosa', 'versicolor', 'virginica', 'versicolor',\n",
" 'versicolor', 'versicolor', 'virginica', 'setosa', 'virginica',\n",
" 'setosa', 'setosa', 'versicolor', 'versicolor', 'virginica',\n",
" 'virginica', 'setosa', 'setosa', 'setosa', 'versicolor',\n",
" 'virginica', 'virginica', 'versicolor', 'setosa', 'setosa'],\n",
" dtype=object)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 評価用データの特徴量と正解ラベルを取得\n",
"X_test = iris_test_df[features]\n",
"y_test = iris_test_df.species\n",
"\n",
"# 予測モデルを使って,品種が未知の個体の品種を推定\n",
"iris_predicted = model.predict(X_test)\n",
"\n",
"# 予測結果の一部を表示\n",
"iris_predicted"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"予測結果が変数``iris_predicted``に格納された.\n",
"``iris_test_df``の列``Species``には実際の品種情報が格納されていた.これと予測結果と照らし合わせて,予測性能を評価してみよう.\n",
"\n",
"予測性能の評価指標には様々なものがあるが,ここでは精度(accuracy)を計算してみよう.\n",
"精度は「予測結果のうち, **各個体の品種について,予測モデルが予測したものと,実際の品種が一致したケースの割合」** を意味する.\n",
"精度の計算には`sklearn`の`accuracy_score`関数を用いる.\n",
"第1引数に予測結果,第2引数に実際の結果を入力します.以下のコードを実行してみよう."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9777777777777777"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_score(iris_predicted, iris_test_df.species)"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"上記結果によると,Accuracyは約97.8%を示しており,かなりの精度で品種を予測できていることが分かる."
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"\n",
"---\n",
"## 例題2: タイタニック号の乗船者データ"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"1912年4月14日,処女航海中の豪華客船タイタニック号は多くの乗船者を乗せたまま沈没した.\n",
"タイタニックとその事故は映画化されるなど世界的に有名である.\n",
"\n",
"乗船者に関する情報が残っていたために,事故後,多くの人が事故に関する分析を行った.\n",
"私たちもタイタニック号の乗船者情報を用いて,生死を分けた条件について分析を行ってみよう.\n",
"以下のコードを実行して,タイタニック号の乗船者(の一部)のデータを読み込もう([★Quiz 1](#C2-Q1))."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" PassengerId | \n",
" Survived | \n",
" Pclass | \n",
" Name | \n",
" Sex | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Ticket | \n",
" Fare | \n",
" Cabin | \n",
" Embarked | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" died | \n",
" 3 | \n",
" Braund, Mr. Owen Harris | \n",
" male | \n",
" 22.0 | \n",
" 1 | \n",
" 0 | \n",
" A/5 21171 | \n",
" 7.2500 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" survived | \n",
" 1 | \n",
" Cumings, Mrs. John Bradley (Florence Briggs Th... | \n",
" female | \n",
" 38.0 | \n",
" 1 | \n",
" 0 | \n",
" PC 17599 | \n",
" 71.2833 | \n",
" C85 | \n",
" C | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" survived | \n",
" 3 | \n",
" Heikkinen, Miss. Laina | \n",
" female | \n",
" 26.0 | \n",
" 0 | \n",
" 0 | \n",
" STON/O2. 3101282 | \n",
" 7.9250 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" survived | \n",
" 1 | \n",
" Futrelle, Mrs. Jacques Heath (Lily May Peel) | \n",
" female | \n",
" 35.0 | \n",
" 1 | \n",
" 0 | \n",
" 113803 | \n",
" 53.1000 | \n",
" C123 | \n",
" S | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" died | \n",
" 3 | \n",
" Allen, Mr. William Henry | \n",
" male | \n",
" 35.0 | \n",
" 0 | \n",
" 0 | \n",
" 373450 | \n",
" 8.0500 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" PassengerId Survived Pclass \\\n",
"0 1 died 3 \n",
"1 2 survived 1 \n",
"2 3 survived 3 \n",
"3 4 survived 1 \n",
"4 5 died 3 \n",
"\n",
" Name Sex Age SibSp \\\n",
"0 Braund, Mr. Owen Harris male 22.0 1 \n",
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n",
"2 Heikkinen, Miss. Laina female 26.0 0 \n",
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n",
"4 Allen, Mr. William Henry male 35.0 0 \n",
"\n",
" Parch Ticket Fare Cabin Embarked \n",
"0 0 A/5 21171 7.2500 NaN S \n",
"1 0 PC 17599 71.2833 C85 C \n",
"2 0 STON/O2. 3101282 7.9250 NaN S \n",
"3 0 113803 53.1000 C123 S \n",
"4 0 373450 8.0500 NaN S "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# データの読み込み\n",
"url = \"https://raw.githubusercontent.com/hontolab-courses/ml-lecturenote/refs/heads/main/content/data/titanic_train.csv\"\n",
"titanic_df = pd.read_table(url, header=0, sep=\",\")\n",
"\n",
"# 生存情報を分かりやすくする\n",
"titanic_df = titanic_df.assign(\n",
" Survived = lambda df: df.Survived.map({1: 'survived', 0: 'died'})\n",
")\n",
"\n",
"# 最初の数件のみ表示\n",
"titanic_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"様々な情報が表示された.\n",
"変数``titanic_train_df``に格納されたデータの属性(列名)の詳細は以下の通り:\n",
"\n",
"* PassengerId: 乗船者を識別するためのID\n",
"* Survived: ある乗船者が沈没事故で生き残った否かを示すフラグ.\n",
"* Pclass: チケットの等級.1は1等乗客,2は2等乗客,3は3等乗客を表す\n",
"* Name: 乗客名\n",
"* Sex: 性別\n",
"* Age: 年齢\n",
"* SibSp: タイタニック号に同乗した兄弟もしくは配偶者の数\n",
"* Parch: タイタニック号に乗船した両親もしくは子どもの数\n",
"* Ticket: チケット番号\n",
"* Fare: 乗船料金\n",
"* Cabin: 客室番号\n",
"* Embarked: 乗船した港.C = Cherbourg, Q = Queenstown, S = Southampton"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"このデータを用いて,どんな乗客が生き残れたのかを予測できるようにしよう."
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"決定木を適用する前に,``titanic_df``データに対して簡易的な分析を行い,各データ属性と生存情報との関係を眺めてみよう.\n",
"以下のコードを実行すると, **乗客の等級(Pclass)と生存の有無(Survived)** の属性の値を集計して,ある等級の乗客のうち生き残った方の割合が表示される."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" Pclass | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
"
\n",
" \n",
" Survived | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" died | \n",
" 0.37037 | \n",
" 0.527174 | \n",
" 0.757637 | \n",
"
\n",
" \n",
" survived | \n",
" 0.62963 | \n",
" 0.472826 | \n",
" 0.242363 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Pclass 1 2 3\n",
"Survived \n",
"died 0.37037 0.527174 0.757637\n",
"survived 0.62963 0.472826 0.242363"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.crosstab(titanic_df['Survived'], titanic_df['Pclass'], normalize='columns')"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"分析の結果,どうやら等級が高い(数値が小さい)ほど生き残っている方の割合が大きいようだ.\n",
"等級以外の属性でも同様の分析を行ってみよう.\n",
"例えば,性別(Sex)と生存の有無の関係は以下のコードで得られる([★Quiz 2](#C2-Q2),[★Quiz 3](#C2-Q3))."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" Sex | \n",
" female | \n",
" male | \n",
"
\n",
" \n",
" Survived | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" died | \n",
" 0.257962 | \n",
" 0.811092 | \n",
"
\n",
" \n",
" survived | \n",
" 0.742038 | \n",
" 0.188908 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Sex female male\n",
"Survived \n",
"died 0.257962 0.811092\n",
"survived 0.742038 0.188908"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.crosstab(titanic_df['Survived'], titanic_df['Sex'], normalize='columns')"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"決定木アルゴリズムを適用する前に,データの欠損を確認しておこう.\n",
"収集したデータの一部が欠損していることはよくある.\n",
"欠損値がデータに含まれると,機械学習のアルゴリズムがうまく動作しない場合がある.\n",
"\n",
"欠損値がある場合の対応は,\n",
"* 欠損しているデータを捨てる\n",
"* 欠損値を代表的な値で埋める\n",
"\n",
"といったアプローチが採られることが多い.\n",
"欠損しているデータを捨ててしまうと,学習に用いる貴重なデータが減るので,今回は欠損値を代表値で埋める.\n",
"\n",
"まず,以下のコードを走らせて,欠損値を確認してみよう."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PassengerId 0\n",
"Survived 0\n",
"Pclass 0\n",
"Name 0\n",
"Sex 0\n",
"Age 177\n",
"SibSp 0\n",
"Parch 0\n",
"Ticket 0\n",
"Fare 0\n",
"Cabin 687\n",
"Embarked 2\n",
"dtype: int64"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"titanic_df.isnull().sum()"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"上の結果から,Age,Cabin,Embarkedに欠損値が含まれることが分かる.\n",
"Cabinは乗船客に与えられた固有の情報で,生存者の予測には役立たない.\n",
"AgeとEmbarkedのみ欠損値を埋めることにしよう.\n",
"\n",
"欠損値を埋めるには様々な方法が提案されているが,今回は\n",
"* Ageは中央値\n",
"* Embarkedは最頻値\n",
"\n",
"で埋めることにする.\n",
"以下のコードを実行しよう."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Embarkedの欠損を最頻値で埋める\n",
"titanic_df[\"Embarked\"] = titanic_df[\"Embarked\"].fillna(titanic_df[\"Embarked\"].mode().iloc[0]) \n",
"\n",
"# Ageを中央値で埋める\n",
"titanic_df[\"Age\"] = titanic_df[\"Age\"].fillna(titanic_df[\"Age\"].median()) "
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"これで欠損値はなくなった.\n",
"それでは決定木アルゴリズムを適用してみよう.\n",
"例題1と同様,まず,用意したデータを学習用(70%)と評価用(30%)に分割する."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# データを学習用(70%)と評価用(30%)に分割する\n",
"titanic_train_df, titanic_test_df = train_test_split(\n",
" titanic_df, test_size=0.3,\n",
" random_state=1,\n",
" stratify=titanic_df.Survived)"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"変数``titanic_test_df``には生存の有無の情報も含まれているが,予測モデルの性能評価の際には,生存情報が未知であるとして予測を行い,予測結果と(隠しておいた)生存情報を照らし合わせて評価することになる([★Quiz 4](#C2-Q4))."
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"簡易的な分析を行ってみると,生存の有無を識別するために有効な指標がありそうな気もする.\n",
"しかし実際には,複数の指標が絡み合って生存の有無が決まっていると思われる.\n",
"このような状況で,指標(特徴量)同士の複雑な関係性を考慮しながら,予測のためのルールを抽出するのが**教師あり学習**である.\n",
"\n",
"早速,決定木アルゴリズムを適用してみよう.\n",
"まずは決定木を適用するデータを整形する.\n",
"データを眺めると,氏名(Name)やチケット番号(Ticket),客室番号(Cabin)は各乗船者に固有に与えられた情報であることが分かる.\n",
"これら特徴量は生存者の予測には役に立たないため,それ以外の情報を利用することにする.\n",
"\n",
"下記コードを実行して,決定木を適用する際に注目する指標を,変数``target_features``に格納しておく.\n",
"さらに,``titanic_train_df``から上記指標に関するデータのみを抽出する."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Pclass | \n",
" Sex | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Fare | \n",
" Embarked | \n",
"
\n",
" \n",
" \n",
" \n",
" 472 | \n",
" 2 | \n",
" female | \n",
" 33.0 | \n",
" 1 | \n",
" 2 | \n",
" 27.7500 | \n",
" S | \n",
"
\n",
" \n",
" 597 | \n",
" 3 | \n",
" male | \n",
" 49.0 | \n",
" 0 | \n",
" 0 | \n",
" 0.0000 | \n",
" S | \n",
"
\n",
" \n",
" 843 | \n",
" 3 | \n",
" male | \n",
" 34.5 | \n",
" 0 | \n",
" 0 | \n",
" 6.4375 | \n",
" C | \n",
"
\n",
" \n",
" 112 | \n",
" 3 | \n",
" male | \n",
" 22.0 | \n",
" 0 | \n",
" 0 | \n",
" 8.0500 | \n",
" S | \n",
"
\n",
" \n",
" 869 | \n",
" 3 | \n",
" male | \n",
" 4.0 | \n",
" 1 | \n",
" 1 | \n",
" 11.1333 | \n",
" S | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 650 | \n",
" 3 | \n",
" male | \n",
" 28.0 | \n",
" 0 | \n",
" 0 | \n",
" 7.8958 | \n",
" S | \n",
"
\n",
" \n",
" 241 | \n",
" 3 | \n",
" female | \n",
" 28.0 | \n",
" 1 | \n",
" 0 | \n",
" 15.5000 | \n",
" Q | \n",
"
\n",
" \n",
" 265 | \n",
" 2 | \n",
" male | \n",
" 36.0 | \n",
" 0 | \n",
" 0 | \n",
" 10.5000 | \n",
" S | \n",
"
\n",
" \n",
" 15 | \n",
" 2 | \n",
" female | \n",
" 55.0 | \n",
" 0 | \n",
" 0 | \n",
" 16.0000 | \n",
" S | \n",
"
\n",
" \n",
" 464 | \n",
" 3 | \n",
" male | \n",
" 28.0 | \n",
" 0 | \n",
" 0 | \n",
" 8.0500 | \n",
" S | \n",
"
\n",
" \n",
"
\n",
"
623 rows × 7 columns
\n",
"
"
],
"text/plain": [
" Pclass Sex Age SibSp Parch Fare Embarked\n",
"472 2 female 33.0 1 2 27.7500 S\n",
"597 3 male 49.0 0 0 0.0000 S\n",
"843 3 male 34.5 0 0 6.4375 C\n",
"112 3 male 22.0 0 0 8.0500 S\n",
"869 3 male 4.0 1 1 11.1333 S\n",
".. ... ... ... ... ... ... ...\n",
"650 3 male 28.0 0 0 7.8958 S\n",
"241 3 female 28.0 1 0 15.5000 Q\n",
"265 2 male 36.0 0 0 10.5000 S\n",
"15 2 female 55.0 0 0 16.0000 S\n",
"464 3 male 28.0 0 0 8.0500 S\n",
"\n",
"[623 rows x 7 columns]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 注目する指標\n",
"target_features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']\n",
"\n",
"# 以下のように書けば,target_featuresの指標のみに注目してデータを抽出できる\n",
"titanic_train_df[target_features]"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"性別(Sex)や乗船した港(Embarked)は数値情報ではなくカテゴリ情報である.\n",
"多くの機械学習は数値を受け取って処理をするので,カテゴリ情報も数値情報に変換しておいた方が都合がよい.\n",
"ここでは,「EmbarkedがSであることをEmbarked_Sが1,EmbarkedがSでないことをEmbarked_S=0」となるような変換を行う.\n",
"この変換は下記コードで行える."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OneHotEncoder(cols=['Embarked', 'Sex'], use_cat_names=True)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"encoder = category_encoders.OneHotEncoder(cols=['Embarked', 'Sex'], use_cat_names=True)\n",
"encoder.fit(titanic_train_df[target_features])"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"それでは,``titanic_train_df``に決定木アルゴリズムを適用して,生存の有無のルールを抽出(学習)してみよう.\n",
"決定木アルゴリズムは``DecisionTreeClassifier``クラスを用いて実行できる.\n",
"下記コードを実行してみよう."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=12345)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 予測に用いる生存情報以外のすべての指標をX_trainに\n",
"X_train = titanic_train_df[target_features]\n",
"\n",
"# カテゴリ変数を数値情報に変換\n",
"X_train = encoder.transform(X_train)\n",
"\n",
"# y_trainは生存有無をあらわす指標\n",
"y_train = titanic_train_df.Survived\n",
"\n",
"# 学習\n",
"model = DecisionTreeClassifier(criterion='entropy',\n",
" random_state=12345, # 初期値を固定\n",
" max_depth=3) # 木の深さを3に限定\n",
"model.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"生存の有無を予測するルールが学習された.\n",
"以下のコードを実行して,生存の有無を予測するためのルールをわかりやすく可視化してみよう."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Source(export_graphviz(model, out_file=None,\n",
" feature_names=X_train.columns,\n",
" class_names=['died', 'survived'],\n",
" proportion=True,\n",
" filled=True, rounded=True # 見た目の調整\n",
" ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"分類ルールが得られた.\n",
"\n",
"結果を解釈してみよう.\n",
"例えば,上図の上から3段目,左端にある「class=survived, entropy=0.258」という四角は,\n",
"* 性別が女性であり(Sex_male<=0.5: True),乗船クラスが1等もしくは2等クラス(Pclass<=2.5: True)の乗客は95.7%の確率で生存したこと\n",
"* その条件にマッチする乗客は,全体の18.5%存在すること\n",
"\n",
"を示している([★Quiz 5](#C2-Q5))."
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"なんとなく予測ルールは分かったが,各指標が予測にどの程度影響があるかを調べてみよう.\n",
"以下のコードを実行しよう."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pclass\t0.2907129831238367\n",
"Sex_female\t0.0\n",
"Sex_male\t0.5350835392973032\n",
"Age\t0.1165652637163436\n",
"SibSp\t0.0\n",
"Parch\t0.0\n",
"Fare\t0.05763821386251649\n",
"Embarked_S\t0.0\n",
"Embarked_C\t0.0\n",
"Embarked_Q\t0.0\n"
]
}
],
"source": [
"for feature, importance in zip(X_train.columns, model.feature_importances_):\n",
" print(\"{}\\t{}\".format(feature, importance))"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"この結果からも,**性別**や**等級**が生存に大きく影響を与えていたことがうかがえる([★Quiz 6](#C2-Q6)).\n",
"\n",
"さて,ここまでやったことは予測のためのルール(モデル)の構築であった.\n",
"構築した予測モデルを使って,未知のデータを予測してみよう.\n",
"この例題の冒頭で,変数``titanic_test_df``に**予測モデルの構築に使われていないデータ**を別途用意していたことを思い出そう."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" PassengerId | \n",
" Survived | \n",
" Pclass | \n",
" Name | \n",
" Sex | \n",
" Age | \n",
" SibSp | \n",
" Parch | \n",
" Ticket | \n",
" Fare | \n",
" Cabin | \n",
" Embarked | \n",
"
\n",
" \n",
" \n",
" \n",
" 433 | \n",
" 434 | \n",
" died | \n",
" 3 | \n",
" Kallio, Mr. Nikolai Erland | \n",
" male | \n",
" 17.0 | \n",
" 0 | \n",
" 0 | \n",
" STON/O 2. 3101274 | \n",
" 7.125 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
" 221 | \n",
" 222 | \n",
" died | \n",
" 2 | \n",
" Bracken, Mr. James H | \n",
" male | \n",
" 27.0 | \n",
" 0 | \n",
" 0 | \n",
" 220367 | \n",
" 13.000 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
" 217 | \n",
" 218 | \n",
" died | \n",
" 2 | \n",
" Jacobsohn, Mr. Sidney Samuel | \n",
" male | \n",
" 42.0 | \n",
" 1 | \n",
" 0 | \n",
" 243847 | \n",
" 27.000 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
" 376 | \n",
" 377 | \n",
" survived | \n",
" 3 | \n",
" Landergren, Miss. Aurora Adelia | \n",
" female | \n",
" 22.0 | \n",
" 0 | \n",
" 0 | \n",
" C 7077 | \n",
" 7.250 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
" 447 | \n",
" 448 | \n",
" survived | \n",
" 1 | \n",
" Seward, Mr. Frederic Kimber | \n",
" male | \n",
" 34.0 | \n",
" 0 | \n",
" 0 | \n",
" 113794 | \n",
" 26.550 | \n",
" NaN | \n",
" S | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" PassengerId Survived Pclass Name Sex \\\n",
"433 434 died 3 Kallio, Mr. Nikolai Erland male \n",
"221 222 died 2 Bracken, Mr. James H male \n",
"217 218 died 2 Jacobsohn, Mr. Sidney Samuel male \n",
"376 377 survived 3 Landergren, Miss. Aurora Adelia female \n",
"447 448 survived 1 Seward, Mr. Frederic Kimber male \n",
"\n",
" Age SibSp Parch Ticket Fare Cabin Embarked \n",
"433 17.0 0 0 STON/O 2. 3101274 7.125 NaN S \n",
"221 27.0 0 0 220367 13.000 NaN S \n",
"217 42.0 1 0 243847 27.000 NaN S \n",
"376 22.0 0 0 C 7077 7.250 NaN S \n",
"447 34.0 0 0 113794 26.550 NaN S "
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 最初の数件を表示\n",
"titanic_test_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"先ほど構築した予測モデルをこの``titanic_test_df``に適用して,生存の有無を予測してみよう.\n",
"構築した予測モデル``model``を用いて未知データを予測するには``predict``メソッドを用いる."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['died', 'died', 'died', 'survived', 'survived', 'died', 'died',\n",
" 'survived', 'survived', 'died'], dtype=object)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# X_testは,生存情報以外のすべての指標\n",
"X_test = titanic_test_df[target_features]\n",
"\n",
"# カテゴリ変数を変換して計算しやすくする\n",
"X_test = encoder.transform(X_test)\n",
"\n",
"# 予測\n",
"y_predicted = model.predict(X_test)\n",
"\n",
"# 予測結果(最初の10件)\n",
"y_predicted[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"予測結果が変数``y_predicted``に格納された.\n",
"``titanic_test_df``の列``Survived``には実際の生存情報が格納されていた.\n",
"これと予測結果と照らし合わせて,予測性能を評価してみよう.\n",
"以下のコードを実行して,予測性能の評価を行ってみよう."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.7611940298507462"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# y_testは生存の指標\n",
"y_test = titanic_test_df.Survived\n",
"\n",
"accuracy_score(y_predicted, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"色々情報が出てきたが,``Accuracy``という数値を見てほしい.\n",
"Accuracyは予測結果のうち,**実際に生存した乗客を予測モデルが「生存」と予測し,死亡した乗客を予測モデルが「死亡」と予測できたケースの割合**を意味する.\n",
"上記結果によると,Accuracyは約76.1%を示しており,そこそこの割合で生存の有無を予測できていることが分かる([★Quiz 7](#C2-Q7)).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [],
"user_expressions": []
},
"source": [
"## クイズ"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [],
"user_expressions": []
},
"source": [
"※ 以下のクイズの回答にGoogle Colaboratoryを使いたい方は[コチラ](https://colab.research.google.com/github/hontolab-courses/ml-lecturenote/blob/main/content/quiz/introduction-to-ml.ipynb)をクリック.\n",
"\n",
"以下のコードを実行して`income_df`に格納されるデータは,ある年にアメリカで実施された国勢調査のデータである."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" education | \n",
" education-num | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" sex | \n",
" capital-gain | \n",
" capital-loss | \n",
" hours-per-week | \n",
" native-country | \n",
" income | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 39 | \n",
" State-gov | \n",
" 77516 | \n",
" Bachelors | \n",
" 13 | \n",
" Never-married | \n",
" Adm-clerical | \n",
" Not-in-family | \n",
" White | \n",
" Male | \n",
" 2174 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" <=50K | \n",
"
\n",
" \n",
" 1 | \n",
" 50 | \n",
" Self-emp-not-inc | \n",
" 83311 | \n",
" Bachelors | \n",
" 13 | \n",
" Married-civ-spouse | \n",
" Exec-managerial | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 13 | \n",
" United-States | \n",
" <=50K | \n",
"
\n",
" \n",
" 2 | \n",
" 38 | \n",
" Private | \n",
" 215646 | \n",
" HS-grad | \n",
" 9 | \n",
" Divorced | \n",
" Handlers-cleaners | \n",
" Not-in-family | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" <=50K | \n",
"
\n",
" \n",
" 3 | \n",
" 53 | \n",
" Private | \n",
" 234721 | \n",
" 11th | \n",
" 7 | \n",
" Married-civ-spouse | \n",
" Handlers-cleaners | \n",
" Husband | \n",
" Black | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" <=50K | \n",
"
\n",
" \n",
" 4 | \n",
" 28 | \n",
" Private | \n",
" 338409 | \n",
" Bachelors | \n",
" 13 | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Wife | \n",
" Black | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" Cuba | \n",
" <=50K | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"0 39 State-gov 77516 Bachelors 13 \n",
"1 50 Self-emp-not-inc 83311 Bachelors 13 \n",
"2 38 Private 215646 HS-grad 9 \n",
"3 53 Private 234721 11th 7 \n",
"4 28 Private 338409 Bachelors 13 \n",
"\n",
" marital-status occupation relationship race sex \\\n",
"0 Never-married Adm-clerical Not-in-family White Male \n",
"1 Married-civ-spouse Exec-managerial Husband White Male \n",
"2 Divorced Handlers-cleaners Not-in-family White Male \n",
"3 Married-civ-spouse Handlers-cleaners Husband Black Male \n",
"4 Married-civ-spouse Prof-specialty Wife Black Female \n",
"\n",
" capital-gain capital-loss hours-per-week native-country income \n",
"0 2174 0 40 United-States <=50K \n",
"1 0 0 13 United-States <=50K \n",
"2 0 0 40 United-States <=50K \n",
"3 0 0 40 United-States <=50K \n",
"4 0 0 40 Cuba <=50K "
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# データの読み込み\n",
"income_df = pd.read_table(\"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\", sep=',', header=None)\n",
"\n",
"# 列名(特徴)に名前を付ける\n",
"income_df.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', \n",
" 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income']\n",
"\n",
"# データ表示(先頭5件)\n",
"income_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"データ中の列名(特徴量)の意味は以下の通りである:\n",
"\n",
"* age: 年齢(整数)\n",
"* workclass: 雇用形態(公務員,会社員など)\n",
"* fnlwgt: 使わない\n",
"* education: 学歴\n",
"* education-num: 使わない\n",
"* marital-status: 婚姻状態\n",
"* occupation: 職業\n",
"* relationship: 家族内における役割\n",
"* race: 人種\n",
"* sex: 性別\n",
"* capital-gain: 使わない\n",
"* capital-loss: 使わない\n",
"* hours-per-week: 週あたりの労働時間(整数値)\n",
"* native-country: 出身国\n",
"* income: 年収(50Kドル以上,50Kドル未満の二値)\n",
"\n",
"このデータに対して決定木アルゴリズムを適用して,ある人物が年間収入が50Kドル以上か未満かを分類する機械学習モデルを構築したい."
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"(C2-Q1)=\n",
"### Q1: ヒストグラム \n",
"機械学習モデルを構築する前に,`income_df`データに含まれる調査対象者の年齢の分布を知りたい.\n",
"年齢に関するヒストグラム(階級数は10)を作成せよ.\n",
"\n",
"※ ヒント: ヒストグラムの作成には`pandas.series.hist`関数を用いるとよい([参考](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.hist.html))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(C2-Q2)=\n",
"### Q2: 出現頻度\n",
"機械学習モデルを構築する前に,`income_df`データに含まれる性別,年収の分布を知りたい.\n",
"性別(男,女),年収(50K以上,50K未満)について,属性値に対応する人数を求めよ.\n",
"\n",
"※ ヒント: 要素の出現頻度を求めるには`pandas.series.value_counts`メソッドを用いるとよい([参考](https://note.nkmk.me/python-pandas-value-counts/))"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"(C2-Q3)=\n",
"### Q3: データの集約\n",
"``income_df``データを集約し,学歴ごとに年間収入クラスの内訳(割合)を調べよ.\n",
"\n",
"※ ヒント: pandasの[crosstab](https://pandas.pydata.org/docs/reference/api/pandas.crosstab.html)関数を使う(タイタニックの例でも使ったので,確認してみよう)"
]
},
{
"cell_type": "markdown",
"metadata": {
"user_expressions": []
},
"source": [
"(C2-Q4)=\n",
"### Q4: 学習のためのデータ分割\n",
"`income_df`データに決定木アルゴリズムを適用するために,データを7:3に分割し,7割のデータを学習用データ(`income_train_df`),3割のデータを評価用データ(`income_test_df`)としなさい."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(C2-Q5)=\n",
"### Q5: 決定木の構築\n",
"\n",
"以下は,「年齢」「雇用形態」「学歴」「婚姻の有無」「職業」「家族内における役割」「人種」「性別」「週あたりの労働時間」「出身国」の属性に着目して,`income_df`データから年収カテゴリを予測する決定木を構築するコードである.\n",
"`# ---------- ` の間を埋めてコードを完成させなさい."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"# 注目する属性\n",
"target_features = ['age', 'workclass', 'education', 'marital-status', 'occupation', \n",
" 'relationship', 'race', 'sex', 'hours-per-week', 'native-country']\n",
"\n",
"# 数値に変換したいカテゴリ変数\n",
"encoded_features = ['education', 'workclass', 'marital-status', 'relationship', 'occupation', 'native-country', 'race', 'sex']\n",
"\n",
"# カテゴリ変数を数値情報に変換する\n",
"encoder = category_encoders.OneHotEncoder(cols=encoded_features, use_cat_names=True)\n",
"encoder.fit(income_train_df[target_features])\n",
"\n",
"# ---------------------\n",
"# ここから必要なコードを埋める\n",
"\n",
"\n",
"# ここまで必要なコードを埋める\n",
"# ---------------------\n",
"\n",
"# 学習用データを使って学習\n",
"model.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"(C2-Q6)=\n",
"### Q6: 決定木における各属性の寄与度\n",
"構築した決定木モデル(`model`)を用いて,年収(`income`)の分類における各属性(列)の寄与度を表示しなさい.\n",
"なお,寄与度がゼロのものは表示しなくてよい."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [],
"user_expressions": []
},
"source": [
"(C2-Q7)=\n",
"### Q7: 決定木の再構築\n",
"Q6の結果をもとに年収分類に寄与する特徴量を(最大5つ)特定し,その特徴量のみを用いて再度決定木モデルを構築しなさい.\n",
"その際,あまり木が深くならないよう調整し,できる限りシンプルなモデルになるようにすること."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}