基于AnimeGANv2的照片动漫化

基于AnimeGANv2的照片动漫化

一、作者介绍

作者:陈辛怡,西安工程大学电子信息学院,2025级研究生

研究方向:可见光与SAR图像融合

联系邮箱:321424455@qq.com

作者:李逸超,西安工程大学电子信息学院,2025级研究生 张宏伟人工智能课题组

研究方向:机器视觉与人工智能

联系邮箱:2317314922@qq.com

二、AnimeGANv2核心原理

2.1图像风格迁移概述

图像风格迁移是计算机视觉领域的核心研究方向,目标是在保留输入图像内容结构 的前提下,将图像的视觉表现形式转换为目标艺术风格。

传统风格迁移依赖 CNN 提取内容与风格特征进行图优化,速度慢、效果不稳定;基于 GAN 的端到端风格迁移凭借生成质量高、推理速度快、风格一致性强等优势,成为照片动漫化的主流技术路线。

照片动漫化属于非成对图像转换 任务,无需配对数据,可将真实人像、风景、建筑、街景等自动转为日系二次元动漫风格,广泛用于短视频特效、图像创作、虚拟形象、文创设计等场景。

2.2AnimeGANv2概述

AnimeGANv2是一种轻量级生成对抗网络,专注于将真实世界高效转换为日系动漫风格图像。它在 AnimeGAN 基础上优化了网络结构、损失函数与训练策略,在保持动漫风格高保真的同时,大幅提升推理速度与内容保留能力。

该模型广泛应用于短视频特效、图片创作、虚拟形象生成、游戏素材制作等领域,因其风格纯正、细节清晰、轻量化 而成为照片动漫化的主流方案。

2.3生成对抗基础网络

GAN 由生成器与判别器构成:生成器学习真实数据分布并生成假样本;判别器判断输入为真实样本还是生成样本。二者通过对抗训练不断优化,最终生成器可输出逼近真实分布的图像。

AnimeGANv2 基于 GAN 架构,属于无配对图像风格迁移模型 ,无需照片 --- 动漫成对数据,只需分别收集真实照片与动漫截图即可训练,大幅降低数据成本。

2.4生成器(generator)

负责将真实照片映射为动漫风格图像:

  1. 编码器(下采样) :通过卷积提取高层内容特征,压缩空间尺寸
  2. 残差块 :增强深层特征表达,避免梯度消失,保留细节
  3. 解码器(上采样) :逐步恢复尺寸,输出动漫化结果

使用实例归一化 提升风格一致性与训练稳定性。

2.5判别器(Discriminator)

扮演"真伪鉴别"的角色,通过深度卷积网络学习真实动漫图片的特征分布。它不断接收生成器的输出与真实动漫图,输出图像为"真动漫"的概率,以此引导生成器优化生成策略,让生成的动漫图像在细节、色彩、笔触上更接近真实的动漫作品。

2.6损失函数设计

AnimeGANv2 的总损失由四部分加权组成,是保证高质量动漫化的关键:

  1. 对抗损失(Adversarial Loss)

基于 WGAN-GP,确保生成图像分布与真实动漫分布一致,提升真实感。

  1. 内容损失(Content Loss)

使用 VGG19 网络提取图像高层语义特征,保证生成图像与输入照片内容一致,避免轮廓变形。

  1. 风格损失(Style Loss)

包含 Gram 矩阵风格损失与灰度风格损失,强制学习动漫的色彩、纹理、笔触特征。

  1. 边缘梯度损失(Gradient Loss)

保留图像边缘轮廓,使动漫线条清晰流畅,无模糊断裂。

三、数据集介绍与预处理

3.1数据集构成

AnimeGANv2 训练需要三类数据:

  1. 真实照片数据集(Photo)

来源:网络风景、人像、街景照片;

数量:500-700 张;

尺寸:256×256;

用途:作为风格迁移输入源。

  1. 动漫风格数据集(Anime)

来源:日系动漫高清截图(如《起风了》、《你的名字&与你共度时光》、《辣椒粉》);

数量:500-700 张;

尺寸:256×256;

用途:提供目标动漫风格。

  1. 动漫平滑数据集(Anime_Smooth)

