作者:SkyXZ
CSDN:SkyXZ~-CSDN博客
博客园:SkyXZ - 博客园
一、获取数据集
FaceForensics++ 是一个取证数据集,由1000段原始视频序列组成,这些视频通过四种自动人脸操纵方法进行处理:Deepfakes、Face2Face、FaceSwap 和 NeuralTextures。数据来自 977 段 YouTube 视频,所有视频中都包含一张可跟踪的、主要为正面且没有遮挡的人脸,使得自动篡改方法能够生成逼真的伪造视频。同时,由于该数据集提供了二值掩码,这些数据可以用于图像和视频分类以及分割。此外,官方还提供了 1000 个 Deepfakes 模型,用于生成和扩充新数据。
FaceForensics++数据集无法直接下载,需要按照要求填写谷歌表单来申请获取https://docs.google.com/forms/d/e/1FAIpQLSdRRR3L5zAv6tQ_CKxmK4W96tAab_pfBu2EKAgQbeDVhmXagg/viewform

等待几天之后会收到如下邮件,里面会附上数据集的下载Code,直接使用下载脚本下载即可获取:

python
#!/usr/bin/env python
""" Downloads FaceForensics++ and Deep Fake Detection public data release
Example usage:
see -h or https://github.com/ondyari/FaceForensics
"""
# -*- coding: utf-8 -*-
import argparse
import os
import urllib
import urllib.request
import tempfile
import time
import sys
import json
import random
from tqdm import tqdm
from os.path import join
# URLs and filenames
FILELIST_URL = 'misc/filelist.json'
DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json'
DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',]
# Parameters
DATASETS = {
'original_youtube_videos': 'misc/downloaded_youtube_videos.zip',
'original_youtube_videos_info': 'misc/downloaded_youtube_videos_info.zip',
'original': 'original_sequences/youtube',
'DeepFakeDetection_original': 'original_sequences/actors',
'Deepfakes': 'manipulated_sequences/Deepfakes',
'DeepFakeDetection': 'manipulated_sequences/DeepFakeDetection',
'Face2Face': 'manipulated_sequences/Face2Face',
'FaceShifter': 'manipulated_sequences/FaceShifter',
'FaceSwap': 'manipulated_sequences/FaceSwap',
'NeuralTextures': 'manipulated_sequences/NeuralTextures'
}
ALL_DATASETS = ['original', 'DeepFakeDetection_original', 'Deepfakes',
'DeepFakeDetection', 'Face2Face', 'FaceShifter', 'FaceSwap',
'NeuralTextures']
COMPRESSION = ['raw', 'c23', 'c40']
TYPE = ['videos', 'masks', 'models']
SERVERS = ['EU', 'EU2', 'CA']
def parse_args():
parser = argparse.ArgumentParser(
description='Downloads FaceForensics v2 public data release.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('output_path', type=str, help='Output directory.')
parser.add_argument('-d', '--dataset', type=str, default='all',
help='Which dataset to download, either pristine or '
'manipulated data or the downloaded youtube '
'videos.',
choices=list(DATASETS.keys()) + ['all']
)
parser.add_argument('-c', '--compression', type=str, default='raw',
help='Which compression degree. All videos '
'have been generated with h264 with a varying '
'codec. Raw (c0) videos are lossless compressed.',
choices=COMPRESSION
)
parser.add_argument('-t', '--type', type=str, default='videos',
help='Which file type, i.e. videos, masks, for our '
'manipulation methods, models, for Deepfakes.',
choices=TYPE
)
parser.add_argument('-n', '--num_videos', type=int, default=None,
help='Select a number of videos number to '
"download if you don't want to download the full"
' dataset.')
parser.add_argument('--server', type=str, default='EU',
help='Server to download the data from. If you '
'encounter a slow download speed, consider '
'changing the server.',
choices=SERVERS
)
args = parser.parse_args()
# URLs
server = args.server
if server == 'EU':
server_url = 'http://canis.vc.in.tum.de:8100/'
elif server == 'EU2':
server_url = 'http://kaldir.vc.in.tum.de/faceforensics/'
elif server == 'CA':
server_url = 'http://falas.cmpt.sfu.ca:8100/'
else:
raise Exception('Wrong server name. Choices: {}'.format(str(SERVERS)))
args.tos_url = server_url + 'webpage/FaceForensics_TOS.pdf'
args.base_url = server_url + 'v3/'
args.deepfakes_model_url = server_url + 'v3/manipulated_sequences/' + \
'Deepfakes/models/'
return args
def download_files(filenames, base_url, output_path, report_progress=True):
os.makedirs(output_path, exist_ok=True)
if report_progress:
filenames = tqdm(filenames)
for filename in filenames:
download_file(base_url + filename, join(output_path, filename))
def reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\rProgress: %d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()
def download_file(url, out_file, report_progress=False):
out_dir = os.path.dirname(out_file)
if not os.path.isfile(out_file):
fh, out_file_tmp = tempfile.mkstemp(dir=out_dir)
f = os.fdopen(fh, 'w')
f.close()
if report_progress:
urllib.request.urlretrieve(url, out_file_tmp,
reporthook=reporthook)
else:
urllib.request.urlretrieve(url, out_file_tmp)
os.rename(out_file_tmp, out_file)
else:
tqdm.write('WARNING: skipping download of existing file ' + out_file)
def main(args):
# TOS
print('By pressing any key to continue you confirm that you have agreed '\
'to the FaceForensics terms of use as described at:')
print(args.tos_url)
print('***')
print('Press any key to continue, or CTRL-C to exit.')
_ = input('')
# Extract arguments
c_datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS
c_type = args.type
c_compression = args.compression
num_videos = args.num_videos
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)
# Check for special dataset cases
for dataset in c_datasets:
dataset_path = DATASETS[dataset]
# Special cases
if 'original_youtube_videos' in dataset:
# Here we download the original youtube videos zip file
print('Downloading original youtube videos.')
if not 'info' in dataset_path:
print('Please be patient, this may take a while (~40gb)')
suffix = ''
else:
suffix = 'info'
download_file(args.base_url + '/' + dataset_path,
out_file=join(output_path,
'downloaded_videos{}.zip'.format(
suffix)),
report_progress=True)
return
# Else: regular datasets
print('Downloading {} of dataset "{}"'.format(
c_type, dataset_path
))
# Get filelists and video lenghts list from server
if 'DeepFakeDetection' in dataset_path or 'actors' in dataset_path:
filepaths = json.loads(urllib.request.urlopen(args.base_url + '/' +
DEEPFEAKES_DETECTION_URL).read().decode("utf-8"))
if 'actors' in dataset_path:
filelist = filepaths['actors']
else:
filelist = filepaths['DeepFakesDetection']
elif 'original' in dataset_path:
# Load filelist from server
file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
FILELIST_URL).read().decode("utf-8"))
filelist = []
for pair in file_pairs:
filelist += pair
else:
# Load filelist from server
file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
FILELIST_URL).read().decode("utf-8"))
# Get filelist
filelist = []
for pair in file_pairs:
filelist.append('_'.join(pair))
if c_type != 'models':
filelist.append('_'.join(pair[::-1]))
# Maybe limit number of videos for download
if num_videos is not None and num_videos > 0:
print('Downloading the first {} videos'.format(num_videos))
filelist = filelist[:num_videos]
# Server and local paths
dataset_videos_url = args.base_url + '{}/{}/{}/'.format(
dataset_path, c_compression, c_type)
dataset_mask_url = args.base_url + '{}/{}/videos/'.format(
dataset_path, 'masks', c_type)
if c_type == 'videos':
dataset_output_path = join(output_path, dataset_path, c_compression,
c_type)
print('Output path: {}'.format(dataset_output_path))
filelist = [filename + '.mp4' for filename in filelist]
download_files(filelist, dataset_videos_url, dataset_output_path)
elif c_type == 'masks':
dataset_output_path = join(output_path, dataset_path, c_type,
'videos')
print('Output path: {}'.format(dataset_output_path))
if 'original' in dataset:
if args.dataset != 'all':
print('Only videos available for original data. Aborting.')
return
else:
print('Only videos available for original data. '
'Skipping original.\n')
continue
if 'FaceShifter' in dataset:
print('Masks not available for FaceShifter. Aborting.')
return
filelist = [filename + '.mp4' for filename in filelist]
download_files(filelist, dataset_mask_url, dataset_output_path)
# Else: models for deepfakes
else:
if dataset != 'Deepfakes' and c_type == 'models':
print('Models only available for Deepfakes. Aborting')
return
dataset_output_path = join(output_path, dataset_path, c_type)
print('Output path: {}'.format(dataset_output_path))
# Get Deepfakes models
for folder in tqdm(filelist):
folder_filelist = DEEPFAKES_MODEL_NAMES
# Folder paths
folder_base_url = args.deepfakes_model_url + folder + '/'
folder_dataset_output_path = join(dataset_output_path,
folder)
download_files(folder_filelist, folder_base_url,
folder_dataset_output_path,
report_progress=False) # already done
if __name__ == "__main__":
args = parse_args()
main(args)
接下来使用如下命令即可下载数据集
bash
python download-FaceForensics.py
<output path>
-d <dataset type, e.g., Face2Face, original or all>
-c <compression quality, e.g., c23 or raw>
-t <file type, e.g., videos, masks or models>
<output path> 表示数据集的保存路径,即下载后的 FaceForensics++ 或 DeepFakeDetection 数据将被存放的位置。例如,可以设置为当前项目下的 ./data/,也可以设置为单独的数据盘路径,如 /mnt/data2/qi.xiong/Dataset/FaceForensics/。下载脚本会在该目录下自动构建对应的数据集层级结构.
d 用于指定下载的数据类型(dataset type)。常见可选项包括 original、Face2Face、Deepfakes、FaceSwap、NeuralTextures、DeepFakeDetection 以及 all 等。其中,original 表示下载原始真实视频序列,通常对应 original_sequences/youtube;Face2Face、Deepfakes、FaceSwap 和 NeuralTextures 表示下载四种主要伪造方法生成的数据;DeepFakeDetection 表示下载 DeepFakeDetection 扩展数据;all 表示一次性下载全部可用数据。若仅用于常规 deepfake 检测实验,通常优先选择 original 与四种主流伪造类型。
c 用于指定压缩等级(compression quality)。常用选项为 raw、c23 和 c40。其中,raw 表示原始或无损压缩版本,数据体积最大,但保留了最完整的图像细节;c23 表示较高质量压缩版本,是目前较常见、也较平衡的一种设置,既能保留较好的视觉质量,又显著降低存储开销;c40 表示压缩更强、质量更低的数据版本,更适合做强压缩场景下的鲁棒性测试。实际使用中,如果只是复现主流实验或进行预处理,通常推荐优先下载 c23 视频版本。
t 用于指定文件类型(file type)。常见选项包括 videos、masks 和 models。其中,videos 表示下载视频文件,这是最常用的选项;masks 表示下载伪造区域的二值掩码,适用于伪造区域定位、分割或可解释性分析任务;models 主要与部分伪造方法相关,用于获取对应的生成模型文件。对于大多数 deepfake 分类或人脸抽帧任务,仅下载 videos 即可。
下载完成的数据集格式如下:
bash
(xq) qi.xiong@instance-ujccspas:/mnt/data2/qi.xiong/Dataset/FaceForensics$ tree -L 3
.
├── manipulated_sequences
│ ├── DeepFakeDetection
│ │ ├── c23
│ │ └── masks
│ ├── Deepfakes
│ │ ├── c23
│ │ └── masks
│ ├── Face2Face
│ │ ├── c23
│ │ └── masks
│ ├── FaceShifter
│ │ └── c23
│ ├── FaceSwap
│ │ └── c23
│ └── NeuralTextures
│ └── c23
└── original_sequences
├── actors
│ └── c23
└── youtube
└── c23
22 directories, 0 files
二、数据集预处理
我们前面下载得到的数据集仍然是视频格式,因此在正式用于 deepfake 检测之前,还需要先进行预处理。通常来说,这类任务不会直接将整段视频输入模型,而是先从视频中抽取若干具有代表性的帧,再从每一帧中提取对应的人脸区域。这样做一方面可以明显降低后续数据处理和模型训练的开销,另一方面也能让模型更聚焦于真正有用的面部伪造信息。FaceForensics++ 官方文档中也提到,通常更推荐先下载压缩后的视频,再自行完成帧提取。本文这里采用一种比较简化且实用的处理方式:从每个视频中均匀抽取固定数量的帧,然后使用 RetinaFace 对这些帧进行人脸检测,并将检测到的人脸区域裁剪保存。相比一些传统方法,RetinaFace 在检测精度和鲁棒性方面通常更有优势,尤其是在侧脸、光照变化较大或者人脸尺度变化明显的情况下,检测结果往往更加稳定。需要说明的是,本文这里的预处理目标比较明确,即只做人脸抽帧和人脸裁剪,不额外涉及关键点对齐、伪造区域掩码生成等更复杂的步骤,因此整个流程会更加清晰,也更适合作为 FaceForensics++ 数据预处理的基础版本。
bash
git clone https://github.com/ternaus/retinaface.git
cd retinaface
pip install -v -e .
我们配置好了retinaface之后,即可使用如下脚本继续转换:
python
from glob import glob
import os
import cv2
from tqdm import tqdm
import numpy as np
import argparse
from retinaface.pre_trained_models import get_model
import torch
def facecrop(model, org_path, save_path, num_frames=10):
cap_org = cv2.VideoCapture(org_path)
frame_count_org = int(cap_org.get(cv2.CAP_PROP_FRAME_COUNT))
if frame_count_org <= 0:
print(f"Invalid video: {org_path}")
cap_org.release()
return
frame_idxs = np.linspace(0, frame_count_org - 1, num_frames, endpoint=True, dtype=int)
frame_idxs = set(frame_idxs.tolist())
for cnt_frame in range(frame_count_org):
ret_org, frame_org = cap_org.read()
if not ret_org or frame_org is None:
continue
if cnt_frame not in frame_idxs:
continue
frame = cv2.cvtColor(frame_org, cv2.COLOR_BGR2RGB)
faces = model.predict_jsons(frame)
if len(faces) == 0:
continue
save_path_frames = os.path.join(
save_path, 'frames_retina', os.path.basename(org_path).replace('.mp4', '')
)
os.makedirs(save_path_frames, exist_ok=True)
for face_idx, face in enumerate(faces):
bbox = face.get('bbox', None)
if bbox is None or len(bbox) < 4:
continue
x0, y0, x1, y1 = map(int, bbox[:4])
x0 = max(0, x0)
y0 = max(0, y0)
x1 = min(frame_org.shape[1], x1)
y1 = min(frame_org.shape[0], y1)
if x1 <= x0 or y1 <= y0:
continue
cropped_face = frame_org[y0:y1, x0:x1]
face_image_path = os.path.join(
save_path_frames, f'frame_{cnt_frame}_face_{face_idx}.png'
)
cv2.imwrite(face_image_path, cropped_face)
cap_org.release()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-d',
dest='dataset',
choices=[
'Original',
'DeepFakeDetection_original',
'DeepFakeDetection',
'Deepfakes',
'Face2Face',
'FaceShifter',
'FaceSwap',
'NeuralTextures'
]
)
parser.add_argument('-c', dest='comp', choices=['raw', 'c23', 'c40'], default='raw')
parser.add_argument('-n', dest='num_frames', type=int, default=20)
args = parser.parse_args()
if args.dataset == 'Original':
dataset_path = 'data/FaceForensics++/original_sequences/youtube/{}/'.format(args.comp)
elif args.dataset == 'DeepFakeDetection_original':
dataset_path = 'data/FaceForensics++/original_sequences/actors/{}/'.format(args.comp)
elif args.dataset in ['DeepFakeDetection', 'FaceShifter', 'Face2Face', 'Deepfakes', 'FaceSwap', 'NeuralTextures']:
dataset_path = 'data/FaceForensics++/manipulated_sequences/{}/{}/'.format(args.dataset, args.comp)
else:
raise NotImplementedError
device = torch.device('cpu')
model = get_model("resnet50_2020-07-20", max_size=2048, device=device)
model.eval()
movies_path = dataset_path + 'videos/'
movies_path_list = sorted(glob(movies_path + '*.mp4'))
print("{} : videos are exist in {}".format(len(movies_path_list), args.dataset))
for i in tqdm(range(len(movies_path_list))):
facecrop(model, movies_path_list[i], save_path=dataset_path, num_frames=args.num_frames)
在具体使用时,我们主要关心三个参数:-d 用于指定要处理的数据子集,例如 Original 表示原始真实视频,Deepfakes、Face2Face、FaceSwap 和 NeuralTextures 表示不同伪造方法生成的视频;-c 用于指定压缩等级,常见取值包括 raw、c23 和 c40,其中 c23 是较为常用的一种设置;-n 表示每个视频需要抽取的帧数,例如 -n 20 表示从一个视频中均匀抽取 20 帧进行处理。
如果只以 FaceForensics++ 中的原始真实视频为例,并采用 c23 压缩版本,那么待处理的视频通常位于如下目录中:
bash
data/FaceForensics++/original_sequences/youtube/c23/videos/
当脚本运行完成后,处理结果会保存在对应目录下新生成的 frames_retina 文件夹中。例如,如果处理的是 Original 的 c23 数据,那么输出目录通常为:
python
data/FaceForensics++/original_sequences/youtube/c23/frames_retina/


