FFN前馈网络C++实现

一、代码整体要求

FFN的完整计算公式y = max(0, xW1 + b1)W2 + b2,代码中固定特征维度D=512,确保各参数维度精准匹配,其中W1为512×2048(即D×4D)、W2为2048×512(即4D×D)、b1为1×2048(即4D)、b2为1×512(即D)。

实现ReLU激活函数,该函数支持矩阵输入,比如4×2048维度的矩阵,主要是将输入中的负数置为0,正数和0则保持不变。

FFN函数需实现完整功能,能够接收4×512维度的输入矩阵,经过非线性转换后,输出维度需与输入保持一致,仍为4×512,以此完成对自注意力输出特征的增强。

二、代码核心部分解释

1. 矩阵类核心解释

矩阵类是整个FFN实现的基础,主要封装了FFN运算过程中所需的各类矩阵操作,能够精准适配各参数的维度需求。

其属性包含矩阵的行数、列数,以及用于存储矩阵数据的二维向量,通过构造函数可完成矩阵维度和初始值的初始化,初始值支持自定义设置,默认情况下为0.0f。

矩阵乘法方法是核心功能之一,主要用于实现公式中的各类线性变换,比如输入x与W1的乘法、ReLU激活后的隐藏层特征与W2的乘法,在运算开始前会先检查维度,确保当前矩阵的列数与另一个矩阵的行数相等,运算过程通过三重循环完成矩阵元素的累加计算。

矩阵加法方法支持两种合法运算,一种是同维度矩阵之间的相加,另一种是矩阵与行向量的广播相加,这种广播相加能够很好地适配偏置b1和b2的运算需求,比如4×2048的矩阵与1×2048的行向量相。

复制代码
// 矩阵类定义,封装矩阵操作
class Matrix {
public:
    int rows;          // 矩阵行数
    int cols;          // 矩阵列数
    std::vector<std::vector<float>> data;  // 矩阵数据存储

    // 初始化矩阵维度
    Matrix(int r, int c, float init_val = 0.0f)
        : rows(r), cols(c), data(r, std::vector<float>(c, init_val)) {
    }

    // 矩阵乘法
    Matrix multiply(const Matrix& other) const {
        Matrix result(this->rows, other.cols, 0.0f);
        for (int i = 0; i < this->rows; ++i) {
            for (int j = 0; j < other.cols; ++j) {
                float sum = 0.0f;
                for (int k = 0; k < this->cols; ++k) {
                    sum += this->data[i][k] * other.data[k][j];
                }
                result.data[i][j] = sum;
            }
        }
        return result;
    }

    // 矩阵加法:支持矩阵+矩阵、矩阵+行向量
    Matrix add(const Matrix& other) const {
        // 同维度矩阵相加
        if (this->rows == other.rows && this->cols == other.cols) {
            Matrix result(this->rows, this->cols, 0.0f);
            for (int i = 0; i < this->rows; ++i) {
                for (int j = 0; j < this->cols; ++j) {
                    result.data[i][j] = this->data[i][j] + other.data[i][j];
                }
            }
            return result;
        }
        // 矩阵 + 行向量
        else if (other.rows == 1 && this->cols == other.cols) {
            Matrix result(this->rows, this->cols, 0.0f);
            for (int i = 0; i < this->rows; ++i) {
                for (int j = 0; j < this->cols; ++j) {
                    result.data[i][j] = this->data[i][j] + other.data[0][j];
                }
            }
            return result;
        }
        throw std::invalid_argument("矩阵加法维度不匹配");
    }
};

// 重载运算符,简化矩阵运算语法
Matrix operator*(const Matrix& a, const Matrix& b) {
    return a.multiply(b);
}

Matrix operator+(const Matrix& a, const Matrix& b) {
    return a.add(b);
}

2. ReLU激活函数核心解释

该函数可接收任意维度的矩阵输入,输出矩阵的维度与输入完全一致,是为了确保不改变特征的维度结构。

代码通过双重循环逐元素来处理输入矩阵,使用max(),来实现负数置0、正数和0保持不变"。

这个函数主要是对xW1 + b1的运算结果进行非线性过滤,从而增强模型的表达能力,避免线性变换带来的局限性。

