人脸伪造判别分类网络CNN&Transformer

作者:SkyXZ

CSDN:SkyXZ~-CSDN博客

博客园:SkyXZ - 博客园

一、获取数据集

FaceForensics++ 是一个取证数据集,由1000段原始视频序列组成,这些视频通过四种自动人脸操纵方法进行处理:Deepfakes、Face2Face、FaceSwap 和 NeuralTextures。数据来自 977 段 YouTube 视频,所有视频中都包含一张可跟踪的、主要为正面且没有遮挡的人脸,使得自动篡改方法能够生成逼真的伪造视频。同时,由于该数据集提供了二值掩码,这些数据可以用于图像和视频分类以及分割。此外,官方还提供了 1000 个 Deepfakes 模型,用于生成和扩充新数据。

FaceForensics++数据集无法直接下载,需要按照要求填写谷歌表单来申请获取https://docs.google.com/forms/d/e/1FAIpQLSdRRR3L5zAv6tQ_CKxmK4W96tAab_pfBu2EKAgQbeDVhmXagg/viewform

等待几天之后会收到如下邮件,里面会附上数据集的下载Code,直接使用下载脚本下载即可获取:

python 复制代码
#!/usr/bin/env python
""" Downloads FaceForensics++ and Deep Fake Detection public data release
Example usage:
    see -h or https://github.com/ondyari/FaceForensics
"""
# -*- coding: utf-8 -*-
import argparse
import os
import urllib
import urllib.request
import tempfile
import time
import sys
import json
import random
from tqdm import tqdm
from os.path import join

# URLs and filenames
FILELIST_URL = 'misc/filelist.json'
DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json'
DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',]

# Parameters
DATASETS = {
    'original_youtube_videos': 'misc/downloaded_youtube_videos.zip',
    'original_youtube_videos_info': 'misc/downloaded_youtube_videos_info.zip',
    'original': 'original_sequences/youtube',
    'DeepFakeDetection_original': 'original_sequences/actors',
    'Deepfakes': 'manipulated_sequences/Deepfakes',
    'DeepFakeDetection': 'manipulated_sequences/DeepFakeDetection',
    'Face2Face': 'manipulated_sequences/Face2Face',
    'FaceShifter': 'manipulated_sequences/FaceShifter',
    'FaceSwap': 'manipulated_sequences/FaceSwap',
    'NeuralTextures': 'manipulated_sequences/NeuralTextures'
    }
ALL_DATASETS = ['original', 'DeepFakeDetection_original', 'Deepfakes',
                'DeepFakeDetection', 'Face2Face', 'FaceShifter', 'FaceSwap',
                'NeuralTextures']
COMPRESSION = ['raw', 'c23', 'c40']
TYPE = ['videos', 'masks', 'models']
SERVERS = ['EU', 'EU2', 'CA']

def parse_args():
    parser = argparse.ArgumentParser(
        description='Downloads FaceForensics v2 public data release.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('output_path', type=str, help='Output directory.')
    parser.add_argument('-d', '--dataset', type=str, default='all',
                        help='Which dataset to download, either pristine or '
                             'manipulated data or the downloaded youtube '
                             'videos.',
                        choices=list(DATASETS.keys()) + ['all']
                        )
    parser.add_argument('-c', '--compression', type=str, default='raw',
                        help='Which compression degree. All videos '
                             'have been generated with h264 with a varying '
                             'codec. Raw (c0) videos are lossless compressed.',
                        choices=COMPRESSION
                        )
    parser.add_argument('-t', '--type', type=str, default='videos',
                        help='Which file type, i.e. videos, masks, for our '
                             'manipulation methods, models, for Deepfakes.',
                        choices=TYPE
                        )
    parser.add_argument('-n', '--num_videos', type=int, default=None,
                        help='Select a number of videos number to '
                             "download if you don't want to download the full"
                             ' dataset.')
    parser.add_argument('--server', type=str, default='EU',
                        help='Server to download the data from. If you '
                             'encounter a slow download speed, consider '
                             'changing the server.',
                        choices=SERVERS
                        )
    args = parser.parse_args()

    # URLs
    server = args.server
    if server == 'EU':
        server_url = 'http://canis.vc.in.tum.de:8100/'
    elif server == 'EU2':
        server_url = 'http://kaldir.vc.in.tum.de/faceforensics/'
    elif server == 'CA':
        server_url = 'http://falas.cmpt.sfu.ca:8100/'
    else:
        raise Exception('Wrong server name. Choices: {}'.format(str(SERVERS)))
    args.tos_url = server_url + 'webpage/FaceForensics_TOS.pdf'
    args.base_url = server_url + 'v3/'
    args.deepfakes_model_url = server_url + 'v3/manipulated_sequences/' + \
                               'Deepfakes/models/'

    return args

def download_files(filenames, base_url, output_path, report_progress=True):
    os.makedirs(output_path, exist_ok=True)
    if report_progress:
        filenames = tqdm(filenames)
    for filename in filenames:
        download_file(base_url + filename, join(output_path, filename))

def reporthook(count, block_size, total_size):
    global start_time
    if count == 0:
        start_time = time.time()
        return
    duration = time.time() - start_time
    progress_size = int(count * block_size)
    speed = int(progress_size / (1024 * duration))
    percent = int(count * block_size * 100 / total_size)
    sys.stdout.write("\rProgress: %d%%, %d MB, %d KB/s, %d seconds passed" %
                     (percent, progress_size / (1024 * 1024), speed, duration))
    sys.stdout.flush()

def download_file(url, out_file, report_progress=False):
    out_dir = os.path.dirname(out_file)
    if not os.path.isfile(out_file):
        fh, out_file_tmp = tempfile.mkstemp(dir=out_dir)
        f = os.fdopen(fh, 'w')
        f.close()
        if report_progress:
            urllib.request.urlretrieve(url, out_file_tmp,
                                       reporthook=reporthook)
        else:
            urllib.request.urlretrieve(url, out_file_tmp)
        os.rename(out_file_tmp, out_file)
    else:
        tqdm.write('WARNING: skipping download of existing file ' + out_file)

def main(args):
    # TOS
    print('By pressing any key to continue you confirm that you have agreed '\
          'to the FaceForensics terms of use as described at:')
    print(args.tos_url)
    print('***')
    print('Press any key to continue, or CTRL-C to exit.')
    _ = input('')

    # Extract arguments
    c_datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS
    c_type = args.type
    c_compression = args.compression
    num_videos = args.num_videos
    output_path = args.output_path
    os.makedirs(output_path, exist_ok=True)

    # Check for special dataset cases
    for dataset in c_datasets:
        dataset_path = DATASETS[dataset]
        # Special cases
        if 'original_youtube_videos' in dataset:
            # Here we download the original youtube videos zip file
            print('Downloading original youtube videos.')
            if not 'info' in dataset_path:
                print('Please be patient, this may take a while (~40gb)')
                suffix = ''
            else:
            	suffix = 'info'
            download_file(args.base_url + '/' + dataset_path,
                          out_file=join(output_path,
                                        'downloaded_videos{}.zip'.format(
                                            suffix)),
                          report_progress=True)
            return

        # Else: regular datasets
        print('Downloading {} of dataset "{}"'.format(
            c_type, dataset_path
        ))

        # Get filelists and video lenghts list from server
        if 'DeepFakeDetection' in dataset_path or 'actors' in dataset_path:
        	filepaths = json.loads(urllib.request.urlopen(args.base_url + '/' +
                DEEPFEAKES_DETECTION_URL).read().decode("utf-8"))
        	if 'actors' in dataset_path:
        		filelist = filepaths['actors']
        	else:
        		filelist = filepaths['DeepFakesDetection']
        elif 'original' in dataset_path:
            # Load filelist from server
            file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
                FILELIST_URL).read().decode("utf-8"))
            filelist = []
            for pair in file_pairs:
            	filelist += pair
        else:
            # Load filelist from server
            file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
                FILELIST_URL).read().decode("utf-8"))
            # Get filelist
            filelist = []
            for pair in file_pairs:
                filelist.append('_'.join(pair))
                if c_type != 'models':
                    filelist.append('_'.join(pair[::-1]))
        # Maybe limit number of videos for download
        if num_videos is not None and num_videos > 0:
        	print('Downloading the first {} videos'.format(num_videos))
        	filelist = filelist[:num_videos]

        # Server and local paths
        dataset_videos_url = args.base_url + '{}/{}/{}/'.format(
            dataset_path, c_compression, c_type)
        dataset_mask_url = args.base_url + '{}/{}/videos/'.format(
            dataset_path, 'masks', c_type)

        if c_type == 'videos':
            dataset_output_path = join(output_path, dataset_path, c_compression,
                                       c_type)
            print('Output path: {}'.format(dataset_output_path))
            filelist = [filename + '.mp4' for filename in filelist]
            download_files(filelist, dataset_videos_url, dataset_output_path)
        elif c_type == 'masks':
            dataset_output_path = join(output_path, dataset_path, c_type,
                                       'videos')
            print('Output path: {}'.format(dataset_output_path))
            if 'original' in dataset:
                if args.dataset != 'all':
                    print('Only videos available for original data. Aborting.')
                    return
                else:
                    print('Only videos available for original data. '
                          'Skipping original.\n')
                    continue
            if 'FaceShifter' in dataset:
                print('Masks not available for FaceShifter. Aborting.')
                return
            filelist = [filename + '.mp4' for filename in filelist]
            download_files(filelist, dataset_mask_url, dataset_output_path)

        # Else: models for deepfakes
        else:
            if dataset != 'Deepfakes' and c_type == 'models':
                print('Models only available for Deepfakes. Aborting')
                return
            dataset_output_path = join(output_path, dataset_path, c_type)
            print('Output path: {}'.format(dataset_output_path))

            # Get Deepfakes models
            for folder in tqdm(filelist):
                folder_filelist = DEEPFAKES_MODEL_NAMES

                # Folder paths
                folder_base_url = args.deepfakes_model_url + folder + '/'
                folder_dataset_output_path = join(dataset_output_path,
                                                  folder)
                download_files(folder_filelist, folder_base_url,
                               folder_dataset_output_path,
                               report_progress=False)   # already done

