CNN实现卫星图像分类(tensorflow)

使用的数据集卫星图像有两类,airplane和lake,每个类别样本量各700张,大小为256*256,RGB三通道彩色卫星影像。搭建深度卷积神经网络,实现卫星影像二分类。
数据链接百度网盘地址,提取码: cq47

1、查看tensorflow版本

python 复制代码
import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

2、加载并显示训练数据

从文件夹中获取所有数据路径

python 复制代码
import glob
import random

all_image_path = glob.glob('./data/air_lake_dataset/*/*.jpg')  # glob相比于pathlib更简洁
random.shuffle(all_image_path)

读取并处理图像

python 复制代码
def load_and_preprocess_image(path):
    img_raw = tf.io.read_file(path)
    img_tensor = tf.image.decode_jpeg(img_raw,channels=3)
    img_tensor = tf.image.resize(img_tensor,[256,256])
    img_tensor = tf.cast(img_tensor,tf.float32)
    img_tensor = img_tensor/255
    return img_tensor

处理标签

python 复制代码
label_to_index = {'airplane':0,'lake':1}
index_to_label = dict((v,k) for k,v in label_to_index.items())
labels = [label_to_index.get(img.split('/')[3]) for img in all_image_path]

显示卫星影像

python 复制代码
import matplotlib.pyplot as plt

def plot_images_lables(all_image_path,labels,start_idx,num=5):
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    images = [load_and_preprocess_image(img_path) for img_path in all_image_path[start_idx:start_idx+5]]
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[i])
        title = 'label=' + index_to_label.get(labels[start_idx+i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

plot_images_lables(all_image_path,labels,0,5)

4、使用tf.data.Dataset制作训练/测试数据

制作 Dataset

python 复制代码
img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
img_ds = img_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))

训练集、测试集划分

python 复制代码
test_count = int(len(labels)*0.2) 
train_count = len(labels) - test_count

train_ds = img_label_ds.skip(test_count)
test_ds = img_label_ds.take(test_count)

分批次加载数据

python 复制代码
BATCH_SIZE = 16
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.repeat().batch(BATCH_SIZE)

5、CNN模型构建

python 复制代码
from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D

model = tf.keras.Sequential([
    Input(shape=(256,256,3)),
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),  # 增加filter个数,增加模型拟合能力
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),  # 默认2*2. 池化层扩大视野
    Dropout(0.2),  # 防止过拟合
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    GlobalAvgPool2D(),  # 全局平均池化
    Dense(1024,activation='relu'),
    Dense(256,activation='relu'),
    Dense(1,activation='sigmoid') 
])

model.summary()

6、模型编译与训练

python 复制代码
model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),  # 已经使用sigmoid激活过了
              metrics=['acc'])

steps_per_epoch = train_count//BATCH_SIZE
val_step = test_count//BATCH_SIZE

H = model.fit(train_ds,
             epochs=10,
             steps_per_epoch=steps_per_epoch,
             validation_data=test_ds,
             validation_steps=val_step,
             verbose=1)

7、模型评估

python 复制代码
import matplotlib.pyplot as plt

fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')

plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
plt.show()

8、模型预测

python 复制代码
def pred_img(img_path):
    img = load_and_preprocess_image(img_path)
    img = tf.expand_dims(img, axis=0)
    pred = model.predict(img)
    pred = index_to_label.get((pred>0.5).astype('int')[0][0])
    return pred
    
img_path = './data/air_lake_dataset/airplane/airplane_240.jpg'
pred = pred_img(img_path)
img_tensor = load_and_preprocess_image(img_path)
plt.imshow(img_tensor)
title = 'label=' + img_path.split('/')[3].strip() + ', pred=' + pred
plt.title(title)
plt.show()
相关推荐
jerwey2 小时前
大语言模型(LLM)按架构分类
人工智能·语言模型·分类
微学AI2 小时前
遥感影像岩性分类:基于CNN与CNN-EL集成学习的深度学习方法
深度学习·分类·cnn
Green1Leaves10 小时前
pytorch学习-11卷积神经网络(高级篇)
pytorch·学习·cnn
拓端研究室12 小时前
视频讲解|核密度估计朴素贝叶斯:业务数据分类—从理论到实践
人工智能·分类·数据挖掘
彭祥.20 小时前
Jetson边缘计算主板:Ubuntu 环境配置 CUDA 与 cudNN 推理环境 + OpenCV 与 C++ 进行目标分类
c++·opencv·分类
生态遥感监测笔记20 小时前
GEE利用已有土地利用数据选取样本点并进行分类
人工智能·算法·机器学习·分类·数据挖掘
遇雪长安1 天前
差分定位技术:原理、分类与应用场景
算法·分类·数据挖掘·rtk·差分定位
是Dream呀1 天前
基于连接感知的实时困倦分类图神经网络
神经网络·分类·数据挖掘
Blossom.1181 天前
机器学习在智能制造业中的应用:质量检测与设备故障预测
人工智能·深度学习·神经网络·机器学习·机器人·tensorflow·sklearn