Spring AI Alibaba 1.x 系列【20】MessagesAgentHook 、MessagesModelHook 相关实现类

文章目录

  • [1. MessagesAgentHook 实现类](#1. MessagesAgentHook 实现类)
    • [1.1 InstructionAgentHook](#1.1 InstructionAgentHook)
  • [2. MessagesModelHook 实现类](#2. MessagesModelHook 实现类)
    • [2.1 SummarizationHook](#2.1 SummarizationHook)
    • [2.2 ReturnDirectModelHook](#2.2 ReturnDirectModelHook)
    • [2.3 PIIDetectionHook](#2.3 PIIDetectionHook)

1. MessagesAgentHook 实现类

1.1 InstructionAgentHook

指令注入 Agent 钩子,在 Agent 每次运行前将 ReactAgent 的系统指令注入消息列表。

  • 执行时机:HookPosition.BEFORE_AGENTAgent 执行前),自动读取 Agent 指令并添加到消息头部,避免子图场景下重复注入指令。
  • 核心逻辑:读取 ReactAgent 的指令配置,非空时自动添加 AgentInstructionMessage 到消息头部
  • 策略特性:返回 REPLACE 更新策略,避免作为子图节点时重复添加指令
  • 默认机制:ReactAgent 无其他指令类钩子时,会自动加载该钩子,且拥有最高执行优先级(最先运行)。

属性定义:

java 复制代码
@HookPositions(HookPosition.BEFORE_AGENT)
public class InstructionAgentHook extends MessagesAgentHook {

	/**
	 * 关联的 ReactAgent 实例,用于获取系统指令
	 */
	private ReactAgent reactAgent;
	}

重写 Agent 执行前钩子,注入系统指令消息:

java 复制代码
/**
	 * 重写 Agent 执行前钩子:注入系统指令消息
	 * <p>
	 * 核心逻辑:
	 * <ol>
	 *   <li>校验 ReactAgent 实例是否为空</li>
	 *   <li>获取 Agent 的系统指令,为空则直接返回原消息</li>
	 *   <li>构建指令消息,追加到消息列表头部</li>
	 *   <li>返回包含新消息的 AgentCommand</li>
	 * </ol>
	 * </p>
	 * @param previousMessages 原始消息列表
	 * @param config 运行配置
	 * @return 包含指令消息的 AgentCommand
	 */
	@Override
	public AgentCommand beforeAgent(List<Message> previousMessages, RunnableConfig config) {
		if (reactAgent == null) {
			return new AgentCommand(previousMessages);
		}
		String instruction = reactAgent.instruction();
		if (!StringUtils.hasLength(instruction)) {
			return new AgentCommand(previousMessages);
		}
		AgentInstructionMessage instructionMessage = AgentInstructionMessage.builder().text(instruction).build();
		List<Message> newMessages = new ArrayList<>(previousMessages);
		newMessages.add(instructionMessage);
		return new AgentCommand(newMessages);
	}

其他源码内容:

java 复制代码
	/**
	 * 获取钩子名称
	 * @return 固定返回 InstructionAgentHook
	 */
	@Override
	public String getName() {
		return "InstructionAgentHook";
	}

	/**
	 * 获取钩子执行优先级
	 * @return -100,最低优先级值,代表最先执行
	 */
	@Override
	public int getOrder() {
		return -100;
	}

	/**
	 * 获取关联的 ReactAgent 实例
	 * @return ReactAgent 实例
	 */
	@Override
	public ReactAgent getAgent() {
		return reactAgent;
	}

	/**
	 * 设置关联的 ReactAgent 实例
	 * @param agent ReactAgent 实例
	 */
	@Override
	public void setAgent(ReactAgent agent) {
		this.reactAgent = agent;
	}

	/**
	 * 创建默认的 InstructionAgentHook 实例
	 * <p>无其他指令类钩子处理指令时使用</p>
	 * @return 新的 InstructionAgentHook 实例
	 */
	public static InstructionAgentHook create() {
		return new InstructionAgentHook();
	}
}

2. MessagesModelHook 实现类

2.1 SummarizationHook

对话历史总结钩子。模型调用前执行,监控对话消息的 Token 数量,当达到阈值时自动总结历史消息,防止 Token 超限,同时保留关键上下文(首条用户消息、最新消息),保证对话连贯性。

核心特性:

  • 自动检测 Token 用量,触发智能总结
  • 安全切割消息,不拆分 AI 工具调用配对
  • 可配置保留消息数、总结提示词
  • 支持保留首条用户消息,锁定核心对话意图

使用示例:

java 复制代码
SummarizationHook summarizer = SummarizationHook.builder()
      .model(chatModel)
      .maxTokensBeforeSummary(4000)
      .messagesToKeep(20)
      .keepFirstUserMessage(true)
      .build();

常量定义:

java 复制代码
@HookPositions({HookPosition.BEFORE_MODEL})
public class SummarizationHook extends MessagesModelHook {

	private static final Logger log = LoggerFactory.getLogger(SummarizationHook.class);

	/**
	 * 默认总结提示词:提取对话历史中的核心上下文
	 */
	private static final String DEFAULT_SUMMARY_PROMPT =
			"<role>\nContext Extraction Assistant\n</role>\n\n" +
					"<primary_objective>\n" +
					"Your sole objective in this task is to extract the highest quality/most relevant context " +
					"from the conversation history below.\n</primary_objective>\n\n" +
					"<instructions>\n" +
					"The conversation history below will be replaced with the context you extract in this step. " +
					"Extract and record all of the most important context from the conversation history.\n" +
					"Respond ONLY with the extracted context. Do not include any additional information.\n" +
					"</instructions>\n\n" +
					"<messages>\nMessages to summarize:\n%s\n</messages>";

	/**
	 * 总结消息的前缀标识,用于区分普通消息
	 */
	private static final String SUMMARY_PREFIX = "## Previous conversation summary:";
	
	/**
	 * 默认保留的最新消息数量
	 */
	private static final int DEFAULT_MESSAGES_TO_KEEP = 20;
	
	/**
	 * 工具消息配对检索范围,避免拆分相关消息
	 */
	private static final int SEARCH_RANGE_FOR_TOOL_PAIRS = 5;
	
	/**
	 * 默认保留首条用户消息
	 */
	private static final boolean DEFAULT_KEEP_FIRST_USER_MESSAGE = true;

属性定义:

java 复制代码
	/**
	 * 用于生成总结的大模型实例
	 */
	private final ChatModel model;
	
	/**
	 * 触发总结的Token阈值,null表示不触发自动总结
	 */
	private final Integer maxTokensBeforeSummary;
	
	/**
	 * 总结后保留的最新消息数量
	 */
	private final int messagesToKeep;
	
	/**
	 * Token计数器,用于估算消息Token用量
	 */
	private final TokenCounter tokenCounter;
	
	/**
	 * 自定义总结提示词,用于引导模型生成高质量总结
	 */
	private final String summaryPrompt;
	
	/**
	 * 总结消息的前缀,用于标识总结内容
	 */
	private final String summaryPrefix;
	
	/**
	 * 是否保留首条用户消息,默认true
	 */
	private final boolean keepFirstUserMessage;

模型调用前执行的核心方法,检查 Token 用量并生成总结:

java 复制代码
	/**
	 * 模型调用前执行的核心方法:检查Token用量并生成总结
	 * <p>
	 * 执行流程:
	 * <ol>
	 *   <li>校验配置,无模型或无阈值则返回原消息</li>
	 *   <li>计算当前消息Token总量,未达阈值则返回原消息</li>
	 *   <li>查找安全切割点,保留最新消息</li>
	 *   <li>生成历史消息总结,保留首条用户消息</li>
	 *   <li>构建新消息列表,替换原消息</li>
	 * </ol>
	 *
	 * @param previousMessages 原始消息列表
	 * @param config 运行配置
	 * @return 包含总结后的新消息列表的AgentCommand
	 */
	@Override
	public AgentCommand beforeModel(List<Message> previousMessages, RunnableConfig config) {
		if (maxTokensBeforeSummary == null) {
			return new AgentCommand(previousMessages);
		}

		int totalTokens = tokenCounter.countTokens(previousMessages);

		if (totalTokens < maxTokensBeforeSummary) {
			return new AgentCommand(previousMessages);
		}

		log.info("Token count {} exceeds threshold {}, triggering summarization",
				totalTokens, maxTokensBeforeSummary);

		int cutoffIndex = findSafeCutoff(previousMessages);

		if (cutoffIndex <= 0) {
			log.warn("Cannot find safe cutoff point for summarization");
			return new AgentCommand(previousMessages);
		}

		UserMessage firstUserMessage = null;
		if (keepFirstUserMessage) {
			for (Message msg : previousMessages) {
				if (msg instanceof UserMessage) {
					firstUserMessage = (UserMessage) msg;
					break;
				}
			}
		}

		List<Message> toSummarize = new ArrayList<>();
		for (int i = 0; i < cutoffIndex; i++) {
			Message msg = previousMessages.get(i);
			if (msg != firstUserMessage) {
				toSummarize.add(msg);
			}
		}

		String summary = createSummary(toSummarize);

		SystemMessage summaryMessage = new SystemMessage(summaryPrefix + "\n" + summary);

		List<Message> recentMessages = new ArrayList<>();
		for (int i = cutoffIndex; i < previousMessages.size(); i++) {
			recentMessages.add(previousMessages.get(i));
		}

		List<Message> newMessages = new ArrayList<>();
		if (firstUserMessage != null) {
			newMessages.add(firstUserMessage);
		}
		newMessages.add(summaryMessage);
		newMessages.addAll(recentMessages);

		if (firstUserMessage != null) {
			log.info("Summarized {} messages, keeping {} recent messages (First UserMessage preserved)",
					toSummarize.size(), recentMessages.size());
		} else {
			log.info("Summarized {} messages, keeping {} recent messages",
					toSummarize.size(), recentMessages.size());
		}

		return new AgentCommand(newMessages, UpdatePolicy.REPLACE);
	}

	/**
	 * 查找安全的消息切割点,确保不拆分AI/工具消息配对
	 *
	 * @param messages 消息列表
	 * @return 安全切割点索引,返回0表示无法安全切割
	 */
	private int findSafeCutoff(List<Message> messages) {
		if (messages.size() <= messagesToKeep) {
			return 0;
		}

		int targetCutoff = messages.size() - messagesToKeep;

		// 从目标切割点向后搜索,找到安全的切割位置
		for (int i = targetCutoff; i >= 0; i--) {
			if (isSafeCutoffPoint(messages, i)) {
				return i;
			}
		}

		return 0;
	}

	/**
	 * 检查切割点是否安全,不会拆分相关AI和工具消息
	 *
	 * @param messages 消息列表
	 * @param cutoffIndex 切割点索引
	 * @return true表示安全,false表示不安全
	 */
	private boolean isSafeCutoffPoint(List<Message> messages, int cutoffIndex) {
		if (cutoffIndex >= messages.size()) {
			return true;
		}

		int searchStart = Math.max(0, cutoffIndex - SEARCH_RANGE_FOR_TOOL_PAIRS);
		int searchEnd = Math.min(messages.size(), cutoffIndex + SEARCH_RANGE_FOR_TOOL_PAIRS);

		for (int i = searchStart; i < searchEnd; i++) {
			if (!hasToolCalls(messages.get(i))) {
				continue;
			}

			AssistantMessage aiMessage = (AssistantMessage) messages.get(i);
			Set<String> toolCallIds = extractToolCallIds(aiMessage);
			if (cutoffSeparatesToolPair(messages, i, cutoffIndex, toolCallIds)) {
				return false;
			}
		}

		return true;
	}

	/**
	 * 检查消息是否为包含工具调用的AI消息
	 *
	 * @param message 消息对象
	 * @return true表示是包含工具调用的AI消息
	 */
	private boolean hasToolCalls(Message message) {
		return message instanceof AssistantMessage assistantMessage && !assistantMessage.getToolCalls().isEmpty();
	}

	/**
	 * 从AI消息中提取工具调用ID集合
	 *
	 * @param aiMessage AI消息对象
	 * @return 工具调用ID集合
	 */
	private Set<String> extractToolCallIds(AssistantMessage aiMessage) {
		Set<String> toolCallIds = new HashSet<>();
		for (AssistantMessage.ToolCall toolCall : aiMessage.getToolCalls()) {
			String callId = toolCall.id();
			toolCallIds.add(callId);
		}
		return toolCallIds;
	}

	/**
	 * 检查切割点是否会拆分AI消息和对应的工具响应消息
	 *
	 * @param messages 消息列表
	 * @param aiMessageIndex AI消息索引
	 * @param cutoffIndex 切割点索引
	 * @param toolCallIds 工具调用ID集合
	 * @return true表示会拆分,false表示不会拆分
	 */
	private boolean cutoffSeparatesToolPair(
			List<Message> messages,
			int aiMessageIndex,
			int cutoffIndex,
			Set<String> toolCallIds) {
		for (int j = aiMessageIndex + 1; j < messages.size(); j++) {
			Message message = messages.get(j);
			if (message instanceof ToolResponseMessage toolResponseMessage) {
				// 检查工具响应消息是否包含当前AI消息的工具调用ID
				for (ToolResponseMessage.ToolResponse response : toolResponseMessage.getResponses()) {
					if (toolCallIds.contains(response.id())) {
						boolean aiBeforeCutoff = aiMessageIndex < cutoffIndex;
						boolean toolBeforeCutoff = j < cutoffIndex;
						// 如果AI消息和工具响应消息被切割点分开,则返回true
						if (aiBeforeCutoff != toolBeforeCutoff) {
							return true;
						}
					}
				}
			}
		}
		return false;
	}

	/**
	 * 使用大模型生成消息总结
	 *
	 * @param messages 待总结的消息列表
	 * @return 生成的总结文本
	 */
	private String createSummary(List<Message> messages) {
		if (messages.isEmpty()) {
			return "No previous conversation.";
		}

		StringBuilder messageText = new StringBuilder();
		for (Message msg : messages) {
			String role = getRoleName(msg);
			messageText.append(role).append(": ").append(msg.getText()).append("\n");
		}

		String prompt = String.format(summaryPrompt, messageText.toString());

		try {
			Prompt summaryPromptObj = new Prompt(List.of(new UserMessage(prompt)));
			var response = model.call(summaryPromptObj);
			return response.getResult().getOutput().getText();
		}
		catch (Exception e) {
			log.error("Failed to create summary: {}", e.getMessage());
			return "Summary generation failed: " + e.getMessage();
		}
	}

	/**
	 * 获取消息角色名称,用于总结文本格式化
	 *
	 * @param message 消息对象
	 * @return 角色名称(Human/Assistant/System/Tool/Unknown)
	 */
	private String getRoleName(Message message) {
		if (message instanceof UserMessage) {
			return "Human";
		}
		else if (message instanceof AssistantMessage) {
			return "Assistant";
		}
		else if (message instanceof SystemMessage) {
			return "System";
		}
		else if (message instanceof ToolResponseMessage) {
			return "Tool";
		}
		else {
			return "Unknown";
		}
	}

其他源码内容:

java 复制代码
	/**
	 * 私有构造方法,通过建造者模式创建实例
	 * @param builder 建造者对象,包含所有配置参数
	 */
	private SummarizationHook(Builder builder) {
		this.model = builder.model;
		this.maxTokensBeforeSummary = builder.maxTokensBeforeSummary;
		this.messagesToKeep = builder.messagesToKeep;
		this.tokenCounter = builder.tokenCounter;
		this.summaryPrompt = builder.summaryPrompt;
		this.summaryPrefix = builder.summaryPrefix;
		this.keepFirstUserMessage = builder.keepFirstUserMessage;
	}

	/**
	 * 获取建造者实例,用于配置创建钩子
	 * @return 建造者对象
	 */
	public static Builder builder() {
		return new Builder();
	}

	/**
	 * 获取钩子名称,用于日志和调试
	 * @return 钩子名称
	 */
	@Override
	public String getName() {
		return "Summarization";
	}

	/**
	 * 获取钩子支持的跳转能力,当前不支持任何跳转
	 * @return 空列表
	 */
	@Override
	public List<JumpTo> canJumpTo() {
		return List.of();
	}

	/**
	 * 建造者类:用于配置和创建SummarizationHook实例
	 * <p>支持链式调用配置所有参数,确保实例创建的安全性和灵活性</p>
	 */
	public static class Builder {
		private ChatModel model;
		private Integer maxTokensBeforeSummary;
		private int messagesToKeep = DEFAULT_MESSAGES_TO_KEEP;
		private TokenCounter tokenCounter = TokenCounter.approximateMsgCounter();
		private String summaryPrompt = DEFAULT_SUMMARY_PROMPT;
		private String summaryPrefix = SUMMARY_PREFIX;
		private boolean keepFirstUserMessage = DEFAULT_KEEP_FIRST_USER_MESSAGE;

		/**
		 * 设置大模型实例(必填)
		 * @param model 大模型实例
		 * @return 建造者自身
		 */
		public Builder model(ChatModel model) {
			this.model = model;
			return this;
		}

		/**
		 * 设置触发总结的Token阈值
		 * @param maxTokens Token阈值
		 * @return 建造者自身
		 */
		public Builder maxTokensBeforeSummary(Integer maxTokens) {
			this.maxTokensBeforeSummary = maxTokens;
			return this;
		}

		/**
		 * 设置总结后保留的最新消息数量
		 * @param count 消息数量
		 * @return 建造者自身
		 */
		public Builder messagesToKeep(int count) {
			this.messagesToKeep = count;
			return this;
		}

		/**
		 * 设置自定义总结提示词
		 * @param prompt 总结提示词
		 * @return 建造者自身
		 */
		public Builder summaryPrompt(String prompt) {
			this.summaryPrompt = prompt;
			return this;
		}

		/**
		 * 设置总结消息前缀
		 * @param prefix 总结前缀
		 * @return 建造者自身
		 */
		public Builder summaryPrefix(String prefix) {
			this.summaryPrefix = prefix;
			return this;
		}

		/**
		 * 设置Token计数器
		 * @param counter Token计数器
		 * @return 建造者自身
		 */
		public Builder tokenCounter(TokenCounter counter) {
			this.tokenCounter = counter;
			return this;
		}

		/**
		 * 设置是否保留首条用户消息
		 * @param keep true表示保留,false表示不保留
		 * @return 建造者自身
		 */
		public Builder keepFirstUserMessage(boolean keep) {
			this.keepFirstUserMessage = keep;
			return this;
		}

		/**
		 * 构建SummarizationHook实例
		 * @return 配置完成的SummarizationHook实例
		 * @throws IllegalArgumentException 当model为null时抛出
		 */
		public SummarizationHook build() {
			if (model == null) {
				throw new IllegalArgumentException("model must be specified");
			}
			return new SummarizationHook(this);
		}
	}
}

2.2 ReturnDirectModelHook

工具调用直接返回钩子,模型调用前执行的最高优先级钩子,用于检测工具响应消息中的直接返回标识

  • 检测最后一条消息是否为带结束标识的工具响应消息
  • 识别到 returnDirect 标识后,自动生成助手消息
  • 强制跳转到流程结束节点,终止后续模型调用
  • 最高执行优先级,确保优先拦截直接返回场景

适用场景:工具调用配置 returnDirect=true 时,无需再次调用大模型,直接返回结果

模型调用前核心逻辑,检测直接返回标识并处理:

java 复制代码
	/**
	 * 模型调用前核心逻辑:检测直接返回标识并处理
	 * <p>
	 * 执行流程:
	 * <ol>
	 *   <li>校验消息列表是否为空,为空则直接返回</li>
	 *   <li>判断最后一条消息是否为工具响应消息,非则直接返回</li>
	 *   <li>检查元数据中是否包含结束标识(FINISH_REASON)</li>
	 *   <li>识别到直接返回标识:生成助手消息,跳转到结束节点</li>
	 *   <li>无标识:正常执行后续流程</li>
	 * </ol>
	 * </p>
	 * @param previousMessages 历史消息列表
	 * @param config 运行配置
	 * @return 包含跳转指令/新消息的AgentCommand
	 */
	@Override
	public AgentCommand beforeModel(List<Message> previousMessages, RunnableConfig config) {
		// 消息列表为空,直接返回
		if (previousMessages.isEmpty()) {
			return new AgentCommand(previousMessages);
		}

		Message lastMessage = previousMessages.get(previousMessages.size() - 1);
		// 最后一条消息不是工具响应消息,正常执行
		if (!(lastMessage instanceof ToolResponseMessage toolResponseMessage)) {
			return new AgentCommand(previousMessages);
		}

		// 检查工具响应消息的元数据是否包含直接返回标识
		// 该标识由 AgentToolNode 在 returnDirect=true 时设置
		boolean returnDirect = false;
		Map<String, Object> metadata = toolResponseMessage.getMetadata();
		if (metadata.containsKey(FINISH_REASON_METADATA_KEY)) {
			Object finishReason = metadata.get(FINISH_REASON_METADATA_KEY);
			if (FINISH_REASON.equals(finishReason)) {
				returnDirect = true;
			}
		}

		// 触发直接返回逻辑
		if (returnDirect) {
			// 根据工具响应生成助手消息
			String generatedText = generateAssistantMessageText(toolResponseMessage);
			AssistantMessage newAssistantMessage = AssistantMessage.builder()
					.content(generatedText)
					.build();

			// 构建新的消息列表
			List<Message> newMessages = new ArrayList<>(previousMessages);
			newMessages.add(newAssistantMessage);

			// 跳转到结束节点,终止流程
			return new AgentCommand(JumpTo.end, newMessages);
		}

		// 无直接返回标识,正常执行
		return new AgentCommand(previousMessages);
	}

	/**
	 * 根据工具响应消息生成助手消息内容
	 * <p>
	 * 生成规则:
	 * <ul>
	 *   <li>单条响应:直接返回响应数据</li>
	 *   <li>多条响应:组装为标准JSON数组</li>
	 *   <li>空响应:返回空字符串</li>
	 *   <li>自动处理JSON格式与字符串转义</li>
	 * </ul>
	 * </p>
	 * @param toolResponseMessage 工具响应消息
	 * @return 格式化后的助手消息文本
	 */
	private String generateAssistantMessageText(ToolResponseMessage toolResponseMessage) {
		List<ToolResponseMessage.ToolResponse> responses = toolResponseMessage.getResponses();
		if (responses.isEmpty()) {
			return "";
		} else if (responses.size() == 1) {
			// 单条响应直接返回数据
			return responses.get(0).responseData();
		} else {
			// 多条响应组装为JSON数组
			StringBuilder jsonArray = new StringBuilder("[");
			for (int i = 0; i < responses.size(); i++) {
				if (i > 0) {
					jsonArray.append(",");
				}
				String responseData = responses.get(i).responseData();
				// 处理空值
				if (responseData == null) {
					jsonArray.append("null");
				} else {
					String trimmed = responseData.trim();
					// 已为JSON格式直接拼接,否则转为字符串
					if (trimmed.startsWith("{") || trimmed.startsWith("[")) {
						jsonArray.append(responseData);
					} else {
						jsonArray.append("\"").append(escapeJsonString(responseData)).append("\"");
					}
				}
			}
			jsonArray.append("]");
			return jsonArray.toString();
		}
	}

其他源码内容:

java 复制代码
@HookPositions({HookPosition.BEFORE_MODEL})
public class ReturnDirectModelHook extends MessagesModelHook {

	/**
	 * 获取钩子名称
	 * @return 钩子唯一标识
	 */
	@Override
	public String getName() {
		return "finish_reason_check_messages_model_hook";
	}

	/**
	 * 获取钩子执行优先级
	 * @return 最高优先级,确保最先执行
	 */
	@Override
	public int getOrder() {
		return Prioritized.HIGHEST_PRECEDENCE;
	}

	/**
	 * 支持的流程跳转目标
	 * @return 仅支持跳转到结束节点
	 */
	@Override
	public List<JumpTo> canJumpTo() {
		return List.of(JumpTo.end);
	}

	/**
	 * JSON字符串转义:处理特殊字符,保证JSON格式合法
	 * <p>
	 * 转义字符:双引号、反斜杠、退格、换页、换行、回车、制表符、控制字符
	 * </p>
	 * @param str 原始字符串
	 * @return 转义后的安全字符串
	 */
	private String escapeJsonString(String str) {
		if (str == null) {
			return "";
		}
		StringBuilder sb = new StringBuilder();
		for (char c : str.toCharArray()) {
			switch (c) {
				case '"':
					sb.append("\\\"");
					break;
				case '\\':
					sb.append("\\\\");
					break;
				case '\b':
					sb.append("\\b");
					break;
				case '\f':
					sb.append("\\f");
					break;
				case '\n':
					sb.append("\\n");
					break;
				case '\r':
					sb.append("\\r");
					break;
				case '\t':
					sb.append("\\t");
					break;
				default:
					if (c < 0x20) {
						sb.append(String.format("\\u%04x", (int) c));
					} else {
						sb.append(c);
					}
					break;
			}
		}
		return sb.toString();
	}
}

2.3 PIIDetectionHook

个人身份信息(PII)检测与处理钩子。执行于模型调用前/调用后,用于检测对话中的敏感个人信息(PII),并支持脱敏、掩码、哈希、拦截四种处理策略。

支持范围

  • 检测类型:邮箱、信用卡、IP地址、MAC地址、URL
  • 处理场景:用户输入、助手输出、工具响应结果

使用示例

java 复制代码
PIIDetectionHook pii = PIIDetectionHook.builder()
    .piiType(PIIType.EMAIL)
    .strategy(RedactionStrategy.REDACT)
    .applyToInput(true)
    .build();

属性定义:

java 复制代码
@HookPositions({HookPosition.BEFORE_MODEL, HookPosition.AFTER_MODEL})
public class PIIDetectionHook extends MessagesModelHook {

	/** 待检测的PII类型(邮箱/信用卡/IP等) */
	private final PIIType piiType;
	/** PII处理策略(脱敏/掩码/哈希/拦截) */
	private final RedactionStrategy strategy;
	/** PII检测器实例 */
	private final PIIDetector detector;
	/** 是否应用于用户输入消息 */
	private final boolean applyToInput;
	/** 是否应用于助手输出消息 */
	private final boolean applyToOutput;
	/** 是否应用于工具响应结果 */
	private final boolean applyToToolResults;

模型调用前处理,检测并处理用户输入/工具响应中的 PII

java 复制代码
	/**
	 * 模型调用前处理:检测并处理用户输入/工具响应中的PII
	 * <p>遍历所有消息,根据配置处理用户消息、工具响应消息,替换敏感信息</p>
	 * @param previousMessages 历史消息列表
	 * @param config 运行配置
	 * @return 处理后的消息命令
	 */
	@Override
	public AgentCommand beforeModel(List<Message> previousMessages, RunnableConfig config) {
		List<Message> processedMessages = new ArrayList<>();
		boolean hasChanges = false;

		for (Message message : previousMessages) {
			Message processed = processMessage(message);
			processedMessages.add(processed);
			if (processed != message) {
				hasChanges = true;
			}
		}
		if (hasChanges) {
			return new AgentCommand(processedMessages, UpdatePolicy.REPLACE);
		}

		return new AgentCommand(previousMessages);
	}

模型调用后处理,检测并处理助手输出中的 PII

java 复制代码
	/**
	 * 模型调用后处理:检测并处理助手输出中的PII
	 * <p>仅处理最后一条助手消息,支持拦截/脱敏替换</p>
	 * @param previousMessages 模型输出后的消息列表
	 * @param config 运行配置
	 * @return 处理后的消息命令
	 */
	@Override
	public AgentCommand afterModel(List<Message> previousMessages, RunnableConfig config) {
		// 仅当启用输出处理时执行
		if (!applyToOutput) {
			return new AgentCommand(previousMessages);
		}

		if (previousMessages.isEmpty()) {
			return new AgentCommand(previousMessages);
		}

		// 查找最后一条助手消息
		AssistantMessage aiMessage = null;
		int lastIndex = -1;
		for (int i = previousMessages.size() - 1; i >= 0; i--) {
			if (previousMessages.get(i) instanceof AssistantMessage am) {
				aiMessage = am;
				lastIndex = i;
				break;
			}
		}

		if (aiMessage == null) {
			return new AgentCommand(previousMessages);
		}

		String content = aiMessage.getText();

		if (content == null || content.isEmpty()) {
			return new AgentCommand(previousMessages);
		}

		// 检测PII
		ProcessResult result = processText(content);

		if (!result.hasMatches) {
			return new AgentCommand(previousMessages);
		}

		// 拦截策略:直接抛出异常
		if (result.hasMatches && strategy == RedactionStrategy.BLOCK) {
			throw new PIIDetectionException(piiType.name(), result.matches);
		}

		if (result.redactedText.equals(content)) {
			return new AgentCommand(previousMessages);
		}

		// 构建脱敏后的助手消息
		AssistantMessage updatedMessage = AssistantMessage.builder()
			.content(result.redactedText)
			.properties(aiMessage.getMetadata())
			.toolCalls(aiMessage.getToolCalls())
			.media(aiMessage.getMedia())
			.build();

		List<Message> updatedMessages = new ArrayList<>(previousMessages);
		updatedMessages.set(lastIndex, updatedMessage);

		return new AgentCommand(updatedMessages, UpdatePolicy.REPLACE);
	}

其他源码内容:

java 复制代码
	/**
	 * 私有构造方法,通过建造者模式创建实例
	 * @param builder 建造者配置对象
	 */
	private PIIDetectionHook(Builder builder) {
		this.piiType = builder.piiType;
		this.strategy = builder.strategy;
		this.detector = builder.detector != null ? builder.detector : getDefaultDetector(piiType);
		this.applyToInput = builder.applyToInput;
		this.applyToOutput = builder.applyToOutput;
		this.applyToToolResults = builder.applyToToolResults;
	}

	/**
	 * 获取建造者实例,用于配置PII检测钩子
	 * @return Builder 建造者对象
	 */
	public static Builder builder() {
		return new Builder();
	}

	/**
	 * 消息分发处理:根据消息类型和配置执行PII处理
	 * @param message 原始消息
	 * @return 处理后的消息(无变化则返回原对象)
	 */
	private Message processMessage(Message message) {
		if (applyToInput && message instanceof UserMessage) {
			return processContent((UserMessage) message);
		}
		else if (applyToOutput && message instanceof AssistantMessage) {
			return processContent((AssistantMessage) message);
		}
		else if (applyToToolResults && message instanceof ToolResponseMessage) {
			return processToolResponse((ToolResponseMessage) message);
		}
		return message;
	}

	/**
	 * 处理用户消息:检测并脱敏PII,拦截策略直接抛异常
	 * @param message 用户消息
	 * @return 处理后的用户消息
	 */
	private UserMessage processContent(UserMessage message) {
		String content = message.getText();
		ProcessResult result = processText(content);

		if (result.hasMatches && strategy == RedactionStrategy.BLOCK) {
			throw new PIIDetectionException(piiType.name(), result.matches);
		}

		if (result.redactedText.equals(content)) {
			return message;
		}

		return UserMessage.builder().text(result.redactedText).metadata(message.getMetadata()).build();
	}

	/**
	 * 处理助手消息:检测并脱敏PII,拦截策略直接抛异常
	 * @param message 助手消息
	 * @return 处理后的助手消息
	 */
	private AssistantMessage processContent(AssistantMessage message) {
		String content = message.getText();
		ProcessResult result = processText(content);

		if (result.hasMatches && strategy == RedactionStrategy.BLOCK) {
			throw new PIIDetectionException(piiType.name(), result.matches);
		}

		if (result.redactedText.equals(content)) {
			return message;
		}

		return AssistantMessage.builder()
			.content(result.redactedText)
			.properties(message.getMetadata())
			.toolCalls(message.getToolCalls())
			.media(message.getMedia())
			.build();
	}

	/**
	 * 处理工具响应消息:遍历所有响应结果,检测并脱敏PII
	 * @param message 工具响应消息
	 * @return 处理后的工具响应消息
	 */
	private ToolResponseMessage processToolResponse(ToolResponseMessage message) {
		List<ToolResponseMessage.ToolResponse> responses = new ArrayList<>();
		boolean hasChanges = false;

		for (ToolResponseMessage.ToolResponse response : message.getResponses()) {
			String content = response.responseData();
			ProcessResult result = processText(content);

			if (result.hasMatches && strategy == RedactionStrategy.BLOCK) {
				throw new PIIDetectionException(piiType.name(), result.matches);
			}

			if (!result.redactedText.equals(content)) {
				responses.add(new ToolResponseMessage.ToolResponse(
						response.id(), response.name(), result.redactedText));
				hasChanges = true;
			}
			else {
				responses.add(response);
			}
		}

		return hasChanges
			? ToolResponseMessage.builder()
				.responses(responses)
				.metadata(message.getMetadata())
				.build()
			: message;
	}

	/**
	 * 文本处理核心:调用检测器识别PII,应用处理策略
	 * @param text 待检测文本
	 * @return 处理结果(脱敏文本、匹配标记、匹配列表)
	 */
	private ProcessResult processText(String text) {
		List<PIIMatch> matches = detector.detect(text);

		if (matches.isEmpty()) {
			return new ProcessResult(text, false, matches);
		}

		String redacted = applyStrategy(text, matches);
		return new ProcessResult(redacted, true, matches);
	}

	/**
	 * 应用PII处理策略:脱敏/掩码/哈希
	 * @param text 原始文本
	 * @param matches PII匹配结果
	 * @return 策略处理后的文本
	 */
	private String applyStrategy(String text, List<PIIMatch> matches) {
		if (matches.isEmpty()) {
			return text;
		}

		StringBuilder result = new StringBuilder();
		int lastEnd = 0;

		// 按起始位置排序匹配项
		matches.sort(Comparator.comparingInt(m -> m.start));

		for (PIIMatch match : matches) {
			result.append(text, lastEnd, match.start);

			switch (strategy) {
			case REDACT:
				result.append("[REDACTED_").append(piiType.name()).append("]");
				break;
			case MASK:
				result.append(maskValue(match.value));
				break;
			case HASH:
				result.append(hashValue(match.value));
				break;
			case BLOCK:
				// 拦截逻辑已在上层处理
				break;
			}

			lastEnd = match.end;
		}

		result.append(text.substring(lastEnd));
		return result.toString();
	}

	/**
	 * 掩码处理:保留最后4位,其余替换为*
	 * @param value 原始敏感值
	 * @return 掩码后的值
	 */
	private String maskValue(String value) {
		if (value.length() <= 4) {
			return "****";
		}
		int visibleChars = 4;
		String masked = "*".repeat(value.length() - visibleChars);
		return masked + value.substring(value.length() - visibleChars);
	}

	/**
	 * 哈希处理:生成固定格式的哈希标识
	 * @param value 原始敏感值
	 * @return 哈希后的值
	 */
	private String hashValue(String value) {
		int hash = value.hashCode();
		return String.format("<%s_hash:%08x>", piiType.name().toLowerCase(), hash);
	}

	/**
	 * 根据PII类型获取默认检测器
	 * @param type PII类型
	 * @return 对应的检测器实例
	 */
	private PIIDetector getDefaultDetector(PIIType type) {
		switch (type) {
		case EMAIL:
			return PIIDetectors.emailDetector();
		case CREDIT_CARD:
			return PIIDetectors.creditCardDetector();
		case IP:
			return PIIDetectors.ipDetector();
		case MAC_ADDRESS:
			return PIIDetectors.macAddressDetector();
		case URL:
			return PIIDetectors.urlDetector();
		default:
			throw new IllegalArgumentException("No default detector for PII type: " + type);
		}
	}

	/**
	 * 获取钩子名称
	 * @return 带PII类型的钩子名称
	 */
	@Override
	public String getName() {
		return "PIIDetection[" + piiType.name() + "]";
	}

	/**
	 * 支持的流程跳转:无跳转能力
	 * @return 空列表
	 */
	@Override
	public List<JumpTo> canJumpTo() {
		return List.of();
	}

	/**
	 * PII处理结果内部类:封装脱敏文本、匹配状态、匹配列表
	 */
	private static class ProcessResult {
		/** 脱敏后的文本 */
		final String redactedText;
		/** 是否匹配到PII */
		final boolean hasMatches;
		/** PII匹配详情列表 */
		final List<PIIMatch> matches;

		/**
		 * 构造方法
		 * @param redactedText 脱敏文本
		 * @param hasMatches 是否匹配
		 * @param matches 匹配列表
		 */
		ProcessResult(String redactedText, boolean hasMatches, List<PIIMatch> matches) {
			this.redactedText = redactedText;
			this.hasMatches = hasMatches;
			this.matches = matches;
		}
	}

	/**
	 * 建造者类:配置化创建PIIDetectionHook实例
	 */
	public static class Builder {
		private PIIType piiType;
		private RedactionStrategy strategy = RedactionStrategy.REDACT;
		private PIIDetector detector;
		private boolean applyToInput = true;
		private boolean applyToOutput = false;
		private boolean applyToToolResults = false;

		/**
		 * 设置PII检测类型(必填)
		 * @param piiType 敏感信息类型
		 * @return Builder
		 */
		public Builder piiType(PIIType piiType) {
			this.piiType = piiType;
			return this;
		}

		/**
		 * 设置PII处理策略,默认脱敏
		 * @param strategy 处理策略
		 * @return Builder
		 */
		public Builder strategy(RedactionStrategy strategy) {
			this.strategy = strategy;
			return this;
		}

		/**
		 * 自定义PII检测器
		 * @param detector 检测器实例
		 * @return Builder
		 */
		public Builder detector(PIIDetector detector) {
			this.detector = detector;
			return this;
		}

		/**
		 * 是否处理用户输入,默认开启
		 * @param applyToInput 开关
		 * @return Builder
		 */
		public Builder applyToInput(boolean applyToInput) {
			this.applyToInput = applyToInput;
			return this;
		}

		/**
		 * 是否处理助手输出,默认关闭
		 * @param applyToOutput 开关
		 * @return Builder
		 */
		public Builder applyToOutput(boolean applyToOutput) {
			this.applyToOutput = applyToOutput;
			return this;
		}

		/**
		 * 是否处理工具响应结果,默认关闭
		 * @param applyToToolResults 开关
		 * @return Builder
		 */
		public Builder applyToToolResults(boolean applyToToolResults) {
			this.applyToToolResults = applyToToolResults;
			return this;
		}

		/**
		 * 构建PIIDetectionHook实例
		 * @return 钩子实例
		 * @throws IllegalArgumentException 未指定piiType时抛出
		 */
		public PIIDetectionHook build() {
			if (piiType == null) {
				throw new IllegalArgumentException("piiType must be specified");
			}
			return new PIIDetectionHook(this);
		}
	}
}
相关推荐
程序员小嬛1 小时前
中科院一区TOP:用于求解偏微分方程的物理信息神经网络前沿创新思路
人工智能·深度学习·神经网络·机器学习
霸道流氓气质1 小时前
SpringBoot中集成LangChain4j实现集成阿里百炼平台进行AI对话记忆功能和对话隔离功能
java·人工智能·spring boot·langchain4j
xiaotao1311 小时前
01-编程基础与数学基石:Matplotlib & Seaborn
人工智能·python·matplotlib
用户2018792831671 小时前
解密「并行派发特工」dispatching-parallel-agents:一个让AI工作效率×3的超级技能
人工智能
XS0301061 小时前
Java 基础笔记(二)
java·笔记·python
CHPCWWHSU1 小时前
智慧城市可视化:基于osgPotree 的都柏林大规模点云高程着色实践
人工智能·智能城市
JoyCong19981 小时前
ToDesk企业版助力伯锐锶:远程连接打破时空壁垒,国产高端电镜跑出“加速度”
大数据·人工智能·经验分享·物联网
papaofdoudou1 小时前
AMD-V 嵌套分页白皮书翻译
java·linux·服务器
Zldaisy3d1 小时前
联泰科技全链路鞋业智造解决方案出海印尼
大数据·人工智能