深度学习之超分辨率算法——SRCNN

  • 网络为基础卷积层

  • tensorflow 1.14

  • scipy 1.2.1

  • numpy 1.16

  • 大概意思就是针对数据,我们先把图片按缩小因子照整数倍进行缩减为小图片,再针对小图片进行插值算法,获得还原后的低分辨率的图片作为标签。

  • main.py 配置文件

python 复制代码
from model import SRCNN
from utils import input_setup
import numpy as np
import tensorflow as tf
import pprint
import os

flags = tf.app.flags
# 设置轮次
flags.DEFINE_integer("epoch", 1000, "Number of epoch [1000]")
# 设置批次
flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
# 设置image大小
flags.DEFINE_integer("image_size", 33, "The size of image to use [33]")
# 设置label
flags.DEFINE_integer("label_size", 21, "The size of label to produce [21]")
# 学习率
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of gradient descent algorithm [1e-4]")
# 图像颜色的尺寸
flags.DEFINE_integer("c_dim", 1, "Dimension of image color. [1]")
# 对输入图像进行预处理的比例因子大小
flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]")
# 步长
flags.DEFINE_integer("stride", 14, "The size of stride to apply input image [14]")
# 权重位置
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
# 样本目录
flags.DEFINE_string("sample_dir", "sample", "Name of sample directory [sample]")
# 训练还是测试
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [True]")
FLAGS = flags.FLAGS

# 格式化打印
pp = pprint.PrettyPrinter()

def main(_):
    #   打印参数
    pp.pprint(flags.FLAGS.__flags)

    # 没有就新建~
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    # Session提供了Operation执行和Tensor求值的环境;
    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        srcnn.train(FLAGS)
    
if __name__ == '__main__':
  tf.app.run()
python 复制代码
from utils import (
    read_data,
    input_setup,
    imsave,
    merge
)
import time
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

try:
    xrange
except:
    xrange = range


