AI风控之生成图像鉴伪实战
字数 1237 2025-08-20 18:18:24
AI风控之生成图像鉴伪实战教学文档
1. 前言与背景
AI生成的虚假图像检测(AI鉴伪)是当前AI安全领域的重要课题。随着生成对抗网络(GANs)等技术的发展,AI已能生成高度逼真的虚假图像,如人脸生成网站"https://thispersondoesnotexist.com/"每次刷新都会产生一个不存在的人脸图像。
Deepfake技术是这类技术的典型代表:
- 基于深度学习的图像/视频合成技术
- 使用生成对抗网络(GANs)不断优化模型
- 应用领域广泛:娱乐、教育、安全防护等
- 潜在风险:假新闻、政治宣传、隐私侵犯等
2. 技术思路
核心思想:用深度学习检测深度学习生成的图像
采用端到端(End-to-End)深度学习方法:
- 直接学习真实图像与AI生成图像之间的差异特征
- 无需人工设计特征提取过程
- 自动捕捉数据中的高级抽象信息
3. 实战环境准备
3.1 所需库导入
import numpy as np
import pandas as pd
from keras.applications.mobilenet import MobileNet, preprocess_input
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dropout, Dense,BatchNormalization, Flatten, MaxPool2D
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback
from keras.layers import Conv2D, Reshape
from keras.utils import Sequence
from keras.backend import epsilon
import tensorflow as tf
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
import cv2
from tqdm.notebook import tqdm_notebook as tqdm
import os
3.2 数据集准备
假设数据集结构如下:
/real_and_fake_face/
/training_real/ # 真实人脸图像
/training_fake/ # AI生成的人脸图像
加载路径:
real = "/real_and_fake_face/training_real/"
fake = "/real_and_fake_face/training_fake/"
real_path = os.listdir(real)
fake_path = os.listdir(fake)
3.3 图像加载辅助函数
def load_img(path):
image = cv2.imread(path)
image = cv2.resize(image,(224, 224)) # 调整为224×224大小
return image[...,::-1] # BGR转RGB
4. 数据预处理
4.1 数据增强(Data Augmentation)
data_with_aug = ImageDataGenerator(
horizontal_flip=True, # 水平翻转
vertical_flip=False, # 不垂直翻转
rescale=1./255, # 归一化到[0,1]
validation_split=0.2 # 20%数据作为验证集
)
数据增强技术的作用:
- 增加数据多样性
- 防止过拟合
- 提高模型泛化能力
4.2 可视化真实与虚假图像
# 可视化真实人脸
fig = plt.figure(figsize=(10, 10))
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow(load_img(real + real_path[i]), cmap='gray')
plt.suptitle("Real faces",fontsize=20)
plt.axis('off')
plt.show()
# 可视化虚假人脸(类似代码)
5. 模型构建
5.1 使用MobileNetV2作为基础模型
mnet = MobileNetV2(
include_top=False, # 不包含顶层全连接层
weights="imagenet", # 使用ImageNet预训练权重
input_shape=(96,96,3) # 输入图像尺寸
)
MobileNetV2特点:
- 轻量级卷积神经网络
- 适合移动设备部署
- 在准确率和模型大小间取得平衡
5.2 构建完整模型
tf.keras.backend.clear_session() # 清除之前的会话
model = Sequential([
mnet, # MobileNetV2基础模型
GlobalAveragePooling2D(), # 全局平均池化
Dense(512, activation="relu"),
BatchNormalization(),
Dropout(0.3), # 30% dropout
Dense(128, activation="relu"),
Dropout(0.1), # 10% dropout
Dense(2, activation="softmax") # 二分类输出
])
model.layers[0].trainable = False # 冻结MobileNetV2权重
5.3 模型编译
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics="accuracy"
)
model.summary() # 打印模型结构
6. 训练配置
6.1 学习率调度器
def scheduler(epoch):
if epoch <= 2:
return 0.001
elif 2 < epoch <= 15:
return 0.0001
else:
return 0.00001
lr_callbacks = tf.keras.callbacks.LearningRateScheduler(scheduler)
学习率调度策略:
- 前2个epoch:0.001
- 3-15个epoch:0.0001
- 之后:0.00001
6.2 训练过程
hist = model.fit(
train_data,
epochs=20,
validation_data=val_data,
callbacks=[lr_callbacks]
)
7. 结果分析与可视化
7.1 训练曲线可视化
epochs = 20
train_loss = hist.history['loss']
val_loss = hist.history['val_loss']
train_acc = hist.history['accuracy']
val_acc = hist.history['val_accuracy']
xc = range(epochs)
# 绘制损失曲线
plt.figure(1,figsize=(7,5))
plt.plot(xc,train_loss)
plt.plot(xc,val_loss)
plt.xlabel('num of Epochs')
plt.ylabel('loss')
plt.title('train_loss vs val_loss')
plt.grid(True)
plt.legend(['train','val'])
plt.style.use(['classic'])
# 绘制准确率曲线
plt.figure(2,figsize=(7,5))
plt.plot(xc,train_acc)
plt.plot(xc,val_acc)
plt.xlabel('num of Epochs')
plt.ylabel('accuracy')
plt.title('train_acc vs val_acc')
plt.grid(True)
plt.legend(['train','val'],loc=4)
plt.style.use(['classic'])
7.2 预测结果可视化
val_path = "real-and-fake-face-detection/real_and_fake_face/"
plt.figure(figsize=(15,15))
start_index = 250
for i in range(16):
plt.subplot(4,4, i+1)
plt.grid(False)
plt.xticks([])
plt.yticks([])
preds = np.argmax(predictions[[start_index+i]]) # 获取预测结果
gt = val.filenames[start_index+i][9:13] # 获取真实标签
if gt == "fake":
gt = 0
else:
gt = 1
# 预测错误标红,正确标绿
if preds != gt:
col ="r"
else:
col = "g"
plt.xlabel('i={}, pred={}, gt={}'.format(start_index+i,preds,gt),color=col)
plt.imshow(load_img(val_path+val.filenames[start_index+i]))
plt.tight_layout()
plt.show()
8. 关键点总结
- 数据准备:需要平衡的真实图像和AI生成图像数据集
- 数据增强:提高模型泛化能力的关键技术
- 模型选择:轻量级MobileNetV2适合此类任务
- 训练技巧:
- 使用预训练权重
- 动态学习率调整
- 适当的Dropout防止过拟合
- 评估方法:准确率和损失曲线监控,可视化预测结果
9. 参考资料
- Deepfake识别指南:https://sosafe-awareness.com/blog/how-to-spot-a-deepfake/
- 端到端学习解释:https://ai.stackexchange.com/questions/16575/what-does-end-to-end-training-mean
- AI人脸生成示例:https://thispersondoesnotexist.com/
- MobileNetV2论文解析:https://towardsdatascience.com/review-mobilenetv2-light-weight-model-image-classification-8febb490e61c
- 数据集来源:https://www.kaggle.com/