基于OpenCV进行表情识别

一、程序概述

该程序实现了基于深度学习的实时表情识别功能。通过摄像头捕获视频流,检测其中的人脸,并对检测到的人脸进行表情分类,最终在图像上显示出识别出的表情标签。

二、依赖库

  1. cv2(OpenCV):用于计算机视觉任务,如读取摄像头视频流、图像的基本处理(如转换为灰度图、绘制矩形和文本等)以及人脸检测。
  2. numpy:提供了高效的数值计算功能,用于处理图像数据等。
  3. torch:PyTorch 深度学习框架,用于构建和训练深度学习模型,以及进行模型的推理计算。
  4. torchvision.transforms:包含了对图像进行预处理的各种转换操作,如调整大小、归一化等。
  5. torchvision.models:提供了预训练的深度学习模型,如 ResNet18。
  6. torch.nn.functional :包含了 PyTorch 中常用的神经网络函数,如softmax函数,用于计算概率分布。

三、代码详解

1. 加载模型

python

复制代码
model = resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 7)  # 修改输出层为 7 类(7 种表情)
model.eval()
  • torchvision.models中加载预训练的 ResNet18 模型,pretrained=True表示加载在 ImageNet 上预训练的权重。
  • 将模型的全连接层(fc)修改为输出 7 个类别的层,因为这里要识别 7 种表情(生气、厌恶、恐惧、快乐、悲伤、惊讶、中性)。
  • 将模型设置为评估模式(eval()),以关闭一些在训练时使用的操作,如随机失活(dropout)和批量归一化(batch normalization)的特定行为。

2. 表情标签定义

python

复制代码
emotion_labels = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

定义了一个列表,包含了 7 种表情的英文标签,用于将模型预测的类别索引转换为对应的表情名称。

3. 图像预处理

python

复制代码
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
  • 使用transforms.Compose组合了一系列的图像转换操作。
  • ToPILImage():将输入的图像数据转换为PIL(Python Imaging Library)图像格式。
  • Resize((224, 224)):将图像大小调整为 224x224 像素,这是 ResNet18 模型输入图像的要求尺寸。
  • ToTensor():将PIL图像转换为 PyTorch 的张量格式。
  • Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):对图像张量进行归一化处理,使用给定的均值和标准差。

4. 人脸检测

python

复制代码
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

加载 Haar 级联分类器,用于检测图像中的人脸。haarcascade_frontalface_default.xml是一个预训练的 Haar 级联模型文件,包含了用于人脸检测的特征数据。

5. 打开摄像头

python

复制代码
cap = cv2.VideoCapture(0)

使用 OpenCV 的VideoCapture函数打开默认的摄像头(索引为 0),用于捕获视频流。

6. 设备选择

python

复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

检查 GPU 是否可用,如果可用,则将设备设置为cuda,否则设置为cpu。然后将模型移动到选定的设备上,以便在计算时使用相应的硬件资源。

7. 主循环

python

