深度学习实战之超分辨率算法(tensorflow)——ESPCN

espcn原理算法请参考上一篇论文,这里主要给实现。

数据集如下:尺寸相等即可

  • 针对数据集,生成样本代码
  • preeate_data.py
python 复制代码
import imageio
from scipy import misc, ndimage
import numpy as np
import imghdr
import shutil
import os
import json

mat = np.array(
    [[ 65.481, 128.553, 24.966 ],
     [-37.797, -74.203, 112.0  ],
     [  112.0, -93.786, -18.214]])
mat_inv = np.linalg.inv(mat)
offset = np.array([16, 128, 128])

def rgb2ycbcr(rgb_img):
    ycbcr_img = np.zeros(rgb_img.shape, dtype=np.uint8)
    for x in range(rgb_img.shape[0]):
        for y in range(rgb_img.shape[1]):
            ycbcr_img[x, y, :] = np.round(np.dot(mat, rgb_img[x, y, :] * 1.0 / 255) + offset)
    return ycbcr_img

def ycbcr2rgb(ycbcr_img):
    rgb_img = np.zeros(ycbcr_img.shape, dtype=np.uint8)
    for x in range(ycbcr_img.shape[0]):
        for y in range(ycbcr_img.shape[1]):
            [r, g, b] = ycbcr_img[x,y,:]
            rgb_img[x, y, :] = np.maximum(0, np.minimum(255, np.round(np.dot(mat_inv, ycbcr_img[x, y, :] - offset) * 255.0)))
    return rgb_img

def my_anti_shuffle(input_image, ratio):
    shape = input_image.shape
    ori_height = int(shape[0])
    ori_width = int(shape[1])
    ori_channels = int(shape[2])
    if ori_height % ratio != 0 or ori_width % ratio != 0:
        print("Error! Height and width must be divided by ratio!")
        return
    height = ori_height // ratio
    width = ori_width // ratio
    channels = ori_channels * ratio * ratio
    anti_shuffle = np.zeros((height, width, channels), dtype=np.uint8)
    for c in range(0, ori_channels):
        for x in range(0, ratio):
            for y in range(0, ratio):
                anti_shuffle[:,:,c * ratio * ratio + x * ratio + y] = input_image[x::ratio, y::ratio, c]
    return anti_shuffle

