{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 随机森林可视化 " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "张**" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Populating the interactive namespace from numpy and matplotlib\n" ] } ], "source": [ "%pylab inline\n", "%config InlineBackend.figure_format = 'svg'\n", "%matplotlib inline\n", "import warnings\n", "warnings.simplefilter(\"ignore\")\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "from skimage import io\n", "import scipy.io as sio\n", "import matplotlib.pyplot as plt\n", "from sklearn.decomposition import PCA\n", "from sklearn.preprocessing import MinMaxScaler\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "mms = MinMaxScaler();\n", "#原始数据转为BIP格式后,读取影像数据,并reshape为(nsamples,nfeatures)\n", "hiperSpectral = io.imread('dcBIP.tif');\n", "hang = hiperSpectral.shape[0]; \n", "lie = hiperSpectral.shape[1]; NumberBands = hiperSpectral.shape[2]; \n", "AllXdata = hiperSpectral.reshape(hang*lie,NumberBands);\n", "AllxNormalization = mms.fit_transform( AllXdata );\n", "#样本矩阵的读取\n", "Ylabel = sio.loadmat('allSampe.mat')['S'];\n", "YLabelOnedime = Ylabel.reshape(hang*lie);\n", "\n", "index=np.where(YLabelOnedime > 0)[0];\n", "Ylabelsample = YLabelOnedime[index];\n", "\n", "Xdatasamples = AllxNormalization[index,:];" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "69317790a7af40f6bf8e39ca4c0c519f", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type interactive.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "interactive(children=(IntSlider(value=5, description='numTree', max=20, min=1), FloatSlider(value=0.1, description='testRatio', max=0.6, min=0.1, step=0.05), Dropdown(description='Nodesplit', options=('gini', 'entropy'), value='gini'), IntSlider(value=1, description='maxDepth', max=20, min=1), Output()), _dom_classes=('widget-interact',))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "#@interact_manual\n", "def plot_RFtree_interactive(numberTree,testsize,Nodesplit,maxDepth):\n", " X_train,X_test,y_train,y_test = train_test_split (Xdatasamples,Ylabelsample,test_size = testsize,random_state=2,stratify = Ylabelsample);\n", " clfRF = RandomForestClassifier(n_estimators = numberTree, criterion = Nodesplit,max_depth = maxDepth); \n", " \n", " tempModel = clfRF.fit(X_train, y_train);\n", " featureImport = tempModel.feature_importances_;\n", " plt.figure(figsize=(12,6), dpi=100);\n", " plt.subplot(211); plt.grid();\n", " plt.xlabel(\"Index of features\"); afterSorted = sorted( list(featureImport), reverse=True );\n", " plt.ylabel(\"feature_importances\"); TopValue = afterSorted[0:10]; TopIndex = [ list(featureImport).index(x)+1 for x in TopValue ]; \n", " plt.scatter(range(1,NumberBands+1), featureImport,s=3,c='r');\n", " plt.show();\n", " y_model = tempModel.predict(X_test);\n", " testscore = accuracy_score(y_test, y_model); \n", " print('训练集的样本数:',len(y_train),' ','测试集的样本数:',len(y_test),' ','测试精度:',testscore); \n", " print('featureImportance的TOP10:',TopValue ); print('featureImportance的TOP10对应的波段:',TopIndex);\n", " \n", "interact(plot_RFtree_interactive,numberTree = widgets.IntSlider(min=1,max=20,step=1,value=5,description='numTree'),\n", " testsize= widgets.FloatSlider(min=0.1,max=0.6,step=0.05,value=0.1,description='testRatio'),Nodesplit=['gini','entropy'],\n", " maxDepth = widgets.IntSlider(min=1,max=20,step=1,value=1,description='maxDepth') );\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9cce01d95cfe40e5a35fcb9cb5b88ddf", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type interactive.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "interactive(children=(IntSlider(value=5, description='numTree', max=20, min=1), FloatSlider(value=0.1, description='testRatio', max=0.6, min=0.1, step=0.05), Dropdown(description='Nodesplit', options=('gini', 'entropy'), value='gini'), IntSlider(value=5, description='maxDepth', max=20, min=1), IntSlider(value=5, description='nuberCON', max=30, min=1), Output()), _dom_classes=('widget-interact',))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#@interact_manual\n", "def plot_RFtree_interactivewithPCA(numberTree,testsize,Nodesplit,maxDepth,nuberCON):\n", " pcaUse = PCA(n_components = nuberCON).fit(AllxNormalization);\n", " XafterTransform = pcaUse.transform(AllxNormalization);\n", " XdatasamplesPCA = XafterTransform[index,:];\n", " \n", " X_train,X_test,y_train,y_test = train_test_split (XdatasamplesPCA,Ylabelsample,\n", " test_size = testsize,random_state=2,stratify = Ylabelsample);\n", " clfRF = RandomForestClassifier(n_estimators = numberTree, criterion = Nodesplit,max_depth = maxDepth); \n", " \n", " tempModel = clfRF.fit(X_train, y_train);\n", " featureImport = tempModel.feature_importances_;\n", " plt.figure(figsize=(12,6), dpi=100);\n", " plt.subplot(211); plt.grid();\n", " plt.xlabel(\"Index of features\"); afterSorted = sorted( list(featureImport), reverse=True );\n", " plt.ylabel(\"feature_importances\"); TopValue = afterSorted[0:10]; TopIndex = [ list(featureImport).index(x)+1 for x in TopValue ]; \n", " plt.scatter(range(1,nuberCON+1), featureImport,s=3,c='r');\n", " plt.show();\n", " y_model = tempModel.predict(X_test);\n", " testscore = accuracy_score(y_test, y_model); \n", " print('主成分个数:',nuberCON,' 训练集的样本数:',len(y_train),' ','测试集的样本数:',len(y_test),' ','测试精度:',testscore); \n", " print('featureImportance的TOP10:',TopValue ); print('featureImportance的TOP10对应的波段:',TopIndex);\n", " \n", "interact(plot_RFtree_interactivewithPCA,numberTree = widgets.IntSlider(min=1,max=20,step=1,value=5,description='numTree'),\n", " testsize= widgets.FloatSlider(min=0.1,max=0.6,step=0.05,value=0.1,description='testRatio'),Nodesplit=['gini','entropy'],\n", " maxDepth = widgets.IntSlider(min=1,max=20,step=1,value=5,description='maxDepth'), \n", " nuberCON = widgets.IntSlider(min=1,max=30,step=1,value=5,description='nuberCON') );" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }