总结之Temporal全局速率控制(二)第三方速率控制服务设计

第三方速率控制服务详细设计

一、服务总体架构设计

1.1 系统架构图

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    Temporal 集群                             │
│                                                             │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐    │
│  │   Worker 1  │     │   Worker 2  │     │   Worker N  │    │
│  └──────┬──────┘     └──────┬──────┘     └──────┬──────┘    │
│         │                   │                   │           │
│         └───────────────────┼───────────────────┘           │
│                             │                               │
│                  ┌──────────▼──────────┐                    │
│                  │   Activity 实现     │                    │
│                  │  (回调注册逻辑)      │                    │
│                  └──────────┬──────────┘                    │
└─────────────────────────────┼───────────────────────────────┘
                              │ 注册请求/心跳
                              ▼
                    ┌─────────────────────┐
                    │   Rate Limit Service│  ←─ 第三方速率控制服务
                    │  (外部服务)         │
                    └─────────┬───────────┘
                              │ 回调通知
                              ▼
                    ┌─────────────────────┐
                    │   Temporal Callback │  ←─ Temporal回调端点
                    │     Endpoint        │
                    └─────────────────────┘

1.2 服务组件设计

java 复制代码
/**
 * 第三方速率控制服务核心组件
 */
@Configuration
public class RateLimitServiceArchitecture {
    
    /**
     * 核心组件:
     * 
     * 1. RateLimitGateway: 入口网关,处理注册请求
     * 2. RateLimitRegistry: 注册中心,管理所有等待请求
     * 3. TokenBucketManager: 令牌桶管理器,控制速率
     * 4. CallbackDispatcher: 回调分发器,通知Temporal
     * 5. MonitorDashboard: 监控面板,实时查看状态
     * 6. ConfigurationManager: 配置管理,动态调整参数
     * 
     * 数据存储:
     * - Redis: 缓存等待队列和令牌桶状态
     * - MySQL: 持久化请求记录和配置
     * - Elasticsearch: 存储日志和监控数据
     */
}

二、详细服务设计

2.1 API设计(REST + gRPC)

java 复制代码
/**
 * 速率控制服务API设计
 */
@RestController
@RequestMapping("/api/v1/rate-limit")
@Slf4j
public class RateLimitApiController {
    
    private final RateLimitService rateLimitService;
    private final RequestValidator requestValidator;
    private final RateLimitConfigProvider configProvider;
    
    /**
     * 1. 注册速率控制请求
     * 
     * 请求参数:
     * - activityId: Temporal Activity唯一ID
     * - taskToken: Temporal TaskToken,用于回调
     * - rateKey: 速率控制键(如:"send_message")
     * - callbackUrl: Temporal回调端点
     * - priority: 优先级(0-10)
     * - timeoutSeconds: 超时时间
     * - metadata: 附加元数据
     */
    @PostMapping("/register")
    public ResponseEntity<RateLimitResponse> register(
            @Valid @RequestBody RateLimitRequest request) {
        
        try {
            // 验证请求
            ValidationResult validation = requestValidator.validate(request);
            if (!validation.isValid()) {
                return ResponseEntity.badRequest()
                    .body(RateLimitResponse.error(validation.getErrors()));
            }
            
            // 获取速率配置
            RateLimitConfig config = configProvider.getConfig(request.getRateKey());
            if (config == null) {
                return ResponseEntity.status(HttpStatus.NOT_FOUND)
                    .body(RateLimitResponse.error("Rate limit config not found"));
            }
            
            // 注册请求
            String requestId = rateLimitService.register(request, config);
            
            // 立即返回请求ID
            return ResponseEntity.ok(RateLimitResponse.success(requestId));
            
        } catch (RateLimitException e) {
            log.error("Rate limit registration failed", e);
            return ResponseEntity.status(HttpStatus.TOO_MANY_REQUESTS)
                .body(RateLimitResponse.error(e.getMessage()));
        } catch (Exception e) {
            log.error("Internal server error", e);
            return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                .body(RateLimitResponse.error("Internal server error"));
        }
    }
    
    /**
     * 2. 心跳接口(保持连接)
     * 
     * Activity可以定期发送心跳,表明自己仍在等待
     * 服务端可以根据心跳调整超时时间
     */
    @PostMapping("/heartbeat/{requestId}")
    public ResponseEntity<HeartbeatResponse> heartbeat(
            @PathVariable String requestId,
            @RequestBody HeartbeatRequest heartbeat) {
        
        try {
            boolean success = rateLimitService.heartbeat(requestId, heartbeat);
            
            if (!success) {
                // 请求可能已超时或被取消
                return ResponseEntity.status(HttpStatus.NOT_FOUND)
                    .body(HeartbeatResponse.notFound());
            }
            
            // 返回剩余等待时间
            long remaining = rateLimitService.getRemainingTime(requestId);
            return ResponseEntity.ok(HeartbeatResponse.success(remaining));
            
        } catch (Exception e) {
            log.error("Heartbeat failed", e);
            return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                .body(HeartbeatResponse.error());
        }
    }
    
    /**
     * 3. 取消请求
     * 
     * 如果Activity需要提前取消等待
     */
    @DeleteMapping("/cancel/{requestId}")
    public ResponseEntity<CancelResponse> cancel(
            @PathVariable String requestId,
            @RequestParam String reason) {
        
        try {
            boolean cancelled = rateLimitService.cancel(requestId, reason);
            
            if (!cancelled) {
                return ResponseEntity.status(HttpStatus.NOT_FOUND)
                    .body(CancelResponse.notFound());
            }
            
            return ResponseEntity.ok(CancelResponse.success());
            
        } catch (Exception e) {
            log.error("Cancel failed", e);
            return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                .body(CancelResponse.error());
        }
    }
    
    /**
     * 4. 查询请求状态
     */
    @GetMapping("/status/{requestId}")
    public ResponseEntity<StatusResponse> getStatus(@PathVariable String requestId) {
        try {
            RateLimitStatus status = rateLimitService.getStatus(requestId);
            
            if (status == null) {
                return ResponseEntity.status(HttpStatus.NOT_FOUND)
                    .body(StatusResponse.notFound());
            }
            
            return ResponseEntity.ok(StatusResponse.success(status));
            
        } catch (Exception e) {
            log.error("Get status failed", e);
            return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                .body(StatusResponse.error());
        }
    }
    
    /**
     * 5. 批量注册(提高效率)
     */
    @PostMapping("/batch-register")
    public ResponseEntity<BatchResponse> batchRegister(
            @Valid @RequestBody List<RateLimitRequest> requests) {
        
        try {
            List<BatchResult> results = rateLimitService.batchRegister(requests);
            return ResponseEntity.ok(BatchResponse.success(results));
            
        } catch (Exception e) {
            log.error("Batch register failed", e);
            return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                .body(BatchResponse.error());
        }
    }
    
    /**
     * 6. gRPC接口(高性能场景)
     */
    @GrpcService
    public class RateLimitGrpcService extends RateLimitServiceGrpc.RateLimitServiceImplBase {
        
        @Override
        public void register(RateLimitProto.RegisterRequest request,
                           StreamObserver<RateLimitProto.RegisterResponse> responseObserver) {
            
            try {
                // 转换为内部请求对象
                RateLimitRequest internalRequest = convertToInternal(request);
                
                // 注册请求
                String requestId = rateLimitService.register(internalRequest);
                
                // 返回响应
                RateLimitProto.RegisterResponse response = RateLimitProto.RegisterResponse.newBuilder()
                    .setRequestId(requestId)
                    .setSuccess(true)
                    .build();
                
                responseObserver.onNext(response);
                responseObserver.onCompleted();
                
            } catch (Exception e) {
                responseObserver.onError(Status.INTERNAL.withDescription(e.getMessage()).asRuntimeException());
            }
        }
    }
}

2.2 核心服务实现

java 复制代码
/**
 * 速率控制服务核心实现
 */
@Service
@Slf4j
public class RateLimitServiceImpl implements RateLimitService {
    
    private final RequestQueueManager queueManager;
    private final TokenBucketService tokenBucketService;
    private final CallbackService callbackService;
    private final RequestRepository requestRepository;
    private final PriorityResolver priorityResolver;
    private final RateLimitConfigCache configCache;
    private final CircuitBreaker circuitBreaker;
    private final MetricsCollector metricsCollector;
    
    // 线程池配置
    private final ExecutorService registrationExecutor;
    private final ScheduledExecutorService scheduler;
    
    public RateLimitServiceImpl() {
        // 配置线程池
        this.registrationExecutor = Executors.newFixedThreadPool(
            50, 
            new ThreadFactoryBuilder()
                .setNameFormat("rate-limit-registration-%d")
                .setDaemon(true)
                .build()
        );
        
        this.scheduler = Executors.newScheduledThreadPool(
            10,
            new ThreadFactoryBuilder()
                .setNameFormat("rate-limit-scheduler-%d")
                .setDaemon(true)
                .build()
        );
        
        // 启动清理任务
        startCleanupTask();
    }
    
    /**
     * 注册速率控制请求
     */
    @Override
    public String register(RateLimitRequest request, RateLimitConfig config) {
        long startTime = System.currentTimeMillis();
        
        try {
            // 1. 生成唯一请求ID
            String requestId = generateRequestId(request);
            
            // 2. 解析优先级
            int priority = priorityResolver.resolvePriority(request, config);
            
            // 3. 创建请求记录
            RateLimitRecord record = RateLimitRecord.builder()
                .requestId(requestId)
                .activityId(request.getActivityId())
                .taskToken(request.getTaskToken())
                .rateKey(request.getRateKey())
                .callbackUrl(request.getCallbackUrl())
                .priority(priority)
                .status(RequestStatus.PENDING)
                .createdAt(System.currentTimeMillis())
                .expiresAt(System.currentTimeMillis() + 
                    request.getTimeoutSeconds() * 1000L)
                .metadata(request.getMetadata())
                .config(config)
                .build();
            
            // 4. 异步保存到数据库(先写Redis缓存)
            saveRequestAsync(record);
            
            // 5. 放入合适的队列
            String queueName = determineQueue(request, config);
            queueManager.enqueue(queueName, record, priority);
            
            // 6. 触发队列处理
            triggerQueueProcessing(queueName);
            
            // 7. 记录指标
            metricsCollector.recordRegistration(
                request.getRateKey(),
                priority,
                System.currentTimeMillis() - startTime
            );
            
            // 8. 返回请求ID
            return requestId;
            
        } catch (Exception e) {
            metricsCollector.recordRegistrationError(request.getRateKey());
            throw new RateLimitException("Registration failed", e);
        }
    }
    
    /**
     * 队列处理触发机制
     */
    private void triggerQueueProcessing(String queueName) {
        // 使用令牌桶判断是否应该立即处理
        boolean shouldProcess = tokenBucketService.tryAcquire(queueName, 1);
        
        if (shouldProcess) {
            // 立即处理队列
            processQueue(queueName);
        } else {
            // 调度稍后处理
            scheduler.schedule(() -> 
                processQueue(queueName),
                calculateNextProcessDelay(queueName),
                TimeUnit.MILLISECONDS
            );
        }
    }
    
    /**
     * 处理队列中的请求
     */
    @SneakyThrows
    private void processQueue(String queueName) {
        // 使用批处理提高效率
        List<RateLimitRecord> batch = queueManager.dequeueBatch(queueName, 100);
        
        if (batch.isEmpty()) {
            return;
        }
        
        // 分组处理:按优先级和回调URL分组
        Map<String, List<RateLimitRecord>> groupedByCallback = batch.stream()
            .collect(Collectors.groupingBy(RateLimitRecord::getCallbackUrl));
        
        // 并发处理各组
        List<CompletableFuture<Void>> futures = new ArrayList<>();
        
        for (Map.Entry<String, List<RateLimitRecord>> entry : groupedByCallback.entrySet()) {
            CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
                processBatchForCallback(entry.getKey(), entry.getValue());
            }, registrationExecutor);
            
