深度学习之超分辨率算法——SRCNN

  • 网络为基础卷积层

  • tensorflow 1.14

  • scipy 1.2.1

  • numpy 1.16

  • 大概意思就是针对数据,我们先把图片按缩小因子照整数倍进行缩减为小图片,再针对小图片进行插值算法,获得还原后的低分辨率的图片作为标签。

  • main.py 配置文件

python 复制代码
from model import SRCNN
from utils import input_setup
import numpy as np
import tensorflow as tf
import pprint
import os

flags = tf.app.flags
# 设置轮次
flags.DEFINE_integer("epoch", 1000, "Number of epoch [1000]")
# 设置批次
flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
# 设置image大小
flags.DEFINE_integer("image_size", 33, "The size of image to use [33]")
# 设置label
flags.DEFINE_integer("label_size", 21, "The size of label to produce [21]")
# 学习率
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of gradient descent algorithm [1e-4]")
# 图像颜色的尺寸
flags.DEFINE_integer("c_dim", 1, "Dimension of image color. [1]")
# 对输入图像进行预处理的比例因子大小
flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]")
# 步长
flags.DEFINE_integer("stride", 14, "The size of stride to apply input image [14]")
# 权重位置
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
# 样本目录
flags.DEFINE_string("sample_dir", "sample", "Name of sample directory [sample]")
# 训练还是测试
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [True]")
FLAGS = flags.FLAGS

# 格式化打印
pp = pprint.PrettyPrinter()

def main(_):
    #   打印参数
    pp.pprint(flags.FLAGS.__flags)

    # 没有就新建~
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    # Session提供了Operation执行和Tensor求值的环境;
    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        srcnn.train(FLAGS)
    
if __name__ == '__main__':
  tf.app.run()
python 复制代码
from utils import (
    read_data,
    input_setup,
    imsave,
    merge
)
import time
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

try:
    xrange
except:
    xrange = range