复制代码
// ReLU激活:负数置0,0和正数保持不变
Matrix relu(const Matrix& input) {
    Matrix output(input.rows, input.cols);
    for (int i = 0; i < input.rows; ++i) {
        for (int j = 0; j < input.cols; ++j) {
            output.data[i][j] = std::max(0.0f, input.data[i][j]);
        }
    }
    return output;
}

3. FFN核心函数解释

FFN核心函数严格遵循既定的数学公式,将线性变换和ReLU激活函数串联起来,逐步完成对输入特征的非线性转换。

运算第一步是计算xW1 + b1,先将输入x(4×512维度)与W1(512×2048维度)进行矩阵乘法,得到4×2048维度的中间结果,再将该中间结果与b1(1×2048维度)进行广播相加,完成第一次线性变换。

运算第二步是ReLU激活,将第一步得到的中间结果传入ReLU激活函数,通过逐元素处理过滤掉所有负数,得到经过非线性变换后的隐藏层特征,该特征的维度仍为4×2048。

运算第三步是计算hidden_relu×W2 + b2,将ReLU激活后的隐藏层特征(4×2048维度)与W2(2048×512维度)进行矩阵乘法,得到4×512维度的结果,再将该结果与b2(1×512维度)进行广播相加,完成第二次线性变换,最终得到FFN的输出结果,输出维度与输入维度一致,均为4×512。

复制代码
// 初始化权重/偏置
Matrix init_weight(int rows, int cols, float mean = 0.0f, float stddev = 0.01f) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::normal_distribution<float> dist(mean, stddev);

    Matrix mat(rows, cols);
    for (int i = 0; i < rows; ++i) {
        for (int j = 0; j < cols; ++j) {
            mat.data[i][j] = dist(gen);
        }
    }
    return mat;
}

// FFN核心:y = max(0, xW1 + b1)W2 + b2
Matrix ffn(const Matrix& x, const Matrix& W1, const Matrix& W2, const Matrix& b1, const Matrix& b2) {
    Matrix xW1 = x * W1;
    Matrix hidden = xW1 + b1;
    Matrix hidden_relu = relu(hidden);
    Matrix hidden_relu_W2 = hidden_relu * W2;
    Matrix y = hidden_relu_W2 + b2;
    return y;
}

三、完整代码

复制代码
#include <iostream>
#include <vector>
#include <random>
#include <stdexcept>

// 矩阵类定义,封装矩阵操作
class Matrix {
public:
    int rows;          // 矩阵行数
    int cols;          // 矩阵列数
    std::vector<std::vector<float>> data;  // 矩阵数据存储

    // 初始化矩阵维度
    Matrix(int r, int c, float init_val = 0.0f)
        : rows(r), cols(c), data(r, std::vector<float>(c, init_val)) {
    }

    // 矩阵乘法
    Matrix multiply(const Matrix& other) const {
        Matrix result(this->rows, other.cols, 0.0f);
        for (int i = 0; i < this->rows; ++i) {
            for (int j = 0; j < other.cols; ++j) {
                float sum = 0.0f;
                for (int k = 0; k < this->cols; ++k) {
                    sum += this->data[i][k] * other.data[k][j];
                }
                result.data[i][j] = sum;
            }
        }
        return result;
    }

    // 矩阵加法:支持矩阵+矩阵、矩阵+行向量
    Matrix add(const Matrix& other) const {
        // 同维度矩阵相加
        if (this->rows == other.rows && this->cols == other.cols) {
            Matrix result(this->rows, this->cols, 0.0f);
            for (int i = 0; i < this->rows; ++i) {
                for (int j = 0; j < this->cols; ++j) {
                    result.data[i][j] = this->data[i][j] + other.data[i][j];
                }
            }
            return result;
        }
        // 矩阵 + 行向量
        else if (other.rows == 1 && this->cols == other.cols) {
            Matrix result(this->rows, this->cols, 0.0f);
            for (int i = 0; i < this->rows; ++i) {
                for (int j = 0; j < this->cols; ++j) {
                    result.data[i][j] = this->data[i][j] + other.data[0][j];
                }
            }
            return result;
        }
        throw std::invalid_argument("矩阵加法维度不匹配");
    }
};