            futures.add(future);
        }
        
        // 等待所有处理完成
        CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]))
            .exceptionally(ex -> {
                log.error("Batch processing failed", ex);
                return null;
            })
            .get();
    }
    
    /**
     * 处理单个回调URL的批次
     */
    private void processBatchForCallback(String callbackUrl, List<RateLimitRecord> batch) {
        try {
            // 构建批量回调请求
            BatchCallbackRequest batchRequest = new BatchCallbackRequest();
            batchRequest.setCallbackUrl(callbackUrl);
            batchRequest.setRecords(batch);
            
            // 发送批量回调
            boolean success = callbackService.sendBatchCallback(batchRequest);
            
            if (success) {
                // 更新状态为已回调
                batch.forEach(record -> {
                    record.setStatus(RequestStatus.CALLBACK_SENT);
                    record.setCallbackTime(System.currentTimeMillis());
                    updateRecordAsync(record);
                });
                
                metricsCollector.recordBatchCallbackSuccess(
                    callbackUrl, 
                    batch.size()
                );
            } else {
                // 回调失败,重新入队
                batch.forEach(record -> {
                    record.incrementRetryCount();
                    if (record.getRetryCount() < 3) {
                        // 重新入队
                        queueManager.enqueue(
                            determineQueue(record),
                            record,
                            record.getPriority()
                        );
                    } else {
                        // 达到最大重试次数,标记为失败
                        record.setStatus(RequestStatus.FAILED);
                        updateRecordAsync(record);
                        
                        // 发送告警
                        alertService.sendAlert(
                            AlertLevel.ERROR,
                            "Callback failed after max retries",
                            record
                        );
                    }
                });
                
                metricsCollector.recordBatchCallbackFailure(
                    callbackUrl,
                    batch.size()
                );
            }
            
        } catch (Exception e) {
            log.error("Failed to process batch for callback: {}", callbackUrl, e);
            metricsCollector.recordBatchCallbackError(callbackUrl);
        }
    }
    
    /**
     * 生成请求ID(可扩展)
     */
    private String generateRequestId(RateLimitRequest request) {
        // 方案1: UUID
        // return UUID.randomUUID().toString();
        
        // 方案2: Snowflake ID(分布式ID)
        // return snowflakeIdGenerator.nextId();
        
        // 方案3: 组合ID
        return String.format("%s-%s-%d",
            request.getRateKey(),
            request.getActivityId(),
            System.currentTimeMillis()
        );
    }
    
    /**
     * 异步保存请求记录
     */
    private void saveRequestAsync(RateLimitRecord record) {
        CompletableFuture.runAsync(() -> {
            try {
                // 1. 先写Redis缓存(快速)
                redisTemplate.opsForValue().set(
                    buildRedisKey(record.getRequestId()),
                    serialize(record),
                    Duration.ofMinutes(30)
                );
                
                // 2. 异步写入数据库(持久化)
                requestRepository.saveAsync(record);
                
                // 3. 写入搜索索引(用于查询)
                elasticsearchTemplate.indexAsync(record);
                
            } catch (Exception e) {
                log.error("Failed to save request record", e);
                metricsCollector.recordPersistenceError();
            }
        }, registrationExecutor);
    }
    
    /**
     * 心跳处理
     */
    @Override
    public boolean heartbeat(String requestId, HeartbeatRequest heartbeat) {
        try {
            // 1. 获取请求记录
            RateLimitRecord record = getRecord(requestId);
            
            if (record == null) {
                return false;
            }
            
            // 2. 更新心跳时间
            record.setLastHeartbeatTime(System.currentTimeMillis());
            
            // 3. 根据心跳延长超时时间(可选)
            if (heartbeat.isExtendTimeout()) {
                record.setExpiresAt(record.getExpiresAt() + 
                    heartbeat.getExtensionSeconds() * 1000L);
            }
            
            // 4. 异步更新
            updateRecordAsync(record);
            
            metricsCollector.recordHeartbeat(requestId);
            
            return true;
            
        } catch (Exception e) {
            log.error("Heartbeat processing failed", e);
            return false;
        }
    }
    
    /**
     * 启动清理任务
     */
    private void startCleanupTask() {
        // 每5分钟清理一次过期请求
        scheduler.scheduleAtFixedRate(() -> {
            try {
                cleanupExpiredRequests();
            } catch (Exception e) {
                log.error("Cleanup task failed", e);
            }
        }, 5, 5, TimeUnit.MINUTES);
        
        // 每30秒检查超时请求
        scheduler.scheduleAtFixedRate(() -> {
            try {
                checkTimeoutRequests();
            } catch (Exception e) {
                log.error("Timeout check failed", e);
            }
        }, 30, 30, TimeUnit.SECONDS);
    }
    
    /**
     * 清理过期请求
     */
    private void cleanupExpiredRequests() {
        long currentTime = System.currentTimeMillis();
        long cleanupBefore = currentTime - (24 * 60 * 60 * 1000); // 24小时前
        
        try {
            // 1. 清理Redis
            String pattern = "rate_limit:record:*";
            Set<String> keys = redisTemplate.keys(pattern);
            
            if (keys != null) {
                for (String key : keys) {
                    RateLimitRecord record = deserialize(redisTemplate.opsForValue().get(key));
                    if (record != null && record.getCreatedAt() < cleanupBefore) {
                        redisTemplate.delete(key);
                    }
                }
            }
            
            // 2. 清理数据库(软删除)
            requestRepository.cleanupExpired(cleanupBefore);
            
            metricsCollector.recordCleanup();
            
        } catch (Exception e) {
            log.error("Cleanup failed", e);
        }
    }
    
    /**
     * 检查超时请求
     */
    private void checkTimeoutRequests() {
        long currentTime = System.currentTimeMillis();
        
        try {
            // 查找即将超时的请求
            List<RateLimitRecord> expiringSoon = requestRepository
                .findExpiringSoon(currentTime + 60000); // 60秒后超时
            
            for (RateLimitRecord record : expiringSoon) {
                // 发送超时预警
                if (record.getStatus() == RequestStatus.PENDING) {
                    alertService.sendAlert(
                        AlertLevel.WARNING,
                        "Request will timeout soon",
                        record
                    );
                    
                    // 可选:提前回调通知超时
                    if (record.getConfig().isNotifyBeforeTimeout()) {
                        sendTimeoutWarning(record);
                    }
                }
            }
            
            // 处理已超时的请求
            List<RateLimitRecord> timedOut = requestRepository
                .findTimedOut(currentTime);
            
            for (RateLimitRecord record : timedOut) {
                handleTimeout(record);
            }
            
        } catch (Exception e) {
            log.error("Timeout check failed", e);
        }
    }
    
    /**
     * 处理超时请求
     */
    private void handleTimeout(RateLimitRecord record) {
        // 标记为超时
        record.setStatus(RequestStatus.TIMEOUT);
        record.setEndTime(System.currentTimeMillis());
        
        // 异步更新
        updateRecordAsync(record);
        
        // 发送超时回调
        sendTimeoutCallback(record);
        
        // 从队列中移除
        queueManager.remove(record.getRequestId());
        
        metricsCollector.recordTimeout(record.getRateKey());
    }
}

