CNN实现卫星图像分类(tensorflow)

使用的数据集卫星图像有两类,airplane和lake,每个类别样本量各700张,大小为256*256,RGB三通道彩色卫星影像。搭建深度卷积神经网络,实现卫星影像二分类。
数据链接百度网盘地址,提取码: cq47

1、查看tensorflow版本

python 复制代码
import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

2、加载并显示训练数据

从文件夹中获取所有数据路径

python 复制代码
import glob
import random

all_image_path = glob.glob('./data/air_lake_dataset/*/*.jpg')  # glob相比于pathlib更简洁
random.shuffle(all_image_path)

读取并处理图像

python 复制代码
def load_and_preprocess_image(path):
    img_raw = tf.io.read_file(path)
    img_tensor = tf.image.decode_jpeg(img_raw,channels=3)
    img_tensor = tf.image.resize(img_tensor,[256,256])
    img_tensor = tf.cast(img_tensor,tf.float32)
    img_tensor = img_tensor/255
    return img_tensor

处理标签

python 复制代码
label_to_index = {'airplane':0,'lake':1}
index_to_label = dict((v,k) for k,v in label_to_index.items())
labels = [label_to_index.get(img.split('/')[3]) for img in all_image_path]

显示卫星影像

python 复制代码
import matplotlib.pyplot as plt

def plot_images_lables(all_image_path,labels,start_idx,num=5):
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    images = [load_and_preprocess_image(img_path) for img_path in all_image_path[start_idx:start_idx+5]]
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[i])
        title = 'label=' + index_to_label.get(labels[start_idx+i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

plot_images_lables(all_image_path,labels,0,5)

4、使用tf.data.Dataset制作训练/测试数据

制作 Dataset

python 复制代码
img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
img_ds = img_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))

训练集、测试集划分

python 复制代码
test_count = int(len(labels)*0.2) 
train_count = len(labels) - test_count

train_ds = img_label_ds.skip(test_count)
test_ds = img_label_ds.take(test_count)

分批次加载数据

python 复制代码
BATCH_SIZE = 16
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.repeat().batch(BATCH_SIZE)

5、CNN模型构建

python 复制代码
from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D

model = tf.keras.Sequential([
    Input(shape=(256,256,3)),
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),  # 增加filter个数,增加模型拟合能力
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),  # 默认2*2. 池化层扩大视野
    Dropout(0.2),  # 防止过拟合
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    GlobalAvgPool2D(),  # 全局平均池化
    Dense(1024,activation='relu'),
    Dense(256,activation='relu'),
    Dense(1,activation='sigmoid') 
])

model.summary()

6、模型编译与训练

python 复制代码
model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),  # 已经使用sigmoid激活过了
              metrics=['acc'])

steps_per_epoch = train_count//BATCH_SIZE
val_step = test_count//BATCH_SIZE

H = model.fit(train_ds,
             epochs=10,
             steps_per_epoch=steps_per_epoch,
             validation_data=test_ds,
             validation_steps=val_step,
             verbose=1)

7、模型评估

python 复制代码
import matplotlib.pyplot as plt

fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')

plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
plt.show()

8、模型预测

python 复制代码
def pred_img(img_path):
    img = load_and_preprocess_image(img_path)
    img = tf.expand_dims(img, axis=0)
    pred = model.predict(img)
    pred = index_to_label.get((pred>0.5).astype('int')[0][0])
    return pred
    
img_path = './data/air_lake_dataset/airplane/airplane_240.jpg'
pred = pred_img(img_path)
img_tensor = load_and_preprocess_image(img_path)
plt.imshow(img_tensor)
title = 'label=' + img_path.split('/')[3].strip() + ', pred=' + pred
plt.title(title)
plt.show()
相关推荐
沅_Yuan8 小时前
基于GRU门控循环神经网络的多分类预测【MATLAB】
matlab·分类·gru
IT古董9 小时前
【机器学习】机器学习的基本分类-强化学习-Actor-Critic 方法
人工智能·机器学习·分类
martian6659 小时前
【人工智能数学基础】——深入详解贝叶斯理论:掌握贝叶斯定理及其在分类和预测中的应用
人工智能·数学·分类·数据挖掘·贝叶斯
yusaisai大鱼1 天前
tensorflow_probability与tensorflow版本依赖关系
人工智能·python·tensorflow
18号房客1 天前
一个简单的深度学习模型例程,使用Keras(基于TensorFlow)构建一个卷积神经网络(CNN)来分类MNIST手写数字数据集。
人工智能·深度学习·机器学习·生成对抗网络·语言模型·自然语言处理·tensorflow
数据分析能量站2 天前
目标检测-R-CNN
目标检测·r语言·cnn
醒了就刷牙2 天前
transformer用作分类任务
深度学习·分类·transformer
四口鲸鱼爱吃盐2 天前
Pytorch | 从零构建ParNet/Non-Deep Networks对CIFAR10进行分类
人工智能·pytorch·分类
小陈phd2 天前
深度学习实战之超分辨率算法(tensorflow)——ESPCN
网络·深度学习·神经网络·tensorflow
IT古董2 天前
【机器学习】机器学习的基本分类-强化学习-模型预测控制(MPC:Model Predictive Control)
人工智能·机器学习·分类