接下来我们需要划分train、val和test数据集,我们按照官方的比例来划分,使用如下代码即可:
python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
将 frames_retina 组织为 fakefacecls 所需结构
用法:
python setup_ffpp_dataset.py
python setup_ffpp_dataset.py --data_root /path/to/data/FaceForensics++
输出:
data/FaceForensics++/ffpp/
├── train.json, val.json, test.json
├── Origin/c23/larger_images/ -> symlinks to frames_retina
├── Deepfakes/c23/larger_images/
├── Face2Face/c23/larger_images/
├── FaceSwap/c23/larger_images/
└── NeuralTextures/c23/larger_images/
"""
import argparse
import json
import os
from pathlib import Path
# FF++ 官方划分 (来自 https://github.com/ondyari/FaceForensics)
TRAIN_JSON = [
["071", "054"], ["087", "081"], ["881", "856"], ["187", "234"], ["645", "688"],
["754", "758"], ["811", "920"], ["710", "788"], ["628", "568"], ["312", "021"],
["950", "836"], ["059", "050"], ["524", "580"], ["751", "752"], ["918", "934"],
["604", "703"], ["296", "293"], ["518", "131"], ["536", "540"], ["969", "897"],
["372", "413"], ["357", "432"], ["809", "799"], ["092", "098"], ["302", "323"],
["981", "985"], ["512", "495"], ["088", "060"], ["795", "907"], ["535", "587"],
["297", "270"], ["838", "810"], ["850", "764"], ["476", "400"], ["268", "269"],
["033", "097"], ["226", "491"], ["784", "769"], ["195", "442"], ["678", "460"],
["320", "328"], ["451", "449"], ["409", "382"], ["556", "588"], ["027", "009"],
["196", "310"], ["241", "210"], ["295", "099"], ["043", "110"], ["753", "789"],
["716", "712"], ["508", "831"], ["005", "010"], ["276", "185"], ["498", "433"],
["294", "292"], ["105", "180"], ["984", "967"], ["318", "334"], ["356", "324"],
["344", "020"], ["289", "228"], ["022", "489"], ["137", "165"], ["095", "053"],
["999", "960"], ["481", "469"], ["534", "490"], ["543", "559"], ["150", "153"],
["598", "178"], ["475", "265"], ["671", "677"], ["204", "230"], ["863", "853"],
["561", "998"], ["163", "031"], ["655", "444"], ["038", "125"], ["735", "774"],
["184", "205"], ["499", "539"], ["717", "684"], ["878", "866"], ["127", "129"],
["286", "267"], ["032", "944"], ["681", "711"], ["236", "237"], ["989", "993"],
["537", "563"], ["814", "871"], ["509", "525"], ["221", "206"], ["808", "829"],
["696", "686"], ["431", "447"], ["737", "719"], ["609", "596"], ["408", "424"],
["976", "954"], ["156", "243"], ["434", "438"], ["627", "658"], ["025", "067"],
["635", "642"], ["523", "541"], ["572", "554"], ["215", "208"], ["651", "835"],
["975", "978"], ["792", "903"], ["931", "936"], ["846", "845"], ["899", "914"],
["209", "016"], ["398", "457"], ["797", "844"], ["360", "437"], ["738", "804"],
["694", "767"], ["790", "014"], ["657", "644"], ["374", "407"], ["728", "673"],
["193", "030"], ["876", "891"], ["553", "545"], ["331", "260"], ["873", "872"],
["109", "107"], ["121", "093"], ["143", "140"], ["778", "798"], ["983", "113"],
["504", "502"], ["709", "390"], ["940", "941"], ["894", "848"], ["311", "387"],
["562", "626"], ["330", "162"], ["112", "892"], ["765", "867"], ["124", "085"],
["665", "679"], ["414", "385"], ["555", "516"], ["072", "037"], ["086", "090"],
["202", "348"], ["341", "340"], ["333", "377"], ["082", "103"], ["569", "921"],
["750", "743"], ["211", "177"], ["770", "791"], ["329", "327"], ["613", "685"],
["007", "132"], ["304", "300"], ["860", "905"], ["986", "994"], ["378", "368"],
["761", "766"], ["232", "248"], ["136", "285"], ["601", "653"], ["693", "698"],
["359", "317"], ["246", "258"], ["500", "592"], ["776", "676"], ["262", "301"],
["307", "365"], ["600", "505"], ["833", "826"], ["361", "448"], ["473", "366"],
["885", "802"], ["277", "335"], ["667", "446"], ["522", "337"], ["018", "019"],
["430", "459"], ["886", "877"], ["456", "435"], ["239", "218"], ["771", "849"],
["065", "089"], ["654", "648"], ["151", "225"], ["152", "149"], ["229", "247"],
["624", "570"], ["290", "240"], ["011", "805"], ["461", "250"], ["251", "375"],
["639", "841"], ["602", "397"], ["028", "068"], ["338", "336"], ["964", "174"],
["782", "787"], ["478", "506"], ["313", "283"], ["659", "749"], ["690", "689"],
["893", "913"], ["197", "224"], ["253", "183"], ["373", "394"], ["803", "017"],
["305", "513"], ["051", "332"], ["238", "282"], ["621", "546"], ["401", "395"],
["510", "528"], ["410", "411"], ["049", "946"], ["663", "231"], ["477", "487"],
["252", "266"], ["952", "882"], ["315", "322"], ["216", "164"], ["061", "080"],
["603", "575"], ["828", "830"], ["723", "704"], ["870", "001"], ["201", "203"],
["652", "773"], ["108", "052"], ["272", "396"], ["040", "997"], ["988", "966"],
["281", "474"], ["077", "100"], ["146", "256"], ["972", "718"], ["303", "309"],
["582", "172"], ["222", "168"], ["884", "968"], ["217", "117"], ["118", "120"],
["242", "182"], ["858", "861"], ["101", "096"], ["697", "581"], ["763", "930"],
["839", "864"], ["542", "520"], ["122", "144"], ["687", "615"], ["544", "532"],
["721", "715"], ["179", "212"], ["591", "605"], ["275", "887"], ["996", "056"],
["825", "074"], ["530", "594"], ["757", "573"], ["611", "760"], ["189", "200"],
["392", "339"], ["734", "699"], ["977", "075"], ["879", "963"], ["910", "911"],
["889", "045"], ["962", "929"], ["515", "519"], ["062", "066"], ["937", "888"],
["199", "181"], ["785", "736"], ["079", "076"], ["155", "576"], ["748", "355"],
["819", "786"], ["577", "593"], ["464", "463"], ["439", "441"], ["574", "547"],
["747", "854"], ["403", "497"], ["965", "948"], ["726", "713"], ["943", "942"],
["160", "928"], ["496", "417"], ["700", "813"], ["756", "503"], ["213", "083"],
["039", "058"], ["781", "806"], ["620", "619"], ["351", "346"], ["959", "957"],
["264", "271"], ["006", "002"], ["391", "406"], ["631", "551"], ["501", "326"],
["412", "274"], ["641", "662"], ["111", "094"], ["166", "167"], ["130", "139"],
["938", "987"], ["055", "147"], ["990", "008"], ["013", "883"], ["614", "616"],
["772", "708"], ["840", "800"], ["415", "484"], ["287", "426"], ["680", "486"],
["057", "070"], ["590", "034"], ["194", "235"], ["291", "874"], ["902", "901"],
["343", "363"], ["279", "298"], ["393", "405"], ["674", "744"], ["244", "822"],
["133", "148"], ["636", "578"], ["637", "427"], ["041", "063"], ["869", "780"],
["733", "935"], ["259", "345"], ["069", "961"], ["783", "916"], ["191", "188"],
["526", "436"], ["123", "119"], ["207", "908"], ["796", "740"], ["815", "730"],
["173", "171"], ["383", "353"], ["458", "722"], ["533", "450"], ["618", "629"],
["646", "643"], ["531", "549"], ["428", "466"], ["859", "843"], ["692", "610"],
]
VAL_JSON = [
["720", "672"], ["939", "115"], ["284", "263"], ["402", "453"], ["820", "818"],
["762", "832"], ["834", "852"], ["922", "898"], ["104", "126"], ["106", "198"],
["159", "175"], ["416", "342"], ["857", "909"], ["599", "585"], ["443", "514"],
["566", "617"], ["472", "511"], ["325", "492"], ["816", "649"], ["583", "558"],
["933", "925"], ["419", "824"], ["465", "482"], ["565", "589"], ["261", "254"],
["992", "980"], ["157", "245"], ["571", "746"], ["947", "951"], ["926", "900"],
["493", "538"], ["468", "470"], ["915", "895"], ["362", "354"], ["440", "364"],
["640", "638"], ["827", "817"], ["793", "768"], ["837", "890"], ["004", "982"],
["192", "134"], ["745", "777"], ["299", "145"], ["742", "775"], ["586", "223"],
["483", "370"], ["779", "794"], ["971", "564"], ["273", "807"], ["991", "064"],
["664", "668"], ["823", "584"], ["656", "666"], ["557", "560"], ["471", "455"],
["042", "084"], ["979", "875"], ["316", "369"], ["091", "116"], ["023", "923"],
["702", "612"], ["904", "046"], ["647", "622"], ["958", "956"], ["606", "567"],
["632", "548"], ["927", "912"], ["350", "349"], ["595", "597"], ["727", "729"],
]
TEST_JSON = [
["953", "974"], ["012", "026"], ["078", "955"], ["623", "630"], ["919", "015"],
["367", "371"], ["847", "906"], ["529", "633"], ["418", "507"], ["227", "169"],
["389", "480"], ["821", "812"], ["670", "661"], ["158", "379"], ["423", "421"],
["352", "319"], ["579", "701"], ["488", "399"], ["695", "422"], ["288", "321"],
["705", "707"], ["306", "278"], ["865", "739"], ["995", "233"], ["755", "759"],
["467", "462"], ["314", "347"], ["741", "731"], ["970", "973"], ["634", "660"],
["494", "445"], ["706", "479"], ["186", "170"], ["176", "190"], ["380", "358"],
["214", "255"], ["454", "527"], ["425", "485"], ["388", "308"], ["384", "932"],
["035", "036"], ["257", "420"], ["924", "917"], ["114", "102"], ["732", "691"],
["550", "452"], ["280", "249"], ["842", "714"], ["625", "650"], ["024", "073"],
["044", "945"], ["896", "128"], ["862", "047"], ["607", "683"], ["517", "521"],
["682", "669"], ["138", "142"], ["552", "851"], ["376", "381"], ["000", "003"],
["048", "029"], ["724", "725"], ["608", "675"], ["386", "154"], ["220", "219"],
["801", "855"], ["161", "141"], ["949", "868"], ["880", "135"], ["429", "404"],
]
# 路径映射: (method, codec) -> (frames_retina 相对路径)
ORIGIN_FRAMES = "original_sequences/youtube/{codec}/frames_retina"
MANIPULATED_FRAMES = "manipulated_sequences/{method}/{codec}/frames_retina"
METHODS = ["Deepfakes", "Face2Face", "FaceSwap", "NeuralTextures", "FaceShifter"] # 可选 DeepFakeDetection
def main():
root = Path(__file__).resolve().parent / "FaceForensics++"
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", default=str(root), help="FaceForensics++ 根目录")
parser.add_argument("--codec", default="c23")
parser.add_argument("--methods", nargs="+", default=METHODS)
args = parser.parse_args()
data_root = Path(args.data_root)
codec = args.codec
ffpp = data_root / "ffpp"
ffpp.mkdir(parents=True, exist_ok=True)
# 1. 保存 JSON
for name, pairs in [("train", TRAIN_JSON), ("val", VAL_JSON), ("test", TEST_JSON)]:
f = ffpp / f"{name}.json"
with open(f, "w") as fp:
json.dump(pairs, fp, indent=2)
print(f" {f}")
# 2. Origin: larger_images/{id} -> symlink to frames_retina/xxx
origin_frames = data_root / ORIGIN_FRAMES.format(codec=codec)
origin_larger = ffpp / "Origin" / codec / "larger_images"
origin_larger.mkdir(parents=True, exist_ok=True)
if origin_frames.exists():
for vid in sorted(origin_frames.iterdir()):
if vid.is_dir():
dst = origin_larger / vid.name
if not dst.exists():
dst.symlink_to(vid.resolve())
print(f" Origin: {origin_larger} ({len(list(origin_larger.iterdir()))} videos)")
else:
print(f" [skip] Origin {origin_frames} not found")
# 3. Manipulated: larger_images/{id1_id2} -> symlink to frames_retina/xxx
for method in args.methods:
man_frames = data_root / MANIPULATED_FRAMES.format(method=method, codec=codec)
man_larger = ffpp / method / codec / "larger_images"
man_larger.mkdir(parents=True, exist_ok=True)
if man_frames.exists():
n = 0
for vid in sorted(man_frames.iterdir()):
if vid.is_dir():
dst = man_larger / vid.name
if not dst.exists():
dst.symlink_to(vid.resolve())
n += 1
print(f" {method}: {man_larger} ({n} videos)")
else:
print(f" [skip] {method} {man_frames} not found")
print(f"\n完成: ffpp 目录 -> {ffpp}")
print("\n使用方式:")
print(" 1. fakefacecls: export FFPP_ROOT=" + str(ffpp.resolve()))
print(" 2. multiple-attention: 在 datasets/data.py 中设置 ffpproot = '" + str(ffpp.resolve()) + "/'")
if __name__ == "__main__":
main()
三、人脸分类网络
我们接下来直接使用Timm库来验证CNN和Transformer作为Backbone对人脸伪造分类的识别性能,我们将支持两种分类方式,分别是二分类和五分类,二分类即单纯的True/False,五分类则在正确区分的基础上额外实现分类人脸伪造的方式
所有代码已上传至GitHub:https://github.com/xiongqi123123/fakefaceclsnet
数据集加载及数据增强代码如下:
python
import os
import random
import torch
import cv2
from torch.utils.data import Dataset
import albumentations as A
from albumentations import Compose
from .augmentations import augmentations
from . import data
class DeepfakeDataset(Dataset):
def __init__(
self,
phase='train',
datalabel='',
resize=(224, 224),
imgs_per_video=30,
min_frames=0,
normalize=None,
frame_interval=10,
max_frames=300,
augment='augment0',
):
assert phase in ['train', 'val', 'test']
normalize = normalize or dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
self.datalabel = datalabel
self.phase = phase
self.imgs_per_video = imgs_per_video
self.frame_interval = frame_interval
self.epoch = 0
self.max_frames = max_frames
self.min_frames = min_frames if min_frames else max_frames * 0.3
self.aug = augmentations.get(augment, augmentations['augment0'])
self.resize = resize
self.trans = Compose([
A.Resize(resize[0], resize[1]), # 小图(如19x14)需先 resize,CenterCrop 会报错
A.Normalize(mean=normalize['mean'], std=normalize['std']),
A.ToTensorV2(),
])
self.dataset = self._build_dataset()
self._frame_cache = {} # 缓存 os.listdir,避免每帧重复读目录
def _build_dataset(self):
if isinstance(self.datalabel, (list, tuple)):
return self.datalabel
if 'ff-5' in self.datalabel:
codec = self.datalabel.split('-')[2]
out = []
for idx, tag in enumerate(['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face']):
for item in data.FF_dataset(tag, codec, self.phase):
out.append([item[0], idx])
return out
if 'ff-all' in self.datalabel:
codec = self.datalabel.split('-')[2]
out = []
for tag in ['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face']:
out.extend(data.FF_dataset(tag, codec, self.phase))
if self.phase != 'test':
out = data.make_balance(out)
return out
if 'ff' in self.datalabel:
parts = self.datalabel.split('-')
codec = parts[2]
tag = parts[1]
return data.FF_dataset(tag, codec, self.phase) + data.FF_dataset('Origin', codec, self.phase)
if 'celeb' in self.datalabel:
return data.Celeb_test
if 'deeper' in self.datalabel:
codec = self.datalabel.split('-')[1]
return data.deeperforensics_dataset(self.phase) + data.FF_dataset('Origin', codec, self.phase)
if 'dfdc' in self.datalabel:
return data.dfdc_dataset(self.phase)
raise ValueError(f'Unknown datalabel: {self.datalabel}')
def next_epoch(self):
self.epoch += 1
def __getitem__(self, item):
for _ in range(len(self.dataset)): # 避免无限递归
try:
vid = self.dataset[item // self.imgs_per_video]
vid_path = vid[0]
if vid_path not in self._frame_cache:
self._frame_cache[vid_path] = sorted(os.listdir(vid_path))
vd = self._frame_cache[vid_path]
if len(vd) < self.min_frames:
raise ValueError(f"frames {len(vd)} < min_frames {self.min_frames}")
idx = (item % self.imgs_per_video * self.frame_interval + self.epoch) % min(len(vd), self.max_frames)
fname = vd[idx]
img = cv2.imread(os.path.join(vid[0], fname))
if img is None:
raise ValueError(f"cv2.imread failed: {fname}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.phase == 'train':
img = self.aug(image=img)['image']
return self.trans(image=img)['image'], vid[1]
except Exception as e:
if os.environ.get('DEBUG_DATASET') == '1' and not getattr(self, '_debug_printed', False):
import traceback
vp = self.dataset[item // self.imgs_per_video][0] if item < len(self) else '?'
print(f'[DEBUG] item={item} path={vp} err={e}')
traceback.print_exc()
self._debug_printed = True # 只打印第一次
if self.phase == 'test':
return torch.zeros(3, self.resize[0], self.resize[1]), -1
item = (item + self.imgs_per_video) % len(self)
return torch.zeros(3, self.resize[0], self.resize[1]), -1 # 全部失败时返回占位
def __len__(self):
return len(self.dataset) * self.imgs_per_video
python
import os
import json
import random
# 数据根目录:FFPP_ROOT 或默认 FFDeepFake/data/FaceForensics++/ffpp
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_FFDEEPFAKE_ROOT = os.path.dirname(os.path.dirname(_SCRIPT_DIR)) # fakefacecls/ -> FFDeepFake
_FFDEEPFAKE_ROOT = os.path.dirname(_FFDEEPFAKE_ROOT) # FFDeepFake
_data_root = os.path.join(_FFDEEPFAKE_ROOT, 'data')
_DEFAULT_FFPP = os.path.join(_data_root, 'FaceForensics++', 'ffpp')
ffpproot = os.environ.get('FFPP_ROOT', _DEFAULT_FFPP)
if ffpproot and not ffpproot.endswith(os.sep):
ffpproot += os.sep
dfdcroot = os.path.join(_data_root, 'dfdc')
celebroot = os.path.join(_data_root, 'celebDF')
deeperforensics_root = os.path.join(_data_root, 'deeper')
def load_json(name):
with open(name) as f:
return json.load(f)
def FF_dataset(tag='Origin', codec='c0', part='train'):
assert tag in ['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face', 'FaceShifter']
assert codec in ['c0', 'c23', 'c40', 'all']
assert part in ['train', 'val', 'test', 'all']
if part == 'all':
return FF_dataset(tag, codec, 'train') + FF_dataset(tag, codec, 'val') + FF_dataset(tag, codec, 'test')
if codec == 'all':
return FF_dataset(tag, 'c0', part) + FF_dataset(tag, 'c23', part) + FF_dataset(tag, 'c40', part)
path = os.path.join(ffpproot, tag, codec, 'larger_images')
metafile = load_json(os.path.join(ffpproot, part + '.json'))
files = []
if tag == 'Origin':
for i in metafile:
files.append([os.path.join(path, i[0]), 0])
files.append([os.path.join(path, i[1]), 0])
else:
for i in metafile:
files.append([os.path.join(path, i[0] + '_' + i[1]), 1])
files.append([os.path.join(path, i[1] + '_' + i[0]), 1])
return files
def make_balance(data):
tr = [x for x in data if x[1] == 0]
tf = [x for x in data if x[1] == 1]
if len(tr) > len(tf):
tr, tf = tf, tr
rate = len(tf) // len(tr)
res = len(tf) - rate * len(tr)
tr = tr * rate + random.sample(tr, res)
return tr + tf
def dfdc_dataset(part='train'):
assert part in ['train', 'val', 'test']
lf = load_json(os.path.join(dfdcroot, 'DFDC.json'))
if part == 'train':
path = os.path.join(dfdcroot, 'dfdc')
files = make_balance(lf['train'])
elif part == 'test':
path = os.path.join(dfdcroot, 'dfdc-test')
files = lf['test']
else:
path = os.path.join(dfdcroot, 'dfdc-val')
files = lf['val']
return [[os.path.join(path, i[0]), i[1]] for i in files]
def deeperforensics_dataset(part='train'):
a = os.listdir(deeperforensics_root)
d = {i.split('_')[0]: i for i in a}
metafile = load_json(os.path.join(ffpproot, part + '.json'))
files = []
for i in metafile:
p = os.path.join(deeperforensics_root, d[i[0]])
files.append([p, 1])
p = os.path.join(deeperforensics_root, d[i[1]])
files.append([p, 1])
return files
try:
Celeb_test = list(map(lambda x: [os.path.join(celebroot, x[0]), 1 - x[1]], load_json(os.path.join(celebroot, 'celeb.json'))))
except Exception:
Celeb_test = []
python
import albumentations as A
augment0 = A.Compose([A.HorizontalFlip()], p=1)
augment1 = A.Compose([
A.HorizontalFlip(),
A.HueSaturationValue(p=0.5),
A.RandomBrightnessContrast(p=0.5),
], p=1)
augment2 = A.Compose([
A.HorizontalFlip(),
A.HueSaturationValue(p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.OneOf([A.GaussNoise()], p=0.3),
A.OneOf([
A.MotionBlur(),
A.GaussianBlur(),
A.ImageCompression(quality_range=(65, 80)),
], p=0.3),
A.ToGray(p=0.1),
], p=1)
augmentations = {'augment0': augment0, 'augment1': augment1, 'augment2': augment2}
网络直接使用Timm的预置模型:
python
####### __init__.py
from .cnn import build_cnn_model
def build_model(backbone='resnet50', num_classes=2, pretrained=True, dropout=0.3, **kwargs):
"""任意 timm 模型名均可,如 resnet50, vit_base_patch16_224, deit_small_patch16_224"""
return build_cnn_model(backbone, num_classes, pretrained, dropout, **kwargs)
####### backbone.py
"""
基于 timm 的 backbone + 分类头
支持 CNN (resnet50, efficientnet_b0, ...) 和 ViT (vit_base_patch16_224, deit_base_patch16_224, ...)
"""
import torch.nn as nn
import timm
def build_cnn_model(backbone='resnet50', num_classes=2, pretrained=True, dropout=0.3, **kwargs):
"""
Args:
backbone: timm 模型名,如 resnet50, efficientnet_b0, convnext_tiny
num_classes: 2 (真/假) 或 5 (Origin/Deepfakes/NeuralTextures/FaceSwap/Face2Face)
pretrained: 是否加载 ImageNet 预训练
dropout: 分类头 dropout
"""
return CNNClassifier(backbone, num_classes, pretrained, dropout, **kwargs)
class CNNClassifier(nn.Module):
"""timm backbone + 可替换分类头"""
def __init__(self, backbone='resnet50', num_classes=2, pretrained=True, dropout=0.3, in_chans=3, **kwargs):
super().__init__()
self.num_classes = num_classes
weights = 'imagenet' if pretrained else None
self.backbone = timm.create_model(
backbone,
pretrained=weights,
num_classes=0, # 移除原分类头
global_pool='avg',
in_chans=in_chans,
**kwargs
)
feat_dim = self.backbone.num_features
self.head = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(feat_dim, num_classes),
)
def forward(self, x):
feat = self.backbone(x)
return self.head(feat)
然后就是训练的代码
python
"""
训练与验证脚本
用法:
python train.py
python train.py --backbone resnet50 --num_classes 5 --datalabel ff-5-c23
"""
import argparse
import logging
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from config import TrainConfig
from datasets import DeepfakeDataset
from models import build_model
def get_args():
p = argparse.ArgumentParser()
p.add_argument('--backbone', default='resnet50', help='timm backbone')
p.add_argument('--num_classes', type=int, default=2, choices=[2, 5])
p.add_argument('--datalabel', default='ff-all-c23', help='ff-all-c23 | ff-5-c23')
p.add_argument('--epochs', type=int, default=20)
p.add_argument('--batch_size', type=int, default=64)
p.add_argument('--lr', type=float, default=1e-5)
p.add_argument('--name', default='')
p.add_argument('--no_pretrained', action='store_true')
p.add_argument('--resume', default='')
p.add_argument('--save_every', type=int, default=5, help='每隔多少轮保存一次 ckpt')
p.add_argument('--log_interval', type=int, default=100, help='每隔多少 batch 打印一次 log')
p.add_argument('--val_every_steps', type=int, default=500, help='每隔多少 step 在 val 上验证一次,0=仅每 epoch 结束')
p.add_argument('--test_every_epoch', action='store_true', help='每个 epoch 结束在 test 上评估(仅监控,不参与选 best)')
p.add_argument('--workers', type=int, default=0, help='DataLoader workers,0=用 config 默认')
p.add_argument('--no_amp', action='store_true', help='禁用混合精度')
return p.parse_args()
def main():
args = get_args()
cfg = TrainConfig(
backbone=args.backbone,
num_classes=args.num_classes,
pretrained=not args.no_pretrained,
datalabel=args.datalabel,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
name=args.name if args.name else None,
)
run_dir = os.path.join('runs', cfg.name)
os.makedirs(run_dir, exist_ok=True)
log_path = os.path.join(run_dir, 'train.log')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[logging.FileHandler(log_path), logging.StreamHandler()],
force=True,
)
logging.info(f'Run dir: {run_dir}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_model(
backbone=cfg.backbone,
num_classes=cfg.num_classes,
pretrained=cfg.pretrained,
dropout=cfg.dropout,
).to(device)
workers = args.workers if args.workers > 0 else cfg.workers
use_amp = not args.no_amp and torch.cuda.is_available()
scaler = torch.amp.GradScaler('cuda') if use_amp else None
if use_amp:
logging.info('Using AMP (mixed precision)')
train_ds = DeepfakeDataset(**cfg.dataset_kwargs('train'))
val_ds = DeepfakeDataset(**cfg.dataset_kwargs('val'))
test_ds = DeepfakeDataset(**cfg.dataset_kwargs('test'))
dl_kw = dict(batch_size=cfg.batch_size, pin_memory=torch.cuda.is_available())
if workers > 0:
dl_kw.update(num_workers=workers, persistent_workers=True, prefetch_factor=4)
else:
dl_kw['num_workers'] = 0
train_loader = DataLoader(train_ds, shuffle=True, **dl_kw)
val_loader = DataLoader(val_ds, shuffle=False, **dl_kw)
test_loader = DataLoader(test_ds, shuffle=False, **dl_kw)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)
start_epoch = 0
best_acc = 0.0
if args.resume:
ckpt = torch.load(args.resume, map_location=device)
model.load_state_dict(ckpt.get('model', ckpt), strict=False)
start_epoch = ckpt.get('epoch', 0) + 1
best_acc = ckpt.get('best_acc', 0.0)
if 'optimizer' in ckpt:
optimizer.load_state_dict(ckpt['optimizer'])
if 'scheduler' in ckpt:
scheduler.load_state_dict(ckpt['scheduler'])
logging.info(f'Resumed from epoch {start_epoch}, best_acc={best_acc:.4f}')
global_step = start_epoch * len(train_loader)
for epoch in range(start_epoch, cfg.epochs):
train_ds.next_epoch()
train_loss, train_acc = train_epoch(
model, train_loader, optimizer, device,
log_interval=args.log_interval,
scaler=scaler,
val_loader=val_loader if args.val_every_steps > 0 else None,
val_every_steps=args.val_every_steps,
global_step=global_step,
)
global_step += len(train_loader)
val_loss, val_acc = validate(model, val_loader, device)
scheduler.step()
logging.info(f'E{epoch} train loss={train_loss:.4f} acc={train_acc:.4f} | val loss={val_loss:.4f} acc={val_acc:.4f}')
ckpt = {
'model': model.state_dict(), 'epoch': epoch, 'best_acc': best_acc,
'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(),
}
torch.save(ckpt, os.path.join(run_dir, 'latest.pth'))
if val_acc > best_acc:
best_acc = val_acc
torch.save({'model': model.state_dict(), 'epoch': epoch, 'best_acc': best_acc}, os.path.join(run_dir, 'best.pth'))
logging.info(f' -> best acc={best_acc:.4f} saved')
if args.save_every > 0 and (epoch + 1) % args.save_every == 0:
torch.save({'model': model.state_dict(), 'epoch': epoch}, os.path.join(run_dir, f'ep{epoch}.pth'))
if args.test_every_epoch:
test_loss, test_acc = validate(model, test_loader, device, desc='test')
logging.info(f'E{epoch} test loss={test_loss:.4f} acc={test_acc:.4f}')
# 训练结束:用 best 在 test 上做最终评估(无 best 则用 latest)
for ckpt_name in ('best.pth', 'latest.pth'):
ckpt_path = os.path.join(run_dir, ckpt_name)
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt.get('model', ckpt), strict=False)
test_loss, test_acc = validate(model, test_loader, device, desc='test')
logging.info(f'[FINAL] {ckpt_name} on test: loss={test_loss:.4f} acc={test_acc:.4f}')
break
def train_epoch(model, loader, optimizer, device, log_interval=50, scaler=None,
val_loader=None, val_every_steps=0, global_step=0):
model.train()
total_loss, total_acc, n = 0.0, 0.0, 0
pbar = tqdm(loader, desc='train', leave=False)
for i, (x, y) in enumerate(pbar):
step = global_step + i + 1
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
optimizer.zero_grad()
if scaler is not None:
with torch.amp.autocast('cuda'):
logits = model(x)
loss = F.cross_entropy(logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
logits = model(x)
loss = F.cross_entropy(logits, y)
loss.backward()
optimizer.step()
acc = (logits.argmax(1) == y).float().mean().item()
total_loss += loss.item() * x.size(0)
total_acc += acc * x.size(0)
n += x.size(0)
pbar.set_postfix(loss=f'{loss.item():.4f}', acc=f'{acc:.4f}')
if log_interval > 0 and (i + 1) % log_interval == 0:
logging.info(f' batch {i+1}/{len(loader)} loss={loss.item():.4f} acc={acc:.4f}')
if val_loader is not None and val_every_steps > 0 and step % val_every_steps == 0:
val_loss, val_acc = validate(model, val_loader, device)
logging.info(f' [step {step}] val loss={val_loss:.4f} acc={val_acc:.4f}')
model.train()
return total_loss / n, total_acc / n
def validate(model, loader, device, desc='val'):
model.eval()
total_loss, total_acc, n = 0.0, 0.0, 0
with torch.no_grad():
pbar = tqdm(loader, desc=desc, leave=False)
for x, y in pbar:
x, y = x.to(device), y.to(device)
# 过滤无效标签(dataset 加载失败时返回 -1)
valid = y >= 0
if valid.sum() == 0:
continue
x, y = x[valid], y[valid]
logits = model(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(1) == y).float().mean().item()
total_loss += loss.item() * x.size(0)
total_acc += acc * x.size(0)
n += x.size(0)
pbar.set_postfix(loss=f'{loss.item():.4f}', acc=f'{acc:.4f}')
return total_loss / n, total_acc / n if n > 0 else (0.0, 0.0)
if __name__ == '__main__':
main()
可以看到训练的效果非常的好,基本一个Epoch就可以在Test验证集上达到0.8以上的正确率,且可以观察发现Transformer作为Backbone的效果远比CNN的效果好