class SRCNN(object):
    # 模型初始化
    def __init__(self,
                 sess,
                 image_size=33,
                 label_size=21,
                 batch_size=128,
                 c_dim=1,
                 checkpoint_dir=None,
                 sample_dir=None):

        self.sess = sess
        # 判断灰度图
        self.is_grayscale = (c_dim == 1)

        self.image_size = image_size
        self.label_size = label_size
        self.batch_size = batch_size

        self.c_dim = c_dim

        self.checkpoint_dir = checkpoint_dir
        self.sample_dir = sample_dir

        self.build_model()

    def build_model(self):
        # tf.placeholder(
        # dtype,
        # shape = None,
        # name = None
        # )
        # 定义image,labels 输入形式 N W H C
        self.images = tf.placeholder(dtype=tf.float32, shape=[None, self.image_size, self.image_size, self.c_dim], name='images')
        self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
        # tf.Variable(initializer, name), 参数initializer是初始化参数,name是可自定义的变量名称,
        # shape为[filter_height, filter_width, in_channel, out_channels]
        # 构建模型参数
        self.weights = {
            'w1': tf.Variable(initial_value=tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
            'w2': tf.Variable(initial_value=tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
            'w3': tf.Variable(initial_value=tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
        }
        # the dim of bias== c_dim
        self.biases = {
            'b1': tf.Variable(tf.zeros([64]), name='b1'),
            'b2': tf.Variable(tf.zeros([32]), name='b2'),
            'b3': tf.Variable(tf.zeros([1]), name='b3')
        }
        # 构建模型 返回MHWC
        self.pred = self.model()

        # Loss function (MSE)
        self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
        # 保存和加载模型
        # 如果只想保留最新的4个模型,并希望每2个小时保存一次,
        self.saver = tf.train.Saver(max_to_keep=4,keep_checkpoint_every_n_hours=2)

    def train(self, config):

        if config.is_train:
            # 训练状态
            input_setup(self.sess, config)
        else:
            nx, ny = input_setup(self.sess, config)


        if config.is_train:

            data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")

        else:

            data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")


        train_data, train_label = read_data(data_dir)

        # Stochastic gradient descent with the standard backpropagation
        self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)

        tf.initialize_all_variables().run()

        counter = 0
        start_time = time.time()

        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        if config.is_train:
            print("Training...")

            for ep in xrange(config.epoch):
                # Run by batch images
                batch_idxs = len(train_data) // config.batch_size
                for idx in xrange(0, batch_idxs):
                    batch_images = train_data[idx * config.batch_size: (idx + 1) * config.batch_size]
                    batch_labels = train_label[idx * config.batch_size: (idx + 1) * config.batch_size]

                    counter += 1
                    _, err = self.sess.run([self.train_op, self.loss],
                                           feed_dict={self.images: batch_images, self.labels: batch_labels})

                    if counter % 10 == 0:
                        print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
                              % ((ep + 1), counter, time.time() - start_time, err))

                    if counter % 500 == 0:
                        self.save(config.checkpoint_dir, counter)

        else:
            print("Testing...")
            # print(train_data.shape)
            # print(train_label.shape)
            # print("---------")
            result = self.pred.eval({self.images: train_data, self.labels: train_label})
            # print(result.shape)
            result = merge(result, [nx, ny])
            result = result.squeeze()
            image_path = os.path.join(os.getcwd(), config.sample_dir)
            image_path = os.path.join(image_path, "test_image.png")
            imsave(result, image_path)

    def model(self):
        # input : 输入的要做卷积的图片,要求为一个张量,shape为 [ batch, in_height, in_width, in_channel ],其中batch为图片的数量,in_height 为图片高度,in_width 为图片宽度,in_channel 为图片的通道数,灰度图该值为1,彩色图为3。(也可以用其它值,但是具体含义不是很理解)
        # filter: 卷积核,要求也是一个张量,shape为 [ filter_height, filter_width, in_channel, out_channels ],其中 filter_height 为卷积核高度,filter_width 为卷积核宽度,in_channel 是图像通道数 ,和 input 的 in_channel 要保持一致,out_channel 是卷积核数量。
        # strides: 卷积时在图像每一维的步长,这是一个一维的向量,[ 1, strides, strides, 1],第一位和最后一位固定必须是1
        # padding: string类型,值为"SAME" 和 "VALID",表示的是卷积的形式,是否考虑边界。"SAME"是考虑边界,不足的时候用0去填充周围,"VALID"则不考虑
        # use_cudnn_on_gpu:  bool类型,是否使用cudnn加速,默认为true
        # padding = "SAME"输入和输出大小关系如下:输出大小等于输入大小除以步长向上取整,s是步长大小;
        # padding = "VALID"输入和输出大小关系如下:输出大小等于输入大小减去滤波器大小加上1,最后再除以步长(f为滤波器的大小,s是步长大小)。

        conv1 = tf.nn.relu(
            tf.nn.conv2d(self.images, self.weights['w1'], strides=[1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=True) + self.biases['b1'])
        conv2 = tf.nn.relu(
            tf.nn.conv2d(conv1, self.weights['w2'], strides=[1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=True) + self.biases['b2'])
        conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=True) + self.biases['b3']

        return conv3

    def save(self, checkpoint_dir, step):
        model_name = "SRCNN.model"
        model_dir = "%s_%s" % ("srcnn", self.label_size)
        # 目录
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        # 不存在就新建
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        # 保存
        # 参数
        '''
        sess,
        save_path,
        global_step=None,
        latest_filename=None,
        meta_graph_suffix="meta",
        write_meta_graph=True,
        write_state=True,
        strip_default_attrs=False,
        save_debug_info=False)
        '''
        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, model_name),
                        global_step=step)

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoints...")
        model_dir = "%s_%s" % ("srcnn", self.label_size)
        # 加载模型
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        # 通过checkpoint文件找到模型文件名
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)

        if ckpt and ckpt.model_checkpoint_path:
            # 返回path最后的文件名。如果path以/或\结尾,那么就会返回空值。即os.path.split(path)的第二个元素。
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            # 加载成功
            return True
        else:
            # 加载失败
            return False