class SRCNN(object):
    # 模型初始化
    def __init__(self,
                 sess,
                 image_size=33,
                 label_size=21,
                 batch_size=128,
                 c_dim=1,
                 checkpoint_dir=None,
                 sample_dir=None):

        self.sess = sess
        # 判断灰度图
        self.is_grayscale = (c_dim == 1)

        self.image_size = image_size
        self.label_size = label_size
        self.batch_size = batch_size

        self.c_dim = c_dim

        self.checkpoint_dir = checkpoint_dir
        self.sample_dir = sample_dir

        self.build_model()

    def build_model(self):
        # tf.placeholder(
        # dtype,
        # shape = None,
        # name = None
        # )
        # 定义image,labels 输入形式 N W H C
        self.images = tf.placeholder(dtype=tf.float32, shape=[None, self.image_size, self.image_size, self.c_dim], name='images')
        self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
        # tf.Variable(initializer, name), 参数initializer是初始化参数,name是可自定义的变量名称,
        # shape为[filter_height, filter_width, in_channel, out_channels]
        # 构建模型参数
        self.weights = {
            'w1': tf.Variable(initial_value=tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
            'w2': tf.Variable(initial_value=tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
            'w3': tf.Variable(initial_value=tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
        }
        # the dim of bias== c_dim
        self.biases = {
            'b1': tf.Variable(tf.zeros([64]), name='b1'),
            'b2': tf.Variable(tf.zeros([32]), name='b2'),
            'b3': tf.Variable(tf.zeros([1]), name='b3')
        }
        # 构建模型 返回MHWC
        self.pred = self.model()

        # Loss function (MSE)
        self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
        # 保存和加载模型
        # 如果只想保留最新的4个模型,并希望每2个小时保存一次,
        self.saver = tf.train.Saver(max_to_keep=4,keep_checkpoint_every_n_hours=2)

    def train(self, config):

        if config.is_train:
            # 训练状态
            input_setup(self.sess, config)
        else:
            nx, ny = input_setup(self.sess, config)


        if config.is_train:

            data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")

        else:

            data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")


        train_data, train_label = read_data(data_dir)

        # Stochastic gradient descent with the standard backpropagation
        self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)

        tf.initialize_all_variables().run()

        counter = 0
        start_time = time.time()

        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        if config.is_train:
            print("Training...")

            for ep in xrange(config.epoch):
                # Run by batch images
                batch_idxs = len(train_data) // config.batch_size
                for idx in xrange(0, batch_idxs):
                    batch_images = train_data[idx * config.batch_size: (idx + 1) * config.batch_size]
                    batch_labels = train_label[idx * config.batch_size: (idx + 1) * config.batch_size]

                    counter += 1
                    _, err = self.sess.run([self.train_op, self.loss],
                                           feed_dict={self.images: batch_images, self.labels: batch_labels})

                    if counter % 10 == 0:
                        print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
                              % ((ep + 1), counter, time.time() - start_time, err))

                    if counter % 500 == 0:
                        self.save(config.checkpoint_dir, counter)

        else:
            print("Testing...")
            # print(train_data.shape)
            # print(train_label.shape)
            # print("---------")
            result = self.pred.eval({self.images: train_data, self.labels: train_label})
            # print(result.shape)
            result = merge(result, [nx, ny])
            result = result.squeeze()
            image_path = os.path.join(os.getcwd(), config.sample_dir)
            image_path = os.path.join(image_path, "test_image.png")
            imsave(result, image_path)

    def model(self):
        # input : 输入的要做卷积的图片,要求为一个张量,shape为 [ batch, in_height, in_width, in_channel ],其中batch为图片的数量,in_height 为图片高度,in_width 为图片宽度,in_channel 为图片的通道数,灰度图该值为1,彩色图为3。(也可以用其它值,但是具体含义不是很理解)
        # filter: 卷积核,要求也是一个张量,shape为 [ filter_height, filter_width, in_channel, out_channels ],其中 filter_height 为卷积核高度,filter_width 为卷积核宽度,in_channel 是图像通道数 ,和 input 的 in_channel 要保持一致,out_channel 是卷积核数量。
        # strides: 卷积时在图像每一维的步长,这是一个一维的向量,[ 1, strides, strides, 1],第一位和最后一位固定必须是1
        # padding: string类型,值为"SAME" 和 "VALID",表示的是卷积的形式,是否考虑边界。"SAME"是考虑边界,不足的时候用0去填充周围,"VALID"则不考虑
        # use_cudnn_on_gpu:  bool类型,是否使用cudnn加速,默认为true
        # padding = "SAME"输入和输出大小关系如下:输出大小等于输入大小除以步长向上取整,s是步长大小;
        # padding = "VALID"输入和输出大小关系如下:输出大小等于输入大小减去滤波器大小加上1,最后再除以步长(f为滤波器的大小,s是步长大小)。

        conv1 = tf.nn.relu(
            tf.nn.conv2d(self.images, self.weights['w1'], strides=[1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=True) + self.biases['b1'])
        conv2 = tf.nn.relu(
            tf.nn.conv2d(conv1, self.weights['w2'], strides=[1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=True) + self.biases['b2'])
        conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=True) + self.biases['b3']

        return conv3

    def save(self, checkpoint_dir, step):
        model_name = "SRCNN.model"
        model_dir = "%s_%s" % ("srcnn", self.label_size)
        # 目录
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        # 不存在就新建
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        # 保存
        # 参数
        '''
        sess,
        save_path,
        global_step=None,
        latest_filename=None,
        meta_graph_suffix="meta",
        write_meta_graph=True,
        write_state=True,
        strip_default_attrs=False,
        save_debug_info=False)
        '''
        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, model_name),
                        global_step=step)

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoints...")
        model_dir = "%s_%s" % ("srcnn", self.label_size)
        # 加载模型
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        # 通过checkpoint文件找到模型文件名
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)

        if ckpt and ckpt.model_checkpoint_path:
            # 返回path最后的文件名。如果path以/或\结尾,那么就会返回空值。即os.path.split(path)的第二个元素。
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            # 加载成功
            return True
        else:
            # 加载失败
            return False
