NestJS 框架 Socket 优化实战指南

NestJS 作为企业级 Node.js 框架,凭借其模块化架构、依赖注入和装饰器特性,为 Socket.io 服务提供了更加结构化和可维护的实现方式。本文将前文中的原生 Socket.io 代码转换为 NestJS 最佳实践,构建一个高性能的 IM 通信服务。

项目结构与依赖配置

1.1 项目初始化与依赖安装

NestJS 项目需要安装核心依赖以及 Socket.io 相关包。在开始实现之前,需要初始化项目并安装必要的依赖包:

bash 复制代码
# 创建 NestJS 项目
nest new im-socket-service
cd im-socket-service

# 安装 Socket.io 相关依赖
npm install @nestjs/platform-socket.io @nestjs/websockets socket.io

# 安装 Redis 适配器
npm install @socket.io/redis-adapter ioredis

# 安装性能监控依赖
npm install @willsoto/nestjs-prometheus prom-client

# 安装类验证和转换
npm install class-validator class-transformer

项目采用模块化结构,将不同的功能域分离到独立模块中。主目录结构包含核心业务模块(用户、消息、会话)、基础设施模块(缓存、数据库)以及网关模块(Socket 连接)。每个模块拥有自己的控制器、服务和数据访问层,这种分层架构确保了代码的可测试性和可维护性。配置目录集中管理环境变量和全局配置,便于不同环境下的部署切换。

1.2 环境配置与常量定义

typescript 复制代码
// config/constants.ts
export const SOCKET_CONFIG = {
  // 心跳配置
  PING_TIMEOUT: 10000,
  PING_INTERVAL: 15000,

  // 消息配置
  MAX_MESSAGE_SIZE: 100 * 1024, // 100KB
  MESSAGE_COMPRESSION: true,

  // 连接配置
  MAX_CONNECTIONS_PER_IP: 5,
  CONNECTION_TIMEOUT: 30000,

  // Redis 配置
  REDIS_PREFIX: 'im:socket:',
  REDIS_TTL: 3600,

  // 限流配置
  RATE_LIMIT: {
    WINDOW_MS: 60000,
    MAX_MESSAGES: 120,
  },
} as const;

export const CONNECTION_TIERS = {
  CRITICAL: {
    name: 'critical',
    weight: 100,
    maxConnections: 10000,
    heartbeatInterval: 30000,
    messageBufferSize: 1000,
  },
  STANDARD: {
    name: 'standard',
    weight: 50,
    maxConnections: 100000,
    heartbeatInterval: 30000,
    messageBufferSize: 500,
  },
  DEGRADED: {
    name: 'degraded',
    weight: 10,
    maxConnections: 50000,
    heartbeatInterval: 60000,
    messageBufferSize: 100,
  },
} as const;

NestJS WebSocket 网关实现

2.1 网关基础架构

NestJS 的 WebSocket 网关是处理 Socket 连接的核心组件。通过 @WebSocketGateway 装饰器定义网关,使用 @WebSocketServer 注入 Server 实例,并通过 @SubscribeMessage 装饰器处理各种消息事件。NestJS 的网关还支持依赖注入,可以直接使用 Service 层的能力。