if __name__ == "__main__":
    args = parse_args()
    main(args)

接下来使用如下命令即可下载数据集

bash 复制代码
python download-FaceForensics.py
    <output path>
    -d <dataset type, e.g., Face2Face, original or all>
    -c <compression quality, e.g., c23 or raw>
    -t <file type, e.g., videos, masks or models>

<output path> 表示数据集的保存路径,即下载后的 FaceForensics++ 或 DeepFakeDetection 数据将被存放的位置。例如,可以设置为当前项目下的 ./data/,也可以设置为单独的数据盘路径,如 /mnt/data2/qi.xiong/Dataset/FaceForensics/。下载脚本会在该目录下自动构建对应的数据集层级结构.

d 用于指定下载的数据类型(dataset type)。常见可选项包括 originalFace2FaceDeepfakesFaceSwapNeuralTexturesDeepFakeDetection 以及 all 等。其中,original 表示下载原始真实视频序列,通常对应 original_sequences/youtubeFace2FaceDeepfakesFaceSwapNeuralTextures 表示下载四种主要伪造方法生成的数据;DeepFakeDetection 表示下载 DeepFakeDetection 扩展数据;all 表示一次性下载全部可用数据。若仅用于常规 deepfake 检测实验,通常优先选择 original 与四种主流伪造类型。

c 用于指定压缩等级(compression quality)。常用选项为 rawc23c40。其中,raw 表示原始或无损压缩版本,数据体积最大,但保留了最完整的图像细节;c23 表示较高质量压缩版本,是目前较常见、也较平衡的一种设置,既能保留较好的视觉质量,又显著降低存储开销;c40 表示压缩更强、质量更低的数据版本,更适合做强压缩场景下的鲁棒性测试。实际使用中,如果只是复现主流实验或进行预处理,通常推荐优先下载 c23 视频版本。

t 用于指定文件类型(file type)。常见选项包括 videosmasksmodels。其中,videos 表示下载视频文件,这是最常用的选项;masks 表示下载伪造区域的二值掩码,适用于伪造区域定位、分割或可解释性分析任务;models 主要与部分伪造方法相关,用于获取对应的生成模型文件。对于大多数 deepfake 分类或人脸抽帧任务,仅下载 videos 即可。

下载完成的数据集格式如下:

bash 复制代码
(xq) qi.xiong@instance-ujccspas:/mnt/data2/qi.xiong/Dataset/FaceForensics$ tree -L 3
.
├── manipulated_sequences
│   ├── DeepFakeDetection
│   │   ├── c23
│   │   └── masks
│   ├── Deepfakes
│   │   ├── c23
│   │   └── masks
│   ├── Face2Face
│   │   ├── c23
│   │   └── masks
│   ├── FaceShifter
│   │   └── c23
│   ├── FaceSwap
│   │   └── c23
│   └── NeuralTextures
│       └── c23
└── original_sequences
    ├── actors
    │   └── c23
    └── youtube
        └── c23

22 directories, 0 files

二、数据集预处理

我们前面下载得到的数据集仍然是视频格式,因此在正式用于 deepfake 检测之前,还需要先进行预处理。通常来说,这类任务不会直接将整段视频输入模型,而是先从视频中抽取若干具有代表性的帧,再从每一帧中提取对应的人脸区域。这样做一方面可以明显降低后续数据处理和模型训练的开销,另一方面也能让模型更聚焦于真正有用的面部伪造信息。FaceForensics++ 官方文档中也提到,通常更推荐先下载压缩后的视频,再自行完成帧提取。本文这里采用一种比较简化且实用的处理方式:从每个视频中均匀抽取固定数量的帧,然后使用 RetinaFace 对这些帧进行人脸检测,并将检测到的人脸区域裁剪保存。相比一些传统方法,RetinaFace 在检测精度和鲁棒性方面通常更有优势,尤其是在侧脸、光照变化较大或者人脸尺度变化明显的情况下,检测结果往往更加稳定。需要说明的是,本文这里的预处理目标比较明确,即只做人脸抽帧和人脸裁剪,不额外涉及关键点对齐、伪造区域掩码生成等更复杂的步骤,因此整个流程会更加清晰,也更适合作为 FaceForensics++ 数据预处理的基础版本。

bash 复制代码
git clone https://github.com/ternaus/retinaface.git
cd retinaface
pip install -v -e .

我们配置好了retinaface之后,即可使用如下脚本继续转换:

python 复制代码
from glob import glob
import os
import cv2
from tqdm import tqdm
import numpy as np
import argparse
from retinaface.pre_trained_models import get_model
import torch

def facecrop(model, org_path, save_path, num_frames=10):
    cap_org = cv2.VideoCapture(org_path)
    frame_count_org = int(cap_org.get(cv2.CAP_PROP_FRAME_COUNT))

    if frame_count_org <= 0:
        print(f"Invalid video: {org_path}")
        cap_org.release()
        return

    frame_idxs = np.linspace(0, frame_count_org - 1, num_frames, endpoint=True, dtype=int)
    frame_idxs = set(frame_idxs.tolist())

    for cnt_frame in range(frame_count_org):
        ret_org, frame_org = cap_org.read()
        if not ret_org or frame_org is None:
            continue

        if cnt_frame not in frame_idxs:
            continue

        frame = cv2.cvtColor(frame_org, cv2.COLOR_BGR2RGB)
        faces = model.predict_jsons(frame)

        if len(faces) == 0:
            continue

        save_path_frames = os.path.join(
            save_path, 'frames_retina', os.path.basename(org_path).replace('.mp4', '')
        )
        os.makedirs(save_path_frames, exist_ok=True)

        for face_idx, face in enumerate(faces):
            bbox = face.get('bbox', None)
            if bbox is None or len(bbox) < 4:
                continue

            x0, y0, x1, y1 = map(int, bbox[:4])

            x0 = max(0, x0)
            y0 = max(0, y0)
            x1 = min(frame_org.shape[1], x1)
            y1 = min(frame_org.shape[0], y1)

            if x1 <= x0 or y1 <= y0:
                continue

            cropped_face = frame_org[y0:y1, x0:x1]
            face_image_path = os.path.join(
                save_path_frames, f'frame_{cnt_frame}_face_{face_idx}.png'
            )
            cv2.imwrite(face_image_path, cropped_face)

    cap_org.release()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-d',
        dest='dataset',
        choices=[
            'Original',
            'DeepFakeDetection_original',
            'DeepFakeDetection',
            'Deepfakes',
            'Face2Face',
            'FaceShifter',
            'FaceSwap',
            'NeuralTextures'
        ]
    )
    parser.add_argument('-c', dest='comp', choices=['raw', 'c23', 'c40'], default='raw')
    parser.add_argument('-n', dest='num_frames', type=int, default=20)
    args = parser.parse_args()

    if args.dataset == 'Original':
        dataset_path = 'data/FaceForensics++/original_sequences/youtube/{}/'.format(args.comp)
    elif args.dataset == 'DeepFakeDetection_original':
        dataset_path = 'data/FaceForensics++/original_sequences/actors/{}/'.format(args.comp)
    elif args.dataset in ['DeepFakeDetection', 'FaceShifter', 'Face2Face', 'Deepfakes', 'FaceSwap', 'NeuralTextures']:
        dataset_path = 'data/FaceForensics++/manipulated_sequences/{}/{}/'.format(args.dataset, args.comp)
    else:
        raise NotImplementedError

    device = torch.device('cpu')
    model = get_model("resnet50_2020-07-20", max_size=2048, device=device)
    model.eval()

    movies_path = dataset_path + 'videos/'
    movies_path_list = sorted(glob(movies_path + '*.mp4'))
    print("{} : videos are exist in {}".format(len(movies_path_list), args.dataset))

    for i in tqdm(range(len(movies_path_list))):
        facecrop(model, movies_path_list[i], save_path=dataset_path, num_frames=args.num_frames)

在具体使用时,我们主要关心三个参数:-d 用于指定要处理的数据子集,例如 Original 表示原始真实视频,DeepfakesFace2FaceFaceSwapNeuralTextures 表示不同伪造方法生成的视频;-c 用于指定压缩等级,常见取值包括 rawc23c40,其中 c23 是较为常用的一种设置;-n 表示每个视频需要抽取的帧数,例如 -n 20 表示从一个视频中均匀抽取 20 帧进行处理。

如果只以 FaceForensics++ 中的原始真实视频为例,并采用 c23 压缩版本,那么待处理的视频通常位于如下目录中:

bash 复制代码
data/FaceForensics++/original_sequences/youtube/c23/videos/

当脚本运行完成后,处理结果会保存在对应目录下新生成的 frames_retina 文件夹中。例如,如果处理的是 Originalc23 数据,那么输出目录通常为:

python 复制代码
data/FaceForensics++/original_sequences/youtube/c23/frames_retina/

接下来我们需要划分train、val和test数据集,我们按照官方的比例来划分,使用如下代码即可:

python 复制代码
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
将 frames_retina 组织为 fakefacecls 所需结构

用法:
  python setup_ffpp_dataset.py
  python setup_ffpp_dataset.py --data_root /path/to/data/FaceForensics++

输出:
  data/FaceForensics++/ffpp/
  ├── train.json, val.json, test.json
  ├── Origin/c23/larger_images/     -> symlinks to frames_retina
  ├── Deepfakes/c23/larger_images/
  ├── Face2Face/c23/larger_images/
  ├── FaceSwap/c23/larger_images/
  └── NeuralTextures/c23/larger_images/
"""
import argparse
import json
import os
from pathlib import Path

# FF++ 官方划分 (来自 https://github.com/ondyari/FaceForensics)
TRAIN_JSON = [
    ["071", "054"], ["087", "081"], ["881", "856"], ["187", "234"], ["645", "688"],
    ["754", "758"], ["811", "920"], ["710", "788"], ["628", "568"], ["312", "021"],
    ["950", "836"], ["059", "050"], ["524", "580"], ["751", "752"], ["918", "934"],
    ["604", "703"], ["296", "293"], ["518", "131"], ["536", "540"], ["969", "897"],
    ["372", "413"], ["357", "432"], ["809", "799"], ["092", "098"], ["302", "323"],
    ["981", "985"], ["512", "495"], ["088", "060"], ["795", "907"], ["535", "587"],
    ["297", "270"], ["838", "810"], ["850", "764"], ["476", "400"], ["268", "269"],
    ["033", "097"], ["226", "491"], ["784", "769"], ["195", "442"], ["678", "460"],
    ["320", "328"], ["451", "449"], ["409", "382"], ["556", "588"], ["027", "009"],
    ["196", "310"], ["241", "210"], ["295", "099"], ["043", "110"], ["753", "789"],
    ["716", "712"], ["508", "831"], ["005", "010"], ["276", "185"], ["498", "433"],
    ["294", "292"], ["105", "180"], ["984", "967"], ["318", "334"], ["356", "324"],
    ["344", "020"], ["289", "228"], ["022", "489"], ["137", "165"], ["095", "053"],
    ["999", "960"], ["481", "469"], ["534", "490"], ["543", "559"], ["150", "153"],
    ["598", "178"], ["475", "265"], ["671", "677"], ["204", "230"], ["863", "853"],
    ["561", "998"], ["163", "031"], ["655", "444"], ["038", "125"], ["735", "774"],
    ["184", "205"], ["499", "539"], ["717", "684"], ["878", "866"], ["127", "129"],
    ["286", "267"], ["032", "944"], ["681", "711"], ["236", "237"], ["989", "993"],
    ["537", "563"], ["814", "871"], ["509", "525"], ["221", "206"], ["808", "829"],
    ["696", "686"], ["431", "447"], ["737", "719"], ["609", "596"], ["408", "424"],
    ["976", "954"], ["156", "243"], ["434", "438"], ["627", "658"], ["025", "067"],
    ["635", "642"], ["523", "541"], ["572", "554"], ["215", "208"], ["651", "835"],
    ["975", "978"], ["792", "903"], ["931", "936"], ["846", "845"], ["899", "914"],
    ["209", "016"], ["398", "457"], ["797", "844"], ["360", "437"], ["738", "804"],
    ["694", "767"], ["790", "014"], ["657", "644"], ["374", "407"], ["728", "673"],
    ["193", "030"], ["876", "891"], ["553", "545"], ["331", "260"], ["873", "872"],
    ["109", "107"], ["121", "093"], ["143", "140"], ["778", "798"], ["983", "113"],
    ["504", "502"], ["709", "390"], ["940", "941"], ["894", "848"], ["311", "387"],
    ["562", "626"], ["330", "162"], ["112", "892"], ["765", "867"], ["124", "085"],
    ["665", "679"], ["414", "385"], ["555", "516"], ["072", "037"], ["086", "090"],
    ["202", "348"], ["341", "340"], ["333", "377"], ["082", "103"], ["569", "921"],
    ["750", "743"], ["211", "177"], ["770", "791"], ["329", "327"], ["613", "685"],
    ["007", "132"], ["304", "300"], ["860", "905"], ["986", "994"], ["378", "368"],
    ["761", "766"], ["232", "248"], ["136", "285"], ["601", "653"], ["693", "698"],
    ["359", "317"], ["246", "258"], ["500", "592"], ["776", "676"], ["262", "301"],
    ["307", "365"], ["600", "505"], ["833", "826"], ["361", "448"], ["473", "366"],
    ["885", "802"], ["277", "335"], ["667", "446"], ["522", "337"], ["018", "019"],
    ["430", "459"], ["886", "877"], ["456", "435"], ["239", "218"], ["771", "849"],
    ["065", "089"], ["654", "648"], ["151", "225"], ["152", "149"], ["229", "247"],
    ["624", "570"], ["290", "240"], ["011", "805"], ["461", "250"], ["251", "375"],
    ["639", "841"], ["602", "397"], ["028", "068"], ["338", "336"], ["964", "174"],
    ["782", "787"], ["478", "506"], ["313", "283"], ["659", "749"], ["690", "689"],
    ["893", "913"], ["197", "224"], ["253", "183"], ["373", "394"], ["803", "017"],
    ["305", "513"], ["051", "332"], ["238", "282"], ["621", "546"], ["401", "395"],
    ["510", "528"], ["410", "411"], ["049", "946"], ["663", "231"], ["477", "487"],
    ["252", "266"], ["952", "882"], ["315", "322"], ["216", "164"], ["061", "080"],
    ["603", "575"], ["828", "830"], ["723", "704"], ["870", "001"], ["201", "203"],
    ["652", "773"], ["108", "052"], ["272", "396"], ["040", "997"], ["988", "966"],
    ["281", "474"], ["077", "100"], ["146", "256"], ["972", "718"], ["303", "309"],
    ["582", "172"], ["222", "168"], ["884", "968"], ["217", "117"], ["118", "120"],
    ["242", "182"], ["858", "861"], ["101", "096"], ["697", "581"], ["763", "930"],
    ["839", "864"], ["542", "520"], ["122", "144"], ["687", "615"], ["544", "532"],
    ["721", "715"], ["179", "212"], ["591", "605"], ["275", "887"], ["996", "056"],
    ["825", "074"], ["530", "594"], ["757", "573"], ["611", "760"], ["189", "200"],
    ["392", "339"], ["734", "699"], ["977", "075"], ["879", "963"], ["910", "911"],
    ["889", "045"], ["962", "929"], ["515", "519"], ["062", "066"], ["937", "888"],
    ["199", "181"], ["785", "736"], ["079", "076"], ["155", "576"], ["748", "355"],
    ["819", "786"], ["577", "593"], ["464", "463"], ["439", "441"], ["574", "547"],
    ["747", "854"], ["403", "497"], ["965", "948"], ["726", "713"], ["943", "942"],
    ["160", "928"], ["496", "417"], ["700", "813"], ["756", "503"], ["213", "083"],
    ["039", "058"], ["781", "806"], ["620", "619"], ["351", "346"], ["959", "957"],
    ["264", "271"], ["006", "002"], ["391", "406"], ["631", "551"], ["501", "326"],
    ["412", "274"], ["641", "662"], ["111", "094"], ["166", "167"], ["130", "139"],
    ["938", "987"], ["055", "147"], ["990", "008"], ["013", "883"], ["614", "616"],
    ["772", "708"], ["840", "800"], ["415", "484"], ["287", "426"], ["680", "486"],
    ["057", "070"], ["590", "034"], ["194", "235"], ["291", "874"], ["902", "901"],
    ["343", "363"], ["279", "298"], ["393", "405"], ["674", "744"], ["244", "822"],
    ["133", "148"], ["636", "578"], ["637", "427"], ["041", "063"], ["869", "780"],
    ["733", "935"], ["259", "345"], ["069", "961"], ["783", "916"], ["191", "188"],
    ["526", "436"], ["123", "119"], ["207", "908"], ["796", "740"], ["815", "730"],
    ["173", "171"], ["383", "353"], ["458", "722"], ["533", "450"], ["618", "629"],
    ["646", "643"], ["531", "549"], ["428", "466"], ["859", "843"], ["692", "610"],
]

VAL_JSON = [
    ["720", "672"], ["939", "115"], ["284", "263"], ["402", "453"], ["820", "818"],
    ["762", "832"], ["834", "852"], ["922", "898"], ["104", "126"], ["106", "198"],
    ["159", "175"], ["416", "342"], ["857", "909"], ["599", "585"], ["443", "514"],
    ["566", "617"], ["472", "511"], ["325", "492"], ["816", "649"], ["583", "558"],
    ["933", "925"], ["419", "824"], ["465", "482"], ["565", "589"], ["261", "254"],
    ["992", "980"], ["157", "245"], ["571", "746"], ["947", "951"], ["926", "900"],
    ["493", "538"], ["468", "470"], ["915", "895"], ["362", "354"], ["440", "364"],
    ["640", "638"], ["827", "817"], ["793", "768"], ["837", "890"], ["004", "982"],
    ["192", "134"], ["745", "777"], ["299", "145"], ["742", "775"], ["586", "223"],
    ["483", "370"], ["779", "794"], ["971", "564"], ["273", "807"], ["991", "064"],
    ["664", "668"], ["823", "584"], ["656", "666"], ["557", "560"], ["471", "455"],
    ["042", "084"], ["979", "875"], ["316", "369"], ["091", "116"], ["023", "923"],
    ["702", "612"], ["904", "046"], ["647", "622"], ["958", "956"], ["606", "567"],
    ["632", "548"], ["927", "912"], ["350", "349"], ["595", "597"], ["727", "729"],
]

TEST_JSON = [
    ["953", "974"], ["012", "026"], ["078", "955"], ["623", "630"], ["919", "015"],
    ["367", "371"], ["847", "906"], ["529", "633"], ["418", "507"], ["227", "169"],
    ["389", "480"], ["821", "812"], ["670", "661"], ["158", "379"], ["423", "421"],
    ["352", "319"], ["579", "701"], ["488", "399"], ["695", "422"], ["288", "321"],
    ["705", "707"], ["306", "278"], ["865", "739"], ["995", "233"], ["755", "759"],
    ["467", "462"], ["314", "347"], ["741", "731"], ["970", "973"], ["634", "660"],
    ["494", "445"], ["706", "479"], ["186", "170"], ["176", "190"], ["380", "358"],
    ["214", "255"], ["454", "527"], ["425", "485"], ["388", "308"], ["384", "932"],
    ["035", "036"], ["257", "420"], ["924", "917"], ["114", "102"], ["732", "691"],
    ["550", "452"], ["280", "249"], ["842", "714"], ["625", "650"], ["024", "073"],
    ["044", "945"], ["896", "128"], ["862", "047"], ["607", "683"], ["517", "521"],
    ["682", "669"], ["138", "142"], ["552", "851"], ["376", "381"], ["000", "003"],
    ["048", "029"], ["724", "725"], ["608", "675"], ["386", "154"], ["220", "219"],
    ["801", "855"], ["161", "141"], ["949", "868"], ["880", "135"], ["429", "404"],
]

# 路径映射: (method, codec) -> (frames_retina 相对路径)
ORIGIN_FRAMES = "original_sequences/youtube/{codec}/frames_retina"
MANIPULATED_FRAMES = "manipulated_sequences/{method}/{codec}/frames_retina"

METHODS = ["Deepfakes", "Face2Face", "FaceSwap", "NeuralTextures", "FaceShifter"]  # 可选 DeepFakeDetection

def main():
    root = Path(__file__).resolve().parent / "FaceForensics++"
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", default=str(root), help="FaceForensics++ 根目录")
    parser.add_argument("--codec", default="c23")
    parser.add_argument("--methods", nargs="+", default=METHODS)
    args = parser.parse_args()

    data_root = Path(args.data_root)
    codec = args.codec
    ffpp = data_root / "ffpp"
    ffpp.mkdir(parents=True, exist_ok=True)

    # 1. 保存 JSON
    for name, pairs in [("train", TRAIN_JSON), ("val", VAL_JSON), ("test", TEST_JSON)]:
        f = ffpp / f"{name}.json"
        with open(f, "w") as fp:
            json.dump(pairs, fp, indent=2)
        print(f"  {f}")

    # 2. Origin: larger_images/{id} -> symlink to frames_retina/xxx
    origin_frames = data_root / ORIGIN_FRAMES.format(codec=codec)
    origin_larger = ffpp / "Origin" / codec / "larger_images"
    origin_larger.mkdir(parents=True, exist_ok=True)
    if origin_frames.exists():
        for vid in sorted(origin_frames.iterdir()):
            if vid.is_dir():
                dst = origin_larger / vid.name
                if not dst.exists():
                    dst.symlink_to(vid.resolve())
        print(f"  Origin: {origin_larger} ({len(list(origin_larger.iterdir()))} videos)")
    else:
        print(f"  [skip] Origin {origin_frames} not found")

    # 3. Manipulated: larger_images/{id1_id2} -> symlink to frames_retina/xxx
    for method in args.methods:
        man_frames = data_root / MANIPULATED_FRAMES.format(method=method, codec=codec)
        man_larger = ffpp / method / codec / "larger_images"
        man_larger.mkdir(parents=True, exist_ok=True)
        if man_frames.exists():
            n = 0
            for vid in sorted(man_frames.iterdir()):
                if vid.is_dir():
                    dst = man_larger / vid.name
                    if not dst.exists():
                        dst.symlink_to(vid.resolve())
                    n += 1
            print(f"  {method}: {man_larger} ({n} videos)")
        else:
            print(f"  [skip] {method} {man_frames} not found")

    print(f"\n完成: ffpp 目录 -> {ffpp}")
    print("\n使用方式:")
    print("  1. fakefacecls: export FFPP_ROOT=" + str(ffpp.resolve()))
    print("  2. multiple-attention: 在 datasets/data.py 中设置 ffpproot = '" + str(ffpp.resolve()) + "/'")

if __name__ == "__main__":
    main()

三、人脸分类网络

我们接下来直接使用Timm库来验证CNN和Transformer作为Backbone对人脸伪造分类的识别性能,我们将支持两种分类方式,分别是二分类和五分类,二分类即单纯的True/False,五分类则在正确区分的基础上额外实现分类人脸伪造的方式

所有代码已上传至GitHub:https://github.com/xiongqi123123/fakefaceclsnet

数据集加载及数据增强代码如下:

python 复制代码
import os
import random
import torch
import cv2
from torch.utils.data import Dataset
import albumentations as A
from albumentations import Compose

from .augmentations import augmentations
from . import data

class DeepfakeDataset(Dataset):
    def __init__(
        self,
        phase='train',
        datalabel='',
        resize=(224, 224),
        imgs_per_video=30,
        min_frames=0,
        normalize=None,
        frame_interval=10,
        max_frames=300,
        augment='augment0',
    ):
        assert phase in ['train', 'val', 'test']
        normalize = normalize or dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self.datalabel = datalabel
        self.phase = phase
        self.imgs_per_video = imgs_per_video
        self.frame_interval = frame_interval
        self.epoch = 0
        self.max_frames = max_frames
        self.min_frames = min_frames if min_frames else max_frames * 0.3
        self.aug = augmentations.get(augment, augmentations['augment0'])
        self.resize = resize
        self.trans = Compose([
            A.Resize(resize[0], resize[1]),  # 小图(如19x14)需先 resize,CenterCrop 会报错
            A.Normalize(mean=normalize['mean'], std=normalize['std']),
            A.ToTensorV2(),
        ])
        self.dataset = self._build_dataset()
        self._frame_cache = {}  # 缓存 os.listdir,避免每帧重复读目录

    def _build_dataset(self):
        if isinstance(self.datalabel, (list, tuple)):
            return self.datalabel
        if 'ff-5' in self.datalabel:
            codec = self.datalabel.split('-')[2]
            out = []
            for idx, tag in enumerate(['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face']):
                for item in data.FF_dataset(tag, codec, self.phase):
                    out.append([item[0], idx])
            return out
        if 'ff-all' in self.datalabel:
            codec = self.datalabel.split('-')[2]
            out = []
            for tag in ['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face']:
                out.extend(data.FF_dataset(tag, codec, self.phase))
            if self.phase != 'test':
                out = data.make_balance(out)
            return out
        if 'ff' in self.datalabel:
            parts = self.datalabel.split('-')
            codec = parts[2]
            tag = parts[1]
            return data.FF_dataset(tag, codec, self.phase) + data.FF_dataset('Origin', codec, self.phase)
        if 'celeb' in self.datalabel:
            return data.Celeb_test
        if 'deeper' in self.datalabel:
            codec = self.datalabel.split('-')[1]
            return data.deeperforensics_dataset(self.phase) + data.FF_dataset('Origin', codec, self.phase)
        if 'dfdc' in self.datalabel:
            return data.dfdc_dataset(self.phase)
        raise ValueError(f'Unknown datalabel: {self.datalabel}')

    def next_epoch(self):
        self.epoch += 1

    def __getitem__(self, item):
        for _ in range(len(self.dataset)):  # 避免无限递归
            try:
                vid = self.dataset[item // self.imgs_per_video]
                vid_path = vid[0]
                if vid_path not in self._frame_cache:
                    self._frame_cache[vid_path] = sorted(os.listdir(vid_path))
                vd = self._frame_cache[vid_path]
                if len(vd) < self.min_frames:
                    raise ValueError(f"frames {len(vd)} < min_frames {self.min_frames}")
                idx = (item % self.imgs_per_video * self.frame_interval + self.epoch) % min(len(vd), self.max_frames)
                fname = vd[idx]
                img = cv2.imread(os.path.join(vid[0], fname))
                if img is None:
                    raise ValueError(f"cv2.imread failed: {fname}")
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                if self.phase == 'train':
                    img = self.aug(image=img)['image']
                return self.trans(image=img)['image'], vid[1]
            except Exception as e:
                if os.environ.get('DEBUG_DATASET') == '1' and not getattr(self, '_debug_printed', False):
                    import traceback
                    vp = self.dataset[item // self.imgs_per_video][0] if item < len(self) else '?'
                    print(f'[DEBUG] item={item} path={vp} err={e}')
                    traceback.print_exc()
                    self._debug_printed = True  # 只打印第一次
                if self.phase == 'test':
                    return torch.zeros(3, self.resize[0], self.resize[1]), -1
                item = (item + self.imgs_per_video) % len(self)
        return torch.zeros(3, self.resize[0], self.resize[1]), -1  # 全部失败时返回占位

    def __len__(self):
        return len(self.dataset) * self.imgs_per_video
python 复制代码
import os
import json
import random

# 数据根目录:FFPP_ROOT 或默认 FFDeepFake/data/FaceForensics++/ffpp
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_FFDEEPFAKE_ROOT = os.path.dirname(os.path.dirname(_SCRIPT_DIR))  # fakefacecls/ -> FFDeepFake
_FFDEEPFAKE_ROOT = os.path.dirname(_FFDEEPFAKE_ROOT)  # FFDeepFake
_data_root = os.path.join(_FFDEEPFAKE_ROOT, 'data')
_DEFAULT_FFPP = os.path.join(_data_root, 'FaceForensics++', 'ffpp')

ffpproot = os.environ.get('FFPP_ROOT', _DEFAULT_FFPP)
if ffpproot and not ffpproot.endswith(os.sep):
    ffpproot += os.sep
dfdcroot = os.path.join(_data_root, 'dfdc')
celebroot = os.path.join(_data_root, 'celebDF')
deeperforensics_root = os.path.join(_data_root, 'deeper')

def load_json(name):
    with open(name) as f:
        return json.load(f)

def FF_dataset(tag='Origin', codec='c0', part='train'):
    assert tag in ['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face', 'FaceShifter']
    assert codec in ['c0', 'c23', 'c40', 'all']
    assert part in ['train', 'val', 'test', 'all']
    if part == 'all':
        return FF_dataset(tag, codec, 'train') + FF_dataset(tag, codec, 'val') + FF_dataset(tag, codec, 'test')
    if codec == 'all':
        return FF_dataset(tag, 'c0', part) + FF_dataset(tag, 'c23', part) + FF_dataset(tag, 'c40', part)
    path = os.path.join(ffpproot, tag, codec, 'larger_images')
    metafile = load_json(os.path.join(ffpproot, part + '.json'))
    files = []
    if tag == 'Origin':
        for i in metafile:
            files.append([os.path.join(path, i[0]), 0])
            files.append([os.path.join(path, i[1]), 0])
    else:
        for i in metafile:
            files.append([os.path.join(path, i[0] + '_' + i[1]), 1])
            files.append([os.path.join(path, i[1] + '_' + i[0]), 1])
    return files

def make_balance(data):
    tr = [x for x in data if x[1] == 0]
    tf = [x for x in data if x[1] == 1]
    if len(tr) > len(tf):
        tr, tf = tf, tr
    rate = len(tf) // len(tr)
    res = len(tf) - rate * len(tr)
    tr = tr * rate + random.sample(tr, res)
    return tr + tf

def dfdc_dataset(part='train'):
    assert part in ['train', 'val', 'test']
    lf = load_json(os.path.join(dfdcroot, 'DFDC.json'))
    if part == 'train':
        path = os.path.join(dfdcroot, 'dfdc')
        files = make_balance(lf['train'])
    elif part == 'test':
        path = os.path.join(dfdcroot, 'dfdc-test')
        files = lf['test']
    else:
        path = os.path.join(dfdcroot, 'dfdc-val')
        files = lf['val']
    return [[os.path.join(path, i[0]), i[1]] for i in files]

def deeperforensics_dataset(part='train'):
    a = os.listdir(deeperforensics_root)
    d = {i.split('_')[0]: i for i in a}
    metafile = load_json(os.path.join(ffpproot, part + '.json'))
    files = []
    for i in metafile:
        p = os.path.join(deeperforensics_root, d[i[0]])
        files.append([p, 1])
        p = os.path.join(deeperforensics_root, d[i[1]])
        files.append([p, 1])
    return files
try:
    Celeb_test = list(map(lambda x: [os.path.join(celebroot, x[0]), 1 - x[1]], load_json(os.path.join(celebroot, 'celeb.json'))))
except Exception:
    Celeb_test = []
python 复制代码
import albumentations as A
augment0 = A.Compose([A.HorizontalFlip()], p=1)
augment1 = A.Compose([
    A.HorizontalFlip(),
    A.HueSaturationValue(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
], p=1)
augment2 = A.Compose([
    A.HorizontalFlip(),
    A.HueSaturationValue(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.OneOf([A.GaussNoise()], p=0.3),
    A.OneOf([
        A.MotionBlur(),
        A.GaussianBlur(),
        A.ImageCompression(quality_range=(65, 80)),
    ], p=0.3),
    A.ToGray(p=0.1),
], p=1)

augmentations = {'augment0': augment0, 'augment1': augment1, 'augment2': augment2}

网络直接使用Timm的预置模型:

python 复制代码
####### __init__.py
from .cnn import build_cnn_model

def build_model(backbone='resnet50', num_classes=2, pretrained=True, dropout=0.3, **kwargs):
    """任意 timm 模型名均可,如 resnet50, vit_base_patch16_224, deit_small_patch16_224"""
    return build_cnn_model(backbone, num_classes, pretrained, dropout, **kwargs)
    
####### backbone.py

"""
基于 timm 的 backbone + 分类头
支持 CNN (resnet50, efficientnet_b0, ...) 和 ViT (vit_base_patch16_224, deit_base_patch16_224, ...)
"""
import torch.nn as nn
import timm

def build_cnn_model(backbone='resnet50', num_classes=2, pretrained=True, dropout=0.3, **kwargs):
    """
    Args:
        backbone: timm 模型名,如 resnet50, efficientnet_b0, convnext_tiny
        num_classes: 2 (真/假) 或 5 (Origin/Deepfakes/NeuralTextures/FaceSwap/Face2Face)
        pretrained: 是否加载 ImageNet 预训练
        dropout: 分类头 dropout
    """
    return CNNClassifier(backbone, num_classes, pretrained, dropout, **kwargs)

class CNNClassifier(nn.Module):
    """timm backbone + 可替换分类头"""

    def __init__(self, backbone='resnet50', num_classes=2, pretrained=True, dropout=0.3, in_chans=3, **kwargs):
        super().__init__()
        self.num_classes = num_classes
        weights = 'imagenet' if pretrained else None
        self.backbone = timm.create_model(
            backbone,
            pretrained=weights,
            num_classes=0,  # 移除原分类头
            global_pool='avg',
            in_chans=in_chans,
            **kwargs
        )
        feat_dim = self.backbone.num_features
        self.head = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(feat_dim, num_classes),
        )

    def forward(self, x):
        feat = self.backbone(x)
        return self.head(feat)

然后就是训练的代码

python 复制代码
"""
训练与验证脚本
用法:
  python train.py
  python train.py --backbone resnet50 --num_classes 5 --datalabel ff-5-c23