// 重载运算符,简化矩阵运算语法
Matrix operator*(const Matrix& a, const Matrix& b) {
    return a.multiply(b);
}

Matrix operator+(const Matrix& a, const Matrix& b) {
    return a.add(b);
}

// ReLU激活:负数置0,0和正数保持不变
Matrix relu(const Matrix& input) {
    Matrix output(input.rows, input.cols);
    for (int i = 0; i < input.rows; ++i) {
        for (int j = 0; j < input.cols; ++j) {
            output.data[i][j] = std::max(0.0f, input.data[i][j]);
        }
    }
    return output;
}

// 初始化权重/偏置
Matrix init_weight(int rows, int cols, float mean = 0.0f, float stddev = 0.01f) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::normal_distribution<float> dist(mean, stddev);

    Matrix mat(rows, cols);
    for (int i = 0; i < rows; ++i) {
        for (int j = 0; j < cols; ++j) {
            mat.data[i][j] = dist(gen);
        }
    }
    return mat;
}

// FFN核心:y = max(0, xW1 + b1)W2 + b2
Matrix ffn(const Matrix& x, const Matrix& W1, const Matrix& W2, const Matrix& b1, const Matrix& b2) {
    Matrix xW1 = x * W1;
    Matrix hidden = xW1 + b1;
    Matrix hidden_relu = relu(hidden);
    Matrix hidden_relu_W2 = hidden_relu * W2;
    Matrix y = hidden_relu_W2 + b2;
    return y;
}

int main() {
    try {
        const int D = 512;
        Matrix W1 = init_weight(D, 4 * D);
        Matrix W2 = init_weight(4 * D, D);
        Matrix b1 = init_weight(1, 4 * D);
        Matrix b2 = init_weight(1, D);

        Matrix x(4, D);
        std::random_device rd;
        std::mt19937 gen(rd());
        std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
        for (int i = 0; i < 4; ++i)
            for (int j = 0; j < D; ++j)
                x.data[i][j] = dist(gen);

        // 执行FFN
        Matrix y = ffn(x, W1, W2, b1, b2);

        // 维度验证
        std::cout << "x维度:" << x.rows << "×" << x.cols << std::endl;
        std::cout << "y维度:" << y.rows << "×" << y.cols << std::endl;

        // ReLU验证
        Matrix hidden = x * W1 + b1;
        Matrix hidden_relu = relu(hidden);
        std::cout << "\nReLU前:" << std::endl;
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < 3; ++j)
                std::cout << hidden.data[i][j] << "\t";
            std::cout << std::endl;
        }
        std::cout << "ReLU后:" << std::endl;
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < 3; ++j)
                std::cout << hidden_relu.data[i][j] << "\t";
            std::cout << std::endl;
        }

    }
    catch (const std::exception& e) {
        std::cerr << "错误:" << e.what() << std::endl;
        return 1;
    }
    return 0;
}
相关推荐
wukangjupingbb7 小时前
AI在靶点识别(Target Identification)中的关键作用与开源工具生态
人工智能·开源
多恩Stone7 小时前
【3D AICG 系列-8】PartUV 流程图详解
人工智能·算法·3d·aigc·流程图
aiguangyuan7 小时前
基于BiLSTM-CRF的命名实体识别模型:原理剖析与实现详解
人工智能·python·nlp
恣逍信点7 小时前
《凌微经 · 理悖相涵》第七章 形性一体——本然如是之元观
人工智能·科技·学习·程序人生·生活·交友·哲学
stars-he7 小时前
AI工具配置学习笔记
人工智能·笔记·学习
Master_oid7 小时前
机器学习32:机器终生学习(Life Long Learning)
人工智能·学习·机器学习
芷栀夏7 小时前
CANN ops-math:为上层 AI 算子库提供核心支撑的基础计算模块深度拆解
人工智能·深度学习·transformer·cann
袁气满满~_~7 小时前
深度学习笔记三
人工智能·笔记·深度学习
风象南7 小时前
OpenSpec 与 Spec Kit 使用对比:规范驱动开发该选哪个?
人工智能