2.3 令牌桶服务实现

java 复制代码
/**
 * 分布式令牌桶服务
 */
@Service
@Slf4j
public class DistributedTokenBucketService implements TokenBucketService {
    
    private final RedisTemplate<String, String> redisTemplate;
    private final RateLimitConfigCache configCache;
    private final ClusterManager clusterManager;
    
    /**
     * Lua脚本:原子操作令牌桶
     */
    private final String tokenBucketLuaScript = """
        local key = KEYS[1]
        local now = tonumber(ARGV[1])
        local tokensRequested = tonumber(ARGV[2])
        local capacity = tonumber(ARGV[3])
        local fillRate = tonumber(ARGV[4])
        local fillInterval = tonumber(ARGV[5])
        
        -- 获取当前桶状态
        local bucket = redis.call('HMGET', key, 'tokens', 'lastRefill', 'lastAccess')
        
        local currentTokens = capacity
        local lastRefill = now
        local lastAccess = now
        
        if bucket[1] then
            -- 桶已存在
            currentTokens = tonumber(bucket[1])
            lastRefill = tonumber(bucket[2])
            lastAccess = tonumber(bucket[3])
            
            -- 计算应补充的令牌
            local timePassed = now - lastRefill
            local fills = math.floor(timePassed / fillInterval)
            
            if fills > 0 then
                local tokensToAdd = fills * fillRate
                currentTokens = math.min(currentTokens + tokensToAdd, capacity)
                lastRefill = lastRefill + fills * fillInterval
            end
        else
            -- 初始化新桶
            redis.call('HMSET', key, 
                'tokens', capacity,
                'lastRefill', now,
                'lastAccess', now
            )
        end
        
        -- 检查是否有足够令牌
        local result = 0
        if currentTokens >= tokensRequested then
            -- 扣减令牌
            currentTokens = currentTokens - tokensRequested
            result = 1
            
            -- 更新桶状态
            redis.call('HMSET', key,
                'tokens', currentTokens,
                'lastRefill', lastRefill,
                'lastAccess', now
            )
            
            -- 设置过期时间(防止无限增长)
            redis.call('EXPIRE', key, 86400) -- 24小时
        end
        
        -- 返回结果和剩余令牌数
        return {result, currentTokens, now - lastAccess}
        """;
    
    /**
     * 尝试获取令牌
     */
    @Override
    public boolean tryAcquire(String rateKey, int tokens) {
        return tryAcquireWithDetails(rateKey, tokens).isSuccess();
    }
    
    /**
     * 获取令牌(带详细信息)
     */
    @Override
    public TokenAcquisitionResult tryAcquireWithDetails(String rateKey, int tokens) {
        String bucketKey = buildBucketKey(rateKey);
        RateLimitConfig config = configCache.getConfig(rateKey);
        
        if (config == null) {
            throw new RateLimitException("Rate limit config not found: " + rateKey);
        }
        
        long now = System.currentTimeMillis();
        
        try {
            // 执行Lua脚本(原子操作)
            List<Object> results = redisTemplate.execute(
                new DefaultRedisScript<>(
                    tokenBucketLuaScript,
                    List.class
                ),
                Collections.singletonList(bucketKey),
                String.valueOf(now),
                String.valueOf(tokens),
                String.valueOf(config.getCapacity()),
                String.valueOf(config.getRefillRate()),
                String.valueOf(config.getRefillIntervalMillis())
            );
            
            if (results == null || results.size() < 3) {
                log.error("Token bucket script returned invalid result");
                return TokenAcquisitionResult.error();
            }
            
            long success = (Long) results.get(0);
            double remainingTokens = (Double) results.get(1);
            long timeSinceLastAccess = (Long) results.get(2);
            
            TokenAcquisitionResult result = new TokenAcquisitionResult();
            result.setSuccess(success == 1);
            result.setRemainingTokens(remainingTokens);
            result.setTimeSinceLastAccess(timeSinceLastAccess);
            
            // 计算下次可用时间
            if (!result.isSuccess()) {
                double tokensNeeded = tokens - remainingTokens;
                double refillsNeeded = Math.ceil(tokensNeeded / config.getRefillRate());
                long waitTime = (long) (refillsNeeded * config.getRefillIntervalMillis());
                
                result.setNextAvailableIn(waitTime);
                result.setNextAvailableAt(now + waitTime);
            }
            
            return result;
            
        } catch (Exception e) {
            log.error("Failed to acquire token for key: {}", rateKey, e);
            
            // 降级策略:如果Redis不可用,使用本地限流
            if (config.isDegradeOnFailure()) {
                return fallbackToLocal(rateKey, tokens, config);
            }
            
            throw new RateLimitException("Token acquisition failed", e);
        }
    }
    
    /**
     * 预热令牌桶(启动时预填充)
     */
    @PostConstruct
    public void warmupBuckets() {
        // 获取所有配置的速率键
        Set<String> rateKeys = configCache.getAllRateKeys();
        
        for (String rateKey : rateKeys) {
            RateLimitConfig config = configCache.getConfig(rateKey);
            
            // 预热到容量的80%
            int warmupTokens = (int) (config.getCapacity() * 0.8);
            
            // 异步预热
            CompletableFuture.runAsync(() -> {
                try {
                    prefillBucket(rateKey, warmupTokens);
                    log.info("Warmed up token bucket for {}: {} tokens", 
                        rateKey, warmupTokens);
                } catch (Exception e) {
                    log.warn("Failed to warmup bucket for {}", rateKey, e);
                }
            });
        }
    }
    