来源:对动漫截图进行高斯模糊处理;

用途:帮助判别器区分动漫纹理与噪声,减少伪影。

3.2数据集来源

该数据集来自谷歌网,如图3.1 AnimeGANv2检索,图3.2 AnimeGANv2检索结果展示,图3.3 数据集展示。

3.1 AnimeGANv2的谷歌检索

图3.2 AnimeGANv2检索结果展示

图3.3 数据集展示

3.3数据集预处理

  1. 尺寸统一 :将所有图像缩放到 256×256,保持训练稳定;
  2. 归一化 :像素值从 [0,255] 归一化到 [-1,1],适配网络输入;
  3. 数据增强 :随机水平翻转、旋转、裁剪、亮度调整,提升泛化能力;
  4. 去重与清洗 :删除模糊、水印、低质量图像,保证训练数据纯度。

3.4数据集目录结构

AnimeGANv2/

├── dataset/

│ ├── train_photo/ 真实照片

│ ├── train_anime/ 动漫图像

│ ├── test/ 测试照片

│ └── smooth/ 动漫平滑图

├── models/ 模型权重

├── results/ 生成结果

└── main.py 主程序

四、算法流程

4.1训练流程

  1. 加载真实照片、动漫图像、动漫平滑图;
  2. 初始化生成器与判别器;
  3. 固定生成器,训练判别器区分三类图像;
  4. 固定判别器,训练生成器最小化总损失;
  5. 周期性保存模型权重,验证生成效果;
  6. 损失收敛后停止训练。

4.2推理流程

  1. 加载预训练模型权重;
  2. 读取输入照片并预处理;
  3. 送入生成器前向推理;
  4. 后处理恢复像素值范围;
  5. 保存动漫化输出图像。

五、代码实现

5.1环境配置

python = 7

pytorch >= 1.8

opencv-python

pillow

numpy

tqdm

matplotlib

安装命令:

pip install torch torchvision opencv-python pillow numpy tqdm matplotlib

5.2代码

1.测试代码

python 复制代码
import os

import warnings



# 屏蔽 TensorFlow 警告信息

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 只显示 ERROR,忽略 INFO 和 WARNING

warnings.filterwarnings('ignore')  # 屏蔽 Python warnings



# 屏蔽 Python 警告

warnings.filterwarnings('ignore')



# 屏蔽 TensorFlow INFO/WARNING

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 0 = all, 1 = INFO, 2 = WARNING, 3 = ERROR



# 屏蔽 TF 2.x 的兼容性警告

import logging

logging.getLogger('tensorflow').setLevel(logging.ERROR)



import tensorflow as tf

tf.get_logger().setLevel('ERROR')

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)  # 对 TF1.x 的 API 警告也屏蔽

import argparse

from tools.utils import *

import os

import warnings



# 屏蔽 TensorFlow 警告信息

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 只显示 ERROR,忽略 INFO 和 WARNING

warnings.filterwarnings('ignore')  # 屏蔽 Python warnings

from tqdm import tqdm

from glob import glob

import time

import numpy as np

from net import generator

os.environ["CUDA_VISIBLE_DEVICES"] = "0"



def parse_args():

    desc = "AnimeGANv2"

    parser = argparse.ArgumentParser(description=desc)



    parser.add_argument('--checkpoint_dir', type=str, default='checkpoint/'+'generator_Shinkai_weight',

                        help='Directory name to save the checkpoints')

    parser.add_argument('--test_dir', type=str, default='test_photo',

                        help='Directory name of test photos')

    parser.add_argument('--save_dir', type=str, default='Shinkai/t',

                        help='what style you want to get')

    parser.add_argument('--if_adjust_brightness', type=bool, default=True,

                        help='adjust brightness by the real photo')



    """checking arguments"""



    return parser.parse_args()



def stats_graph(graph):

flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())

# params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())

    print('FLOPs: {}'.format(flops.total_float_ops))



