{ "cells": [ { "cell_type": "markdown", "id": "bb568697", "metadata": {}, "source": [ "# RuleFit\n", "`Friedman, Jerome H and Bogdan E Popescu. “Predictive learning via rule ensembles.” The Annals of Applied Statistics. JSTOR, 916–54. (2008).`([pdf](https://jerryfriedman.su.domains/ftp/RuleFit.pdf))\n", "\n", "## 実験用のデータを取得する\n", "openmlで公開されている [house_sales\n", "](https://www.openml.org/d/42092) データセットを使用して回帰モデルを作成します。\n", "\n", "※上記openmlページではデータの出典が不明ですが自分が調べた限りではデータの提供元は[ここ](https://gis-kingcounty.opendata.arcgis.com/datasets/zipcodes-for-king-county-and-surrounding-area-shorelines-zipcode-shore-area/explore?location=47.482924%2C-121.477600%2C8.00&showTable=true)のようです。\n", "\n", "```{hint}\n", "[sklearn.datasets.fetch_openml](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_openml.html)\n", "```\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "fb38444d", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# Google Colaboratory で実行する場合はインストールする\n", "if str(get_ipython()).startswith(\"\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
bedroomsbathroomssqft_livingsqft_lotfloorswaterfrontviewconditiongradesqft_abovesqft_basementyr_builtyr_renovatedlatlongsqft_living15sqft_lot15
03.01.001180.05650.01.00.00.03.07.01180.00.01955.00.047.5112-122.2571340.05650.0
13.02.252570.07242.02.00.00.03.07.02170.0400.01951.01991.047.7210-122.3191690.07639.0
22.01.00770.010000.01.00.00.03.06.0770.00.01933.00.047.7379-122.2332720.08062.0
34.03.001960.05000.01.00.00.05.07.01050.0910.01965.00.047.5208-122.3931360.05000.0
43.02.001680.08080.01.00.00.03.08.01680.00.01987.00.047.6168-122.0451800.07503.0
54.04.505420.0101930.01.00.00.03.011.03890.01530.02001.00.047.6561-122.0054760.0101930.0
63.02.251715.06819.02.00.00.03.07.01715.00.01995.00.047.3097-122.3272238.06819.0
73.01.501060.09711.01.00.00.03.07.01060.00.01963.00.047.4095-122.3151650.09711.0
83.01.001780.07470.01.00.00.03.07.01050.0730.01960.00.047.5123-122.3371780.08113.0
93.02.501890.06560.02.00.00.03.07.01890.00.02003.00.047.3684-122.0312390.07570.0
\n", "" ], "text/plain": [ " bedrooms bathrooms sqft_living sqft_lot floors waterfront view \\\n", "0 3.0 1.00 1180.0 5650.0 1.0 0.0 0.0 \n", "1 3.0 2.25 2570.0 7242.0 2.0 0.0 0.0 \n", "2 2.0 1.00 770.0 10000.0 1.0 0.0 0.0 \n", "3 4.0 3.00 1960.0 5000.0 1.0 0.0 0.0 \n", "4 3.0 2.00 1680.0 8080.0 1.0 0.0 0.0 \n", "5 4.0 4.50 5420.0 101930.0 1.0 0.0 0.0 \n", "6 3.0 2.25 1715.0 6819.0 2.0 0.0 0.0 \n", "7 3.0 1.50 1060.0 9711.0 1.0 0.0 0.0 \n", "8 3.0 1.00 1780.0 7470.0 1.0 0.0 0.0 \n", "9 3.0 2.50 1890.0 6560.0 2.0 0.0 0.0 \n", "\n", " condition grade sqft_above sqft_basement yr_built yr_renovated \\\n", "0 3.0 7.0 1180.0 0.0 1955.0 0.0 \n", "1 3.0 7.0 2170.0 400.0 1951.0 1991.0 \n", "2 3.0 6.0 770.0 0.0 1933.0 0.0 \n", "3 5.0 7.0 1050.0 910.0 1965.0 0.0 \n", "4 3.0 8.0 1680.0 0.0 1987.0 0.0 \n", "5 3.0 11.0 3890.0 1530.0 2001.0 0.0 \n", "6 3.0 7.0 1715.0 0.0 1995.0 0.0 \n", "7 3.0 7.0 1060.0 0.0 1963.0 0.0 \n", "8 3.0 7.0 1050.0 730.0 1960.0 0.0 \n", "9 3.0 7.0 1890.0 0.0 2003.0 0.0 \n", "\n", " lat long sqft_living15 sqft_lot15 \n", "0 47.5112 -122.257 1340.0 5650.0 \n", "1 47.7210 -122.319 1690.0 7639.0 \n", "2 47.7379 -122.233 2720.0 8062.0 \n", "3 47.5208 -122.393 1360.0 5000.0 \n", "4 47.6168 -122.045 1800.0 7503.0 \n", "5 47.6561 -122.005 4760.0 101930.0 \n", "6 47.3097 -122.327 2238.0 6819.0 \n", "7 47.4095 -122.315 1650.0 9711.0 \n", "8 47.5123 -122.337 1780.0 8113.0 \n", "9 47.3684 -122.031 2390.0 7570.0 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X.head(10)" ] }, { "cell_type": "markdown", "id": "dd470966", "metadata": {}, "source": [ "## RuleFitを実行する\n", "[Python implementation of the rulefit algorithm - GitHub](https://github.com/christophM/rulefit)の実装を使用してRuleFitを動かしてみます。\n", "\n", "※実行する際は `import warnings;warnings.simplefilter('ignore')` は外してください" ] }, { "cell_type": "code", "execution_count": 5, "id": "55e74f29", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RuleFit(max_rules=100,\n", " tree_generator=GradientBoostingRegressor(learning_rate=0.01,\n", " max_depth=100,\n", " max_leaf_nodes=5,\n", " n_estimators=28,\n", " random_state=27,\n", " subsample=0.04543939429397564))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from rulefit import RuleFit\n", "import warnings\n", "\n", "warnings.simplefilter(\"ignore\") # ConvergenceWarning\n", "\n", "rf = RuleFit(max_rules=100)\n", "rf.fit(X.values, y, feature_names=list(X.columns))" ] }, { "cell_type": "markdown", "id": "5703e9bd", "metadata": {}, "source": [ "## 作成されたルールを確認する" ] }, { "cell_type": "code", "execution_count": 6, "id": "b5f587f7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ruletypecoefsupportimportance
8gradelinear6.199314e+041.00000066184.725645
29sqft_living > 9475.0rule1.927942e+060.00101861491.753935
43grade > 8.5 & sqft_living > 3405.0 & long <= -...rule3.570118e+050.02444055126.384264
2sqft_livinglinear5.532732e+011.00000046347.924165
11yr_builtlinear-1.522004e+031.00000044393.726859
15sqft_living15linear5.344916e+011.00000034501.058499
62lat <= 47.516000747680664 & sqft_living <= 3920.0rule-6.549757e+040.36150731467.457947
103sqft_basement <= 3660.0 & grade > 9.5rule1.240216e+050.06822831270.434139
48sqft_living <= 9475.0 & grade > 9.5 & long > -...rule-1.473030e+050.04073329117.596559
67sqft_living <= 4695.0 & waterfront > 0.5 & sqf...rule3.936285e+050.00509228016.079499
\n", "
" ], "text/plain": [ " rule type coef \\\n", "8 grade linear 6.199314e+04 \n", "29 sqft_living > 9475.0 rule 1.927942e+06 \n", "43 grade > 8.5 & sqft_living > 3405.0 & long <= -... rule 3.570118e+05 \n", "2 sqft_living linear 5.532732e+01 \n", "11 yr_built linear -1.522004e+03 \n", "15 sqft_living15 linear 5.344916e+01 \n", "62 lat <= 47.516000747680664 & sqft_living <= 3920.0 rule -6.549757e+04 \n", "103 sqft_basement <= 3660.0 & grade > 9.5 rule 1.240216e+05 \n", "48 sqft_living <= 9475.0 & grade > 9.5 & long > -... rule -1.473030e+05 \n", "67 sqft_living <= 4695.0 & waterfront > 0.5 & sqf... rule 3.936285e+05 \n", "\n", " support importance \n", "8 1.000000 66184.725645 \n", "29 0.001018 61491.753935 \n", "43 0.024440 55126.384264 \n", "2 1.000000 46347.924165 \n", "11 1.000000 44393.726859 \n", "15 1.000000 34501.058499 \n", "62 0.361507 31467.457947 \n", "103 0.068228 31270.434139 \n", "48 0.040733 29117.596559 \n", "67 0.005092 28016.079499 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rules = rf.get_rules()\n", "rules = rules[rules.coef != 0].sort_values(by=\"importance\", ascending=False)\n", "rules.head(10)" ] }, { "cell_type": "markdown", "id": "1224f87a", "metadata": {}, "source": [ "## ルールが正しいか確認してみる" ] }, { "cell_type": "markdown", "id": "c7d600a6", "metadata": {}, "source": [ "`sqft_above\tlinear\t8.632149e+01\t1.000000\t66243.550192` のルールに基づいて、`sqft_above` が増加すると y(`price`)が増える傾向にあるかどうか確認します。\n", "\n", "```{hint}\n", "[matplotlib.pyplot.boxplot — Matplotlib 3.5.1 documentation](https://matplotlib.org/3.5.1/api/_as_gen/matplotlib.pyplot.boxplot.html)\n", "```" ] }, { "cell_type": "code", "execution_count": 7, "id": "3a4b3b9a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'price')" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(6, 6))\n", "plt.scatter(X[\"sqft_above\"], y, marker=\".\")\n", "plt.xlabel(\"sqft_above\")\n", "plt.ylabel(\"price\")" ] }, { "cell_type": "markdown", "id": "aa296e6c", "metadata": {}, "source": [ "`sqft_living <= 3935.0 & lat <= 47.5314998626709\trule\t-8.271074e+04\t0.377800\t40101.257833` のルールに該当するデータのみ抽出してみます。\n", "係数がマイナスになっているので、このルールに該当するデータのy(`price`)は低い傾向にあるはずです。\n", "log(y)を箱髭図で確認すると、確かにルールに該当しているデータのyはルールに該当しないデータのyと比較すると低くなっています。" ] }, { "cell_type": "code", "execution_count": 8, "id": "39513993", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "applicable_data = np.log(\n", " y[X.query(\"sqft_living <= 3935.0 & lat <= 47.5314998626709\").index]\n", ")\n", "not_applicable_data = np.log(\n", " y[X.query(\"not(sqft_living <= 3935.0 & lat <= 47.5314998626709)\").index]\n", ")\n", "\n", "plt.figure(figsize=(10, 6))\n", "plt.boxplot([applicable_data, not_applicable_data], labels=[\"ルールに該当\", \"ルールに該当しない\"])\n", "plt.ylabel(\"log(price)\")\n", "plt.grid()\n", "plt.show()" ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 5 }