    /**
     * 预填充令牌桶
     */
    private void prefillBucket(String rateKey, int tokens) {
        String bucketKey = buildBucketKey(rateKey);
        RateLimitConfig config = configCache.getConfig(rateKey);
        
        long now = System.currentTimeMillis();
        
        String prefillingLuaScript = """
            local key = KEYS[1]
            local now = tonumber(ARGV[1])
            local tokensToAdd = tonumber(ARGV[2])
            local capacity = tonumber(ARGV[3])
            
            local currentTokens = 0
            local bucket = redis.call('HMGET', key, 'tokens')
            
            if bucket[1] then
                currentTokens = tonumber(bucket[1])
            end
            
            -- 添加令牌,不超过容量
            local newTokens = math.min(currentTokens + tokensToAdd, capacity)
            
            -- 更新桶
            redis.call('HMSET', key,
                'tokens', newTokens,
                'lastRefill', now,
                'lastAccess', now
            )
            
            redis.call('EXPIRE', key, 86400)
            
            return newTokens
            """;
        
        redisTemplate.execute(
            new DefaultRedisScript<>(prefillingLuaScript, Long.class),
            Collections.singletonList(bucketKey),
            String.valueOf(now),
            String.valueOf(tokens),
            String.valueOf(config.getCapacity())
        );
    }
}

2.4 回调服务实现

java 复制代码
/**
 * 回调服务实现
 */
@Service
@Slf4j
public class CallbackServiceImpl implements CallbackService {
    
    private final RestTemplate restTemplate;
    private final RetryTemplate retryTemplate;
    private final CircuitBreakerFactory circuitBreakerFactory;
    private final MetricsCollector metricsCollector;
    
    // HTTP客户端配置
    public CallbackServiceImpl() {
        this.restTemplate = new RestTemplateBuilder()
            .setConnectTimeout(Duration.ofSeconds(5))
            .setReadTimeout(Duration.ofSeconds(30))
            .build();
        
        // 配置重试策略
        this.retryTemplate = new RetryTemplateBuilder()
            .maxAttempts(3)
            .exponentialBackoff(1000, 2, 10000) // 指数退避
            .retryOn(HttpServerErrorException.class)
            .retryOn(ResourceAccessException.class)
            .build();
    }
    
    /**
     * 发送单个回调
     */
    @Override
    public boolean sendCallback(CallbackRequest request) {
        long startTime = System.currentTimeMillis();
        
        try {
            // 使用熔断器保护
            CircuitBreaker circuitBreaker = circuitBreakerFactory
                .create(request.getCallbackUrl());
            
            return circuitBreaker.run(() -> {
                // 构建HTTP请求
                HttpHeaders headers = new HttpHeaders();
                headers.setContentType(MediaType.APPLICATION_JSON);
                headers.set("X-Request-ID", request.getRequestId());
                headers.set("X-Rate-Limit-Service", "v1");
                
                Map<String, Object> body = new HashMap<>();
                body.put("requestId", request.getRequestId());
                body.put("activityId", request.getActivityId());
                body.put("taskToken", Base64.getEncoder().encodeToString(request.getTaskToken()));
                body.put("grantedAt", System.currentTimeMillis());
                body.put("rateKey", request.getRateKey());
                
                HttpEntity<Map<String, Object>> entity = new HttpEntity<>(body, headers);
                
                // 发送请求
                ResponseEntity<String> response = restTemplate.exchange(
                    request.getCallbackUrl(),
                    HttpMethod.POST,
                    entity,
                    String.class
                );
                
                // 检查响应
                if (response.getStatusCode().is2xxSuccessful()) {
                    metricsCollector.recordCallbackSuccess(
                        request.getRateKey(),
                        System.currentTimeMillis() - startTime
                    );
                    return true;
                } else {
                    log.warn("Callback returned non-2xx status: {}", response.getStatusCode());
                    metricsCollector.recordCallbackFailure(request.getRateKey());
                    return false;
                }
                
            }, throwable -> {
                // Fallback处理
                log.error("Callback failed, will retry later", throwable);
                metricsCollector.recordCallbackError(request.getCallbackUrl());
                return false;
            });
            
        } catch (Exception e) {
            log.error("Callback execution failed", e);
            return false;
        }
    }
    
    /**
     * 发送批量回调
     */
    @Override
    public boolean sendBatchCallback(BatchCallbackRequest batchRequest) {
        if (batchRequest.getRecords() == null || batchRequest.getRecords().isEmpty()) {
            return true;
        }
        
        try {
            // 构建批量请求体
            List<Map<String, Object>> batchBody = new ArrayList<>();
            
            for (RateLimitRecord record : batchRequest.getRecords()) {
                Map<String, Object> item = new HashMap<>();
                item.put("requestId", record.getRequestId());
                item.put("activityId", record.getActivityId());
                item.put("taskToken", Base64.getEncoder().encodeToString(record.getTaskToken()));
                item.put("grantedAt", System.currentTimeMillis());
                item.put("rateKey", record.getRateKey());
                item.put("priority", record.getPriority());
                
                batchBody.add(item);
            }
            
            HttpHeaders headers = new HttpHeaders();
            headers.setContentType(MediaType.APPLICATION_JSON);
            headers.set("X-Batch-Size", String.valueOf(batchBody.size()));
            
            HttpEntity<List<Map<String, Object>>> entity = new HttpEntity<>(batchBody, headers);
            
            // 发送批量请求
            ResponseEntity<String> response = restTemplate.exchange(
                batchRequest.getCallbackUrl(),
                HttpMethod.POST,
                entity,
                String.class
            );
            
            boolean success = response.getStatusCode().is2xxSuccessful();
            
            if (success) {
                metricsCollector.recordBatchCallbackSuccess(
                    batchRequest.getCallbackUrl(),
                    batchBody.size()
                );
            } else {
                metricsCollector.recordBatchCallbackFailure(
                    batchRequest.getCallbackUrl(),
                    batchBody.size()
                );
            }
            
            return success;
            
        } catch (Exception e) {
            log.error("Batch callback failed", e);
            metricsCollector.recordBatchCallbackError(batchRequest.getCallbackUrl());
            return false;
        }
    }
    
    /**
     * 异步回调(不阻塞)
     */
    @Async("callbackExecutor")
    @Override
    public CompletableFuture<Boolean> sendCallbackAsync(CallbackRequest request) {
        return CompletableFuture.supplyAsync(() -> sendCallback(request));
    }
}

2.5 Temporal回调端点实现

java 复制代码
/**
 * Temporal回调端点实现
 */
@RestController
@RequestMapping("/temporal/callback")
@Slf4j
public class TemporalCallbackController {
    
    private final TemporalClient temporalClient;
    private final RateLimitService rateLimitService;
    private final CallbackValidator validator;
    
    /**
     * 接收速率服务的回调
     */
    @PostMapping("/rate-limit-granted")
    public ResponseEntity<Map<String, Object>> handleRateLimitCallback(
            @RequestBody RateLimitGrantedRequest request) {
        
        try {
            log.info("Received rate limit callback: {}", request.getRequestId());
            
            // 1. 验证请求签名(可选)
            if (!validator.validateSignature(request)) {
                log.warn("Invalid signature for callback: {}", request.getRequestId());
                return ResponseEntity.status(HttpStatus.UNAUTHORIZED).build();
            }
            
            // 2. 验证请求状态
            RateLimitStatus status = rateLimitService.getStatus(request.getRequestId());
            if (status == null || status != RequestStatus.PENDING) {
                log.warn("Invalid status for callback: {}", request.getRequestId());
                return ResponseEntity.badRequest().body(Map.of(
                    "error", "Invalid request status"
                ));
            }
            
            // 3. 解码TaskToken
            byte[] taskToken = Base64.getDecoder().decode(request.getTaskToken());
            
            // 4. 完成Temporal Activity
            CompletableFuture.runAsync(() -> {
                try {
                    // 使用Temporal Client完成Activity
                    temporalClient.completeActivity(taskToken, Map.of(
                        "requestId", request.getRequestId(),
                        "grantedAt", request.getGrantedAt(),
                        "rateKey", request.getRateKey()
                    ));
                    
                    log.info("Completed activity for request: {}", request.getRequestId());
                    
                    // 更新请求状态
                    rateLimitService.markAsCompleted(request.getRequestId());
                    
                } catch (ActivityNotFoundException e) {
                    log.warn("Activity not found for request: {}", request.getRequestId());
                    // Activity可能已超时或取消
                    rateLimitService.markAsFailed(request.getRequestId(), "Activity not found");
                } catch (Exception e) {
                    log.error("Failed to complete activity", e);
                    // 记录失败,可能需要人工干预
                    rateLimitService.markAsFailed(request.getRequestId(), e.getMessage());
                }
            });
            
            // 5. 立即返回成功响应
            return ResponseEntity.ok(Map.of(
                "status", "success",
                "message", "Callback received and processing"
            ));
            
        } catch (Exception e) {
            log.error("Failed to handle callback", e);
            return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                .body(Map.of("error", "Internal server error"));
        }
    }
    
    /**
     * 批量回调处理
     */
    @PostMapping("/rate-limit-granted/batch")
    public ResponseEntity<BatchCallbackResponse> handleBatchCallback(
            @RequestBody List<RateLimitGrantedRequest> batchRequest) {
        
        BatchCallbackResponse response = new BatchCallbackResponse();
        List<String> successIds = new ArrayList<>();
        List<FailedCallback> failedCallbacks = new ArrayList<>();
        
        for (RateLimitGrantedRequest request : batchRequest) {
            try {
                // 处理单个回调
                handleSingleCallback(request);
                successIds.add(request.getRequestId());
                
            } catch (Exception e) {
                log.error("Failed to process callback for: {}", request.getRequestId(), e);
                failedCallbacks.add(new FailedCallback(
                    request.getRequestId(),
                    e.getMessage()
                ));
            }
        }
        
        response.setSuccessCount(successIds.size());
        response.setFailedCount(failedCallbacks.size());
        response.setSuccessIds(successIds);
        response.setFailedCallbacks(failedCallbacks);
        
        return ResponseEntity.ok(response);
    }
    
    /**
     * 健康检查端点
     */
    @GetMapping("/health")
    public ResponseEntity<HealthResponse> healthCheck() {
        try {
            // 检查Temporal连接
            boolean temporalConnected = temporalClient.isConnected();
            
            // 检查数据库连接
            boolean dbConnected = rateLimitService.isDatabaseConnected();
            
            HealthStatus status = temporalConnected && dbConnected ? 
                HealthStatus.UP : HealthStatus.DOWN;
            
            Map<String, Object> details = new HashMap<>();
            details.put("temporal", temporalConnected ? "connected" : "disconnected");
            details.put("database", dbConnected ? "connected" : "disconnected");
            details.put("timestamp", System.currentTimeMillis());
            
            HealthResponse response = new HealthResponse();
            response.setStatus(status);
            response.setDetails(details);
            
            return ResponseEntity.ok(response);
            
        } catch (Exception e) {
            log.error("Health check failed", e);
            return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                .body(new HealthResponse(HealthStatus.DOWN, Map.of("error", e.getMessage())));
        }
    }
}

三、部署和运维设计

四、监控和告警设计

4.1 监控指标

java 复制代码
/**
 * 监控指标定义
 */
@Component
@Slf4j
public class RateLimitMetricsCollector {
    
    private final MeterRegistry meterRegistry;
    private final Map<String, DistributionSummary> distributions;
    
    @PostConstruct
    public void initMetrics() {
        // 注册核心指标
        registerCoreMetrics();
        startMetricsExportTask();
    }
    
    private void registerCoreMetrics() {
        // 1. 请求指标
        Counter.builder("rate_limit.requests.total")
            .description("Total rate limit requests")
            .tag("type", "registration")
            .register(meterRegistry);
        
        // 2. 队列深度指标
        Gauge.builder("rate_limit.queue.depth", queueManager::getQueueDepth)
            .description("Current queue depth")
            .tag("queue", "default")
            .register(meterRegistry);
        
        // 3. 令牌桶指标
        Timer.builder("rate_limit.token.acquisition.time")
            .description("Time taken to acquire tokens")
            .publishPercentiles(0.5, 0.95, 0.99)
            .register(meterRegistry);
        
        // 4. 回调指标
        Counter.builder("rate_limit.callbacks.total")
            .description("Total callbacks sent")
            .tag("status", "success")
            .register(meterRegistry);
        
        // 5. 错误指标
        Counter.builder("rate_limit.errors.total")
            .description("Total errors")
            .tag("type", "registration_error")
            .register(meterRegistry);
        
        // 6. 等待时间分布
        DistributionSummary.builder("rate_limit.wait.time.distribution")
            .description("Distribution of wait times")
            .baseUnit("milliseconds")
            .publishPercentiles(0.5, 0.9, 0.95, 0.99)
            .register(meterRegistry);
    }
    
    /**
     * 记录关键业务指标
     */
    public void recordRegistration(String rateKey, int priority, long duration) {
        // 记录到MeterRegistry
        Timer.Sample sample = Timer.start(meterRegistry);
        sample.stop(Timer.builder("rate_limit.registration.time")
            .tag("rate_key", rateKey)
            .tag("priority", String.valueOf(priority))
            .register(meterRegistry));
        
        // 记录到日志(可选)
        log.debug("Registration completed: rateKey={}, priority={}, duration={}ms", 
            rateKey, priority, duration);
        
        // 更新自定义指标
        meterRegistry.counter("rate_limit.registration.count", 
            "rate_key", rateKey,
            "priority", String.valueOf(priority)
        ).increment();
    }
    
    /**
     * 导出指标到外部系统
     */
    @Scheduled(fixedRate = 60000) // 每分钟一次
    public void exportMetrics() {
        try {
            // 1. 获取当前所有指标
            Map<String, Object> metrics = collectAllMetrics();
            
            // 2. 发送到监控系统
            metricsExporter.export(metrics);
            
            // 3. 记录到Elasticsearch(用于分析)
            elasticsearchService.indexMetrics(metrics);
            
            // 4. 检查阈值并触发告警
            checkThresholds(metrics);
            
        } catch (Exception e) {
            log.error("Failed to export metrics", e);
        }
    }
    
    private void checkThresholds(Map<String, Object> metrics) {
        // 检查队列深度
        Long queueDepth = (Long) metrics.get("queue.depth");
        if (queueDepth > 10000) {
            alertService.sendAlert(
                AlertLevel.WARNING,
                "Queue depth exceeded threshold",
                Map.of("queue_depth", queueDepth)
            );
        }
        
        // 检查错误率
        Double errorRate = (Double) metrics.get("error.rate");
        if (errorRate > 0.01) { // 1%错误率
            alertService.sendAlert(
                AlertLevel.ERROR,
                "Error rate exceeded threshold",
                Map.of("error_rate", errorRate)
            );
        }
        
        // 检查平均等待时间
        Double avgWaitTime = (Double) metrics.get("wait.time.avg");
        if (avgWaitTime > 300000) { // 5分钟
            alertService.sendAlert(
                AlertLevel.WARNING,
                "Average wait time too high",
                Map.of("avg_wait_time", avgWaitTime)
            );
        }
    }
}

4.2 告警规则

yaml 复制代码
# alert-rules.yml
groups:
- name: rate-limit-service
  rules:
  - alert: HighQueueDepth
    expr: rate_limit_queue_depth > 10000
    for: 5m
    labels:
      severity: warning
      service: rate-limit
    annotations:
      summary: "Rate limit queue depth is high"
      description: "Queue depth is {{ $value }} (threshold: 10000)"
      
  - alert: HighErrorRate
    expr: rate(rate_limit_errors_total[5m]) > 0.01
    for: 2m
    labels:
      severity: critical
      service: rate-limit
    annotations:
      summary: "High error rate in rate limit service"
      description: "Error rate is {{ $value }} (threshold: 1%)"
      
  - alert: CallbackFailure
    expr: rate(rate_limit_callbacks_total{status="failed"}[5m]) / rate(rate_limit_callbacks_total[5m]) > 0.05
    for: 3m
    labels:
      severity: warning
      service: rate-limit
    annotations:
      summary: "High callback failure rate"
      description: "Callback failure rate is {{ $value }}%"
      
  - alert: RedisConnectionIssues
    expr: redis_up == 0
    for: 1m
    labels:
      severity: critical
      service: rate-limit
    annotations:
      summary: "Redis connection lost"
      description: "Rate limit service cannot connect to Redis"
      
  - alert: ServiceDown
    expr: up{job="rate-limit-service"} == 0
    for: 2m
    labels:
      severity: critical
      service: rate-limit
    annotations:
      summary: "Rate limit service is down"
      description: "Service {{ $labels.instance }} is not responding"

五、安全设计

5.1 认证和授权

java 复制代码
/**
 * 安全配置
 */
@Configuration
@EnableWebSecurity
public class SecurityConfig extends WebSecurityConfigurerAdapter {
    
    @Override
    protected void configure(HttpSecurity http) throws Exception {
        http
            .csrf().disable()  // 对API通常禁用CSRF
            .cors().configurationSource(corsConfigurationSource())
            .and()
            .authorizeRequests()
                // 公共端点
                .antMatchers("/api/v1/rate-limit/health").permitAll()
                .antMatchers("/actuator/health").permitAll()
                
                // 注册接口需要API Key
                .antMatchers(HttpMethod.POST, "/api/v1/rate-limit/register")
                    .hasAuthority("API_CLIENT")
                
                // 管理接口需要管理员权限
                .antMatchers("/api/v1/admin/**")
                    .hasRole("ADMIN")
                
                // Temporal回调端点需要签名验证
                .antMatchers("/temporal/callback/**")
                    .hasAuthority("TEMPORAL_SERVICE")
                
                .anyRequest().authenticated()
            .and()
            .addFilterBefore(apiKeyFilter(), UsernamePasswordAuthenticationFilter.class)
            .addFilterBefore(temporalSignatureFilter(), UsernamePasswordAuthenticationFilter.class)
            .exceptionHandling()
                .authenticationEntryPoint(authenticationEntryPoint())
                .accessDeniedHandler(accessDeniedHandler());
    }
    
    /**
     * API Key过滤器
     */
    @Bean
    public ApiKeyAuthenticationFilter apiKeyFilter() {
        return new ApiKeyAuthenticationFilter();
    }
    
    /**
     * Temporal签名过滤器
     */
    @Bean
    public TemporalSignatureFilter temporalSignatureFilter() {
        return new TemporalSignatureFilter();
    }
    
    /**
     * 速率限制配置
     */
    @Bean
    public FilterRegistrationBean<RateLimitFilter> rateLimitFilter() {
        FilterRegistrationBean<RateLimitFilter> registration = new FilterRegistrationBean<>();
        registration.setFilter(new RateLimitFilter());
        registration.addUrlPatterns("/api/v1/*");
        registration.setOrder(1);
        return registration;
    }
}