def test(checkpoint_dir, style_name, test_dir, if_adjust_brightness, img_size=[256,256]):

    # tf.reset_default_graph()

    result_dir = 'results/'+style_name

    check_folder(result_dir)

    test_files = glob('{}/*.*'.format(test_dir))



    test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')



    with tf.variable_scope("generator", reuse=False):

        test_generated = generator.G_net(test_real).fake

    saver = tf.train.Saver()



    gpu_options = tf.GPUOptions(allow_growth=True)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)) as sess:

        # tf.global_variables_initializer().run()

        # load model

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  # checkpoint file information

        if ckpt and ckpt.model_checkpoint_path:

            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)  # first line

            saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))

            print(" [*] Success to read {}".format(os.path.join(checkpoint_dir, ckpt_name)))

        else:

            print(" [*] Failed to find a checkpoint")

            return

        # stats_graph(tf.get_default_graph())



        begin = time.time()

        for sample_file  in tqdm(test_files) :

            # print('Processing image: ' + sample_file)

            sample_image = np.asarray(load_test_data(sample_file, img_size))

            image_path = os.path.join(result_dir,'{0}'.format(os.path.basename(sample_file)))

            fake_img = sess.run(test_generated, feed_dict = {test_real : sample_image})

            if if_adjust_brightness:

                save_images(fake_img, image_path, sample_file)

            else:

                save_images(fake_img, image_path, None)

        end = time.time()

        print(f'test-time: {end-begin} s')

        print(f'one image test time : {(end-begin)/len(test_files)} s')

if __name__ == '__main__':

    arg = parse_args()

    print(arg.checkpoint_dir)

test(arg.checkpoint_dir, arg.save_dir, arg.test_dir, arg.if_adjust_brightness)



2.训练代码



from AnimeGANv2 import AnimeGANv2

import argparse

from tools.utils import *

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"



"""parsing and configuration"""



def parse_args():

    desc = "AnimeGANv2"

    parser = argparse.ArgumentParser(description=desc)

    parser.add_argument('--dataset', type=str, default='Hayao', help='dataset_name')



    parser.add_argument('--epoch', type=int, default=101, help='The number of epochs to run')

    parser.add_argument('--init_epoch', type=int, default=10, help='The number of epochs for weight initialization')

    parser.add_argument('--batch_size', type=int, default=12, help='The size of batch size') # if light : batch_size = 20

    parser.add_argument('--save_freq', type=int, default=1, help='The number of ckpt_save_freq')



    parser.add_argument('--init_lr', type=float, default=2e-4, help='The learning rate')

    parser.add_argument('--g_lr', type=float, default=2e-5, help='The learning rate')

    parser.add_argument('--d_lr', type=float, default=4e-5, help='The learning rate')

    parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')



    parser.add_argument('--g_adv_weight', type=float, default=300.0, help='Weight about GAN')

    parser.add_argument('--d_adv_weight', type=float, default=300.0, help='Weight about GAN')

    parser.add_argument('--con_weight', type=float, default=1.5, help='Weight about VGG19')# 1.5 for Hayao, 2.0 for Paprika, 1.2 for Shinkai

    # ------ the follow weight used in AnimeGAN

    parser.add_argument('--sty_weight', type=float, default=2.5, help='Weight about style')# 2.5 for Hayao, 0.6 for Paprika, 2.0 for Shinkai

    parser.add_argument('--color_weight', type=float, default=10., help='Weight about color') # 15. for Hayao, 50. for Paprika, 10. for Shinkai

    parser.add_argument('--tv_weight', type=float, default=1., help='Weight about tv')# 1. for Hayao, 0.1 for Paprika, 1. for Shinkai

    # ---------------------------------------------

    parser.add_argument('--training_rate', type=int, default=1, help='training rate about G & D')

    parser.add_argument('--gan_type', type=str, default='lsgan', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge')



    parser.add_argument('--img_size', type=list, default=[256,256], help='The size of image: H and W')

    parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')



    parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')

    parser.add_argument('--n_dis', type=int, default=3, help='The number of discriminator layer')

    parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')





    parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',

                        help='Directory name to save the checkpoints')

    parser.add_argument('--log_dir', type=str, default='logs',

                        help='Directory name to save training logs')

    parser.add_argument('--sample_dir', type=str, default='samples',

                        help='Directory name to save the samples on training')



    return check_args(parser.parse_args())



