一、代码整体要求
FFN的完整计算公式y = max(0, xW1 + b1)W2 + b2,代码中固定特征维度D=512,确保各参数维度精准匹配,其中W1为512×2048(即D×4D)、W2为2048×512(即4D×D)、b1为1×2048(即4D)、b2为1×512(即D)。
实现ReLU激活函数,该函数支持矩阵输入,比如4×2048维度的矩阵,主要是将输入中的负数置为0,正数和0则保持不变。
FFN函数需实现完整功能,能够接收4×512维度的输入矩阵,经过非线性转换后,输出维度需与输入保持一致,仍为4×512,以此完成对自注意力输出特征的增强。
二、代码核心部分解释
1. 矩阵类核心解释
矩阵类是整个FFN实现的基础,主要封装了FFN运算过程中所需的各类矩阵操作,能够精准适配各参数的维度需求。
其属性包含矩阵的行数、列数,以及用于存储矩阵数据的二维向量,通过构造函数可完成矩阵维度和初始值的初始化,初始值支持自定义设置,默认情况下为0.0f。
矩阵乘法方法是核心功能之一,主要用于实现公式中的各类线性变换,比如输入x与W1的乘法、ReLU激活后的隐藏层特征与W2的乘法,在运算开始前会先检查维度,确保当前矩阵的列数与另一个矩阵的行数相等,运算过程通过三重循环完成矩阵元素的累加计算。
矩阵加法方法支持两种合法运算,一种是同维度矩阵之间的相加,另一种是矩阵与行向量的广播相加,这种广播相加能够很好地适配偏置b1和b2的运算需求,比如4×2048的矩阵与1×2048的行向量相。
// 矩阵类定义,封装矩阵操作
class Matrix {
public:
int rows; // 矩阵行数
int cols; // 矩阵列数
std::vector<std::vector<float>> data; // 矩阵数据存储
// 初始化矩阵维度
Matrix(int r, int c, float init_val = 0.0f)
: rows(r), cols(c), data(r, std::vector<float>(c, init_val)) {
}
// 矩阵乘法
Matrix multiply(const Matrix& other) const {
Matrix result(this->rows, other.cols, 0.0f);
for (int i = 0; i < this->rows; ++i) {
for (int j = 0; j < other.cols; ++j) {
float sum = 0.0f;
for (int k = 0; k < this->cols; ++k) {
sum += this->data[i][k] * other.data[k][j];
}
result.data[i][j] = sum;
}
}
return result;
}
// 矩阵加法:支持矩阵+矩阵、矩阵+行向量
Matrix add(const Matrix& other) const {
// 同维度矩阵相加
if (this->rows == other.rows && this->cols == other.cols) {
Matrix result(this->rows, this->cols, 0.0f);
for (int i = 0; i < this->rows; ++i) {
for (int j = 0; j < this->cols; ++j) {
result.data[i][j] = this->data[i][j] + other.data[i][j];
}
}
return result;
}
// 矩阵 + 行向量
else if (other.rows == 1 && this->cols == other.cols) {
Matrix result(this->rows, this->cols, 0.0f);
for (int i = 0; i < this->rows; ++i) {
for (int j = 0; j < this->cols; ++j) {
result.data[i][j] = this->data[i][j] + other.data[0][j];
}
}
return result;
}
throw std::invalid_argument("矩阵加法维度不匹配");
}
};
// 重载运算符,简化矩阵运算语法
Matrix operator*(const Matrix& a, const Matrix& b) {
return a.multiply(b);
}
Matrix operator+(const Matrix& a, const Matrix& b) {
return a.add(b);
}
2. ReLU激活函数核心解释
该函数可接收任意维度的矩阵输入,输出矩阵的维度与输入完全一致,是为了确保不改变特征的维度结构。
代码通过双重循环逐元素来处理输入矩阵,使用max(),来实现负数置0、正数和0保持不变"。
这个函数主要是对xW1 + b1的运算结果进行非线性过滤,从而增强模型的表达能力,避免线性变换带来的局限性。
// ReLU激活:负数置0,0和正数保持不变
Matrix relu(const Matrix& input) {
Matrix output(input.rows, input.cols);
for (int i = 0; i < input.rows; ++i) {
for (int j = 0; j < input.cols; ++j) {
output.data[i][j] = std::max(0.0f, input.data[i][j]);
}
}
return output;
}
3. FFN核心函数解释
FFN核心函数严格遵循既定的数学公式,将线性变换和ReLU激活函数串联起来,逐步完成对输入特征的非线性转换。
运算第一步是计算xW1 + b1,先将输入x(4×512维度)与W1(512×2048维度)进行矩阵乘法,得到4×2048维度的中间结果,再将该中间结果与b1(1×2048维度)进行广播相加,完成第一次线性变换。
运算第二步是ReLU激活,将第一步得到的中间结果传入ReLU激活函数,通过逐元素处理过滤掉所有负数,得到经过非线性变换后的隐藏层特征,该特征的维度仍为4×2048。
运算第三步是计算hidden_relu×W2 + b2,将ReLU激活后的隐藏层特征(4×2048维度)与W2(2048×512维度)进行矩阵乘法,得到4×512维度的结果,再将该结果与b2(1×512维度)进行广播相加,完成第二次线性变换,最终得到FFN的输出结果,输出维度与输入维度一致,均为4×512。
// 初始化权重/偏置
Matrix init_weight(int rows, int cols, float mean = 0.0f, float stddev = 0.01f) {
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> dist(mean, stddev);
Matrix mat(rows, cols);
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
mat.data[i][j] = dist(gen);
}
}
return mat;
}
// FFN核心:y = max(0, xW1 + b1)W2 + b2
Matrix ffn(const Matrix& x, const Matrix& W1, const Matrix& W2, const Matrix& b1, const Matrix& b2) {
Matrix xW1 = x * W1;
Matrix hidden = xW1 + b1;
Matrix hidden_relu = relu(hidden);
Matrix hidden_relu_W2 = hidden_relu * W2;
Matrix y = hidden_relu_W2 + b2;
return y;
}
三、完整代码
#include <iostream>
#include <vector>
#include <random>
#include <stdexcept>
// 矩阵类定义,封装矩阵操作
class Matrix {
public:
int rows; // 矩阵行数
int cols; // 矩阵列数
std::vector<std::vector<float>> data; // 矩阵数据存储
// 初始化矩阵维度
Matrix(int r, int c, float init_val = 0.0f)
: rows(r), cols(c), data(r, std::vector<float>(c, init_val)) {
}
// 矩阵乘法
Matrix multiply(const Matrix& other) const {
Matrix result(this->rows, other.cols, 0.0f);
for (int i = 0; i < this->rows; ++i) {
for (int j = 0; j < other.cols; ++j) {
float sum = 0.0f;
for (int k = 0; k < this->cols; ++k) {
sum += this->data[i][k] * other.data[k][j];
}
result.data[i][j] = sum;
}
}
return result;
}
// 矩阵加法:支持矩阵+矩阵、矩阵+行向量
Matrix add(const Matrix& other) const {
// 同维度矩阵相加
if (this->rows == other.rows && this->cols == other.cols) {
Matrix result(this->rows, this->cols, 0.0f);
for (int i = 0; i < this->rows; ++i) {
for (int j = 0; j < this->cols; ++j) {
result.data[i][j] = this->data[i][j] + other.data[i][j];
}
}
return result;
}
// 矩阵 + 行向量
else if (other.rows == 1 && this->cols == other.cols) {
Matrix result(this->rows, this->cols, 0.0f);
for (int i = 0; i < this->rows; ++i) {
for (int j = 0; j < this->cols; ++j) {
result.data[i][j] = this->data[i][j] + other.data[0][j];
}
}
return result;
}
throw std::invalid_argument("矩阵加法维度不匹配");
}
};
// 重载运算符,简化矩阵运算语法
Matrix operator*(const Matrix& a, const Matrix& b) {
return a.multiply(b);
}
Matrix operator+(const Matrix& a, const Matrix& b) {
return a.add(b);
}
// ReLU激活:负数置0,0和正数保持不变
Matrix relu(const Matrix& input) {
Matrix output(input.rows, input.cols);
for (int i = 0; i < input.rows; ++i) {
for (int j = 0; j < input.cols; ++j) {
output.data[i][j] = std::max(0.0f, input.data[i][j]);
}
}
return output;
}
// 初始化权重/偏置
Matrix init_weight(int rows, int cols, float mean = 0.0f, float stddev = 0.01f) {
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> dist(mean, stddev);
Matrix mat(rows, cols);
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
mat.data[i][j] = dist(gen);
}
}
return mat;
}
// FFN核心:y = max(0, xW1 + b1)W2 + b2
Matrix ffn(const Matrix& x, const Matrix& W1, const Matrix& W2, const Matrix& b1, const Matrix& b2) {
Matrix xW1 = x * W1;
Matrix hidden = xW1 + b1;
Matrix hidden_relu = relu(hidden);
Matrix hidden_relu_W2 = hidden_relu * W2;
Matrix y = hidden_relu_W2 + b2;
return y;
}
int main() {
try {
const int D = 512;
Matrix W1 = init_weight(D, 4 * D);
Matrix W2 = init_weight(4 * D, D);
Matrix b1 = init_weight(1, 4 * D);
Matrix b2 = init_weight(1, D);
Matrix x(4, D);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
for (int i = 0; i < 4; ++i)
for (int j = 0; j < D; ++j)
x.data[i][j] = dist(gen);
// 执行FFN
Matrix y = ffn(x, W1, W2, b1, b2);
// 维度验证
std::cout << "x维度:" << x.rows << "×" << x.cols << std::endl;
std::cout << "y维度:" << y.rows << "×" << y.cols << std::endl;
// ReLU验证
Matrix hidden = x * W1 + b1;
Matrix hidden_relu = relu(hidden);
std::cout << "\nReLU前:" << std::endl;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j)
std::cout << hidden.data[i][j] << "\t";
std::cout << std::endl;
}
std::cout << "ReLU后:" << std::endl;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j)
std::cout << hidden_relu.data[i][j] << "\t";
std::cout << std::endl;
}
}
catch (const std::exception& e) {
std::cerr << "错误:" << e.what() << std::endl;
return 1;
}
return 0;
}