5.2 API Key认证实现

java 复制代码
/**
 * API Key认证实现
 */
@Component
public class ApiKeyAuthenticationFilter extends OncePerRequestFilter {
    
    private final ApiKeyService apiKeyService;
    
    @Override
    protected void doFilterInternal(HttpServletRequest request,
                                  HttpServletResponse response,
                                  FilterChain filterChain) throws ServletException, IOException {
        
        // 获取API Key
        String apiKey = extractApiKey(request);
        
        if (apiKey == null) {
            filterChain.doFilter(request, response);
            return;
        }
        
        // 验证API Key
        ApiKeyInfo apiKeyInfo = apiKeyService.validateApiKey(apiKey);
        
        if (apiKeyInfo == null || !apiKeyInfo.isValid()) {
            sendErrorResponse(response, HttpStatus.UNAUTHORIZED, "Invalid API Key");
            return;
        }
        
        // 检查IP白名单
        if (!apiKeyInfo.isIpAllowed(request.getRemoteAddr())) {
            sendErrorResponse(response, HttpStatus.FORBIDDEN, "IP not allowed");
            return;
        }
        
        // 设置认证上下文
        ApiKeyAuthentication authentication = new ApiKeyAuthentication(
            apiKeyInfo.getClientId(),
            apiKey,
            apiKeyInfo.getAuthorities()
        );
        
        SecurityContextHolder.getContext().setAuthentication(authentication);
        
        // 记录访问日志
        logAccess(request, apiKeyInfo);
        
        filterChain.doFilter(request, response);
    }
    
    private String extractApiKey(HttpServletRequest request) {
        // 从Header获取
        String apiKey = request.getHeader("X-API-Key");
        
        if (apiKey != null) {
            return apiKey;
        }
        
        // 从Query Parameter获取
        apiKey = request.getParameter("api_key");
        
        return apiKey;
    }
    
    private void sendErrorResponse(HttpServletResponse response, 
                                 HttpStatus status, 
                                 String message) throws IOException {
        response.setStatus(status.value());
        response.setContentType("application/json");
        
        Map<String, Object> error = new HashMap<>();
        error.put("timestamp", System.currentTimeMillis());
        error.put("status", status.value());
        error.put("error", status.getReasonPhrase());
        error.put("message", message);
        
        response.getWriter().write(new ObjectMapper().writeValueAsString(error));
    }
}

六、成本优化建议

6.1 云服务成本优化

java 复制代码
/**
 * 成本优化策略
 */
@Component
public class CostOptimizationService {
    
    /**
     * 动态资源调整
     */
    @Scheduled(cron = "0 0 * * * *") // 每小时执行
    public void optimizeResources() {
        // 1. 基于负载预测调整副本数
        int predictedLoad = loadPredictor.predictNextHourLoad();
        int optimalReplicas = calculateOptimalReplicas(predictedLoad);
        
        // 调整Kubernetes副本数
        kubernetesClient.scaleDeployment("rate-limit-service", optimalReplicas);
        
        // 2. 自动缩放Redis实例
        optimizeRedisResources();
        
        // 3. 清理不必要的数据
        cleanupOldData();
    }
    
    private void optimizeRedisResources() {
        // 监控Redis内存使用率
        double memoryUsage = redisMonitor.getMemoryUsage();
        
        if (memoryUsage < 0.3) {
            // 使用率低,考虑降级实例类型
            cloudProvider.downgradeRedisInstance();
        } else if (memoryUsage > 0.8) {
            // 使用率高,考虑升级或分片
            cloudProvider.upgradeRedisInstance();
        }
    }
    
    /**
     * 数据生命周期管理
     */
    private void cleanupOldData() {
        // 删除30天前的请求记录
        requestRepository.deleteOlderThan(30, TimeUnit.DAYS);
        
        // 压缩监控数据
        metricsRepository.compressOldMetrics();
        
        // 归档日志
        logArchiver.archiveOldLogs();
    }
    
    /**
     * 缓存优化
     */
    public void optimizeCachingStrategy() {
        // 基于访问模式调整缓存策略
        Map<String, CacheHitRate> hitRates = cacheMonitor.getHitRates();
        
        hitRates.forEach((key, hitRate) -> {
            if (hitRate.getRate() < 0.5) {
                // 命中率低,减少缓存容量
                cacheManager.reduceCacheSize(key);
            } else if (hitRate.getRate() > 0.9) {
                // 命中率高,增加缓存容量
                cacheManager.increaseCacheSize(key);
            }
        });
    }
}

总结

这个第三方速率控制服务设计具有以下特点:

  1. 高可用性:集群部署,自动故障转移
  2. 高性能:批量处理,异步处理,缓存优化
  3. 可扩展性:水平扩展,分区处理,动态调整
  4. 可靠性:数据持久化,重试机制,监控告警
  5. 安全性:API认证,请求签名,访问控制
  6. 可观测性:完整监控指标,日志聚合,分布式追踪
  7. 成本优化:自动缩放,资源优化,数据生命周期管理

推荐技术栈

  • 后端:Spring Boot + Spring Cloud
  • 数据库:Redis + PostgreSQL/MySQL
  • 消息队列:Kafka/RabbitMQ
相关推荐
开开心心就好2 小时前
键盘改键工具免安装,自定义键位屏蔽误触
java·网络·windows·随机森林·计算机外设·电脑·excel
OpenMiniServer2 小时前
电气化能源革命下的社会
java·人工智能·能源
独自破碎E2 小时前
LCR_019_验证回文串II
java·开发语言
坚持就完事了2 小时前
Java中的一些关键字
java·开发语言
寄存器漫游者2 小时前
数据结构 C语言 顺序栈
java·c语言·数据结构
heartbeat..2 小时前
Redis 性能优化全指南:从基础配置到架构升级
java·redis·性能优化·架构
m0_748233172 小时前
C#与C语言:5大核心语法共性
java·jvm·算法
JavaGuide2 小时前
推荐一个基于 Spring Boot 4.0 + Java 21 + Spring AI 2.0 的大模型项目!
java·spring boot·spring
Maynor9963 小时前
Clawdbot安装教程:从零开始到接入飞书
java·数据库·飞书