【从零开始】9. RAG 应用调优-再续(番外篇)

书接上回,上节我们使用了 torch.multiprocessing 的 Pool(池化)技术实现了多进程并发处理。本来我以为可以到此为止了。但实际上我还忽略了一件事...

那就是并发线程数超过可提供进程时会存在大量等待。在资源争抢的情况下会出现显存不足最后爆出 CUDA out of memory 的错误。但我明明已经写了显存检测代码了,为什么不生效呢?代码如下:

python 复制代码
...

class cuda_tools:

    _instance = None
    _initialized = False 

    def __init__(self):
        if not cuda_tools._initialized:
            # 设置显存使用率 80% 为阈值
            self.threshold_percentage = 80
            cuda_tools._initialized = True

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def check_and_clean_gpu_memory(self):
        """
        检查 GPU 显存使用情况,并在使用率高于 self.threshold_percentage 时清理 GPU 显存。

        返回:
            bool: 如果使用率高于 self.threshold_percentage,则返回True,否则返回False。
        """

        if torch.cuda.is_available():
            device = torch.cuda.current_device()
            memory_allocated = torch.cuda.memory_allocated(device)
            memory_total = torch.cuda.get_device_properties(device).total_memory
            # 计算显存使用百分比
            memory_used_percent = (memory_allocated / memory_total) * 100
            return memory_used_percent > self.threshold_percentage
        return False

看上去这段代码好像没有什么问题,但在大神的指点下发现其实会引发两个问题:

  • torch.cuda.memory_allocated() 只能获取到当前被 PyTorch 分配的显存,不包括 CUDA 缓存的显存;
  • 多进程情况下,每个进程都会独立加载模型,无法共享显存;

好吧,既然要改了那就稍微改得更精细一点吧,最起码要做到以下 5 点:

  1. 实时监控:使用单独的线程实时监控每个进程的显存使用情况
  2. 精确控制:使用 pynvml 获取更准确的显存使用信息
  3. 自动清理:当显存使用超过阈值时自动进行清理
  4. 进程隔离:每个进程独立控制自己的显存使用
  5. 安全退出:确保资源正确释放

以下是修改后的代码:

python 复制代码
...
class gpu_memory_manager:
    def __init__(self, check_interval: float = 60):
        """
        构造函数。

        参数:
            check_interval (float): 监控间隔,单位为秒,默认为60秒。

        属性:
            threshold_percentage (float): 内存阈值,单位为百分比,超过该值将清理内存。
            check_interval (float): 监控间隔,单位为秒。
            process_id (int): 当前进程ID。
            _stop_monitoring (threading.Event): 停止监控的事件。
            _monitor_thread (Optional[threading.Thread]): 监控线程。
            handle (nvmlDevice_t): NVML 设备句柄。
        """
        # 阈值
        self.threshold_percentage = 15
        # 检查间隔
        self.check_interval = check_interval
        # 进程ID
        self.process_id = os.getpid()
        # 停止监控
        self._stop_monitoring = threading.Event()
        # 监控线程
        self._monitor_thread: Optional[threading.Thread] = None
        
        # 初始化 NVML
        nvmlInit()
        # 获取设备句柄
        self.handle = nvmlDeviceGetHandleByIndex(0) 
        
    def _get_process_memory_info(self):
        """
        获取当前进程使用的 GPU 内存大小。
        该方法使用 NVML 库获取当前进程使用的 GPU 内存大小。

        返回:
            int: 当前进程使用的 GPU 内存大小,单位为字节。
        """
        try:
            # 获取当前进程的 GPU 内存使用情况
            processes = nvmlDeviceGetComputeRunningProcesses(self.handle)
            for process in processes:
                if process.pid == self.process_id:
                    return process.usedGpuMemory
            return 0
        except:
            return 0
            
    def _get_total_memory(self):
        """
        获取当前 GPU 的总内存大小。
        该方法使用 NVML 库获取当前 GPU 的总内存大小。

        返回:
            int: 当前 GPU 的总内存大小,单位为字节。
        """
        info = nvmlDeviceGetMemoryInfo(self.handle)
        return info.total
        
    def _monitor_memory(self):
        """
        一个循环线程,用于监控当前进程的 GPU 内存使用情况。
        如果当前进程的 GPU 内存使用率超过了 `threshold_percentage`,则执行 `clean_memory` 方法来回收 GPU 内存。
        """
        while not self._stop_monitoring.is_set():
            process_memory = self._get_process_memory_info()
            total_memory = self._get_total_memory()
            
            # 只在进程实际使用显存时检查
            if process_memory > 0:  
                usage_percentage = (process_memory / total_memory) * 100
                # 如果当前进程的 GPU 内存使用率超过了阈值,则执行 clean_memory
                if usage_percentage > self.threshold_percentage:
                    print(f"Process {self.process_id} memory usage ({usage_percentage:.2f}%) exceeded threshold ({self.threshold_percentage}%)")
                    self.clean_memory()
            time.sleep(self.check_interval)
            
    def start_monitoring(self):
        """
        启动监控线程,用于监控当前进程的 GPU 内存使用情况。
        如果当前进程的 GPU 内存使用率超过了 `threshold_percentage`,则执行 `clean_memory` 方法来回收 GPU 内存。
        """
        if self._monitor_thread is None:
            # 清空停止事件
            self._stop_monitoring.clear()
            # 启动监控线程
            self._monitor_thread = threading.Thread(target=self._monitor_memory, daemon=True)
            # 启动线程
            self._monitor_thread.start()
            
    def stop_monitoring(self):
        """
        停止监控线程,确保当前进程的 GPU 内存监控安全退出。
        如果监控线程正在运行,则设置停止事件,并等待线程结束。
        """
        if self._monitor_thread is not None:
            self._stop_monitoring.set()
            self._monitor_thread.join()
            self._monitor_thread = None
            
    def clean_memory(self):
        """
        清理当前进程的 GPU 内存缓存。

        该方法通过调用 PyTorch 的 `torch.cuda.empty_cache()` 来释放未使用的 GPU 内存缓存,
        并通过 `gc.collect()` 强制进行垃圾回收,以便更好地管理内存。

        注意:此方法不会影响已经分配给张量的内存,只会清除未使用缓存。
        """
        torch.cuda.empty_cache()
        gc.collect()
        
    def __del__(self):
        """
        析构函数,用于释放资源。
        该方法在对象销毁时被调用,用于停止监控线程,并释放 NVML 库的资源。
        """
        self.stop_monitoring()
        try:
            # 释放 NVML 资源
            nvmlShutdown()
        except:
            pass

使用的时候也比较简单,如下图:

python 复制代码
class xxx:
    ...
    
    def __init__(self):
        ...
        self.memory_manager = gpu_memory_manager()
        self.memory_manager.start_monitoring()

    def generate_msg():
        try:
            ...
        finally:
            # 确保清理显存
            self.memory_manager.clean_memory()

    def __del__(self):
        self.memory_manager.stop_monitoring()

好了显存监控已经做好了,赶紧实操一下吧。

还是用回之前的压测工具(pressure_util.py),不过这次将 5 线程改为 10 线程,让其超出进程负荷出现资源争抢的情况,结果如下:

bash 复制代码
- 压测完成,运行时间: 726.09秒
- 共完成 100 个任务
- 平均QPS: 0.14
- 清理完成

通过 nvtop 能够得知

进程执行完成后就会回收显存了

这个时间是多了 300 秒,这可能是因为频繁查询显存和清理缓存造成的吧,稍微多做几次调整成合适自己的检查间隔就好了,起码现在无论多少用户进来都不至于爆显存了。

相关推荐
AI英德西牛仔1 分钟前
deepseek导出word排版
人工智能·ai·chatgpt·deepseek·ds随心转
(; ̄ェ ̄)。1 分钟前
深度学习入门(十)RNN、LSTM、GRU
人工智能·rnn·深度学习
谁在黄金彼岸13 分钟前
构建一个多Agent系统(Multi-Agent System, MAS)方法论
人工智能
pandafeeder18 分钟前
Agent工具调用范式:ReAct 和Function Calling
人工智能
jinanwuhuaguo18 分钟前
OpenClaw字节跳动的三只不同的claw龙虾飞书妙搭 OpenClaw、ArkClaw、扣子 OpenClaw 核心区别深度解析
人工智能·语言模型·自然语言处理·visual studio code·openclaw
咚咚王者26 分钟前
人工智能之语言领域 自然语言处理 第十八章 Python NLP生态
人工智能·python·自然语言处理
yeflx26 分钟前
三维空间坐标转换早期笔记
人工智能·算法·机器学习
zzh9407727 分钟前
Gemini 3.1 Pro 2026年国内使用指南:技术解析与镜像站实测
人工智能
初学大模型27 分钟前
基于三层架构的自动驾驶系统设计:环境建模、标准驾驶与风险调制
人工智能
●VON27 分钟前
半小时从零开发鸿蒙记事本应用:AI辅助开发实战
人工智能·华为·harmonyos