def shuffle(input_image, ratio):
    shape = input_image.shape
    height = int(shape[0]) * ratio
    width = int(shape[1]) * ratio
    channels = int(shape[2]) // ratio // ratio
    shuffled = np.zeros((height, width, channels), dtype=np.uint8)
    for i in range(0, height):
        for j in range(0, width):
            for k in range(0, channels):
                shuffled[i,j,k] = input_image[i // ratio, j // ratio, k * ratio * ratio + (i % ratio) * ratio + (j % ratio)]
    return shuffled

def prepare_images(params):
    ratio, training_num, lr_stride, lr_size = params['ratio'], params['training_num'], params['lr_stride'], params['lr_size']
    hr_stride = lr_stride * ratio
    hr_size = lr_size * ratio

    # first clear old images and create new directories
    for ele in ['training', 'validation', 'test']:
        new_dir = params[ele + '_image_dir'].format(ratio)
        if os.path.isdir(new_dir):
            shutil.rmtree(new_dir)
        for sub_dir in ['/hr', 'lr']:
            os.makedirs(new_dir + sub_dir)

    image_num = 0
    folder = params['training_image_dir'].format(ratio)
    for root, dirnames, filenames in os.walk(params['image_dir']):
        for filename in filenames:
            path = os.path.join(root, filename)
            if imghdr.what(path) != 'jpeg':
                continue
                
            hr_image = imageio.imread(path)
            height = hr_image.shape[0]
            new_height = height - height % ratio
            width = hr_image.shape[1]
            new_width = width - width % ratio
            hr_image = hr_image[0:new_height,0:new_width]
            blurred = ndimage.gaussian_filter(hr_image, sigma=(1, 1, 0))
            lr_image = blurred[::ratio,::ratio,:]

            height = hr_image.shape[0]
            width = hr_image.shape[1]
            vertical_number = height / hr_stride - 1
            horizontal_number = width / hr_stride - 1
            image_num = image_num + 1
            if image_num % 10 == 0:
                print ("Finished image: {}".format(image_num))
            if image_num > training_num and image_num <= training_num + params['validation_num']:
                folder = params['validation_image_dir'].format(ratio)
            elif image_num > training_num + params['validation_num']:
                folder = params['test_image_dir'].format(ratio)
            #misc.imsave(folder + 'hr_full/' + filename[0:-4] + '.png', hr_image)
            #misc.imsave(folder + 'lr_full/' + filename[0:-4] + '.png', lr_image)
            for x in range(0, int(horizontal_number)):
                for y in range(0, int(vertical_number)):
                    hr_sub_image = hr_image[y * hr_stride : y * hr_stride + hr_size, x * hr_stride : x * hr_stride + hr_size]
                    lr_sub_image = lr_image[y * lr_stride : y * lr_stride + lr_size, x * lr_stride : x * lr_stride + lr_size]
                    imageio.imwrite("{}hr/{}_{}_{}.png".format(folder, filename[0:-4], y, x), hr_sub_image)
                    imageio.imwrite("{}lr/{}_{}_{}.png".format(folder, filename[0:-4], y, x), lr_sub_image)
            if image_num >= training_num + params['validation_num'] + params['test_num']:
                break
        else:
            continue
        break

def prepare_data(params):
    ratio = params['ratio']
    params['hr_stride'] = params['lr_stride'] * ratio
    params['hr_size'] = params['lr_size'] * ratio

    for ele in ['training', 'validation', 'test']:
        new_dir = params[ele + '_dir'].format(ratio)
        if os.path.isdir(new_dir):
            shutil.rmtree(new_dir)
        os.makedirs(new_dir)

    ratio, lr_size, edge = params['ratio'], params['lr_size'], params['edge']
    image_dirs = [d.format(ratio) for d in [params['training_image_dir'], params['validation_image_dir'], params['test_image_dir']]]
    data_dirs = [d.format(ratio) for d in [params['training_dir'], params['validation_dir'], params['test_dir']]]
    hr_start_idx = ratio * edge // 2
    hr_end_idx = hr_start_idx + (lr_size - edge) * ratio
    sub_hr_size = (lr_size - edge) * ratio
    for dir_idx, image_dir in enumerate(image_dirs):
        data_dir = data_dirs[dir_idx]
        print ("Creating {}".format(data_dir))
        for root, dirnames, filenames in os.walk(image_dir + "/lr"):
            for filename in filenames:
                lr_path = os.path.join(root, filename)
                hr_path = image_dir + "/hr/" + filename
                lr_image = imageio.imread(lr_path)
                hr_image = imageio.imread(hr_path)
                # convert to Ycbcr color space
                lr_image_y = rgb2ycbcr(lr_image)
                hr_image_y = rgb2ycbcr(hr_image)
                lr_data = lr_image_y.reshape((lr_size * lr_size * 3))
                sub_hr_image_y = hr_image_y[int(hr_start_idx):int(hr_end_idx):1,int(hr_start_idx):int(hr_end_idx):1]
                hr_data = my_anti_shuffle(sub_hr_image_y, ratio).reshape(sub_hr_size * sub_hr_size * 3)
                data = np.concatenate([lr_data, hr_data])
                data.astype('uint8').tofile(data_dir + "/" + filename[0:-4])

def remove_images(params):
    # Don't need old image folders
    for ele in ['training', 'validation', 'test']:
        rm_dir = params[ele + '_image_dir'].format(params['ratio'])
        if os.path.isdir(rm_dir):
            shutil.rmtree(rm_dir)


if __name__ == '__main__':
    with open("./params.json", 'r') as f:
        params = json.load(f)

    print("Preparing images with scaling ratio: {}".format(params['ratio']))
    print ("If you want a different ratio change 'ratio' in params.json")
    print ("Splitting images (1/3)")
    prepare_images(params)

    print ("Preparing data, this may take a while (2/3)")
    prepare_data(params)

    print ("Cleaning up split images (3/3)")
    remove_images(params)
    print("Done, you can now train the model!")
python 复制代码
import argparse
from PIL import Image
import imageio
import tensorflow as tf
from scipy import ndimage
from scipy import misc
import numpy as np
from prepare_data import *
from psnr import psnr
import json
import pdb

from espcn import ESPCN

def get_arguments():
    parser = argparse.ArgumentParser(description='EspcnNet generation script')
    parser.add_argument('--checkpoint', type=str,
                        help='Which model checkpoint to generate from',default="logdir_2x/train")
    parser.add_argument('--lr_image', type=str,
                        help='The low-resolution image waiting for processed.',default="images/butterfly_GT.jpg")
    parser.add_argument('--hr_image', type=str,
                        help='The high-resolution image which is used to calculate PSNR.')
    parser.add_argument('--out_path', type=str,
                        help='The output path for the super-resolution image',default="result/butterfly_HR")
    return parser.parse_args()

def check_params(args, params):
    if len(params['filters_size']) - len(params['channels']) != 1:
        print("The length of 'filters_size' must be greater then the length of 'channels' by 1.")
        return False
    return True

def generate():
    args = get_arguments()

    with open("./params.json", 'r') as f:
        params = json.load(f)

    if check_params(args, params) == False:
        return

    sess = tf.Session()

    net = ESPCN(filters_size=params['filters_size'],
                   channels=params['channels'],
                   ratio=params['ratio'],
                   batch_size=1,
                   lr_size=params['lr_size'],
                   edge=params['edge'])

    loss, images, labels = net.build_model()

    lr_image = tf.placeholder(tf.uint8)
    lr_image_data = imageio.imread(args.lr_image)
    lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
    lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
    lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
    lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
    lr_image_batch = np.zeros((1,) + lr_image_y_data.shape)
    lr_image_batch[0] = lr_image_y_data

    sr_image = net.generate(lr_image)

    saver = tf.train.Saver()
    try:
        model_loaded = net.load(sess, saver, args.checkpoint)
    except:
        raise Exception("Failed to load model, does the ratio in params.json match the ratio you trained your checkpoint with?")

    if model_loaded:
        print("[*] Checkpoint load success!")
    else:
        print("[*] Checkpoint load failed/no checkpoint found")
        return

    sr_image_y_data = sess.run(sr_image, feed_dict={lr_image: lr_image_batch})

    sr_image_y_data = shuffle(sr_image_y_data[0], params['ratio'])
    sr_image_ycbcr_data =np.array(Image.fromarray(lr_image_ycbcr_data).resize(params['ratio'] * np.array(lr_image_data.shape[0:2]),Image.BICUBIC))


    edge = params['edge'] * params['ratio'] / 2

    sr_image_ycbcr_data = np.concatenate((sr_image_y_data, sr_image_ycbcr_data[int(edge):int(-edge),int(edge):int(-edge),1:3]), axis=2)
    sr_image_data = ycbcr2rgb(sr_image_ycbcr_data)

    imageio.imwrite(args.out_path + '.png', sr_image_data)

    if args.hr_image != None:
        hr_image_data = misc.imread(args.hr_image)
        model_psnr = psnr(hr_image_data, sr_image_data, edge)
        print('PSNR of the model: {:.2f}dB'.format(model_psnr))

        sr_image_bicubic_data = misc.imresize(lr_image_data,
                                        params['ratio'] * np.array(lr_image_data.shape[0:2]),
                                        'bicubic')
        misc.imsave(args.out_path + '_bicubic.png', sr_image_bicubic_data)
        bicubic_psnr = psnr(hr_image_data, sr_image_bicubic_data, 0)
        print('PSNR of Bicubic: {:.2f}dB'.format(bicubic_psnr))


if __name__ == '__main__':
    generate()

train.py
```python
from __future__ import print_function
import argparse
from datetime import datetime
import os
import sys
import time
import json
import time

import tensorflow as tf
from reader import create_inputs
from espcn import ESPCN

import pdb


try:
    xrange
except Exception as e:
    xrange = range
# 批次
BATCH_SIZE = 32
# epochs
NUM_EPOCHS = 100
# learning rate
LEARNING_RATE = 0.0001
# logdir
LOGDIR_ROOT = './logdir_{}x'

def get_arguments():

    parser = argparse.ArgumentParser(description='EspcnNet example network')
    # 权重
    parser.add_argument('--checkpoint', type=str,
                        help='Which model checkpoint to load from', default=None)
    # batch_size
    parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
                        help='How many image files to process at once.')
    # epochs
    parser.add_argument('--epochs', type=int, default=NUM_EPOCHS,
                        help='Number of epochs.')
    # 学习率
    parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,
                        help='Learning rate for training.')
    # logdir_root
    parser.add_argument('--logdir_root', type=str, default=LOGDIR_ROOT,
                        help='Root directory to place the logging '
                        'output and generated model. These are stored '
                        'under the dated subdirectory of --logdir_root. '
                        'Cannot use with --logdir.')
    # 返回参数
    return parser.parse_args()

def check_params(args, params):
    if len(params['filters_size']) - len(params['channels']) != 1:
        print("The length of 'filters_size' must be greater then the length of 'channels' by 1.")
        return False
    return True

def train():

    args = get_arguments()
    # load json
    with open("./params.json", 'r') as f:
        params = json.load(f)
    # 存在
    if check_params(args, params) == False:
        return

    logdir_root = args.logdir_root # ./logdir
    if logdir_root == LOGDIR_ROOT:
        logdir_root = logdir_root.format(params['ratio']) # ./logdir_{RATIO}x
    logdir = os.path.join(logdir_root, 'train') # ./logdir_{RATIO}x/train

    # Load training data as np arrays
    # 加载数据
    lr_images, hr_labels = create_inputs(params)
    #  网络模型
    net = ESPCN(filters_size=params['filters_size'],
                   channels=params['channels'],
                   ratio=params['ratio'],
                   batch_size=args.batch_size,
                   lr_size=params['lr_size'],
                   edge=params['edge'])

    loss, images, labels = net.build_model()
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # set up logging for tensorboard
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    summaries = tf.summary.merge_all()

    # set up session
    sess = tf.Session()

    # saver for storing/restoring checkpoints of the model
    saver = tf.train.Saver()

    init = tf.initialize_all_variables()
    sess.run(init)

    if net.load(sess, saver, logdir):
        print("[*] Checkpoint load success!")
    else:
        print("[*] Checkpoint load failed/no checkpoint found")

    try:
        steps, start_average, end_average = 0, 0, 0
        start_time = time.time()
        for ep in xrange(1, args.epochs + 1):
            batch_idxs = len(lr_images) // args.batch_size
            batch_average = 0
            for idx in xrange(0, batch_idxs):
                # On the fly batch generation instead of Queue to optimize GPU usage
                batch_images = lr_images[idx * args.batch_size : (idx + 1) * args.batch_size]
                batch_labels = hr_labels[idx * args.batch_size : (idx + 1) * args.batch_size]
                
                steps += 1
                summary, loss_value, _ = sess.run([summaries, loss, optim], feed_dict={images: batch_images, labels: batch_labels})
                writer.add_summary(summary, steps)
                batch_average += loss_value

            # Compare loss of first 20% and last 20%
            batch_average = float(batch_average) / batch_idxs
            if ep < (args.epochs * 0.2):
                start_average += batch_average
            elif ep >= (args.epochs * 0.8):
                end_average += batch_average

            duration = time.time() - start_time
            print('Epoch: {}, step: {:d}, loss: {:.9f}, ({:.3f} sec/epoch)'.format(ep, steps, batch_average, duration))
            start_time = time.time()
            net.save(sess, saver, logdir, steps)
    except KeyboardInterrupt:
        print()
    finally:
        start_average = float(start_average) / (args.epochs * 0.2)
        end_average = float(end_average) / (args.epochs * 0.2)
        print("Start Average: [%.6f], End Average: [%.6f], Improved: [%.2f%%]" \
          % (start_average, end_average, 100 - (100*end_average/start_average)))

if __name__ == '__main__':
    train()

model 实现tensorflow版本

python 复制代码
import tensorflow as tf
import os
import sys
import pdb

def create_variable(name, shape):
    '''Create a convolution filter variable with the specified name and shape,
    and initialize it using Xavier initialition.'''
    initializer = tf.contrib.layers.xavier_initializer_conv2d()
    variable = tf.Variable(initializer(shape=shape), name=name)
    return variable

def create_bias_variable(name, shape):
    '''Create a bias variable with the specified name and shape and initialize
    it to zero.'''
    initializer = tf.constant_initializer(value=0.0, dtype=tf.float32)
    return tf.Variable(initializer(shape=shape), name)

class ESPCN:
    def __init__(self, filters_size, channels, ratio, batch_size, lr_size, edge):
        self.filters_size = filters_size
        self.channels = channels
        self.ratio = ratio
        self.batch_size = batch_size
        self.lr_size = lr_size
        self.edge = edge
        self.variables = self.create_variables()

    def create_variables(self):
        var = dict()
        var['filters'] = list()
        # the input layer
        var['filters'].append(
            create_variable('filter',
                            [self.filters_size[0],
                             self.filters_size[0],
                             1,
                             self.channels[0]]))
        # the hidden layers
        for idx in range(1, len(self.filters_size) - 1):
            var['filters'].append(
                create_variable('filter', 
                                [self.filters_size[idx],
                                 self.filters_size[idx],
                                 self.channels[idx - 1],
                                 self.channels[idx]]))
        # the output layer
        var['filters'].append(
            create_variable('filter',
                            [self.filters_size[-1],
                             self.filters_size[-1],
                             self.channels[-1],
                             self.ratio**2]))

        var['biases'] = list()
        for channel in self.channels:
            var['biases'].append(create_bias_variable('bias', [channel]))
        var['biases'].append(create_bias_variable('bias', [float(self.ratio)**2]))


        image_shape = (self.batch_size, self.lr_size, self.lr_size, 3)
        var['images'] = tf.placeholder(tf.uint8, shape=image_shape, name='images')
        label_shape = (self.batch_size, self.lr_size - self.edge, self.lr_size - self.edge, 3 * self.ratio**2)
        var['labels'] = tf.placeholder(tf.uint8, shape=label_shape, name='labels')

        return var

    def build_model(self):
        images, labels = self.variables['images'], self.variables['labels']
        input_images, input_labels = self.preprocess([images, labels])
        output = self.create_network(input_images)
        reduced_loss = self.loss(output, input_labels)
        return reduced_loss, images, labels

    def save(self, sess, saver, logdir, step):
        # print('[*] Storing checkpoint to {} ...'.format(logdir), end="")
        sys.stdout.flush()

        if not os.path.exists(logdir):
            os.makedirs(logdir)

        checkpoint = os.path.join(logdir, "model.ckpt")
        saver.save(sess, checkpoint, global_step=step)
        # print('[*] Done saving checkpoint.')

    def load(self, sess, saver, logdir):
        print("[*] Reading checkpoints...")
        ckpt = tf.train.get_checkpoint_state(logdir)

        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(logdir, ckpt_name))
            return True
        else:
            return False

    def preprocess(self, input_data):
        # cast to float32 and normalize the data
        input_list = list()
        for ele in input_data:
            if ele is None:
                continue
            ele = tf.cast(ele, tf.float32) / 255.0
            input_list.append(ele)

        input_images, input_labels = input_list[0][:,:,:,0:1], None
        # Generate doesn't use input_labels
        ratioSquare = self.ratio * self.ratio
        if input_data[1] is not None:
            input_labels = input_list[1][:,:,:,0:ratioSquare]
        return input_images, input_labels

    def create_network(self, input_labels):
        '''The default structure of the network is:

        input (3 channels) ---> 5 * 5 conv (64 channels) ---> 3 * 3 conv (32 channels) ---> 3 * 3 conv (3*r^2 channels)

        Where `conv` is 2d convolutions with a non-linear activation (tanh) at the output.
        '''
        current_layer = input_labels

        for idx in range(len(self.filters_size)):
            conv = tf.nn.conv2d(current_layer, self.variables['filters'][idx], [1, 1, 1, 1], padding='VALID')
            with_bias = tf.nn.bias_add(conv, self.variables['biases'][idx])
            if idx == len(self.filters_size) - 1:
                current_layer = with_bias
            else:
                current_layer = tf.nn.tanh(with_bias)
        return current_layer

    def loss(self, output, input_labels):
        residual = output - input_labels
        loss = tf.square(residual)
        reduced_loss = tf.reduce_mean(loss)
        tf.summary.scalar('loss', reduced_loss)
        return reduced_loss

    def generate(self, lr_image):
        lr_image = self.preprocess([lr_image, None])[0]
        sr_image = self.create_network(lr_image)
        sr_image = sr_image * 255.0
        sr_image = tf.cast(sr_image, tf.int32)
        sr_image = tf.maximum(sr_image, 0)
        sr_image = tf.minimum(sr_image, 255)
        sr_image = tf.cast(sr_image, tf.uint8)
        return sr_image
  • 读取文件
python 复制代码
import tensorflow as tf
import numpy as np
import os
import pdb

def create_inputs(params):
    """
    Loads prepared training files and appends them as np arrays to a list.
    This approach is better because a FIFOQueue with a reader can't utilize
    the GPU while this approach can.
    """
    sess = tf.Session()

    lr_images, hr_labels = [], []
    training_dir = params['training_dir'].format(params['ratio'])

    # Raise exception if user has not ran prepare_data.py yet
    if not os.path.isdir(training_dir):
        raise Exception("You must first run prepare_data.py before you can train")

    lr_shape = (params['lr_size'], params['lr_size'], 3)
    hr_shape = output_shape = (params['lr_size'] - params['edge'], params['lr_size'] - params['edge'], 3 * params['ratio']**2)
    for file in os.listdir(training_dir):
        train_file = open("{}/{}".format(training_dir, file), "rb")
        train_data = np.fromfile(train_file, dtype=np.uint8)

        lr_image = train_data[:17 * 17 * 3].reshape(lr_shape)
        lr_images.append(lr_image)

        hr_label = train_data[17 * 17 * 3:].reshape(hr_shape)
        hr_labels.append(hr_label)

    return lr_images, hr_labels

psnr计算

python 复制代码
import numpy as np
import math

def psnr(hr_image, sr_image, hr_edge):
    #assume RGB image
    hr_image_data = np.array(hr_image)
    if hr_edge > 0:
        hr_image_data = hr_image_data[hr_edge:-hr_edge, hr_edge:-hr_edge].astype('float32')

    sr_image_data = np.array(sr_image).astype('float32')
    
    diff = sr_image_data - hr_image_data
    diff = diff.flatten('C')
    rmse = math.sqrt( np.mean(diff ** 2.) )
    return 20*math.log10(255.0/rmse)

训练过程有个BUG:bias is not unsupportd,但是也能学习。

相关推荐
光羽隹衡17 分钟前
深度学习——卷积神经网络CNN
人工智能·深度学习·cnn
捷米研发三部28 分钟前
ProfiNet转CC-Link IE FB协议转换网关实现三菱PLC与西门子变频器通讯在风机调节的应用案例
网络·自动化
1750633194537 分钟前
EtherCAT ubuntu wireshark
网络·ubuntu·wireshark
maosheng11461 小时前
HCIP中mgre的练习题
网络
人工智能培训1 小时前
如何大幅降低大模型的训练和推理成本?
人工智能·深度学习·大模型·知识图谱·强化学习·智能体搭建·大模型工程师
汉克老师1 小时前
GESP2025年9月认证C++二级真题与解析(单选题1-8)
网络·循环结构·表达式·gesp二级·gesp2级·双重循环
之之为知知1 小时前
NLP进化史:一场「打补丁」的技术接力赛
人工智能·深度学习·机器学习·自然语言处理·大模型
Dev7z1 小时前
基于多尺度深度卷积增强的YOLO11公共区域发传单违规行为检测系统(2026年 力作)
人工智能·深度学习·机器学习
米羊1212 小时前
关于 免杀(上)
网络·安全
松涛和鸣2 小时前
DAY49 DS18B20 Single-Wire Digital Temperature Acquisition
linux·服务器·网络·数据库·html