"""
import argparse
import logging
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from config import TrainConfig
from datasets import DeepfakeDataset
from models import build_model

def get_args():
    p = argparse.ArgumentParser()
    p.add_argument('--backbone', default='resnet50', help='timm backbone')
    p.add_argument('--num_classes', type=int, default=2, choices=[2, 5])
    p.add_argument('--datalabel', default='ff-all-c23', help='ff-all-c23 | ff-5-c23')
    p.add_argument('--epochs', type=int, default=20)
    p.add_argument('--batch_size', type=int, default=64)
    p.add_argument('--lr', type=float, default=1e-5)
    p.add_argument('--name', default='')
    p.add_argument('--no_pretrained', action='store_true')
    p.add_argument('--resume', default='')
    p.add_argument('--save_every', type=int, default=5, help='每隔多少轮保存一次 ckpt')
    p.add_argument('--log_interval', type=int, default=100, help='每隔多少 batch 打印一次 log')
    p.add_argument('--val_every_steps', type=int, default=500, help='每隔多少 step 在 val 上验证一次,0=仅每 epoch 结束')
    p.add_argument('--test_every_epoch', action='store_true', help='每个 epoch 结束在 test 上评估(仅监控,不参与选 best)')
    p.add_argument('--workers', type=int, default=0, help='DataLoader workers,0=用 config 默认')
    p.add_argument('--no_amp', action='store_true', help='禁用混合精度')
    return p.parse_args()

def main():
    args = get_args()
    cfg = TrainConfig(
        backbone=args.backbone,
        num_classes=args.num_classes,
        pretrained=not args.no_pretrained,
        datalabel=args.datalabel,
        epochs=args.epochs,
        batch_size=args.batch_size,
        lr=args.lr,
        name=args.name if args.name else None,
    )

    run_dir = os.path.join('runs', cfg.name)
    os.makedirs(run_dir, exist_ok=True)
    log_path = os.path.join(run_dir, 'train.log')
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[logging.FileHandler(log_path), logging.StreamHandler()],
        force=True,
    )
    logging.info(f'Run dir: {run_dir}')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = build_model(
        backbone=cfg.backbone,
        num_classes=cfg.num_classes,
        pretrained=cfg.pretrained,
        dropout=cfg.dropout,
    ).to(device)

    workers = args.workers if args.workers > 0 else cfg.workers
    use_amp = not args.no_amp and torch.cuda.is_available()
    scaler = torch.amp.GradScaler('cuda') if use_amp else None
    if use_amp:
        logging.info('Using AMP (mixed precision)')

    train_ds = DeepfakeDataset(**cfg.dataset_kwargs('train'))
    val_ds = DeepfakeDataset(**cfg.dataset_kwargs('val'))
    test_ds = DeepfakeDataset(**cfg.dataset_kwargs('test'))
    dl_kw = dict(batch_size=cfg.batch_size, pin_memory=torch.cuda.is_available())
    if workers > 0:
        dl_kw.update(num_workers=workers, persistent_workers=True, prefetch_factor=4)
    else:
        dl_kw['num_workers'] = 0
    train_loader = DataLoader(train_ds, shuffle=True, **dl_kw)
    val_loader = DataLoader(val_ds, shuffle=False, **dl_kw)
    test_loader = DataLoader(test_ds, shuffle=False, **dl_kw)

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)
    start_epoch = 0
    best_acc = 0.0

    if args.resume:
        ckpt = torch.load(args.resume, map_location=device)
        model.load_state_dict(ckpt.get('model', ckpt), strict=False)
        start_epoch = ckpt.get('epoch', 0) + 1
        best_acc = ckpt.get('best_acc', 0.0)
        if 'optimizer' in ckpt:
            optimizer.load_state_dict(ckpt['optimizer'])
        if 'scheduler' in ckpt:
            scheduler.load_state_dict(ckpt['scheduler'])
        logging.info(f'Resumed from epoch {start_epoch}, best_acc={best_acc:.4f}')

    global_step = start_epoch * len(train_loader)
    for epoch in range(start_epoch, cfg.epochs):
        train_ds.next_epoch()
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, device,
            log_interval=args.log_interval,
            scaler=scaler,
            val_loader=val_loader if args.val_every_steps > 0 else None,
            val_every_steps=args.val_every_steps,
            global_step=global_step,
        )
        global_step += len(train_loader)
        val_loss, val_acc = validate(model, val_loader, device)
        scheduler.step()

        logging.info(f'E{epoch} train loss={train_loss:.4f} acc={train_acc:.4f} | val loss={val_loss:.4f} acc={val_acc:.4f}')

        ckpt = {
            'model': model.state_dict(), 'epoch': epoch, 'best_acc': best_acc,
            'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(),
        }
        torch.save(ckpt, os.path.join(run_dir, 'latest.pth'))

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({'model': model.state_dict(), 'epoch': epoch, 'best_acc': best_acc}, os.path.join(run_dir, 'best.pth'))
            logging.info(f'  -> best acc={best_acc:.4f} saved')

        if args.save_every > 0 and (epoch + 1) % args.save_every == 0:
            torch.save({'model': model.state_dict(), 'epoch': epoch}, os.path.join(run_dir, f'ep{epoch}.pth'))

        if args.test_every_epoch:
            test_loss, test_acc = validate(model, test_loader, device, desc='test')
            logging.info(f'E{epoch} test loss={test_loss:.4f} acc={test_acc:.4f}')

    # 训练结束:用 best 在 test 上做最终评估(无 best 则用 latest)
    for ckpt_name in ('best.pth', 'latest.pth'):
        ckpt_path = os.path.join(run_dir, ckpt_name)
        if os.path.exists(ckpt_path):
            ckpt = torch.load(ckpt_path, map_location=device)
            model.load_state_dict(ckpt.get('model', ckpt), strict=False)
            test_loss, test_acc = validate(model, test_loader, device, desc='test')
            logging.info(f'[FINAL] {ckpt_name} on test: loss={test_loss:.4f} acc={test_acc:.4f}')
            break

def train_epoch(model, loader, optimizer, device, log_interval=50, scaler=None,
                val_loader=None, val_every_steps=0, global_step=0):
    model.train()
    total_loss, total_acc, n = 0.0, 0.0, 0
    pbar = tqdm(loader, desc='train', leave=False)
    for i, (x, y) in enumerate(pbar):
        step = global_step + i + 1
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad()
        if scaler is not None:
            with torch.amp.autocast('cuda'):
                logits = model(x)
                loss = F.cross_entropy(logits, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optimizer.step()
        acc = (logits.argmax(1) == y).float().mean().item()
        total_loss += loss.item() * x.size(0)
        total_acc += acc * x.size(0)
        n += x.size(0)
        pbar.set_postfix(loss=f'{loss.item():.4f}', acc=f'{acc:.4f}')
        if log_interval > 0 and (i + 1) % log_interval == 0:
            logging.info(f'  batch {i+1}/{len(loader)} loss={loss.item():.4f} acc={acc:.4f}')
        if val_loader is not None and val_every_steps > 0 and step % val_every_steps == 0:
            val_loss, val_acc = validate(model, val_loader, device)
            logging.info(f'  [step {step}] val loss={val_loss:.4f} acc={val_acc:.4f}')
            model.train()
    return total_loss / n, total_acc / n

def validate(model, loader, device, desc='val'):
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0
    with torch.no_grad():
        pbar = tqdm(loader, desc=desc, leave=False)
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            # 过滤无效标签(dataset 加载失败时返回 -1)
            valid = y >= 0
            if valid.sum() == 0:
                continue
            x, y = x[valid], y[valid]
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            acc = (logits.argmax(1) == y).float().mean().item()
            total_loss += loss.item() * x.size(0)
            total_acc += acc * x.size(0)
            n += x.size(0)
            pbar.set_postfix(loss=f'{loss.item():.4f}', acc=f'{acc:.4f}')
    return total_loss / n, total_acc / n if n > 0 else (0.0, 0.0)

if __name__ == '__main__':
    main()

可以看到训练的效果非常的好,基本一个Epoch就可以在Test验证集上达到0.8以上的正确率,且可以观察发现Transformer作为Backbone的效果远比CNN的效果好

相关推荐
飞Link3 小时前
深度解析 LSTM 神经网络架构与实战指南
人工智能·深度学习·神经网络·lstm
love530love3 小时前
Windows 11 源码编译 vLLM 0.16 完全指南(RTX 3090 / CUDA 12.8 / PyTorch 2.7.1)
人工智能·pytorch·windows·python·深度学习·vllm·vs 2022
放下华子我只抽RuiKe53 小时前
机器学习全景指南-基石篇——预测连续值的线性回归
人工智能·深度学习·神经网络·算法·机器学习·自然语言处理·线性回归
带娃的IT创业者3 小时前
专栏系列3.3《时序关联学习:r=0.733 背后的记忆形成》
人工智能·深度学习·神经网络·时序学习·nct·神经调质
快乐非自愿3 小时前
NIO核心原理深度解析:非阻塞I/O的块式设计与高并发实现逻辑
人工智能·深度学习·nio
八月瓜科技3 小时前
擎策·知海全球专利数据库 技术赋能检索 让科技创新少走弯路
大数据·数据库·人工智能·科技·深度学习·娱乐
兴通扫码设备3 小时前
ocr工业场景适配升级:深圳市兴通物联XTC8501智能相机接口与环境适应性技术解析
数据库·人工智能·深度学习·数码相机·计算机视觉
小陈phd3 小时前
多模态大模型学习笔记(十六)——Transformer 学习之 Decoder Only
人工智能·笔记·深度学习·学习·自然语言处理·transformer
Takoony3 小时前
OpenClaw 深度拆解:下一代自主智能体架构全面解析
人工智能·深度学习·算法·机器学习·架构·openclaw