复制代码
while True:
    ret, frame = cap.read()
    if not ret:
        break

    # 转换为灰度图像
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # 检测人脸
    faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

    for (x, y, w, h) in faces:
        # 提取人脸区域
        face = frame[y:y+h, x:x+w]
        face = transform(face).unsqueeze(0)  # 预处理并添加批次维度
        face = face.to(device)  # 将数据移动到 GPU(如果可用)

        # 预测表情
        with torch.no_grad():
            outputs = model(face)
            probabilities = F.softmax(outputs, dim=1)
            emotion_index = torch.argmax(probabilities, dim=1).item()
            emotion = emotion_labels[emotion_index]

        # 在图像上绘制结果
        cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
        cv2.putText(frame, emotion, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

    # 显示图像
    cv2.imshow('Facial Expression Recognition', frame)

    # 按 'q' 键退出
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
  • cap.read()从摄像头读取一帧视频图像,返回一个布尔值ret(表示是否成功读取)和读取到的图像frame
  • 将读取到的彩色图像转换为灰度图像,用于人脸检测。
  • 使用face_cascade.detectMultiScale方法检测灰度图像中的人脸,返回一个包含人脸矩形框坐标(x, y, w, h)的列表。
  • 对于检测到的每个人脸:
    • 提取人脸区域,并对其进行预处理,包括转换为张量、调整大小、归一化等,并添加一个批次维度(因为模型输入需要是一个批次)。
    • 将处理后的人脸数据移动到选定的设备上。
    • 使用模型进行预测,计算预测结果的概率分布,通过softmax函数得到每个类别的概率。
    • 找到概率最大的类别索引,将其转换为对应的表情标签。
    • 在原始图像上绘制人脸矩形框和表情标签。
  • 使用cv2.imshow显示处理后的图像。
  • 使用cv2.waitKey等待用户按键,如果按下的键是q,则退出循环。

8. 释放资源

python

复制代码
cap.release()
cv2.destroyAllWindows()

释放摄像头资源,并关闭所有由 OpenCV 打开的窗口。

四、注意事项

  1. 确保安装了所有必要的依赖库,并且版本兼容。
  2. Haar 级联分类器可能在复杂背景或光线不佳的情况下检测效果不佳,可以尝试使用其他更先进的人脸检测方法。
  3. 模型的性能可能受到训练数据、模型结构和超参数的影响,可以根据实际情况进行调整和优化。
  4. 确保摄像头设备正常工作并且程序具有访问摄像头的权限。

完整代码

python 复制代码
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.nn import functional as F

# 加载 ResNet18 模型
model = resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 7)  # 修改输出层为 7 类(7 种表情)
model.eval()

# 表情标签
emotion_labels = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

# 图像预处理
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载 Haar 级联分类器用于人脸检测
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

# 打开摄像头
cap = cv2.VideoCapture(0)

# 检查 GPU 是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

while True:
    ret, frame = cap.read()
    if not ret:
        break

    # 转换为灰度图像
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # 检测人脸
    faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

    for (x, y, w, h) in faces:
        # 提取人脸区域
        face = frame[y:y+h, x:x+w]
        face = transform(face).unsqueeze(0)  # 预处理并添加批次维度
        face = face.to(device)  # 将数据移动到 GPU(如果可用)

        # 预测表情
        with torch.no_grad():
            outputs = model(face)
            probabilities = F.softmax(outputs, dim=1)
            emotion_index = torch.argmax(probabilities, dim=1).item()
            emotion = emotion_labels[emotion_index]

        # 在图像上绘制结果
        cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
        cv2.putText(frame, emotion, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

    # 显示图像
    cv2.imshow('Facial Expression Recognition', frame)

    # 按 'q' 键退出
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# 释放摄像头并关闭窗口
cap.release()
cv2.destroyAllWindows()
相关推荐
moxiaoran575312 分钟前
python学习笔记--实现简单的爬虫(二)
python
seetimee28 分钟前
Milvus WeightedRanker 对比 RRF 重排机制
人工智能·python·milvus
*星星之火*39 分钟前
【GPT入门】第27课 Jupyter 感知到通过命令行生成的内核
python·gpt·jupyter
weixin_3077791340 分钟前
基于Azure Delta Lake和Databricks的安全数据共享(Delta Sharing)
python·安全·spark·云计算·azure
数字化转型20251 小时前
股票量化交易开发 Yfinance
开发语言·python
钢铁男儿2 小时前
Python 用户账户(创建用户账户)
数据库·python·sqlite
进击的六角龙2 小时前
【Python数据分析+可视化项目案例】:亚马逊平台用户订单数据分析
开发语言·爬虫·python·数据分析·网络爬虫·数据可视化
蹦蹦跳跳真可爱5892 小时前
Python---数据分析(Pandas九:二维数组DataFrame数据操作二: 数据排序,数据筛选,数据拼接)
python·信息可视化·数据分析·pandas
G皮T3 小时前
【Python Cookbook】字符串和文本(一)
python·正则表达式·字符串·查找
云空3 小时前
《Gradio Python 客户端入门》
服务器·python