[OpenGL]使用 Compute Shader 实现矩阵点乘

一、简介

本文介绍了如何使用 OpenGL 中的 compute shader 进行矩阵相乘的并行运算。代码目标是,输入两个大小为 10*10 的矩阵 A 和 B,计算 A*B 的结果并存储到矩阵 C 中。

二、代码

0. 代码逻辑

1. 初始化 glfw, glad, 窗口
2. 初始化 compute shader
3. 准备输入数据
4. 运行 compute shader
5. 读取结果并打印
6. 释放资源

1. main.cpp

cpp 复制代码
#include <glad/glad.h>
#include <GLFW/glfw3.h>
#include "ComputeShader.hpp"

#include <windows.h>
#include <cstdint>
#include <iostream>
#include <iostream>

// 用于处理窗口大小改变的回调函数
void framebuffer_size_callback(GLFWwindow *window, int width, int height);
void window_close_callback(GLFWwindow *window);

// 用于处理用户输入的函数
void processInput(GLFWwindow *window);

// 指定窗口默认width和height像素大小
unsigned int SCR_WIDTH = 800;
unsigned int SCR_HEIGHT = 600;

/************************************/

long long GetCurrentTimeMicros()
{
    auto now = std::chrono::system_clock::now();
    auto duration = now.time_since_epoch();
    auto micros = std::chrono::duration_cast<std::chrono::microseconds>(duration).count();
    return micros;
}

int main()
{
    /****** 1. 初始化 glfw, glad, 窗口 *******/
    // glfw 初始化 + 配置 glfw 参数
    glfwInit();
    glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 4);
    glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
    glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);

    // glfw 生成窗口
    GLFWwindow *window = glfwCreateWindow(SCR_WIDTH, SCR_HEIGHT, "LearnOpenGL", NULL, NULL);
    if (window == NULL)
    {
        // 检查是否成功生成窗口,如果没有成功打印出错信息并且退出
        std::cout << "Failed to create GLFW window" << std::endl;
        glfwTerminate();
        return -1;
    }

    // 设置窗口window的上下文
    glfwMakeContextCurrent(window);
    // 配置window变化时的回调函数
    glfwSetFramebufferSizeCallback(window, framebuffer_size_callback);
    // 设置窗口关闭回调
    glfwSetWindowCloseCallback(window, window_close_callback);
    // 使用 glad 加载 OpenGL 中的各种函数
    if (!gladLoadGLLoader((GLADloadproc)glfwGetProcAddress))
    {
        std::cout << "Failed to initialize GLAD" << std::endl;
        return -1;
    }

    /************************************/

    /****** 2. 初始化 compute shader ******/

    ComputeShader computeShader("../../resources/Compute.comp");

    /************************************/

    /****** 3. 准备输入数据 ******/
    // 输入矩阵A
    float A[10][10];
    for (int i = 0; i < 10; i++)
    {
        for (int j = 0; j < 10; j++)
        {
            A[i][j] = 1.0f * i;
        }
    }
    // 输入矩阵B
    float B[10][10];
    for (int i = 0; i < 10; i++)
    {
        for (int j = 0; j < 10; j++)
        {
            B[i][j] = 1.0f * i;
        }
    }

    GLuint inputMatrixA, inputMatrixB, outputMatrixC;

    // 创建并设置矩阵 inputMatrixA 纹理
    glGenTextures(1, &inputMatrixA);
    glBindTexture(GL_TEXTURE_2D, inputMatrixA);
    glTexImage2D(
        GL_TEXTURE_2D, 0, GL_R32F, 10, 10, 0, GL_RED, GL_FLOAT,
        A); // inputMatrixA 为 10x10 的矩阵数据, 矩阵只包含一个通道,internal format 为 GL_R32F, format 为 GL_RED
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);

    // 创建并设置矩阵 inputMatrixB 纹理
    glGenTextures(1, &inputMatrixB);
    glBindTexture(GL_TEXTURE_2D, inputMatrixB);
    glTexImage2D(
        GL_TEXTURE_2D, 0, GL_R32F, 10, 10, 0, GL_RED, GL_FLOAT,
        B); // inputMatrixB 为 10x10 的矩阵数据, 矩阵只包含一个通道,internal format 为 GL_R32F, format 为 GL_RED
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);

    // 创建输出矩阵 outputMatrixC 纹理
    glGenTextures(1, &outputMatrixC);
    glBindTexture(GL_TEXTURE_2D, outputMatrixC);
    glTexImage2D(GL_TEXTURE_2D, 0, GL_R32F, 10, 10, 0, GL_RED, GL_FLOAT,
                 NULL); // outputMatrixC, 只包含一个通道,internal format 为 GL_R32F, format 为 GL_RED
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);

    computeShader.use();

    // 绑定输入纹理 inputMatrixA 和 inputMatrixB 到相应的 纹理单元0, 和 纹理单元1
    glActiveTexture(GL_TEXTURE0);
    glBindTexture(GL_TEXTURE_2D, inputMatrixA);
    computeShader.setInt("inputMatrixA", 0);

    glActiveTexture(GL_TEXTURE1);
    glBindTexture(GL_TEXTURE_2D, inputMatrixB);
    computeShader.setInt("inputMatrixB", 1);

    // 绑定输出矩阵 outputMatrixC 到相应的 图像单元2
    glBindImageTexture(2, outputMatrixC, 0, GL_FALSE, 0, GL_READ_WRITE, GL_R32F);
    /************************************/

    /****** 4. 运行 compute shader ******/
    // 运行 compute shader, 分为 10*10*1 个 workgroup, 每个 workgroup 计算 C 矩阵中的一个元素值
    glDispatchCompute((unsigned int)10, (unsigned int)10, 1);
    glMemoryBarrier(GL_SHADER_IMAGE_ACCESS_BARRIER_BIT);
    /************************************/

    /****** 5. 读取结果并打印 ******/
    // 绑定纹理 outputMatrixC
    glBindTexture(GL_TEXTURE_2D, outputMatrixC);
    float C[10][10];
    // 使用 glGetTexImage 从纹理中读取数据
    // GL_RED 是 颜色格式, format, format 描述了像素数据在客户端内存中的格式(连同 type 参数)
    // GL_R32F 是 内部格式, internal format, internal format 描述了纹理在 GPU 中的存储方式
    glGetTexImage(GL_TEXTURE_2D, 0, GL_RED, GL_FLOAT, C);

    // 打印矩阵 C 的部分数据
    for (int i = 0; i < 10; i++)
    {
        for (int j = 0; j < 10; j++)
        {
            std::cout << C[i][j] << " ";
        }
        std::cout << std::endl;
    }
    /************************************/

    /****** 6.释放资源 ******/
    // glfw 释放 glfw使用的所有资源
    glfwTerminate();
    /************************************/
    return 0;
}

// 用于处理用户输入的函数
void processInput(GLFWwindow *window)
{
    // 当按下 Esc 按键时调用 glfwSetWindowShouldClose() 函数,关闭窗口
    if (glfwGetKey(window, GLFW_KEY_ESCAPE) == GLFW_PRESS)
    {
        glfwSetWindowShouldClose(window, true);
    }
}

// 在使用 OpenGL 和 GLFW 库时,处理窗口大小改变的回调函数
// 当窗口大小发生变化时,确保 OpenGL 渲染的内容能够适应新的窗口大小,避免图像被拉伸、压缩或出现其他比例失真的问题
void framebuffer_size_callback(GLFWwindow *window, int width, int height)
{
    SCR_WIDTH = width;
    SCR_HEIGHT = height;
    glViewport(0, 0, width, height);
}
void window_close_callback(GLFWwindow *window)
{
    // 这里可以做一些额外的清理工作
    // 例如释放资源、记录日志等
    std::cout << "Window is closing..." << std::endl;
}

2. ComputeShader 类

cpp 复制代码
#ifndef COMPUTESHADER_H
#define COMPUTESHADER_H

#include <glad/glad.h>
#include <glm/glm.hpp>

#include <string>
#include <fstream>
#include <sstream>
#include <iostream>

class ComputeShader
{
  public:
    unsigned int ID;
    // constructor generates the shader on the fly
    // ------------------------------------------------------------------------
    ComputeShader() {};

    ComputeShader(const char *computePath)
    {
        // 1. retrieve the vertex/fragment source code from filePath
        std::string computeCode;

        std::ifstream cShaderFile;

        // ensure ifstream objects can throw exceptions:
        cShaderFile.exceptions(std::ifstream::failbit | std::ifstream::badbit);
        try
        {
            // open files
            cShaderFile.open(computePath);
            std::stringstream cShaderStream;
            // read file's buffer contents into streams
            cShaderStream << cShaderFile.rdbuf();
            // close file handlers
            cShaderFile.close();
            // convert stream into string
            computeCode = cShaderStream.str();
        }
        catch (std::ifstream::failure &e)
        {
            std::cout << "ERROR::SHADER::FILE_NOT_SUCCESSFULLY_READ: " << e.what() << std::endl;
        }

        const char *cShaderCode = computeCode.c_str();

        // 2. compile shaders
        unsigned int compute;
        // compute shader
        compute = glCreateShader(GL_COMPUTE_SHADER);
        glShaderSource(compute, 1, &cShaderCode, NULL);
        glCompileShader(compute);

        checkCompileErrors(compute, "COMPUTE");
        // shader Program
        ID = glCreateProgram();
        glAttachShader(ID, compute);

        glLinkProgram(ID);
        checkCompileErrors(ID, "PROGRAM");
        // delete the shaders as they're linked into our program now and no longer necessary
        glDeleteShader(compute);
    }

    // activate the shader
    // ------------------------------------------------------------------------
    void use() const
    {
        glUseProgram(ID);
    }
    // ------------------------------------------------------------------------
    void setInt(const std::string &name, int value) const
    {
        glUniform1i(glGetUniformLocation(ID, name.c_str()), value);
    }
  private:
    // utility function for checking shader compilation/linking errors.
    // ------------------------------------------------------------------------
    void checkCompileErrors(GLuint shader, std::string type)
    {
        GLint success;
        GLchar infoLog[1024];
        if (type != "PROGRAM")
        {
            glGetShaderiv(shader, GL_COMPILE_STATUS, &success);
            if (!success)
            {
                glGetShaderInfoLog(shader, 1024, NULL, infoLog);
                std::cout << "ERROR::SHADER_COMPILATION_ERROR of type: " << type << "\n"
                          << infoLog << "\n -- --------------------------------------------------- -- " << std::endl;
            }
        }
        else
        {
            glGetProgramiv(shader, GL_LINK_STATUS, &success);
            if (!success)
            {
                glGetProgramInfoLog(shader, 1024, NULL, infoLog);
                std::cout << "ERROR::PROGRAM_LINKING_ERROR of type: " << type << "\n"
                          << infoLog << "\n -- --------------------------------------------------- -- " << std::endl;
            }
        }
    }
};
#endif

3. compute shader (Compute.comp)

cpp 复制代码
#version 430

// 输入纹理(矩阵 A 和矩阵 B)
// 只进行读取操作,使用 sampler2D
layout(binding = 0) uniform sampler2D inputMatrixA; // 输入矩阵 A
layout(binding = 1) uniform sampler2D inputMatrixB; // 输入矩阵 B

// 输出图像(矩阵 C)
// 进行写入操作,使用 image2D
layout(binding = 2, r32f) uniform image2D outputMatrixC; // 输出矩阵 C

layout(local_size_x = 1,
       local_size_y = 1) in; // 每个 workgroup item 计算 C 的一个元素

void main() {
  // 获取当前 workgroup item 的全局位置
  uint row = gl_GlobalInvocationID.x;
  uint col = gl_GlobalInvocationID.y;

  // 确保不会越界
  if (row >= 10 || col >= 10) {
    return;
  }

  // 从矩阵 A 和矩阵 B 中读取数据
  float valueA = 0.0f;
  float valueB = 0.0f;

  // 计算矩阵 C 中对应的元素
  float result = 0.0;
  for (int k = 0; k < 10; k++) {
    valueA = texelFetch(inputMatrixA, ivec2(row, k), 0).r;
    valueB = texelFetch(inputMatrixB, ivec2(k, col), 0).r;
    result += valueA * valueB; // 矩阵乘法
  }
  // 将结果写入输出图像矩阵 C
  imageStore(outputMatrixC, ivec2(row, col),
             vec4(result, 0.0, 0.0, 0.0)); // 存储结果
}

4. 运行结果

cpp 复制代码
Input A:
0 0 0 0 0 0 0 0 0 0 
1 1 1 1 1 1 1 1 1 1 
2 2 2 2 2 2 2 2 2 2 
3 3 3 3 3 3 3 3 3 3 
4 4 4 4 4 4 4 4 4 4 
5 5 5 5 5 5 5 5 5 5 
6 6 6 6 6 6 6 6 6 6 
7 7 7 7 7 7 7 7 7 7 
8 8 8 8 8 8 8 8 8 8 
9 9 9 9 9 9 9 9 9 9 
Input B:
0 0 0 0 0 0 0 0 0 0 
1 1 1 1 1 1 1 1 1 1 
2 2 2 2 2 2 2 2 2 2 
3 3 3 3 3 3 3 3 3 3 
4 4 4 4 4 4 4 4 4 4 
5 5 5 5 5 5 5 5 5 5 
6 6 6 6 6 6 6 6 6 6 
7 7 7 7 7 7 7 7 7 7 
8 8 8 8 8 8 8 8 8 8 
9 9 9 9 9 9 9 9 9 9 
Result C=A*B:
0 0 0 0 0 0 0 0 0 0 
45 45 45 45 45 45 45 45 45 45 
90 90 90 90 90 90 90 90 90 90 
135 135 135 135 135 135 135 135 135 135 
180 180 180 180 180 180 180 180 180 180 
225 225 225 225 225 225 225 225 225 225 
270 270 270 270 270 270 270 270 270 270 
315 315 315 315 315 315 315 315 315 315 
360 360 360 360 360 360 360 360 360 360 
405 405 405 405 405 405 405 405 405 405 

三、参考

[1]LearnOpenGL-Guest Articles-2022-Compute Shaders

相关推荐
code04号2 小时前
C++练习:图论的两种遍历方式
开发语言·c++·图论
煤泥做不到的!4 小时前
挑战一个月基本掌握C++(第十一天)进阶文件,异常处理,动态内存
开发语言·c++
F-2H4 小时前
C语言:指针4(常量指针和指针常量及动态内存分配)
java·linux·c语言·开发语言·前端·c++
axxy20004 小时前
leetcode之hot100---24两两交换链表中的节点(C++)
c++·leetcode·链表
若亦_Royi5 小时前
C++ 的大括号的用法合集
开发语言·c++
ragnwang9 小时前
C++ Eigen常见的高级用法 [学习笔记]
c++·笔记·学习
lqqjuly12 小时前
特殊的“Undefined Reference xxx“编译错误
c语言·c++
冰红茶兑滴水12 小时前
云备份项目--工具类编写
linux·c++
酒鬼猿13 小时前
C++进阶(二)--面向对象--继承
java·开发语言·c++