python 复制代码
"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
"""

import os
import glob
import h5py
import random
import matplotlib.pyplot as plt

from PIL import Image  # for loading images as YCbCr format
import scipy.misc
import scipy.ndimage
import numpy as np

import tensorflow as tf

try:
    xrange
except:
    xrange = range

FLAGS = tf.app.flags.FLAGS


def read_data(path):
    """
    Read h5 format data file

    Args:
      path: file path of desired file
      data: '.h5' file format that contains train data values
      label: '.h5' file format that contains train label values
    """
    with h5py.File(path, 'r') as hf:
        data = np.array(hf.get('data'))
        label = np.array(hf.get('label'))
        return data, label


def preprocess(path, scale=3):
    """
    Preprocess single image file
      (1) Read original image as YCbCr format (and grayscale as default)
      (2) Normalize
      (3) Apply image file with bicubic interpolation

    Args:
      path: file path of desired file
      input_: image applied bicubic interpolation (low-resolution)
      label_: image with original resolution (high-resolution)
    """
    # 读取灰度图
    image = imread(path, is_grayscale=True)
    label_ = modcrop(image, scale)

    # Must be normalized
    # 归一化
    image = image / 255.
    label_ = label_ / 255.
    # zoom:类型为float或sequence,沿轴的缩放系数。 如果float,每个轴的缩放是相同的。 如果sequence,zoom应包含每个轴的一个值。
    # output:放置输出的数组,或返回数组的dtype
    # order:样条插值的顺序,默认为3.顺序必须在0-5范围内。
    # prefilter: bool, optional 。参数预滤波器确定输入是否在插值之前使用spline_filter进行预过滤(对于 > 1
    # 的样条插值所必需的)。 如果为False,则假定输入已被过滤。 默认为True。
    input_ = scipy.ndimage.interpolation.zoom(input=label_,zoom=(1. / scale), prefilter=False)
    input_ = scipy.ndimage.interpolation.zoom(input=input_,zoom=(scale / 1.), prefilter=False)

    return input_, label_


def prepare_data(sess, dataset):
    """
    Args:
      dataset: choose train dataset or test dataset
      For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
    dataset:
        "Train" or "Test":to choose the data is train or test
    """
    if FLAGS.is_train:
        filenames = os.listdir(dataset)
        #  获取数据目录
        data_dir = os.path.join(os.getcwd(), dataset)
        data = glob.glob(os.path.join(data_dir, "*.bmp"))
    else:
        # 获取测试集路径
        data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
        data = glob.glob(os.path.join(data_dir, "*.bmp"))
    # 返回文件目录
    return data


def make_data(sess, data, label):
    """
    Make input data as h5 file format
    Depending on 'is_train' (flag value), savepath would be changed.
    """
    if FLAGS.is_train:
        savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
    else:
        savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')

    with h5py.File(savepath, 'w') as hf:
        hf.create_dataset('data', data=data)
        hf.create_dataset('label', data=label)


def imread(path, is_grayscale=True):
    """
    Read image using its path.
    Default value is gray-scale, and image is read by YCbCr format as the paper said.
    """
    if is_grayscale:
        return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)
    else:
        return scipy.misc.imread(path, mode='YCbCr').astype(np.float)


def modcrop(image, scale=3):
    """
    To scale down and up the original image, first thing to do is to have no remainder while scaling operation.

    We need to find modulo of height (and width) and scale factor.
    Then, subtract the modulo from height (and width) of original image size.
    There would be no remainder even after scaling operation.
    要缩小和放大原始图像,首先要做的是在缩放操作时没有剩余。
    我们需要找到高度(和宽度)和比例因子的模。
    然后,从原始图像的高度(和宽度)中减去模。
    即使经过缩放操作,也不会有余数。
    """
    if len(image.shape) == 3:
        # 取整
        h, w, _ = image.shape
        h = h - np.mod(h, scale)
        w = w - np.mod(w, scale)
        image = image[0:h, 0:w, :]
    else:
        h, w = image.shape
        h = h - np.mod(h, scale)
        w = w - np.mod(w, scale)
        image = image[0:h, 0:w]
    return image


def input_setup(sess, config):
    """
    Read image files and make their sub-images and saved them as a h5 file format.
    """
    # Load data path
    if config.is_train:

        data = prepare_data(sess, dataset="Train")
    else:
        data = prepare_data(sess, dataset="Test")

    sub_input_sequence = []
    sub_label_sequence = []
    # 计算padding
    padding = abs(config.image_size - config.label_size) / 2  # 6

    if config.is_train:
        for i in xrange(len(data)):
            # TODO 获取原图和低分辨率还原标签
            input_, label_ = preprocess(data[i], config.scale)
            if len(input_.shape) == 3:
                h, w, _ = input_.shape
            else:
                h, w = input_.shape

            for x in range(0, h - config.image_size + 1, config.stride):
                for y in range(0, w - config.image_size + 1, config.stride):
                    sub_input = input_[x:x + config.image_size, y:y + config.image_size]  # [33 x 33]
                    sub_label = label_[x + int(padding):x + int(padding) + config.label_size,
                                y + int(padding):y + int(padding) + config.label_size]  # [21 x 21]

                    # Make channel value
                    sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
                    sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

                    sub_input_sequence.append(sub_input)
                    sub_label_sequence.append(sub_label)
    else:
        input_, label_ = preprocess(data[1], config.scale)
        if len(input_.shape) == 3:
            h, w, _ = input_.shape
        else:
            h, w = input_.shape

        # Numbers of sub-images in height and width of image are needed to compute merge operation.
        nx = ny = 0
        for x in range(0, h - config.image_size + 1, config.stride):
            # 保存索引
            nx += 1
            ny = 0
            for y in range(0, w - config.image_size + 1, config.stride):
                ny += 1
                sub_input = input_[x:x + config.image_size, y:y + config.image_size]  # [33 x 33]
                sub_label = label_[x + int(padding):x + int(padding) + config.label_size,
                            y + int(padding):y + int(padding) + config.label_size]  # [21 x 21]

                sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
                sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

                sub_input_sequence.append(sub_input)
                sub_label_sequence.append(sub_label)

    """
    len(sub_input_sequence) : the number of sub_input (33 x 33 x ch) in one image
    (sub_input_sequence[0]).shape : (33, 33, 1)
    """
    # Make list to numpy array. With this transform
    arrdata = np.asarray(sub_input_sequence)  # [?, 33, 33, 1]
    arrlabel = np.asarray(sub_label_sequence)  # [?, 21, 21, 1]
    make_data(sess, arrdata, arrlabel)

    if not config.is_train:
        return nx, ny


def imsave(image, path):
    return scipy.misc.imsave(path, image)


def merge(images, size):
    # 合并图片
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 1))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h:j * h + h, i * w:i * w + w, :] = image

    return img
  • 原图

  • 效果图

相关推荐
martian66521 分钟前
第17篇:python进阶:详解数据分析与处理
开发语言·python
无码不欢的我25 分钟前
使用vscode在本地和远程服务器端运行和调试Python程序的方法总结
ide·vscode·python
五味香26 分钟前
Java学习,查找List最大最小值
android·java·开发语言·python·学习·golang·kotlin
AIGC大时代27 分钟前
方法建议ChatGPT提示词分享
人工智能·深度学习·chatgpt·aigc·ai写作
金融OG32 分钟前
99.8 金融难点通俗解释:净资产收益率(ROE)
大数据·python·线性代数·机器学习·数学建模·金融·矩阵
fmdpenny1 小时前
Django的安装
后端·python·django
小爬菜1 小时前
Django学习笔记(启动项目)-03
前端·笔记·python·学习·django
陈钇钇1 小时前
持续升级《在线写python》小程序的功能,文章页增加一键复制功能,并自动去掉html标签
python·小程序·html
didiplus1 小时前
告别手动编辑:如何用Python快速创建Ansible hosts文件?
网络·python·ansible·hosts