python 复制代码
"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
"""

import os
import glob
import h5py
import random
import matplotlib.pyplot as plt

from PIL import Image  # for loading images as YCbCr format
import scipy.misc
import scipy.ndimage
import numpy as np

import tensorflow as tf

try:
    xrange
except:
    xrange = range

FLAGS = tf.app.flags.FLAGS


def read_data(path):
    """
    Read h5 format data file

    Args:
      path: file path of desired file
      data: '.h5' file format that contains train data values
      label: '.h5' file format that contains train label values
    """
    with h5py.File(path, 'r') as hf:
        data = np.array(hf.get('data'))
        label = np.array(hf.get('label'))
        return data, label


def preprocess(path, scale=3):
    """
    Preprocess single image file
      (1) Read original image as YCbCr format (and grayscale as default)
      (2) Normalize
      (3) Apply image file with bicubic interpolation

    Args:
      path: file path of desired file
      input_: image applied bicubic interpolation (low-resolution)
      label_: image with original resolution (high-resolution)
    """
    # 读取灰度图
    image = imread(path, is_grayscale=True)
    label_ = modcrop(image, scale)

    # Must be normalized
    # 归一化
    image = image / 255.
    label_ = label_ / 255.
    # zoom:类型为float或sequence,沿轴的缩放系数。 如果float,每个轴的缩放是相同的。 如果sequence,zoom应包含每个轴的一个值。
    # output:放置输出的数组,或返回数组的dtype
    # order:样条插值的顺序,默认为3.顺序必须在0-5范围内。
    # prefilter: bool, optional 。参数预滤波器确定输入是否在插值之前使用spline_filter进行预过滤(对于 > 1
    # 的样条插值所必需的)。 如果为False,则假定输入已被过滤。 默认为True。
    input_ = scipy.ndimage.interpolation.zoom(input=label_,zoom=(1. / scale), prefilter=False)
    input_ = scipy.ndimage.interpolation.zoom(input=input_,zoom=(scale / 1.), prefilter=False)

    return input_, label_


def prepare_data(sess, dataset):
    """
    Args:
      dataset: choose train dataset or test dataset
      For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
    dataset:
        "Train" or "Test":to choose the data is train or test
    """
    if FLAGS.is_train:
        filenames = os.listdir(dataset)
        #  获取数据目录
        data_dir = os.path.join(os.getcwd(), dataset)
        data = glob.glob(os.path.join(data_dir, "*.bmp"))
    else:
        # 获取测试集路径
        data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
        data = glob.glob(os.path.join(data_dir, "*.bmp"))
    # 返回文件目录
    return data


def make_data(sess, data, label):
    """
    Make input data as h5 file format
    Depending on 'is_train' (flag value), savepath would be changed.
    """
    if FLAGS.is_train:
        savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
    else:
        savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')

    with h5py.File(savepath, 'w') as hf:
        hf.create_dataset('data', data=data)
        hf.create_dataset('label', data=label)


def imread(path, is_grayscale=True):
    """
    Read image using its path.
    Default value is gray-scale, and image is read by YCbCr format as the paper said.
    """
    if is_grayscale:
        return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)
    else:
        return scipy.misc.imread(path, mode='YCbCr').astype(np.float)


def modcrop(image, scale=3):
    """
    To scale down and up the original image, first thing to do is to have no remainder while scaling operation.

    We need to find modulo of height (and width) and scale factor.
    Then, subtract the modulo from height (and width) of original image size.
    There would be no remainder even after scaling operation.
    要缩小和放大原始图像,首先要做的是在缩放操作时没有剩余。
    我们需要找到高度(和宽度)和比例因子的模。
    然后,从原始图像的高度(和宽度)中减去模。
    即使经过缩放操作,也不会有余数。
    """
    if len(image.shape) == 3:
        # 取整
        h, w, _ = image.shape
        h = h - np.mod(h, scale)
        w = w - np.mod(w, scale)
        image = image[0:h, 0:w, :]
    else:
        h, w = image.shape
        h = h - np.mod(h, scale)
        w = w - np.mod(w, scale)
        image = image[0:h, 0:w]
    return image


def input_setup(sess, config):
    """
    Read image files and make their sub-images and saved them as a h5 file format.
    """
    # Load data path
    if config.is_train:

        data = prepare_data(sess, dataset="Train")
    else:
        data = prepare_data(sess, dataset="Test")

    sub_input_sequence = []
    sub_label_sequence = []
    # 计算padding
    padding = abs(config.image_size - config.label_size) / 2  # 6

    if config.is_train:
        for i in xrange(len(data)):
            # TODO 获取原图和低分辨率还原标签
            input_, label_ = preprocess(data[i], config.scale)
            if len(input_.shape) == 3:
                h, w, _ = input_.shape
            else:
                h, w = input_.shape

            for x in range(0, h - config.image_size + 1, config.stride):
                for y in range(0, w - config.image_size + 1, config.stride):
                    sub_input = input_[x:x + config.image_size, y:y + config.image_size]  # [33 x 33]
                    sub_label = label_[x + int(padding):x + int(padding) + config.label_size,
                                y + int(padding):y + int(padding) + config.label_size]  # [21 x 21]

                    # Make channel value
                    sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
                    sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

                    sub_input_sequence.append(sub_input)
                    sub_label_sequence.append(sub_label)
    else:
        input_, label_ = preprocess(data[1], config.scale)
        if len(input_.shape) == 3:
            h, w, _ = input_.shape
        else:
            h, w = input_.shape

        # Numbers of sub-images in height and width of image are needed to compute merge operation.
        nx = ny = 0
        for x in range(0, h - config.image_size + 1, config.stride):
            # 保存索引
            nx += 1
            ny = 0
            for y in range(0, w - config.image_size + 1, config.stride):
                ny += 1
                sub_input = input_[x:x + config.image_size, y:y + config.image_size]  # [33 x 33]
                sub_label = label_[x + int(padding):x + int(padding) + config.label_size,
                            y + int(padding):y + int(padding) + config.label_size]  # [21 x 21]

                sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
                sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

                sub_input_sequence.append(sub_input)
                sub_label_sequence.append(sub_label)

    """
    len(sub_input_sequence) : the number of sub_input (33 x 33 x ch) in one image
    (sub_input_sequence[0]).shape : (33, 33, 1)
    """
    # Make list to numpy array. With this transform
    arrdata = np.asarray(sub_input_sequence)  # [?, 33, 33, 1]
    arrlabel = np.asarray(sub_label_sequence)  # [?, 21, 21, 1]
    make_data(sess, arrdata, arrlabel)

    if not config.is_train:
        return nx, ny


def imsave(image, path):
    return scipy.misc.imsave(path, image)


def merge(images, size):
    # 合并图片
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 1))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h:j * h + h, i * w:i * w + w, :] = image

    return img
  • 原图

  • 效果图

相关推荐
IT古董33 分钟前
第四章:大模型(LLM)】06.langchain原理-(3)LangChain Prompt 用法
java·人工智能·python
fantasy_arch5 小时前
pytorch例子计算两张图相似度
人工智能·pytorch·python
AndrewHZ7 小时前
【3D重建技术】如何基于遥感图像和DEM等数据进行城市级高精度三维重建?
图像处理·人工智能·深度学习·3d·dem·遥感图像·3d重建
WBluuue7 小时前
数学建模:智能优化算法
python·机器学习·数学建模·爬山算法·启发式算法·聚类·模拟退火算法
赴3357 小时前
矿物分类案列 (一)六种方法对数据的填充
人工智能·python·机器学习·分类·数据挖掘·sklearn·矿物分类
大模型真好玩7 小时前
一文深度解析OpenAI近期发布系列大模型:意欲一统大模型江湖?
人工智能·python·mcp
RPA+AI十二工作室8 小时前
亚马逊店铺绩效巡检_影刀RPA源码解读
chrome·python·rpa·影刀
nonono8 小时前
深度学习——常见的神经网络
人工智能·深度学习·神经网络
小艳加油8 小时前
Python机器学习与深度学习;Transformer模型/注意力机制/目标检测/语义分割/图神经网络/强化学习/生成式模型/自监督学习/物理信息神经网络等
python·深度学习·机器学习·transformer
钢铁男儿9 小时前
如何构建一个神经网络?从零开始搭建你的第一个深度学习模型
人工智能·深度学习·神经网络