多尺度分割

刘**

import numpy as np
import gdal
import pandas as pd
from sklearn import svm
from sklearn import naive_bayes
import lightgbm as lgb
import ecognition_read
import optimal_scale
import os


class ImageClassify:
    '''
    该程序为面向像素的图像分类程序。
    '''

    def __init__(self, train_image_path='july_06_quac.tif', train_label_path='train.dat', rootpath=r'.\data\img\\'):
        '''
        :param train_image_path: 训练数据的path
        :param train_label_path: 训练数据label的path
        :param rootpath: 存储的根目录
        '''
        # os.chdir(rootpath)
        self.dataset_image = gdal.Open(rootpath + train_image_path, gdal.GA_ReadOnly)
        self.bands = self.dataset_image.RasterCount
        self.xsize = self.dataset_image.RasterXSize
        self.ysize = self.dataset_image.RasterYSize
        self.image = np.zeros((self.xsize, self.ysize, self.bands))
        for band in range(self.bands):
            self.image[:, :, band] = self.dataset_image.GetRasterBand(band + 1).ReadAsArray()

        dataset_train = gdal.Open(rootpath + train_label_path, gdal.GA_ReadOnly)
        self.image_label = dataset_train.GetRasterBand(1).ReadAsArray()
        self.fit(self.image, self.image_label)
        pass

    def fit(self, image=None, image_label=None, clf=None):
        '''
        :param image: xsize, ysize, bands 的图像
        :param train_label: xsize, ysize 的label
        :return:
        '''
        if image is None:
            image = self.image
        if image_label is None:
            image_label = self.image_label
        image = image.reshape(-1, self.bands)  # [sample_num, bands]
        image_label = image_label.reshape(-1)  # [sample_num]
        loc = np.where(image_label != 0)
        image = image[loc, :].reshape(-1, self.bands)  # [sample_num - unclassify]
        image_label = image_label[loc]  # [sample_num - unclassify]
        if clf is None:
            self.clf = naive_bayes.GaussianNB()
            # self.clf = naive_bayes.BernoulliNB()
            # self.clf = svm.SVC()
            # self.clf = lgb.LGBMClassifier()
        else:
            self.clf = clf
        self.clf.fit(image, image_label)
        print(self.clf.score(image, image_label))
        # predict_log_proba(X)
        # predict_proba(X)
        pass

    def predict(self, test_image):
        '''
        传入一个新的图像,进行预测
        :param test_image: 新的图像,可以理解为是其它尺度的平均光谱和特征后的图像
        :return: 分类图像,以及分类图像的后验概率
        '''
        assert (self.xsize, self.ysize, self.bands) == test_image.shape  # 保证训练数据和测试数据的特征数目、图像大小是一样的
        test_image = test_image.reshape(-1, self.bands)
        classify_image = self.clf.predict(test_image)
        proba_image = self.clf.predict_proba(test_image)
        classify_image = classify_image.reshape(self.xsize, self.ysize)
        proba_image_shape = proba_image.shape
        proba_image = proba_image.reshape(self.xsize, self.ysize, proba_image_shape[-1])
        return classify_image, proba_image

    def write_array(self, data, path):
        #  write_array(gl30_array_mask, r'E:\Liulicong\GUD scale\data\clip\gl30_mask.tif', dataset_gl30)
        if data.dtype == 'uint8' or data.dtype == 'bool_':
            gdal_type = 1
            data = data.astype('uint8')
        else:
            gdal_type = 6
            data = data.astype('float32')
        if len(data.shape) == 2:
            band_num = 1
        else:
            band_num = data.shape[2]
        out_ds = gdal.GetDriverByName('GTiff').Create(
            path, data.shape[1], data.shape[0], band_num, gdal_type)
        out_ds.SetProjection(self.dataset_image.GetProjection())
        out_ds.SetGeoTransform(self.dataset_image.GetGeoTransform())

        print(gdal_type)
        print(data.dtype)
        if band_num == 1:
            out_band = out_ds.GetRasterBand(1)
            out_band.WriteArray(data)
            out_band.FlushCache()
            # out_ds.BuildOverviews('average', [2, 4, 8, 16, 32])
            del out_ds
        else:
            for i in range(band_num):
                out_band = out_ds.GetRasterBand(i + 1)
                out_band.WriteArray(data[:, :, i])
                out_band.FlushCache()
            # out_ds.BuildOverviews('average', [2, 4, 8, 16, 32])
            del out_ds
        return True

    def get_object_images(self, ecord: ecognition_read.ecoRead):
        id_images = ecord.get_id_images()  # (xsize, ysize, scales)
        feature_dfs = ecord.get_feature_csv()
        self.scales = id_images.shape[-1]
        object_image = np.zeros((self.xsize, self.ysize, self.bands))  # 使用迭代器进行返回,以节约内存
        for scale_i in range(self.scales):
            id_image = id_images[:, :, scale_i]
            feature_df = feature_dfs[scale_i]
            object_num = np.max(id_image)
            assert object_num <= 65536
            id_image_pd = pd.DataFrame({'id': id_image.reshape(-1)})  # 为一列的pd,该列为id
            id_image_pd = pd.merge(id_image_pd, feature_df, on='id', how='left')  # 使用pandas的链接操作可以极大的加快速度
            object_image = id_image_pd.iloc[:, 2:8].values.reshape(self.xsize, self.ysize, self.bands)
            # for i in range(object_num + 1):   循环30000次在python里面太慢了,这里考虑使用链接操作
            #     loc = np.where(id_image == i)
            #     object_image[loc, :] = feature_df.iloc[i, 1:7]
            print(object_image[200, 200, :])
            yield object_image, ecord.names[scale_i]


if __name__ == "__main__":
    image_classify = ImageClassify()
    #  image_classify.fit() 该类会自动调用fit算法,除非传入新的训练数据
    ecord = ecognition_read.ecoRead()
    # image, image_proba = image_classify.predict(image_classify.image)
    # image_classify.write_array(image, 'pixel_classify.tif')
    os = optimal_scale.OptimalScale()
    best_classify_map = os.get_class_map()
    image_classify.write_array(best_classify_map, 'best_classify.tif')

    # object_image_iteration = image_classify.get_object_images(ecord)
    # for object_image, sclae_name in object_image_iteration:
    #     image, image_proba = image_classify.predict(object_image)
    #     image_classify.write_array(image, 'level{}_classify.tif'.format(sclae_name))
    #     image_classify.write_array(image_proba, 'level{}_proba.tif'.format(sclae_name))
[ ]: