【从零开始】9. RAG 应用调优-再续(番外篇)

书接上回,上节我们使用了 torch.multiprocessing 的 Pool(池化)技术实现了多进程并发处理。本来我以为可以到此为止了。但实际上我还忽略了一件事...

那就是并发线程数超过可提供进程时会存在大量等待。在资源争抢的情况下会出现显存不足最后爆出 CUDA out of memory 的错误。但我明明已经写了显存检测代码了,为什么不生效呢?代码如下:

python 复制代码
...

class cuda_tools:

    _instance = None
    _initialized = False 

    def __init__(self):
        if not cuda_tools._initialized:
            # 设置显存使用率 80% 为阈值
            self.threshold_percentage = 80
            cuda_tools._initialized = True

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def check_and_clean_gpu_memory(self):
        """
        检查 GPU 显存使用情况,并在使用率高于 self.threshold_percentage 时清理 GPU 显存。

        返回:
            bool: 如果使用率高于 self.threshold_percentage,则返回True,否则返回False。
        """

        if torch.cuda.is_available():
            device = torch.cuda.current_device()
            memory_allocated = torch.cuda.memory_allocated(device)
            memory_total = torch.cuda.get_device_properties(device).total_memory
            # 计算显存使用百分比
            memory_used_percent = (memory_allocated / memory_total) * 100
            return memory_used_percent > self.threshold_percentage
        return False

看上去这段代码好像没有什么问题,但在大神的指点下发现其实会引发两个问题:

  • torch.cuda.memory_allocated() 只能获取到当前被 PyTorch 分配的显存,不包括 CUDA 缓存的显存;
  • 多进程情况下,每个进程都会独立加载模型,无法共享显存;

好吧,既然要改了那就稍微改得更精细一点吧,最起码要做到以下 5 点:

  1. 实时监控:使用单独的线程实时监控每个进程的显存使用情况
  2. 精确控制:使用 pynvml 获取更准确的显存使用信息
  3. 自动清理:当显存使用超过阈值时自动进行清理
  4. 进程隔离:每个进程独立控制自己的显存使用
  5. 安全退出:确保资源正确释放

以下是修改后的代码:

python 复制代码
...
class gpu_memory_manager:
    def __init__(self, check_interval: float = 60):
        """
        构造函数。

        参数:
            check_interval (float): 监控间隔,单位为秒,默认为60秒。

        属性:
            threshold_percentage (float): 内存阈值,单位为百分比,超过该值将清理内存。
            check_interval (float): 监控间隔,单位为秒。
            process_id (int): 当前进程ID。
            _stop_monitoring (threading.Event): 停止监控的事件。
            _monitor_thread (Optional[threading.Thread]): 监控线程。
            handle (nvmlDevice_t): NVML 设备句柄。
        """
        # 阈值
        self.threshold_percentage = 15
        # 检查间隔
        self.check_interval = check_interval
        # 进程ID
        self.process_id = os.getpid()
        # 停止监控
        self._stop_monitoring = threading.Event()
        # 监控线程
        self._monitor_thread: Optional[threading.Thread] = None
        
        # 初始化 NVML
        nvmlInit()
        # 获取设备句柄
        self.handle = nvmlDeviceGetHandleByIndex(0) 
        
    def _get_process_memory_info(self):
        """
        获取当前进程使用的 GPU 内存大小。
        该方法使用 NVML 库获取当前进程使用的 GPU 内存大小。

        返回:
            int: 当前进程使用的 GPU 内存大小,单位为字节。
        """
        try:
            # 获取当前进程的 GPU 内存使用情况
            processes = nvmlDeviceGetComputeRunningProcesses(self.handle)
            for process in processes:
                if process.pid == self.process_id:
                    return process.usedGpuMemory
            return 0
        except:
            return 0
            
    def _get_total_memory(self):
        """
        获取当前 GPU 的总内存大小。
        该方法使用 NVML 库获取当前 GPU 的总内存大小。

        返回:
            int: 当前 GPU 的总内存大小,单位为字节。
        """
        info = nvmlDeviceGetMemoryInfo(self.handle)
        return info.total
        
    def _monitor_memory(self):
        """
        一个循环线程,用于监控当前进程的 GPU 内存使用情况。
        如果当前进程的 GPU 内存使用率超过了 `threshold_percentage`,则执行 `clean_memory` 方法来回收 GPU 内存。
        """
        while not self._stop_monitoring.is_set():
            process_memory = self._get_process_memory_info()
            total_memory = self._get_total_memory()
            
            # 只在进程实际使用显存时检查
            if process_memory > 0:  
                usage_percentage = (process_memory / total_memory) * 100
                # 如果当前进程的 GPU 内存使用率超过了阈值,则执行 clean_memory
                if usage_percentage > self.threshold_percentage:
                    print(f"Process {self.process_id} memory usage ({usage_percentage:.2f}%) exceeded threshold ({self.threshold_percentage}%)")
                    self.clean_memory()
            time.sleep(self.check_interval)
            
    def start_monitoring(self):
        """
        启动监控线程,用于监控当前进程的 GPU 内存使用情况。
        如果当前进程的 GPU 内存使用率超过了 `threshold_percentage`,则执行 `clean_memory` 方法来回收 GPU 内存。
        """
        if self._monitor_thread is None:
            # 清空停止事件
            self._stop_monitoring.clear()
            # 启动监控线程
            self._monitor_thread = threading.Thread(target=self._monitor_memory, daemon=True)
            # 启动线程
            self._monitor_thread.start()
            
    def stop_monitoring(self):
        """
        停止监控线程,确保当前进程的 GPU 内存监控安全退出。
        如果监控线程正在运行,则设置停止事件,并等待线程结束。
        """
        if self._monitor_thread is not None:
            self._stop_monitoring.set()
            self._monitor_thread.join()
            self._monitor_thread = None
            
    def clean_memory(self):
        """
        清理当前进程的 GPU 内存缓存。

        该方法通过调用 PyTorch 的 `torch.cuda.empty_cache()` 来释放未使用的 GPU 内存缓存,
        并通过 `gc.collect()` 强制进行垃圾回收,以便更好地管理内存。

        注意:此方法不会影响已经分配给张量的内存,只会清除未使用缓存。
        """
        torch.cuda.empty_cache()
        gc.collect()
        
    def __del__(self):
        """
        析构函数,用于释放资源。
        该方法在对象销毁时被调用,用于停止监控线程,并释放 NVML 库的资源。
        """
        self.stop_monitoring()
        try:
            # 释放 NVML 资源
            nvmlShutdown()
        except:
            pass

使用的时候也比较简单,如下图:

python 复制代码
class xxx:
    ...
    
    def __init__(self):
        ...
        self.memory_manager = gpu_memory_manager()
        self.memory_manager.start_monitoring()

    def generate_msg():
        try:
            ...
        finally:
            # 确保清理显存
            self.memory_manager.clean_memory()

    def __del__(self):
        self.memory_manager.stop_monitoring()

好了显存监控已经做好了,赶紧实操一下吧。

还是用回之前的压测工具(pressure_util.py),不过这次将 5 线程改为 10 线程,让其超出进程负荷出现资源争抢的情况,结果如下:

bash 复制代码
- 压测完成,运行时间: 726.09秒
- 共完成 100 个任务
- 平均QPS: 0.14
- 清理完成

通过 nvtop 能够得知

进程执行完成后就会回收显存了

这个时间是多了 300 秒,这可能是因为频繁查询显存和清理缓存造成的吧,稍微多做几次调整成合适自己的检查间隔就好了,起码现在无论多少用户进来都不至于爆显存了。

相关推荐
新智元14 分钟前
Ilya震撼发声!OpenAI前主管亲证:AGI已觉醒,人类还在装睡
人工智能·openai
朱昆鹏23 分钟前
如何通过sessionKey 登录 Claude
前端·javascript·人工智能
汉堡go29 分钟前
1、机器学习与深度学习
人工智能·深度学习·机器学习
只是懒得想了1 小时前
使用 Gensim 进行主题建模(LDA)与词向量训练(Word2Vec)的完整指南
人工智能·自然语言处理·nlp·word2vec·gensim
johnny2331 小时前
OpenAI系列模型介绍、API使用
人工智能
KKKlucifer1 小时前
生成式 AI 冲击下,网络安全如何破局?
网络·人工智能·web安全
LiJieNiub1 小时前
基于 PyTorch 实现 MNIST 手写数字识别
pytorch·深度学习·学习
ARM+FPGA+AI工业主板定制专家1 小时前
基于JETSON ORIN/RK3588+AI相机:机器人-多路视觉边缘计算方案
人工智能·数码相机·机器人
文火冰糖的硅基工坊2 小时前
[创业之路-691]:历史与现实的镜鉴:从三国纷争到华为铁三角的系统性启示
人工智能·科技·华为·重构·架构·创业
chxin140162 小时前
Transformer注意力机制——动手学深度学习10
pytorch·rnn·深度学习·transformer