typescript 复制代码
// gateways/im.gateway.ts
import {
  WebSocketGateway,
  WebSocketServer,
  SubscribeMessage,
  OnGatewayConnection,
  OnGatewayDisconnect,
  OnGatewayInit,
  ConnectedSocket,
  MessageBody,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import { Logger, UseGuards, UsePipes, ValidationPipe } from '@nestjs/common';
import { AuthGuard } from '../guards/ws-auth.guard';
import { RateLimitGuard } from '../guards/rate-limit.guard';
import { ConnectionTierGuard } from '../guards/connection-tier.guard';
import { ImService } from '../services/im.service';
import { MessageDto, TypingDto, ReadReceiptDto } from '../dto';

@WebSocketGateway({
  namespace: '/im',
  cors: {
    origin: '*',
    credentials: true,
  },
  transports: ['websocket'],
  pingTimeout: 10000,
  pingInterval: 15000,
})
export class ImGateway
  implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
{
  @WebSocketServer()
  server: Server;

  private readonly logger = new Logger(ImGateway.name);

  constructor(private readonly imService: ImService) {}

  afterInit(server: Server): void {
    this.logger.log('WebSocket Gateway 初始化完成');
    this.imService.setServer(server);
  }

  async handleConnection(client: Socket): Promise<void> {
    try {
      // 解析认证信息
      const token = client.handshake.auth.token;
      const user = await this.imService.verifyToken(token);

      if (!user) {
        client.emit('auth_error', { message: '认证失败' });
        client.disconnect(true);
        return;
      }

      // 检查连接限制
      const tierResult = await this.imService.acceptConnection(client, user);

      if (!tierResult.accepted) {
        client.emit('connection_rejected', { reason: tierResult.reason });
        client.disconnect(true);
        return;
      }

      // 附加用户信息到 socket
      client.data.userId = user.userId;
      client.data.userTier = tierResult.tier;
      client.data.tierConfig = tierResult.config;
      client.data.connectedAt = Date.now();

      // 加入用户房间
      client.join(`user:${user.userId}`);
      client.join(`tier:${tierResult.tier}`);

      // 设置在线状态
      await this.imService.setPresence(user.userId, 'online', {
        deviceId: user.deviceId,
        tier: tierResult.tier,
      });

      // 发送连接成功消息
      client.emit('connected', {
        socketId: client.id,
        tier: tierResult.tier,
        serverTime: Date.now(),
      });

      this.logger.log(`用户 ${user.userId} 连接成功 (Tier: ${tierResult.tier})`);
    } catch (error) {
      this.logger.error(`连接处理失败: ${error.message}`);
      client.disconnect(true);
    }
  }

  async handleDisconnect(client: Socket): Promise<void> {
    const userId = client.data.userId;

    if (userId) {
      await this.imService.removeConnection(client.id, userId);
      await this.imService.setPresence(userId, 'offline');

      this.logger.log(`用户 ${userId} 断开连接`);
    }
  }

  @UseGuards(AuthGuard, RateLimitGuard)
  @SubscribeMessage('message')
  async handleMessage(
    @ConnectedSocket() client: Socket,
    @MessageBody() data: MessageDto,
  ): Promise<void> {
    const userId = client.data.userId;

    try {
      const result = await this.imService.handleMessage(userId, data);
      client.emit('message_sent', result);
    } catch (error) {
      client.emit('message_error', {
        clientMessageId: data.clientMessageId,
        error: error.message,
      });
    }
  }

  @SubscribeMessage('typing')
  async handleTyping(
    @ConnectedSocket() client: Socket,
    @MessageBody() data: TypingDto,
  ): Promise<void> {
    const userId = client.data.userId;
    await this.imService.handleTyping(userId, data);
  }

  @SubscribeMessage('read_receipt')
  async handleReadReceipt(
    @ConnectedSocket() client: Socket,
    @MessageBody() data: ReadReceiptDto,
  ): Promise<void> {
    const userId = client.data.userId;
    await this.imService.handleReadReceipt(userId, data);
  }

  @SubscribeMessage('join_conversation')
  async handleJoinConversation(
    @ConnectedSocket() client: Socket,
    @MessageBody() data: { conversationId: string },
  ): Promise<void> {
    const userId = client.data.userId;
    await this.imService.joinConversation(client, userId, data.conversationId);
  }

  @SubscribeMessage('leave_conversation')
  async handleLeaveConversation(
    @ConnectedSocket() client: Socket,
    @MessageBody() data: { conversationId: string },
  ): Promise<void> {
    const userId = client.data.userId;
    await this.imService.leaveConversation(client, userId, data.conversationId);
  }
}

2.2 数据传输对象定义

NestJS 推荐使用 DTO(Data Transfer Object)进行请求数据验证。通过 class-validator 和 class-transformer,可以实现自动的数据转换和验证,确保进入网关的数据符合预期格式。

typescript 复制代码
// dto/index.ts
import { IsString, IsNotEmpty, IsOptional, IsNumber, MaxLength, IsEnum } from 'class-validator';

export class MessageDto {
  @IsString()
  @IsNotEmpty()
  clientMessageId: string;

  @IsString()
  @IsNotEmpty()
  conversationId: string;

  @IsString()
  @IsNotEmpty()
  @MaxLength(10000)
  content: string;

  @IsOptional()
  @IsNumber()
  messageType?: number;

  @IsOptional()
  metadata?: Record<string, any>;
}

export class TypingDto {
  @IsString()
  @IsNotEmpty()
  conversationId: string;

  @IsBoolean()
  isTyping: boolean;
}

export class ReadReceiptDto {
  @IsString()
  @IsNotEmpty()
  conversationId: string;

  @IsString()
  @IsNotEmpty()
  messageId: string;
}

export class JoinConversationDto {
  @IsString()
  @IsNotEmpty()
  conversationId: string;
}

2.3 Socket 认证守卫

守卫是 NestJS 管道/守卫/拦截器体系中的重要组成部分。Socket 认证守卫在消息处理之前验证用户身份,支持从握手信息或 Socket 数据中提取认证凭证。

typescript 复制代码
// guards/ws-auth.guard.ts
import { CanActivate, ExecutionContext, Injectable, Logger } from '@nestjs/common';
import { WsException } from '@nestjs/websockets';
import { Socket } from 'socket.io';
import { JwtService } from '../services/jwt.service';

@Injectable()
export class AuthGuard implements CanActivate {
  private readonly logger = new Logger(AuthGuard.name);

  constructor(private readonly jwtService: JwtService) {}

  async canActivate(context: ExecutionContext): Promise<boolean> {
    const client: Socket = context.switchToWs().getClient();

    // 优先从 socket.data 获取已验证的用户信息
    if (client.data.userId) {
      return true;
    }

    // 否则从握手信息中验证
    const token = this.extractToken(client);

    if (!token) {
      throw new WsException('未提供认证令牌');
    }

    try {
      const user = await this.jwtService.verifySocketToken(token);
      client.data.userId = user.userId;
      client.data.userTier = user.tier;
      return true;
    } catch (error) {
      throw new WsException('认证失败');
    }
  }

  private extractToken(client: Socket): string | null {
    // 尝试从多个位置提取 token
    return (
      client.handshake.auth.token ||
      client.handshake.headers.authorization?.replace('Bearer ', '') ||
      this.extractTokenFromQuery(client)
    );
  }

  private extractTokenFromQuery(client: Socket): string | null {
    const token = client.handshake.query.token;
    if (Array.isArray(token)) {
      return token[0];
    }
    return token || null;
  }
}

2.4 限流守卫实现

typescript 复制代码
// guards/rate-limit.guard.ts
import { CanActivate, ExecutionContext, Injectable } from '@nestjs/common';
import { WsException } from '@nestjs/websockets';
import { Socket } from 'socket.io';
import { RateLimiterService } from '../services/rate-limiter.service';

@Injectable()
export class RateLimitGuard implements CanActivate {
  constructor(private readonly rateLimiterService: RateLimiterService) {}

  async canActivate(context: ExecutionContext): Promise<boolean> {
    const client: Socket = context.switchToWs().getClient();
    const userId = client.data.userId;

    if (!userId) {
      return true; // 认证守卫会处理未认证的情况
    }

    const result = await this.rateLimiterService.check(userId);

    if (!result.allowed) {
      client.emit('rate_limited', {
        retryAfter: result.retryAfter,
        message: `发送频率超限,请在 ${result.retryAfter} 秒后重试`,
      });
      throw new WsException('发送频率超限');
    }

    return true;
  }
}

2.5 连接分级守卫

typescript 复制代码
// guards/connection-tier.guard.ts
import { CanActivate, ExecutionContext, Injectable } from '@nestjs/common';
import { WsException } from '@nestjs/websockets';
import { Socket } from 'socket.io';
import { ConnectionTierService } from '../services/connection-tier.service';

@Injectable()
export class ConnectionTierGuard implements CanActivate {
  constructor(private readonly tierService: ConnectionTierService) {}

  async canActivate(context: ExecutionContext): Promise<boolean> {
    const client: Socket = context.switchToWs().getClient();
    const tier = client.data.userTier;

    // 根据连接级别应用不同的限制
    const tierConfig = this.tierService.getTierConfig(tier);

    if (!tierConfig) {
      throw new WsException('无效的连接级别');
    }

    // 将级别配置附加到 socket
    client.data.tierConfig = tierConfig;

    return true;
  }
}

核心服务层实现

3.1 IM 主服务

主服务是整个 Socket 业务的核心,负责协调各个子服务的功能,处理业务逻辑并与网关通信。

typescript 复制代码
// services/im.service.ts
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { Server, Socket } from 'socket.io';
import { MessageDto, TypingDto, ReadReceiptDto } from '../dto';
import { ConnectionTierService } from './connection-tier.service';
import { PresenceService } from './presence.service';
import { MessageService } from './message.service';
import { StateAggregatorService } from './state-aggregator.service';
import { MessageStormControllerService } from './message-storm-controller.service';
import { JwtService } from './jwt.service';
import { RedisService } from './redis.service';

export interface AuthUser {
  userId: string;
  tier: string;
  deviceId: string;
  isPremium: boolean;
  isAdmin: boolean;
}

export interface ConnectionResult {
  accepted: boolean;
  tier?: string;
  config?: any;
  reason?: string;
}

@Injectable()
export class ImService implements OnModuleInit {
  private readonly logger = new Logger(ImService.name);
  private server: Server;

  constructor(
    private readonly tierService: ConnectionTierService,
    private readonly presenceService: PresenceService,
    private readonly messageService: MessageService,
    private readonly stateAggregator: StateAggregatorService,
    private readonly stormController: MessageStormControllerService,
    private readonly jwtService: JwtService,
    private readonly redisService: RedisService,
  ) {}

  onModuleInit(): void {
    this.stateAggregator.setServer(this.server);
    this.stormController.setServer(this.server);
  }

  setServer(server: Server): void {
    this.server = server;
    this.stateAggregator.setServer(server);
    this.stormController.setServer(server);
  }

  async verifyToken(token: string): Promise<AuthUser | null> {
    return this.jwtService.verifySocketToken(token);
  }

  async acceptConnection(
    client: Socket,
    user: AuthUser,
  ): Promise<ConnectionResult> {
    return this.tierService.acceptConnection(client, user);
  }

  async removeConnection(socketId: string, userId: string): Promise<void> {
    this.tierService.removeConnection(socketId, userId);
  }

  async setPresence(
    userId: string,
    status: string,
    metadata?: Record<string, any>,
  ): Promise<void> {
    await this.presenceService.setPresence(userId, status, metadata);

    // 广播状态变更
    this.stateAggregator.recordUpdate(`presence:${userId}`, {
      type: 'presence',
      userId,
      status,
      timestamp: Date.now(),
      room: 'presence',
    });
  }

  async handleMessage(userId: string, data: MessageDto): Promise<any> {
    const message = await this.messageService.createMessage({
      ...data,
      senderId: userId,
    });

    // 获取会话成员并广播消息
    const members = await this.messageService.getConversationMembers(
      data.conversationId,
    );

    for (const memberId of members) {
      if (memberId !== userId) {
        this.stormController.enqueue({
          targetRoom: `user:${memberId}`,
          event: 'new_message',
          data: message,
          senderId: userId,
        });
      }
    }

    return {
      clientMessageId: data.clientMessageId,
      serverMessageId: message.id,
      timestamp: message.createdAt,
    };
  }

  async handleTyping(userId: string, data: TypingDto): Promise<void> {
    this.stateAggregator.recordUpdate(`typing:${data.conversationId}`, {
      type: 'typing',
      userId,
      conversationId: data.conversationId,
      isTyping: data.isTyping,
      timestamp: Date.now(),
      room: `conversation:${data.conversationId}`,
    });
  }

  async handleReadReceipt(
    userId: string,
    data: ReadReceiptDto,
  ): Promise<void> {
    await this.messageService.markAsRead(
      userId,
      data.conversationId,
      data.messageId,
    );

    // 广播已读回执
    this.server
      .to(`user:${userId}`)
      .emit('read_receipt', {
        conversationId: data.conversationId,
        messageId: data.messageId,
        readBy: userId,
        readAt: Date.now(),
      });
  }

  async joinConversation(
    client: Socket,
    userId: string,
    conversationId: string,
  ): Promise<void> {
    client.join(`conversation:${conversationId}`);
    await this.presenceService.subscribeToConversation(
      userId,
      conversationId,
    );

    // 获取当前会话在线成员
    const onlineMembers = await this.presenceService.getConversationPresence(
      conversationId,
    );

    client.emit('conversation_members', {
      conversationId,
      onlineMembers,
    });
  }

  async leaveConversation(
    client: Socket,
    userId: string,
    conversationId: string,
  ): Promise<void> {
    client.leave(`conversation:${conversationId}`);
    await this.presenceService.unsubscribeFromConversation(
      userId,
      conversationId,
    );
  }

  async getOnlineUsers(userIds: string[]): Promise<Map<string, string>> {
    return this.presenceService.getMultiplePresence(userIds);
  }
}

3.2 连接分级服务

连接分级服务实现了用户连接的分级管理和资源分配策略,确保高优先级用户获得更好的服务质量。

typescript 复制代码
// services/connection-tier.service.ts
import { Injectable, Logger } from '@nestjs/common';
import { Socket } from 'socket.io';
import { CONNECTION_TIERS, SOCKET_CONFIG } from '../../config/constants';

export interface ConnectionTierConfig {
  name: string;
  weight: number;
  maxConnections: number;
  heartbeatInterval: number;
  messageBufferSize: number;
}

export interface UserConnection {
  socket: Socket;
  tier: string;
  userId: string;
  connectedAt: number;
  config: ConnectionTierConfig;
}

@Injectable()
export class ConnectionTierService {
  private readonly logger = new Logger(ConnectionTierService.name);

  private connections: Map<string, UserConnection> = new Map();
  private tierConnections: Map<string, Set<string>> = new Map([
    ['critical', new Set()],
    ['standard', new Set()],
    ['degraded', new Set()],
  ]);

  private totalConnections = 0;
  private readonly maxTotalConnections = 200000;

  constructor() {
    // 初始化各等级配置
    Object.entries(CONNECTION_TIERS).forEach(([, config]) => {
      this.tierConnections.set(config.name, new Set());
    });
  }

  private determineTier(authData: any): string {
    if (authData.isPremium || authData.isAdmin) {
      return 'critical';
    }
    if (authData.isActive || authData.lastLoginWithin(7, 'days')) {
      return 'standard';
    }
    return 'degraded';
  }

  private getTierConfig(tierName: string): ConnectionTierConfig | null {
    const tier = Object.values(CONNECTION_TIERS).find(
      (t) => t.name === tierName,
    );
    return tier || null;
  }

  async acceptConnection(
    client: Socket,
    authData: any,
  ): Promise<{ accepted: boolean; tier?: string; config?: any; reason?: string }> {
    const tierName = this.determineTier(authData);
    const tierConfig = this.getTierConfig(tierName);

    if (!tierConfig) {
      return { accepted: false, reason: '无效的连接级别' };
    }

    // 检查全局连接数限制
    if (this.totalConnections >= this.maxTotalConnections) {
      const evicted = await this.evictLowestPriority();
      if (!evicted) {
        return { accepted: false, reason: '系统繁忙,请稍后重试' };
      }
    }

    const currentTierCount = this.tierConnections.get(tierName).size;

    // 检查当前级别连接数限制
    if (currentTierCount >= tierConfig.maxConnections) {
      // 尝试降级到下一级
      const lowerTier = tierName === 'critical' ? 'standard' :
                        tierName === 'standard' ? 'degraded' : null;

      if (lowerTier) {
        const lowerConfig = this.getTierConfig(lowerTier);
        if (lowerConfig && this.tierConnections.get(lowerTier).size < lowerConfig.maxConnections) {
          return this.createConnection(client, authData, lowerTier, lowerConfig);
        }
      }

      return { accepted: false, reason: `${tierName} 连接已满` };
    }

    return this.createConnection(client, authData, tierName, tierConfig);
  }

  private createConnection(
    client: Socket,
    authData: any,
    tierName: string,
    config: ConnectionTierConfig,
  ): { accepted: boolean; tier: string; config: any } {
    const connection: UserConnection = {
      socket: client,
      tier: tierName,
      userId: authData.userId,
      connectedAt: Date.now(),
      config,
    };

    this.connections.set(client.id, connection);
    this.tierConnections.get(tierName).add(client.id);
    this.totalConnections++;

    this.logger.debug(
      `连接建立: ${authData.userId} -> ${tierName} (总计: ${this.totalConnections})`,
    );

    return { accepted: true, tier: tierName, config };
  }

  private async evictLowestPriority(): Promise<boolean> {
    const degradedConnections = this.tierConnections.get('degraded');

    if (degradedConnections.size > 0) {
      let oldestSocketId: string | null = null;
      let oldestTime = Infinity;

      for (const socketId of degradedConnections) {
        const conn = this.connections.get(socketId);
        if (conn && conn.connectedAt < oldestTime) {
          oldestTime = conn.connectedAt;
          oldestSocketId = socketId;
        }
      }

      if (oldestSocketId) {
        const conn = this.connections.get(oldestSocketId);
        if (conn) {
          conn.socket.emit('forced_disconnect', {
            reason: '系统资源紧张',
          });
          conn.socket.disconnect(true);
          return true;
        }
      }
    }

    return false;
  }

  async removeConnection(socketId: string, userId: string): Promise<void> {
    const connection = this.connections.get(socketId);

    if (connection) {
      this.tierConnections.get(connection.tier).delete(socketId);
      this.connections.delete(socketId);
      this.totalConnections--;

      this.logger.debug(
        `连接移除: ${userId} -> ${connection.tier} (总计: ${this.totalConnections})`,
      );
    }
  }

  getTierConfig(tierName: string): ConnectionTierConfig | null {
    return this.getTierConfig(tierName);
  }

  getStats(): any {
    const stats = {
      total: this.totalConnections,
      byTier: {} as Record<string, number>,
      limits: {} as Record<string, number>,
    };

    for (const [tier, sockets] of this.tierConnections.entries()) {
      stats.byTier[tier] = sockets.size;
      const config = this.getTierConfig(tier);
      stats.limits[tier] = config?.maxConnections || 0;
    }

    return stats;
  }
}

3.3 消息风暴控制器

消息风暴控制器是保护系统在高峰期稳定运行的关键组件,通过令牌桶算法控制消息发送速率。

typescript 复制代码
// services/message-storm-controller.service.ts
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { Server } from 'socket.io';

export interface QueuedMessage {
  targetRoom: string;
  event: string;
  data: any;
  senderId: string;
}

@Injectable()
export class MessageStormControllerService implements OnModuleInit {
  private readonly logger = new Logger(MessageStormControllerService.name);
  private server: Server;

  private tokenBucket: TokenBucket;
  private messageQueue: QueuedMessage[] = [];
  private processingInterval: NodeJS.Timeout;

  private readonly maxMessagesPerSecond = 50000;
  private readonly maxBurstSize = 1000;
  private readonly batchSize = 100;

  onModuleInit(): void {
    this.tokenBucket = new TokenBucket({
      capacity: this.maxBurstSize,
      refillRate: this.maxMessagesPerSecond / 1000,
    });

    this.startProcessing();
  }

  setServer(server: Server): void {
    this.server = server;
  }

  enqueue(message: QueuedMessage): void {
    const result = this.tokenBucket.tryConsume(1);

    if (result.consumed) {
      this.processMessage(message);
    } else {
      // 消息入队等待
      this.messageQueue.push(message);

      // 限制队列长度
      if (this.messageQueue.length > 10000) {
        this.messageQueue.shift(); // 丢弃最旧的消息
        this.logger.warn('消息队列已满,丢弃最旧消息');
      }
    }
  }

  private processMessage(message: QueuedMessage): void {
    const { targetRoom, event, data, senderId } = message;

    if (!this.server) {
      this.logger.error('Server 未初始化');
      return;
    }

    this.server.to(targetRoom).volatile.emit(event, {
      ...data,
      senderId,
      sentAt: Date.now(),
    });
  }

  private startProcessing(): void {
    this.processingInterval = setInterval(async () => {
      // 每秒补充的令牌数
      await this.processQueue();
    }, 100);
  }

  private async processQueue(): Promise<void> {
    if (this.messageQueue.length === 0) return;

    // 每 100ms 处理最多 batchSize 条消息
    const batch = this.messageQueue.splice(0, this.batchSize);

    for (const message of batch) {
      const result = this.tokenBucket.tryConsume(1);

      if (result.consumed) {
        this.processMessage(message);
      } else {
        // 放回队列头部
        this.messageQueue.unshift(message);
        break;
      }
    }
  }

  getQueueSize(): number {
    return this.messageQueue.length;
  }

  getTokenBucketStatus(): { tokens: number; capacity: number } {
    return {
      tokens: this.tokenBucket.getTokens(),
      capacity: this.tokenBucket.getCapacity(),
    };
  }
}

// 令牌桶算法实现
class TokenBucket {
  private tokens: number;
  private lastRefill: number;
  private readonly capacity: number;
  private readonly refillRate: number;

  constructor(options: { capacity: number; refillRate: number }) {
    this.capacity = options.capacity;
    this.refillRate = options.refillRate;
    this.tokens = options.capacity;
    this.lastRefill = Date.now();
  }

  private refill(): void {
    const now = Date.now();
    const elapsed = now - this.lastRefill;
    const tokensToAdd = elapsed * this.refillRate;

    this.tokens = Math.min(this.capacity, this.tokens + tokensToAdd);
    this.lastRefill = now;
  }

  tryConsume(tokens: number): { consumed: boolean; remaining: number } {
    this.refill();

    if (this.tokens >= tokens) {
      this.tokens -= tokens;
      return { consumed: true, remaining: this.tokens };
    }

    return { consumed: false, remaining: this.tokens };
  }

  getTokens(): number {
    this.refill();
    return this.tokens;
  }

  getCapacity(): number {
    return this.capacity;
  }
}

3.4 状态聚合服务

状态聚合服务通过批量处理和合并状态更新,减少网络传输量并提高系统响应速度。

typescript 复制代码
// services/state-aggregator.service.ts
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { Server } from 'socket.io';

export interface StateUpdate {
  type: string;
  userId?: string;
  conversationId?: string;
  timestamp: number;
  room: string;
  [key: string]: any;
}

export interface AggregatedUpdate {
  type: string;
  updates: StateUpdate[];
  timestamp: number;
}

@Injectable()
export class StateAggregatorService implements OnModuleInit {
  private readonly logger = new Logger(StateAggregatorService.name);
  private server: Server;

  private pendingUpdates: Map<string, {
    updates: StateUpdate[];
    flushAt: number;
    scheduled: boolean;
  }> = new Map();

  private readonly flushInterval = 100; // 100ms 聚合窗口

  onModuleInit(): void {
    // 定期清理过期的待处理更新
    setInterval(() => {
      this.cleanupExpired();
    }, 60000);
  }

  setServer(server: Server): void {
    this.server = server;
  }

  recordUpdate(key: string, update: StateUpdate): void {
    if (!this.pendingUpdates.has(key)) {
      this.pendingUpdates.set(key, {
        updates: [],
        flushAt: Date.now() + this.flushInterval,
        scheduled: false,
      });
    }

    const pending = this.pendingUpdates.get(key);

    // 对于 typing 状态,合并同一用户的状态
    if (update.type === 'typing') {
      const existingIndex = pending.updates.findIndex(
        (u) => u.type === 'typing' && u.userId === update.userId,
      );

      if (existingIndex >= 0) {
        pending.updates[existingIndex] = {
          ...update,
          timestamp: Date.now(),
        };
        return;
      }
    }

    pending.updates.push(update);

    // 调度刷新
    if (!pending.scheduled) {
      this.scheduleFlush(key);
    }
  }

  private scheduleFlush(key: string): void {
    const pending = this.pendingUpdates.get(key);
    if (!pending) return;

    pending.scheduled = true;

    setTimeout(() => {
      this.flushUpdates(key);
    }, pending.flushAt - Date.now());
  }

  private flushUpdates(key: string): void {
    const pending = this.pendingUpdates.get(key);

    if (!pending || pending.updates.length === 0) {
      this.pendingUpdates.delete(key);
      return;
    }

    // 按房间分组
    const groupedUpdates = this.groupByRoom(pending.updates);

    // 广播更新
    for (const [room, updates] of Object.entries(groupedUpdates)) {
      if (this.server) {
        this.server.to(room).emit('state_batch', {
          updates,
          timestamp: Date.now(),
        } as AggregatedUpdate);
      }
    }

    this.pendingUpdates.delete(key);
  }

  private groupByRoom(updates: StateUpdate[]): Map<string, StateUpdate[]> {
    const grouped = new Map<string, StateUpdate[]>();

    for (const update of updates) {
      const room = update.room || 'global';

      if (!grouped.has(room)) {
        grouped.set(room, []);
      }

      grouped.get(room).push(update);
    }

    return grouped;
  }

  private cleanupExpired(): void {
    const now = Date.now();

    for (const [key, pending] of this.pendingUpdates.entries()) {
      if (pending.flushAt < now - 60000) {
        this.pendingUpdates.delete(key);
      }
    }
  }

  flushAll(): void {
    for (const key of this.pendingUpdates.keys()) {
      this.flushUpdates(key);
    }
  }

  getPendingCount(): number {
    return this.pendingUpdates.size;
  }
}

3.5 限流服务

限流服务采用滑动窗口算法,对用户的消息发送频率进行控制,防止恶意刷屏和资源滥用。

typescript 复制代码
// services/rate-limiter.service.ts
import { Injectable, Logger } from '@nestjs/common';
import { SOCKET_CONFIG } from '../../config/constants';

export interface RateLimitResult {
  allowed: boolean;
  remaining: number;
  currentCount: number;
  retryAfter?: number;
}

@Injectable()
export class RateLimiterService {
  private readonly logger = new Logger(RateLimiterService.name);

  private windowData: Map<string, number[]> = new Map();
  private readonly windowMs = SOCKET_CONFIG.RATE_LIMIT.WINDOW_MS;
  private readonly maxMessages = SOCKET_CONFIG.RATE_LIMIT.MAX_MESSAGES;

  constructor() {
    // 定期清理过期数据
    setInterval(() => {
      this.cleanup();
    }, this.windowMs);
  }

  async check(userId: string): Promise<RateLimitResult> {
    const now = Date.now();
    const windowStart = now - this.windowMs;

    if (!this.windowData.has(userId)) {
      this.windowData.set(userId, []);
    }

    const timestamps = this.windowData.get(userId);

    // 清理过期时间戳
    while (timestamps.length > 0 && timestamps[0] < windowStart) {
      timestamps.shift();
    }

    if (timestamps.length >= this.maxMessages) {
      const oldestTimestamp = timestamps[0];
      const retryAfter = Math.ceil(
        (oldestTimestamp + this.windowMs - now) / 1000,
      );

      return {
        allowed: false,
        retryAfter,
        remaining: 0,
        currentCount: timestamps.length,
      };
    }

    timestamps.push(now);

    return {
      allowed: true,
      remaining: this.maxMessages - timestamps.length,
      currentCount: timestamps.length,
    };
  }

  private cleanup(): void {
    const now = Date.now();
    const windowStart = now - this.windowMs;

    for (const [userId, timestamps] of this.windowData.entries()) {
      // 移除过期的数据
      while (timestamps.length > 0 && timestamps[0] < windowStart) {
        timestamps.shift();
      }

      // 清理空数据
      if (timestamps.length === 0) {
        this.windowData.delete(userId);
      }
    }
  }

  getStats(): { totalTracked: number; averageLoad: number } {
    const allTimestamps = Array.from(this.windowData.values());
    const totalCount = allTimestamps.reduce((sum, arr) => sum + arr.length, 0);

    return {
      totalTracked: this.windowData.size,
      averageLoad: allTimestamps.length > 0
        ? totalCount / allTimestamps.length
        : 0,
    };
  }
}

3.6 在线状态服务

在线状态服务管理用户的在线状态,支持按用户、按会话查询当前在线成员列表。

typescript 复制代码
// services/presence.service.ts
import { Injectable, Logger } from '@nestjs/common';

export interface PresenceState {
  userId: string;
  status: 'online' | 'away' | 'busy' | 'offline';
  lastSeen: number;
  metadata?: Record<string, any>;
}

@Injectable()
export class PresenceService {
  private readonly logger = new Logger(PresenceService.name);

  private presenceCache: Map<string, PresenceState> = new Map();
  private conversationSubscribers: Map<string, Set<string>> = new Map();

  private readonly ttl = 300; // 5 分钟离线判定

  constructor() {
    // 定期检查超时用户
    setInterval(() => {
      this.checkTimeouts();
    }, this.ttl * 1000 / 2);
  }

  async setPresence(
    userId: string,
    status: string,
    metadata?: Record<string, any>,
  ): Promise<void> {
    const state: PresenceState = {
      userId,
      status: status as PresenceState['status'],
      lastSeen: Date.now(),
      metadata,
    };

    this.presenceCache.set(userId, state);
  }

  getPresence(userId: string): PresenceState {
    const state = this.presenceCache.get(userId);

    if (!state) {
      return { userId, status: 'offline', lastSeen: 0 };
    }

    // 检查是否超时
    if (Date.now() - state.lastSeen > this.ttl * 1000) {
      return { userId, status: 'offline', lastSeen: state.lastSeen };
    }

    return state;
  }

  async subscribeToConversation(
    userId: string,
    conversationId: string,
  ): Promise<void> {
    if (!this.conversationSubscribers.has(conversationId)) {
      this.conversationSubscribers.set(conversationId, new Set());
    }

    this.conversationSubscribers.get(conversationId).add(userId);
  }

  async unsubscribeFromConversation(
    userId: string,
    conversationId: string,
  ): Promise<void> {
    const subscribers = this.conversationSubscribers.get(conversationId);

    if (subscribers) {
      subscribers.delete(userId);

      if (subscribers.size === 0) {
        this.conversationSubscribers.delete(conversationId);
      }
    }
  }

  async getConversationPresence(
    conversationId: string,
  ): Promise<PresenceState[]> {
    const subscribers = this.conversationSubscribers.get(conversationId);

    if (!subscribers) {
      return [];
    }

    const presence: PresenceState[] = [];

    for (const userId of subscribers) {
      presence.push(this.getPresence(userId));
    }

    return presence;
  }

  async getMultiplePresence(
    userIds: string[],
  ): Promise<Map<string, string>> {
    const result = new Map<string, string>();

    for (const userId of userIds) {
      const state = this.getPresence(userId);
      result.set(userId, state.status);
    }

    return result;
  }

  private checkTimeouts(): void {
    const now = Date.now();

    for (const [userId, state] of this.presenceCache.entries()) {
      if (
        state.status !== 'offline' &&
        now - state.lastSeen > this.ttl * 1000
      ) {
        state.status = 'offline';
        this.logger.debug(`用户 ${userId} 状态变更为离线`);
      }
    }
  }

  getStats(): { totalOnline: number; conversations: number } {
    let onlineCount = 0;

    for (const state of this.presenceCache.values()) {
      if (state.status !== 'offline') {
        onlineCount++;
      }
    }

    return {
      totalOnline: onlineCount,
      conversations: this.conversationSubscribers.size,
    };
  }
}

NestJS 模块编排

4.1 消息模块

消息模块负责消息的创建、存储和查询,是 IM 服务的数据核心。

typescript 复制代码
// modules/message/message.module.ts
import { Module, forwardRef } from '@nestjs/common';
import { MessageService } from './message.service';
import { MessageController } from './message.controller';
import { DatabaseModule } from '../database/database.module';
import { CacheModule } from '../cache/cache.module';

@Module({
  imports: [
    DatabaseModule,
    forwardRef(() => CacheModule),
  ],
  controllers: [MessageController],
  providers: [MessageService],
  exports: [MessageService],
})
export class MessageModule {}
typescript 复制代码
// modules/message/message.service.ts
import { Injectable, Logger } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { Message } from './entities/message.entity';
import { CacheService } from '../cache/cache.service';

export interface CreateMessageDto {
  clientMessageId: string;
  conversationId: string;
  senderId: string;
  content: string;
  messageType?: number;
  metadata?: Record<string, any>;
}

@Injectable()
export class MessageService {
  private readonly logger = new Logger(MessageService.name);

  constructor(
    @InjectRepository(Message)
    private readonly messageRepository: Repository<Message>,
    private readonly cacheService: CacheService,
  ) {}

  async createMessage(dto: CreateMessageDto): Promise<Message> {
    const message = this.messageRepository.create({
      clientMessageId: dto.clientMessageId,
      conversationId: dto.conversationId,
      senderId: dto.senderId,
      content: dto.content,
      messageType: dto.messageType || 1,
      metadata: dto.metadata,
      status: 'sent',
    });

    const saved = await this.messageRepository.save(message);

    // 更新缓存中的未读计数
    await this.incrementUnreadCount(dto.conversationId, dto.senderId);

    return saved;
  }

  async getConversationMembers(conversationId: string): Promise<string[]> {
    // 从缓存或数据库获取会话成员
    const cacheKey = `conversation:members:${conversationId}`;
    const cached = await this.cacheService.get<string[]>(cacheKey);

    if (cached) {
      return cached;
    }

    // 从数据库查询
    // 实际实现中应该查询会话成员关系表
    const members = await this.getMembersFromDatabase(conversationId);

    await this.cacheService.set(cacheKey, members, { ttl: 300 });

    return members;
  }

  async markAsRead(
    userId: string,
    conversationId: string,
    messageId: string,
  ): Promise<void> {
    // 更新消息的已读状态
    await this.messageRepository.update(
      { id: messageId },
      { readBy: userId, readAt: new Date() },
    );

    // 更新缓存中的未读计数
    await this.decrementUnreadCount(conversationId, userId);
  }

  private async getMembersFromDatabase(
    conversationId: string,
  ): Promise<string[]> {
    // 实际实现中查询会话成员关系表
    return [];
  }

  private async incrementUnreadCount(
    conversationId: string,
    userId: string,
  ): Promise<void> {
    const key = `unread:${conversationId}:${userId}`;
    await this.cacheService.incr(key);
  }

  private async decrementUnreadCount(
    conversationId: string,
    userId: string,
  ): Promise<void> {
    const key = `unread:${conversationId}:${userId}`;
    await this.cacheService.decr(key);
  }

  async getUnreadCount(
    conversationId: string,
    userId: string,
  ): Promise<number> {
    const key = `unread:${conversationId}:${userId}`;
    return this.cacheService.get<number>(key) || 0;
  }
}
typescript 复制代码
// modules/message/entities/message.entity.ts
import {
  Entity,
  PrimaryGeneratedColumn,
  Column,
  CreateDateColumn,
  Index,
} from 'typeorm';

@Entity('messages')
@Index(['conversationId', 'createdAt'])
export class Message {
  @PrimaryGeneratedColumn('uuid')
  id: string;

  @Column({ name: 'client_message_id' })
  clientMessageId: string;

  @Column({ name: 'conversation_id' })
  @Index()
  conversationId: string;

  @Column({ name: 'sender_id' })
  @Index()
  senderId: string;

  @Column('text')
  content: string;

  @Column({ name: 'message_type', default: 1 })
  messageType: number;

  @Column('jsonb', { nullable: true })
  metadata: Record<string, any>;

  @Column({ default: 'sent' })
  status: string;

  @Column({ name: 'read_by', nullable: true })
  readBy: string;

  @Column({ name: 'read_at', type: 'timestamp', nullable: true })
  readAt: Date;

  @CreateDateColumn({ name: 'created_at' })
  createdAt: Date;
}

4.2 Redis 适配器模块

Redis 模块封装了 Socket.io 的 Redis 适配器,支持集群部署。

typescript 复制代码
// modules/adapters/redis-adapter.module.ts
import { Module, Global } from '@nestjs/common';
import { RedisService } from './redis.service';
import { RedisAdapterProvider } from './redis-adapter.provider';

@Global()
@Module({
  providers: [RedisService, RedisAdapterProvider],
  exports: [RedisService, RedisAdapterProvider],
})
export class RedisModule {}
typescript 复制代码
// modules/adapters/redis-adapter.provider.ts
import { Provider } from '@nestjs/common';
import { createAdapter } from '@socket.io/redis-adapter';
import { Redis } from 'ioredis';
import { ServerOptions } from 'socket.io';

export const REDIS_ADAPTER = 'REDIS_ADAPTER';

export class RedisAdapterProvider {
  private pubClient: Redis;
  private subClient: Redis;
  private adapterCreator: any;

  async createAdapter(): Promise<any> {
    this.pubClient = new Redis({
      host: process.env.REDIS_HOST || 'localhost',
      port: parseInt(process.env.REDIS_PORT || '6379'),
      password: process.env.REDIS_PASSWORD,
      maxRetriesPerRequest: 3,
      lazyConnect: true,
    });

    this.subClient = new Redis({
      host: process.env.REDIS_HOST || 'localhost',
      port: parseInt(process.env.REDIS_PORT || '6379'),
      password: process.env.REDIS_PASSWORD,
      maxRetriesPerRequest: 3,
      lazyConnect: true,
    });

    await Promise.all([
      this.pubClient.connect(),
      this.subClient.connect(),
    ]);

    this.adapterCreator = createAdapter(this.pubClient, this.subClient, {
      requestsTimeout: 5000,
      heartbeatInterval: 5000,
      maxRetries: 3,
    });

    return this.adapterCreator;
  }

  getPubClient(): Redis {
    return this.pubClient;
  }

  getSubClient(): Redis {
    return this.subClient;
  }
}

4.3 Redis 服务

typescript 复制代码
// modules/adapters/redis.service.ts
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import { Redis } from 'ioredis';
import { SOCKET_CONFIG } from '../../config/constants';

@Injectable()
export class RedisService implements OnModuleDestroy {
  private readonly logger = new Logger(RedisService.name);
  private client: Redis;

  constructor() {
    this.client = new Redis({
      host: process.env.REDIS_HOST || 'localhost',
      port: parseInt(process.env.REDIS_PORT || '6379'),
      password: process.env.REDIS_PASSWORD,
      db: 0,
      maxRetriesPerRequest: 3,
      retryStrategy: (times) => {
        const delay = Math.min(times * 50, 2000);
        return delay;
      },
    });

    this.client.on('connect', () => {
      this.logger.log('Redis 连接成功');
    });

    this.client.on('error', (err) => {
      this.logger.error(`Redis 连接错误: ${err.message}`);
    });
  }

  async onModuleDestroy(): Promise<void> {
    await this.client.quit();
  }

  async get<T>(key: string): Promise<T | null> {
    const value = await this.client.get(`${SOCKET_CONFIG.REDIS_PREFIX}${key}`);
    if (value) {
      try {
        return JSON.parse(value) as T;
      } catch {
        return value as unknown as T;
      }
    }
    return null;
  }

  async set(
    key: string,
    value: any,
    options?: { ttl?: number },
  ): Promise<void> {
    const prefixedKey = `${SOCKET_CONFIG.REDIS_PREFIX}${key}`;
    const serialized =
      typeof value === 'string' ? value : JSON.stringify(value);

    if (options?.ttl) {
      await this.client.setex(prefixedKey, options.ttl, serialized);
    } else {
      await this.client.set(prefixedKey, serialized);
    }
  }

  async del(key: string): Promise<void> {
    await this.client.del(`${SOCKET_CONFIG.REDIS_PREFIX}${key}`);
  }

  async incr(key: string): Promise<number> {
    return this.client.incr(`${SOCKET_CONFIG.REDIS_PREFIX}${key}`);
  }

  async decr(key: string): Promise<number> {
    return this.client.decr(`${SOCKET_CONFIG.REDIS_PREFIX}${key}`);
  }

  async mget<T>(keys: string[]): Promise<(T | null)[]> {
    const prefixedKeys = keys.map(
      (k) => `${SOCKET_CONFIG.REDIS_PREFIX}${k}`,
    );
    const values = await this.client.mget(...prefixedKeys);

    return values.map((v) => {
      if (v) {
        try {
          return JSON.parse(v) as T;
        } catch {
          return v as unknown as T;
        }
      }
      return null;
    });
  }

  async expire(key: string, seconds: number): Promise<void> {
    await this.client.expire(`${SOCKET_CONFIG.REDIS_PREFIX}${key}`, seconds);
  }

  async hset(key: string, field: string, value: any): Promise<void> {
    const prefixedKey = `${SOCKET_CONFIG.REDIS_PREFIX}${key}`;
    const serialized =
      typeof value === 'string' ? value : JSON.stringify(value);
    await this.client.hset(prefixedKey, field, serialized);
  }

  async hget<T>(key: string, field: string): Promise<T | null> {
    const prefixedKey = `${SOCKET_CONFIG.REDIS_PREFIX}${key}`;
    const value = await this.client.hget(prefixedKey, field);

    if (value) {
      try {
        return JSON.parse(value) as T;
      } catch {
        return value as unknown as T;
      }
    }
    return null;
  }

  async hgetall<T>(key: string): Promise<Record<string, T>> {
    const prefixedKey = `${SOCKET_CONFIG.REDIS_PREFIX}${key}`;
    const hash = await this.client.hgetall(prefixedKey);
    const result: Record<string, T> = {};

    for (const [field, value] of Object.entries(hash)) {
      try {
        result[field] = JSON.parse(value) as T;
      } catch {
        result[field] = value as unknown as T;
      }
    }

    return result;
  }

  async sadd(key: string, ...members: string[]): Promise<number> {
    return this.client.sadd(
      `${SOCKET_CONFIG.REDIS_PREFIX}${key}`,
      ...members,
    );
  }

  async srem(key: string, ...members: string[]): Promise<number> {
    return this.client.srem(
      `${SOCKET_CONFIG.REDIS_PREFIX}${key}`,
      ...members,
    );
  }

  async smembers<T>(key: string): Promise<T[]> {
    const members = await this.client.smembers(
      `${SOCKET_CONFIG.REDIS_PREFIX}${key}`,
    );
    return members as T[];
  }
}

4.4 缓存模块

缓存模块提供多级缓存能力,支持进程内缓存和 Redis 分布式缓存。

typescript 复制代码
// modules/cache/cache.module.ts
import { Module, Global } from '@nestjs/common';
import { CacheService } from './cache.service';
import { RedisModule } from '../adapters/redis-adapter.module';

@Global()
@Module({
  imports: [RedisModule],
  providers: [CacheService],
  exports: [CacheService],
})
export class CacheModule {}
typescript 复制代码
// modules/cache/cache.service.ts
import { Injectable, Logger } from '@nestjs/common';
import { RedisService } from '../adapters/redis.service';

// 简单的 LRU 缓存实现
class LRUCache<K, V> {
  private cache = new Map<K, V>();
  private readonly maxSize: number;

  constructor(maxSize: number) {
    this.maxSize = maxSize;
  }

  get(key: K): V | undefined {
    const value = this.cache.get(key);
    if (value !== undefined) {
      // 移动到末尾表示最近使用
      this.cache.delete(key);
      this.cache.set(key, value);
    }
    return value;
  }

  set(key: K, value: V): void {
    if (this.cache.has(key)) {
      this.cache.delete(key);
    } else if (this.cache.size >= this.maxSize) {
      // 删除最旧的条目
      const firstKey = this.cache.keys().next().value;
      this.cache.delete(firstKey);
    }
    this.cache.set(key, value);
  }

  delete(key: K): boolean {
    return this.cache.delete(key);
  }

  clear(): void {
    this.cache.clear();
  }

  size(): number {
    return this.cache.size;
  }
}

@Injectable()
export class CacheService {
  private readonly logger = new Logger(CacheService.name);

  // L1: 进程内 LRU 缓存
  private l1Cache = new LRUCache<string, any>(10000);

  // L1 缓存的最大生命周期
  private l1Ttl = 60000; // 1 分钟

  constructor(private readonly redisService: RedisService) {}

  async get<T>(key: string): Promise<T | null> {
    // L1 查找
    const l1Value = this.l1Cache.get(key);
    if (l1Value !== undefined) {
      return l1Value as T;
    }

    // L2 查找
    const l2Value = await this.redisService.get<T>(key);
    if (l2Value !== null) {
      // 回填 L1
      this.l1Cache.set(key, l2Value);
    }

    return l2Value;
  }

  async set<T>(
    key: string,
    value: T,
    options?: { ttl?: number; l1Only?: boolean },
  ): Promise<void> {
    // L1 缓存
    this.l1Cache.set(key, value);

    // L2 缓存
    if (!options?.l1Only) {
      await this.redisService.set(key, value, {
        ttl: options?.ttl || 300,
      });
    }
  }

  async invalidate(key: string): Promise<void> {
    this.l1Cache.delete(key);
    await this.redisService.del(key);
  }

  async mget<T>(keys: string[]): Promise<(T | null)[]> {
    const results: (T | null)[] = [];
    const missingKeys: string[] = [];

    // L1 批量查找
    for (const key of keys) {
      const value = this.l1Cache.get(key);
      if (value !== undefined) {
        results.push(value as T);
      } else {
        results.push(null);
        missingKeys.push(key);
      }
    }

    // L2 批量查找
    if (missingKeys.length > 0) {
      const l2Results = await this.redisService.mget<T>(missingKeys);

      let resultIndex = 0;
      for (let i = 0; i < keys.length; i++) {
        if (results[i] === null) {
          const l2Value = l2Results[resultIndex++];
          results[i] = l2Value;

          // 回填 L1
          if (l2Value !== null) {
            this.l1Cache.set(keys[i], l2Value);
          }
        }
      }
    }

    return results;
  }

  async mset(
    entries: Array<{ key: string; value: any; ttl?: number }>,
  ): Promise<void> {
    for (const entry of entries) {
      await this.set(entry.key, entry.value, { ttl: entry.ttl });
    }
  }

  async incr(key: string): Promise<number> {
    const value = await this.redisService.incr(key);
    return value;
  }

  async decr(key: string): Promise<number> {
    const value = await this.redisService.decr(key);
    return value;
  }

  getL1Stats(): { size: number; maxSize: number } {
    return {
      size: this.l1Cache.size(),
      maxSize: 10000,
    };
  }
}

4.5 应用模块组装

typescript 复制代码
// app.module.ts
import { Module } from '@nestjs/common';
import { TypeOrmModule } from '@nestjs/typeorm';
import { ImModule } from './modules/im/im.module';
import { MessageModule } from './modules/message/message.module';
import { CacheModule } from './modules/cache/cache.module';
import { RedisModule } from './modules/adapters/redis-adapter.module';

@Module({
  imports: [
    TypeOrmModule.forRoot({
      type: 'postgres',
      host: process.env.DB_HOST || 'localhost',
      port: parseInt(process.env.DB_PORT || '5432'),
      username: process.env.DB_USER || 'postgres',
      password: process.env.DB_PASSWORD,
      database: process.env.DB_NAME || 'im_db',
      entities: [__dirname + '/**/*.entity{.ts,.js}'],
      synchronize: process.env.NODE_ENV === 'development',
      logging: process.env.NODE_ENV === 'development',
    }),
    RedisModule,
    CacheModule,
    MessageModule,
    ImModule,
  ],
})
export class AppModule {}

4.6 IM 核心模块

typescript 复制代码
// modules/im/im.module.ts
import { Module } from '@nestjs/common';
import { ImGateway } from '../../gateways/im.gateway';
import { ImService } from './services/im.service';
import { ConnectionTierService } from './services/connection-tier.service';
import { PresenceService } from './services/presence.service';
import { MessageService } from '../message/message.service';
import { StateAggregatorService } from './services/state-aggregator.service';
import { MessageStormControllerService } from './services/message-storm-controller.service';
import { RateLimiterService } from './services/rate-limiter.service';
import { JwtService } from './services/jwt.service';
import { CacheModule } from '../cache/cache.module';
import { MessageModule } from '../message/message.module';

@Module({
  imports: [CacheModule, MessageModule],
  providers: [
    ImGateway,
    ImService,
    ConnectionTierService,
    PresenceService,
    StateAggregatorService,
    MessageStormControllerService,
    RateLimiterService,
    JwtService,
  ],
  exports: [ImService],
})
export class ImModule {}

JWT 认证服务

typescript 复制代码
// modules/im/services/jwt.service.ts
import { Injectable, Logger } from '@nestjs/common';
import { JwtService as NestJwtService } from '@nestjs/jwt';

export interface JwtPayload {
  userId: string;
  tier: string;
  deviceId: string;
  isPremium: boolean;
  isAdmin: boolean;
  iat?: number;
  exp?: number;
}

@Injectable()
export class JwtService {
  private readonly logger = new Logger(JwtService.name);

  constructor(private readonly nestJwtService: NestJwtService) {}

  async verifySocketToken(token: string): Promise<JwtPayload | null> {
    try {
      const payload = await this.nestJwtService.verifyAsync<JwtPayload>(token);

      return {
        userId: payload.userId,
        tier: payload.tier || 'standard',
        deviceId: payload.deviceId,
        isPremium: payload.isPremium || false,
        isAdmin: payload.isAdmin || false,
      };
    } catch (error) {
      this.logger.warn(`Token 验证失败: ${error.message}`);
      return null;
    }
  }

  generateToken(payload: Omit<JwtPayload, 'iat' | 'exp'>): string {
    return this.nestJwtService.sign(payload);
  }

  decodeToken(token: string): JwtPayload | null {
    try {
      return this.nestJwtService.decode(token) as JwtPayload;
    } catch {
      return null;
    }
  }
}

健康检查与监控

5.1 健康检查模块

typescript 复制代码
// modules/health/health.module.ts
import { Module } from '@nestjs/common';
import { TerminusModule } from '@nestjs/terminus';
import { HealthController } from './health.controller';

@Module({
  imports: [TerminusModule],
  controllers: [HealthController],
})
export class HealthModule {}
typescript 复制代码
// modules/health/health.controller.ts
import { Controller, Get } from '@nestjs/common';
import {
  HealthCheck,
  HealthCheckService,
  TypeOrmHealthIndicator,
  DiskHealthIndicator,
  MemoryHealthIndicator,
} from '@nestjs/terminus';
import { ConnectionTierService } from '../im/services/connection-tier.service';
import { PresenceService } from '../im/services/presence.service';
import { RateLimiterService } from '../im/services/rate-limiter.service';

@Controller('health')
export class HealthController {
  constructor(
    private readonly health: HealthCheckService,
    private readonly db: TypeOrmHealthIndicator,
    private readonly disk: DiskHealthIndicator,
    private readonly memory: MemoryHealthIndicator,
    private readonly tierService: ConnectionTierService,
    private readonly presenceService: PresenceService,
    private readonly rateLimiterService: RateLimiterService,
  ) {}

  @Get()
  @HealthCheck()
  check() {
    return this.health.check([
      // 数据库健康检查
      () => this.db.pingCheck('database'),

      // 内存健康检查
      () =>
        this.memory.checkHeap('memory_heap', 1500 * 1024 * 1024), // 1.5GB

      // 磁盘健康检查
      () => this.disk.checkStorage('disk', { thresholdPercent: 0.9 }),

      // 自定义 Socket 服务健康检查
      async () => {
        const tierStats = this.tierService.getStats();
        const presenceStats = this.presenceService.getStats();
        const rateLimitStats = this.rateLimiterService.getStats();

        return {
          socket: {
            status: 'up',
            totalConnections: tierStats.total,
            byTier: tierStats.byTier,
            onlineUsers: presenceStats.totalOnline,
            rateLimitTracked: rateLimitStats.totalTracked,
          },
        };
      },
    ]);
  }

  @Get('ready')
  @HealthCheck()
  readiness() {
    return this.health.check([
      () => this.db.pingCheck('database'),
    ]);
  }

  @Get('live')
  @HealthCheck()
  liveness() {
    return this.health.check([
      async () => ({
        api: { status: 'up' },
      }),
    ]);
  }
}

性能指标导出

typescript 复制代码
// modules/metrics/metrics.module.ts
import { Module } from '@nestjs/common';
import { PrometheusModule } from '@willsoto/nestjs-prometheus';
import { MetricsService } from './metrics.service';

@Module({
  imports: [
    PrometheusModule.register({
      path: '/metrics',
      defaultMetrics: {
        enabled: true,
      },
    }),
  ],
  providers: [MetricsService],
  exports: [MetricsService],
})
export class MetricsModule {}
typescript 复制代码
// modules/metrics/metrics.service.ts
import { Injectable, OnModuleInit } from '@nestjs/common';
import { Counter, Gauge, Histogram, Registry } from 'prom-client';
import { InjectMetric } from '@willsoto/nestjs-prometheus';
import { ConnectionTierService } from '../im/services/connection-tier.service';
import { PresenceService } from '../im/services/presence.service';
import { RateLimiterService } from '../im/services/rate-limiter.service';

@Injectable()
export class MetricsService implements OnModuleInit {
  private readonly registry: Registry;

  // 自定义指标
  private readonly activeConnections: Gauge;
  private readonly messagesSent: Counter;
  private readonly messagesReceived: Counter;
  private readonly connectionLatency: Histogram;
  private readonly rateLimitHits: Counter;

  constructor(
    @InjectMetric('im_active_connections') activeConnections: Gauge,
    @InjectMetric('im_messages_sent') messagesSent: Counter,
    @InjectMetric('im_messages_received') messagesReceived: Counter,
    private readonly tierService: ConnectionTierService,
    private readonly presenceService: PresenceService,
    private readonly rateLimiterService: RateLimiterService,
  ) {
    this.activeConnections = activeConnections;
    this.messagesSent = messagesSent;
    this.messagesReceived = messagesReceived;
    this.registry = new Registry();
  }

  onModuleInit(): void {
    // 定期更新指标
    setInterval(() => {
      this.updateMetrics();
    }, 5000);
  }

  private updateMetrics(): void {
    const tierStats = this.tierService.getStats();
    const presenceStats = this.presenceService.getStats();

    // 更新连接数指标
    this.activeConnections.set({
      tier: 'total',
    }, tierStats.total);

    for (const [tier, count] of Object.entries(tierStats.byTier)) {
      this.activeConnections.set({ tier }, count as number);
    }
  }

  recordMessageSent(conversationId: string): void {
    this.messagesSent.inc({ conversation: conversationId });
  }

  recordMessageReceived(conversationId: string): void {
    this.messagesReceived.inc({ conversation: conversationId });
  }

  recordConnectionLatency(duration: number): void {
    this.connectionLatency.observe(duration);
  }

  recordRateLimitHit(): void {
    this.rateLimitHits.inc();
  }

  getRegistry(): Registry {
    return this.registry;
  }
}

主入口配置

typescript 复制代码
// main.ts
import { NestFactory } from '@nestjs/core';
import { ValidationPipe, Logger } from '@nestjs/common';
import { IoAdapter } from '@nestjs/platform-socket.io';
import { AppModule } from './app.module';
import { RedisAdapterProvider } from './modules/adapters/redis-adapter.provider';

async function bootstrap() {
  const logger = new Logger('Bootstrap');
  const app = await NestFactory.create(AppModule);

  // 全局管道
  app.useGlobalPipes(
    new ValidationPipe({
      transform: true,
      whitelist: true,
      forbidNonWhitelisted: true,
    }),
  );

  // CORS 配置
  app.enableCors({
    origin: '*',
    credentials: true,
  });

  // 初始化 Redis 适配器
  const redisAdapterProvider = app.get(RedisAdapterProvider);
  const redisAdapter = await redisAdapterProvider.createAdapter();

  // 自定义 Socket.io 适配器
  app.useWebSocketAdapter(new RedisIoAdapter(app, redisAdapter));

  const port = process.env.PORT || 3000;
  await app.listen(port);

  logger.log(`IM Socket 服务运行在端口 ${port}`);
}

class RedisIoAdapter extends IoAdapter {
  private adapterConstructor: any;

  constructor(app: any, adapter: any) {
    super(app);
    this.adapterConstructor = adapter;
  }

  createIOServer(port: number, options?: any): any {
    const server = super.createIOServer(port, {
      ...options,
      cors: {
        origin: '*',
        credentials: true,
      },
      transports: ['websocket'],
      pingTimeout: 10000,
      pingInterval: 15000,
    });

    return this.adapterConstructor(server);
  }
}

bootstrap();

总结

NestJS 框架为 Socket.io 服务提供了更加结构化和企业级的实现方式。通过模块化架构,业务逻辑被清晰地分离到不同的模块中,每个模块专注于特定的功能域。依赖注入机制使得服务之间的协作变得简单可控,而守卫、管道、拦截器等装饰器提供的 AOP 能力,则让我们能够优雅地处理横切关注点如认证、限流和验证。

在 IM 通信文档协同场景下,通过连接分级管理实现资源的差异化分配,通过消息风暴控制器防止系统在高峰期过载,通过状态聚合服务减少无效的网络传输,通过多级缓存保护数据库,这些优化策略在 NestJS 框架下得以用更加声明式和可维护的方式实现。整体架构既保持了高性能,又提供了良好的可扩展性和可测试性,适合作为企业级 IM 服务的基础框架。

相关推荐
农夫山泉不太甜2 小时前
Node.js 后端服务 Socket 优化深度指南:从基础到 IM 通信实战
前端·后端
傲文博一2 小时前
Microsoft Remote Desktop 能连 Mac 吗?把 Mac 远程 Mac 这件事讲透
后端
JOEH602 小时前
为什么你的 CPU 总是突然飙高?Java 生产环境 6 大排查误区
javascript·后端
烛衔溟2 小时前
TypeScript 类型别名、字面量类型、联合类型与交叉类型
前端·javascript·typescript·联合类型·类型别名·字面量类型·交叉类型
clamlss2 小时前
💥 踩坑实录:MapStruct 映射失效?揭秘 Lombok 组合下的编译期陷阱
java·后端
Cache技术分享2 小时前
369. Java IO API - DOS 文件属性
前端·后端
慧一居士2 小时前
Nuxt4 项目的约定配置都有哪些,哪些可以自动实现, 详细示例和使用说明
前端·vue.js
元俭2 小时前
【Eino 框架入门】Middleware 中间件:给 Agent 加一层"异常保护罩"
后端
芯智工坊2 小时前
每周一个开源项目 #4:ChatGPT-Next-Web 增强版
前端·chatgpt·开源