SVM不同核函数及其参数对分类性能
王**
导入
[1]:
%pylab inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import SGDClassifier
from sklearn.datasets.samples_generator import make_blobs
from sklearn.preprocessing import StandardScaler
from matplotlib.colors import ListedColormap
from skimage import io,data
from sklearn.feature_selection import RFE
from sklearn.ensemble import GradientBoostingClassifier, IsolationForest
from sklearn.externals import joblib
from sklearn.model_selection import train_test_split
import numpy.ma as ma
import os, shutil
Populating the interactive namespace from numpy and matplotlib
读取数据
[2]:
from skimage.io import imread
img = imread('E:\Hyperspectral_Project\dc.tif')
roi = io.imread('E:\Hyperspectral_Project\protest.tif')
img = np.transpose(img,(1,2,0))#(1280, 307,191)
labels = np.unique(roi[roi > 0])
X =img.reshape(392960,191)
t =img.reshape(392960,191)
Y=roi[:,:,0]
np.unique(Y)
Y=Y.ravel()
print(Y.shape)
print(X.shape)
X = X[Y>0,:]
Y = Y[Y>0]
np.unique(Y)
(392960,)
(392960, 191)
[2]:
array([ 2, 3, 51, 102, 153, 204, 255], dtype=int16)
[3]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from ipywidgets import interact,interact_manual
[4]:
X_train, X_test, y_train, y_test = train_test_split(
X,
Y,
train_size=0.75,
random_state= 42,
stratify=Y)
[5]:
X_train, X_valid, y_train, y_valid = train_test_split(
X_train,
y_train,
train_size=0.66,
random_state= 0,
stratify=y_train)
[6]:
#训练模型
from sklearn.svm import SVC
clf = SVC(kernel='linear').fit(X_train, y_train)
[7]:
clf.score(X_train, y_train)
[7]:
1.0
[8]:
clf.score(X_valid, y_valid)
[8]:
0.99921104536489147
[9]:
y_model = clf.predict(X_test)
accuracy_score(y_test, y_model)
[9]:
0.99758551307847088
[10]:
import seaborn as sns
[11]:
from sklearn.metrics import confusion_matrix
mat = confusion_matrix(y_test, y_model)
sns.heatmap(mat, square=True, annot=True,fmt='d', cbar=False)
plt.xlabel('predicted value')
plt.ylabel('true value');