"""checking arguments"""

def check_args(args):

    # --checkpoint_dir

    check_folder(args.checkpoint_dir)



    # --log_dir

    check_folder(args.log_dir)



    # --sample_dir

    check_folder(args.sample_dir)



    # --epoch

    try:

        assert args.epoch >= 1

    except:

        print('number of epochs must be larger than or equal to one')



    # --batch_size

    try:

        assert args.batch_size >= 1

    except:

        print('batch size must be larger than or equal to one')

    return args





"""main"""

def main():

    # parse arguments

    args = parse_args()

    if args is None:

      exit()



    # open session

    gpu_options = tf.GPUOptions(allow_growth=True)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,inter_op_parallelism_threads=8,

                               intra_op_parallelism_threads=8,gpu_options=gpu_options)) as sess:

        gan = AnimeGANv2(sess, args)



        # build graph

        gan.build_model()



        # show network architecture

        show_all_variables()



        gan.train()

        print(" [*] Training finished!")





if __name__ == '__main__':

    main()

六、问题与分析

6.1 pytorch环境适配问题

  1. 问题:由于电脑问题pytorch环境搭建不上,如图6.1。

图6.1 pytorch环境搭建问题图

  1. 解决方法:

(1)询问AI,按AI给出的方案去搭建环境,由于电脑问题,搭建还是未能搭建出来。

(2)寻找懂pytorch的同学帮忙搭建。

6.2 python版本问题

问题:要求python 7版本的环境依赖问题,如图6.2。

图6.2 测试错误图

解决方法:寻找师兄帮忙解决。

6.3 生成图像模糊或出现伪影

现象与原因:画面整体发虚、有马赛克或色彩断层;通常是输入原图分辨率过低、光线不足,或预处理时的降噪操作过度丢失了纹理细节。如图6.3为原图,图6.4为生成模糊或有伪影的图片。图6.5是解决之后所输出的高清动漫图。

图6.3 原图 图6.4出现伪影/模糊 图6.5高清动漫图

解决方案:确保输入照片清晰且光照均匀;尝试更换不同的预训练模型(如Realistic Vision);生成后可通过Photoshop等工具进行轻微的"高反差保留"锐化处理。如图6.5是解决之后所输出的高清动漫图。

6.4 风格差异不显著与原图差异小

现象与原因:生成图只是简单的滤镜效果,缺乏艺术感,如图6.6为原图,图6.7为生成图;原因是选择的模型风格与输入内容的属性不匹配。

6.6原图 6.6 生成简单滤镜效果图 6.7高清动漫图

解决方案:"对症下药"选择模型:风景照用新海诚(Shinkai)或宫崎骏风格模型,人像用face_paint或二次元动漫模型,建筑用赛博朋克风格模型,如图6.7。

七、参考链接

https://github.com/TachibanaYoshino/AnimeGANv2

相关推荐
茉莉玫瑰花茶2 小时前
LangGraph 入门教程:构建 AI 工作流 [ 案例三 ]
前端·人工智能·python
辰尘_星启2 小时前
【ROS2】 Python 节点的开发流程
开发语言·python·机器人·系统·控制·ros2
m0_624578592 小时前
SQL数据更新时如何减少锁表时间_合理控制事务边界与并发
jvm·数据库·python
曲幽2 小时前
让 FastAPI Agent 思考不阻塞:手把手教你实现异步任务与后台处理方案
redis·python·agent·fastapi·web·async·celery·ai agent·backgroundtask
2401_867623982 小时前
如何提取SQL日期中的月份_使用MONTH函数快速过滤
jvm·数据库·python
ㄟ留恋さ寂寞2 小时前
JavaScript中箭头函数在大括号省略时的隐式返回机制
jvm·数据库·python
WangN22 小时前
【SONIC】Isaac Lab 系统入门指南
人工智能·python·机器人·自动驾驶·仿真
2501_901200532 小时前
Laravel 大批量数据填充时的内存泄漏与性能优化指南
jvm·数据库·python
APIshop2 小时前
俄罗斯电商 Ozon 平台:ozon.item_get 商品详情接口深度技术解析
python