文章目录
- [1. MessagesAgentHook 实现类](#1. MessagesAgentHook 实现类)
-
- [1.1 InstructionAgentHook](#1.1 InstructionAgentHook)
- [2. MessagesModelHook 实现类](#2. MessagesModelHook 实现类)
-
- [2.1 SummarizationHook](#2.1 SummarizationHook)
- [2.2 ReturnDirectModelHook](#2.2 ReturnDirectModelHook)
- [2.3 PIIDetectionHook](#2.3 PIIDetectionHook)
1. MessagesAgentHook 实现类
1.1 InstructionAgentHook
指令注入 Agent 钩子,在 Agent 每次运行前将 ReactAgent 的系统指令注入消息列表。
- 执行时机:
HookPosition.BEFORE_AGENT(Agent执行前),自动读取Agent指令并添加到消息头部,避免子图场景下重复注入指令。 - 核心逻辑:读取
ReactAgent的指令配置,非空时自动添加AgentInstructionMessage到消息头部 - 策略特性:返回
REPLACE更新策略,避免作为子图节点时重复添加指令 - 默认机制:
ReactAgent无其他指令类钩子时,会自动加载该钩子,且拥有最高执行优先级(最先运行)。
属性定义:
java
@HookPositions(HookPosition.BEFORE_AGENT)
public class InstructionAgentHook extends MessagesAgentHook {
/**
* 关联的 ReactAgent 实例,用于获取系统指令
*/
private ReactAgent reactAgent;
}
重写 Agent 执行前钩子,注入系统指令消息:
java
/**
* 重写 Agent 执行前钩子:注入系统指令消息
* <p>
* 核心逻辑:
* <ol>
* <li>校验 ReactAgent 实例是否为空</li>
* <li>获取 Agent 的系统指令,为空则直接返回原消息</li>
* <li>构建指令消息,追加到消息列表头部</li>
* <li>返回包含新消息的 AgentCommand</li>
* </ol>
* </p>
* @param previousMessages 原始消息列表
* @param config 运行配置
* @return 包含指令消息的 AgentCommand
*/
@Override
public AgentCommand beforeAgent(List<Message> previousMessages, RunnableConfig config) {
if (reactAgent == null) {
return new AgentCommand(previousMessages);
}
String instruction = reactAgent.instruction();
if (!StringUtils.hasLength(instruction)) {
return new AgentCommand(previousMessages);
}
AgentInstructionMessage instructionMessage = AgentInstructionMessage.builder().text(instruction).build();
List<Message> newMessages = new ArrayList<>(previousMessages);
newMessages.add(instructionMessage);
return new AgentCommand(newMessages);
}
其他源码内容:
java
/**
* 获取钩子名称
* @return 固定返回 InstructionAgentHook
*/
@Override
public String getName() {
return "InstructionAgentHook";
}
/**
* 获取钩子执行优先级
* @return -100,最低优先级值,代表最先执行
*/
@Override
public int getOrder() {
return -100;
}
/**
* 获取关联的 ReactAgent 实例
* @return ReactAgent 实例
*/
@Override
public ReactAgent getAgent() {
return reactAgent;
}
/**
* 设置关联的 ReactAgent 实例
* @param agent ReactAgent 实例
*/
@Override
public void setAgent(ReactAgent agent) {
this.reactAgent = agent;
}
/**
* 创建默认的 InstructionAgentHook 实例
* <p>无其他指令类钩子处理指令时使用</p>
* @return 新的 InstructionAgentHook 实例
*/
public static InstructionAgentHook create() {
return new InstructionAgentHook();
}
}
2. MessagesModelHook 实现类
2.1 SummarizationHook
对话历史总结钩子。模型调用前执行,监控对话消息的 Token 数量,当达到阈值时自动总结历史消息,防止 Token 超限,同时保留关键上下文(首条用户消息、最新消息),保证对话连贯性。
核心特性:
- 自动检测
Token用量,触发智能总结 - 安全切割消息,不拆分
AI工具调用配对 - 可配置保留消息数、总结提示词
- 支持保留首条用户消息,锁定核心对话意图
使用示例:
java
SummarizationHook summarizer = SummarizationHook.builder()
.model(chatModel)
.maxTokensBeforeSummary(4000)
.messagesToKeep(20)
.keepFirstUserMessage(true)
.build();
常量定义:
java
@HookPositions({HookPosition.BEFORE_MODEL})
public class SummarizationHook extends MessagesModelHook {
private static final Logger log = LoggerFactory.getLogger(SummarizationHook.class);
/**
* 默认总结提示词:提取对话历史中的核心上下文
*/
private static final String DEFAULT_SUMMARY_PROMPT =
"<role>\nContext Extraction Assistant\n</role>\n\n" +
"<primary_objective>\n" +
"Your sole objective in this task is to extract the highest quality/most relevant context " +
"from the conversation history below.\n</primary_objective>\n\n" +
"<instructions>\n" +
"The conversation history below will be replaced with the context you extract in this step. " +
"Extract and record all of the most important context from the conversation history.\n" +
"Respond ONLY with the extracted context. Do not include any additional information.\n" +
"</instructions>\n\n" +
"<messages>\nMessages to summarize:\n%s\n</messages>";
/**
* 总结消息的前缀标识,用于区分普通消息
*/
private static final String SUMMARY_PREFIX = "## Previous conversation summary:";
/**
* 默认保留的最新消息数量
*/
private static final int DEFAULT_MESSAGES_TO_KEEP = 20;
/**
* 工具消息配对检索范围,避免拆分相关消息
*/
private static final int SEARCH_RANGE_FOR_TOOL_PAIRS = 5;
/**
* 默认保留首条用户消息
*/
private static final boolean DEFAULT_KEEP_FIRST_USER_MESSAGE = true;
属性定义:
java
/**
* 用于生成总结的大模型实例
*/
private final ChatModel model;
/**
* 触发总结的Token阈值,null表示不触发自动总结
*/
private final Integer maxTokensBeforeSummary;
/**
* 总结后保留的最新消息数量
*/
private final int messagesToKeep;
/**
* Token计数器,用于估算消息Token用量
*/
private final TokenCounter tokenCounter;
/**
* 自定义总结提示词,用于引导模型生成高质量总结
*/
private final String summaryPrompt;
/**
* 总结消息的前缀,用于标识总结内容
*/
private final String summaryPrefix;
/**
* 是否保留首条用户消息,默认true
*/
private final boolean keepFirstUserMessage;
模型调用前执行的核心方法,检查 Token 用量并生成总结:
java
/**
* 模型调用前执行的核心方法:检查Token用量并生成总结
* <p>
* 执行流程:
* <ol>
* <li>校验配置,无模型或无阈值则返回原消息</li>
* <li>计算当前消息Token总量,未达阈值则返回原消息</li>
* <li>查找安全切割点,保留最新消息</li>
* <li>生成历史消息总结,保留首条用户消息</li>
* <li>构建新消息列表,替换原消息</li>
* </ol>
*
* @param previousMessages 原始消息列表
* @param config 运行配置
* @return 包含总结后的新消息列表的AgentCommand
*/
@Override
public AgentCommand beforeModel(List<Message> previousMessages, RunnableConfig config) {
if (maxTokensBeforeSummary == null) {
return new AgentCommand(previousMessages);
}
int totalTokens = tokenCounter.countTokens(previousMessages);
if (totalTokens < maxTokensBeforeSummary) {
return new AgentCommand(previousMessages);
}
log.info("Token count {} exceeds threshold {}, triggering summarization",
totalTokens, maxTokensBeforeSummary);
int cutoffIndex = findSafeCutoff(previousMessages);
if (cutoffIndex <= 0) {
log.warn("Cannot find safe cutoff point for summarization");
return new AgentCommand(previousMessages);
}
UserMessage firstUserMessage = null;
if (keepFirstUserMessage) {
for (Message msg : previousMessages) {
if (msg instanceof UserMessage) {
firstUserMessage = (UserMessage) msg;
break;
}
}
}
List<Message> toSummarize = new ArrayList<>();
for (int i = 0; i < cutoffIndex; i++) {
Message msg = previousMessages.get(i);
if (msg != firstUserMessage) {
toSummarize.add(msg);
}
}
String summary = createSummary(toSummarize);
SystemMessage summaryMessage = new SystemMessage(summaryPrefix + "\n" + summary);
List<Message> recentMessages = new ArrayList<>();
for (int i = cutoffIndex; i < previousMessages.size(); i++) {
recentMessages.add(previousMessages.get(i));
}
List<Message> newMessages = new ArrayList<>();
if (firstUserMessage != null) {
newMessages.add(firstUserMessage);
}
newMessages.add(summaryMessage);
newMessages.addAll(recentMessages);
if (firstUserMessage != null) {
log.info("Summarized {} messages, keeping {} recent messages (First UserMessage preserved)",
toSummarize.size(), recentMessages.size());
} else {
log.info("Summarized {} messages, keeping {} recent messages",
toSummarize.size(), recentMessages.size());
}
return new AgentCommand(newMessages, UpdatePolicy.REPLACE);
}
/**
* 查找安全的消息切割点,确保不拆分AI/工具消息配对
*
* @param messages 消息列表
* @return 安全切割点索引,返回0表示无法安全切割
*/
private int findSafeCutoff(List<Message> messages) {
if (messages.size() <= messagesToKeep) {
return 0;
}
int targetCutoff = messages.size() - messagesToKeep;
// 从目标切割点向后搜索,找到安全的切割位置
for (int i = targetCutoff; i >= 0; i--) {
if (isSafeCutoffPoint(messages, i)) {
return i;
}
}
return 0;
}
/**
* 检查切割点是否安全,不会拆分相关AI和工具消息
*
* @param messages 消息列表
* @param cutoffIndex 切割点索引
* @return true表示安全,false表示不安全
*/
private boolean isSafeCutoffPoint(List<Message> messages, int cutoffIndex) {
if (cutoffIndex >= messages.size()) {
return true;
}
int searchStart = Math.max(0, cutoffIndex - SEARCH_RANGE_FOR_TOOL_PAIRS);
int searchEnd = Math.min(messages.size(), cutoffIndex + SEARCH_RANGE_FOR_TOOL_PAIRS);
for (int i = searchStart; i < searchEnd; i++) {
if (!hasToolCalls(messages.get(i))) {
continue;
}
AssistantMessage aiMessage = (AssistantMessage) messages.get(i);
Set<String> toolCallIds = extractToolCallIds(aiMessage);
if (cutoffSeparatesToolPair(messages, i, cutoffIndex, toolCallIds)) {
return false;
}
}
return true;
}
/**
* 检查消息是否为包含工具调用的AI消息
*
* @param message 消息对象
* @return true表示是包含工具调用的AI消息
*/
private boolean hasToolCalls(Message message) {
return message instanceof AssistantMessage assistantMessage && !assistantMessage.getToolCalls().isEmpty();
}
/**
* 从AI消息中提取工具调用ID集合
*
* @param aiMessage AI消息对象
* @return 工具调用ID集合
*/
private Set<String> extractToolCallIds(AssistantMessage aiMessage) {
Set<String> toolCallIds = new HashSet<>();
for (AssistantMessage.ToolCall toolCall : aiMessage.getToolCalls()) {
String callId = toolCall.id();
toolCallIds.add(callId);
}
return toolCallIds;
}
/**
* 检查切割点是否会拆分AI消息和对应的工具响应消息
*
* @param messages 消息列表
* @param aiMessageIndex AI消息索引
* @param cutoffIndex 切割点索引
* @param toolCallIds 工具调用ID集合
* @return true表示会拆分,false表示不会拆分
*/
private boolean cutoffSeparatesToolPair(
List<Message> messages,
int aiMessageIndex,
int cutoffIndex,
Set<String> toolCallIds) {
for (int j = aiMessageIndex + 1; j < messages.size(); j++) {
Message message = messages.get(j);
if (message instanceof ToolResponseMessage toolResponseMessage) {
// 检查工具响应消息是否包含当前AI消息的工具调用ID
for (ToolResponseMessage.ToolResponse response : toolResponseMessage.getResponses()) {
if (toolCallIds.contains(response.id())) {
boolean aiBeforeCutoff = aiMessageIndex < cutoffIndex;
boolean toolBeforeCutoff = j < cutoffIndex;
// 如果AI消息和工具响应消息被切割点分开,则返回true
if (aiBeforeCutoff != toolBeforeCutoff) {
return true;
}
}
}
}
}
return false;
}
/**
* 使用大模型生成消息总结
*
* @param messages 待总结的消息列表
* @return 生成的总结文本
*/
private String createSummary(List<Message> messages) {
if (messages.isEmpty()) {
return "No previous conversation.";
}
StringBuilder messageText = new StringBuilder();
for (Message msg : messages) {
String role = getRoleName(msg);
messageText.append(role).append(": ").append(msg.getText()).append("\n");
}
String prompt = String.format(summaryPrompt, messageText.toString());
try {
Prompt summaryPromptObj = new Prompt(List.of(new UserMessage(prompt)));
var response = model.call(summaryPromptObj);
return response.getResult().getOutput().getText();
}
catch (Exception e) {
log.error("Failed to create summary: {}", e.getMessage());
return "Summary generation failed: " + e.getMessage();
}
}
/**
* 获取消息角色名称,用于总结文本格式化
*
* @param message 消息对象
* @return 角色名称(Human/Assistant/System/Tool/Unknown)
*/
private String getRoleName(Message message) {
if (message instanceof UserMessage) {
return "Human";
}
else if (message instanceof AssistantMessage) {
return "Assistant";
}
else if (message instanceof SystemMessage) {
return "System";
}
else if (message instanceof ToolResponseMessage) {
return "Tool";
}
else {
return "Unknown";
}
}
其他源码内容:
java
/**
* 私有构造方法,通过建造者模式创建实例
* @param builder 建造者对象,包含所有配置参数
*/
private SummarizationHook(Builder builder) {
this.model = builder.model;
this.maxTokensBeforeSummary = builder.maxTokensBeforeSummary;
this.messagesToKeep = builder.messagesToKeep;
this.tokenCounter = builder.tokenCounter;
this.summaryPrompt = builder.summaryPrompt;
this.summaryPrefix = builder.summaryPrefix;
this.keepFirstUserMessage = builder.keepFirstUserMessage;
}
/**
* 获取建造者实例,用于配置创建钩子
* @return 建造者对象
*/
public static Builder builder() {
return new Builder();
}
/**
* 获取钩子名称,用于日志和调试
* @return 钩子名称
*/
@Override
public String getName() {
return "Summarization";
}
/**
* 获取钩子支持的跳转能力,当前不支持任何跳转
* @return 空列表
*/
@Override
public List<JumpTo> canJumpTo() {
return List.of();
}
/**
* 建造者类:用于配置和创建SummarizationHook实例
* <p>支持链式调用配置所有参数,确保实例创建的安全性和灵活性</p>
*/
public static class Builder {
private ChatModel model;
private Integer maxTokensBeforeSummary;
private int messagesToKeep = DEFAULT_MESSAGES_TO_KEEP;
private TokenCounter tokenCounter = TokenCounter.approximateMsgCounter();
private String summaryPrompt = DEFAULT_SUMMARY_PROMPT;
private String summaryPrefix = SUMMARY_PREFIX;
private boolean keepFirstUserMessage = DEFAULT_KEEP_FIRST_USER_MESSAGE;
/**
* 设置大模型实例(必填)
* @param model 大模型实例
* @return 建造者自身
*/
public Builder model(ChatModel model) {
this.model = model;
return this;
}
/**
* 设置触发总结的Token阈值
* @param maxTokens Token阈值
* @return 建造者自身
*/
public Builder maxTokensBeforeSummary(Integer maxTokens) {
this.maxTokensBeforeSummary = maxTokens;
return this;
}
/**
* 设置总结后保留的最新消息数量
* @param count 消息数量
* @return 建造者自身
*/
public Builder messagesToKeep(int count) {
this.messagesToKeep = count;
return this;
}
/**
* 设置自定义总结提示词
* @param prompt 总结提示词
* @return 建造者自身
*/
public Builder summaryPrompt(String prompt) {
this.summaryPrompt = prompt;
return this;
}
/**
* 设置总结消息前缀
* @param prefix 总结前缀
* @return 建造者自身
*/
public Builder summaryPrefix(String prefix) {
this.summaryPrefix = prefix;
return this;
}
/**
* 设置Token计数器
* @param counter Token计数器
* @return 建造者自身
*/
public Builder tokenCounter(TokenCounter counter) {
this.tokenCounter = counter;
return this;
}
/**
* 设置是否保留首条用户消息
* @param keep true表示保留,false表示不保留
* @return 建造者自身
*/
public Builder keepFirstUserMessage(boolean keep) {
this.keepFirstUserMessage = keep;
return this;
}
/**
* 构建SummarizationHook实例
* @return 配置完成的SummarizationHook实例
* @throws IllegalArgumentException 当model为null时抛出
*/
public SummarizationHook build() {
if (model == null) {
throw new IllegalArgumentException("model must be specified");
}
return new SummarizationHook(this);
}
}
}
2.2 ReturnDirectModelHook
工具调用直接返回钩子,模型调用前执行的最高优先级钩子,用于检测工具响应消息中的直接返回标识:
- 检测最后一条消息是否为带结束标识的工具响应消息
- 识别到
returnDirect标识后,自动生成助手消息 - 强制跳转到流程结束节点,终止后续模型调用
- 最高执行优先级,确保优先拦截直接返回场景
适用场景:工具调用配置 returnDirect=true 时,无需再次调用大模型,直接返回结果
模型调用前核心逻辑,检测直接返回标识并处理:
java
/**
* 模型调用前核心逻辑:检测直接返回标识并处理
* <p>
* 执行流程:
* <ol>
* <li>校验消息列表是否为空,为空则直接返回</li>
* <li>判断最后一条消息是否为工具响应消息,非则直接返回</li>
* <li>检查元数据中是否包含结束标识(FINISH_REASON)</li>
* <li>识别到直接返回标识:生成助手消息,跳转到结束节点</li>
* <li>无标识:正常执行后续流程</li>
* </ol>
* </p>
* @param previousMessages 历史消息列表
* @param config 运行配置
* @return 包含跳转指令/新消息的AgentCommand
*/
@Override
public AgentCommand beforeModel(List<Message> previousMessages, RunnableConfig config) {
// 消息列表为空,直接返回
if (previousMessages.isEmpty()) {
return new AgentCommand(previousMessages);
}
Message lastMessage = previousMessages.get(previousMessages.size() - 1);
// 最后一条消息不是工具响应消息,正常执行
if (!(lastMessage instanceof ToolResponseMessage toolResponseMessage)) {
return new AgentCommand(previousMessages);
}
// 检查工具响应消息的元数据是否包含直接返回标识
// 该标识由 AgentToolNode 在 returnDirect=true 时设置
boolean returnDirect = false;
Map<String, Object> metadata = toolResponseMessage.getMetadata();
if (metadata.containsKey(FINISH_REASON_METADATA_KEY)) {
Object finishReason = metadata.get(FINISH_REASON_METADATA_KEY);
if (FINISH_REASON.equals(finishReason)) {
returnDirect = true;
}
}
// 触发直接返回逻辑
if (returnDirect) {
// 根据工具响应生成助手消息
String generatedText = generateAssistantMessageText(toolResponseMessage);
AssistantMessage newAssistantMessage = AssistantMessage.builder()
.content(generatedText)
.build();
// 构建新的消息列表
List<Message> newMessages = new ArrayList<>(previousMessages);
newMessages.add(newAssistantMessage);
// 跳转到结束节点,终止流程
return new AgentCommand(JumpTo.end, newMessages);
}
// 无直接返回标识,正常执行
return new AgentCommand(previousMessages);
}
/**
* 根据工具响应消息生成助手消息内容
* <p>
* 生成规则:
* <ul>
* <li>单条响应:直接返回响应数据</li>
* <li>多条响应:组装为标准JSON数组</li>
* <li>空响应:返回空字符串</li>
* <li>自动处理JSON格式与字符串转义</li>
* </ul>
* </p>
* @param toolResponseMessage 工具响应消息
* @return 格式化后的助手消息文本
*/
private String generateAssistantMessageText(ToolResponseMessage toolResponseMessage) {
List<ToolResponseMessage.ToolResponse> responses = toolResponseMessage.getResponses();
if (responses.isEmpty()) {
return "";
} else if (responses.size() == 1) {
// 单条响应直接返回数据
return responses.get(0).responseData();
} else {
// 多条响应组装为JSON数组
StringBuilder jsonArray = new StringBuilder("[");
for (int i = 0; i < responses.size(); i++) {
if (i > 0) {
jsonArray.append(",");
}
String responseData = responses.get(i).responseData();
// 处理空值
if (responseData == null) {
jsonArray.append("null");
} else {
String trimmed = responseData.trim();
// 已为JSON格式直接拼接,否则转为字符串
if (trimmed.startsWith("{") || trimmed.startsWith("[")) {
jsonArray.append(responseData);
} else {
jsonArray.append("\"").append(escapeJsonString(responseData)).append("\"");
}
}
}
jsonArray.append("]");
return jsonArray.toString();
}
}
其他源码内容:
java
@HookPositions({HookPosition.BEFORE_MODEL})
public class ReturnDirectModelHook extends MessagesModelHook {
/**
* 获取钩子名称
* @return 钩子唯一标识
*/
@Override
public String getName() {
return "finish_reason_check_messages_model_hook";
}
/**
* 获取钩子执行优先级
* @return 最高优先级,确保最先执行
*/
@Override
public int getOrder() {
return Prioritized.HIGHEST_PRECEDENCE;
}
/**
* 支持的流程跳转目标
* @return 仅支持跳转到结束节点
*/
@Override
public List<JumpTo> canJumpTo() {
return List.of(JumpTo.end);
}
/**
* JSON字符串转义:处理特殊字符,保证JSON格式合法
* <p>
* 转义字符:双引号、反斜杠、退格、换页、换行、回车、制表符、控制字符
* </p>
* @param str 原始字符串
* @return 转义后的安全字符串
*/
private String escapeJsonString(String str) {
if (str == null) {
return "";
}
StringBuilder sb = new StringBuilder();
for (char c : str.toCharArray()) {
switch (c) {
case '"':
sb.append("\\\"");
break;
case '\\':
sb.append("\\\\");
break;
case '\b':
sb.append("\\b");
break;
case '\f':
sb.append("\\f");
break;
case '\n':
sb.append("\\n");
break;
case '\r':
sb.append("\\r");
break;
case '\t':
sb.append("\\t");
break;
default:
if (c < 0x20) {
sb.append(String.format("\\u%04x", (int) c));
} else {
sb.append(c);
}
break;
}
}
return sb.toString();
}
}
2.3 PIIDetectionHook
个人身份信息(PII)检测与处理钩子。执行于模型调用前/调用后,用于检测对话中的敏感个人信息(PII),并支持脱敏、掩码、哈希、拦截四种处理策略。
支持范围:
- 检测类型:邮箱、信用卡、IP地址、MAC地址、URL
- 处理场景:用户输入、助手输出、工具响应结果
使用示例:
java
PIIDetectionHook pii = PIIDetectionHook.builder()
.piiType(PIIType.EMAIL)
.strategy(RedactionStrategy.REDACT)
.applyToInput(true)
.build();
属性定义:
java
@HookPositions({HookPosition.BEFORE_MODEL, HookPosition.AFTER_MODEL})
public class PIIDetectionHook extends MessagesModelHook {
/** 待检测的PII类型(邮箱/信用卡/IP等) */
private final PIIType piiType;
/** PII处理策略(脱敏/掩码/哈希/拦截) */
private final RedactionStrategy strategy;
/** PII检测器实例 */
private final PIIDetector detector;
/** 是否应用于用户输入消息 */
private final boolean applyToInput;
/** 是否应用于助手输出消息 */
private final boolean applyToOutput;
/** 是否应用于工具响应结果 */
private final boolean applyToToolResults;
模型调用前处理,检测并处理用户输入/工具响应中的 PII :
java
/**
* 模型调用前处理:检测并处理用户输入/工具响应中的PII
* <p>遍历所有消息,根据配置处理用户消息、工具响应消息,替换敏感信息</p>
* @param previousMessages 历史消息列表
* @param config 运行配置
* @return 处理后的消息命令
*/
@Override
public AgentCommand beforeModel(List<Message> previousMessages, RunnableConfig config) {
List<Message> processedMessages = new ArrayList<>();
boolean hasChanges = false;
for (Message message : previousMessages) {
Message processed = processMessage(message);
processedMessages.add(processed);
if (processed != message) {
hasChanges = true;
}
}
if (hasChanges) {
return new AgentCommand(processedMessages, UpdatePolicy.REPLACE);
}
return new AgentCommand(previousMessages);
}
模型调用后处理,检测并处理助手输出中的 PII:
java
/**
* 模型调用后处理:检测并处理助手输出中的PII
* <p>仅处理最后一条助手消息,支持拦截/脱敏替换</p>
* @param previousMessages 模型输出后的消息列表
* @param config 运行配置
* @return 处理后的消息命令
*/
@Override
public AgentCommand afterModel(List<Message> previousMessages, RunnableConfig config) {
// 仅当启用输出处理时执行
if (!applyToOutput) {
return new AgentCommand(previousMessages);
}
if (previousMessages.isEmpty()) {
return new AgentCommand(previousMessages);
}
// 查找最后一条助手消息
AssistantMessage aiMessage = null;
int lastIndex = -1;
for (int i = previousMessages.size() - 1; i >= 0; i--) {
if (previousMessages.get(i) instanceof AssistantMessage am) {
aiMessage = am;
lastIndex = i;
break;
}
}
if (aiMessage == null) {
return new AgentCommand(previousMessages);
}
String content = aiMessage.getText();
if (content == null || content.isEmpty()) {
return new AgentCommand(previousMessages);
}
// 检测PII
ProcessResult result = processText(content);
if (!result.hasMatches) {
return new AgentCommand(previousMessages);
}
// 拦截策略:直接抛出异常
if (result.hasMatches && strategy == RedactionStrategy.BLOCK) {
throw new PIIDetectionException(piiType.name(), result.matches);
}
if (result.redactedText.equals(content)) {
return new AgentCommand(previousMessages);
}
// 构建脱敏后的助手消息
AssistantMessage updatedMessage = AssistantMessage.builder()
.content(result.redactedText)
.properties(aiMessage.getMetadata())
.toolCalls(aiMessage.getToolCalls())
.media(aiMessage.getMedia())
.build();
List<Message> updatedMessages = new ArrayList<>(previousMessages);
updatedMessages.set(lastIndex, updatedMessage);
return new AgentCommand(updatedMessages, UpdatePolicy.REPLACE);
}
其他源码内容:
java
/**
* 私有构造方法,通过建造者模式创建实例
* @param builder 建造者配置对象
*/
private PIIDetectionHook(Builder builder) {
this.piiType = builder.piiType;
this.strategy = builder.strategy;
this.detector = builder.detector != null ? builder.detector : getDefaultDetector(piiType);
this.applyToInput = builder.applyToInput;
this.applyToOutput = builder.applyToOutput;
this.applyToToolResults = builder.applyToToolResults;
}
/**
* 获取建造者实例,用于配置PII检测钩子
* @return Builder 建造者对象
*/
public static Builder builder() {
return new Builder();
}
/**
* 消息分发处理:根据消息类型和配置执行PII处理
* @param message 原始消息
* @return 处理后的消息(无变化则返回原对象)
*/
private Message processMessage(Message message) {
if (applyToInput && message instanceof UserMessage) {
return processContent((UserMessage) message);
}
else if (applyToOutput && message instanceof AssistantMessage) {
return processContent((AssistantMessage) message);
}
else if (applyToToolResults && message instanceof ToolResponseMessage) {
return processToolResponse((ToolResponseMessage) message);
}
return message;
}
/**
* 处理用户消息:检测并脱敏PII,拦截策略直接抛异常
* @param message 用户消息
* @return 处理后的用户消息
*/
private UserMessage processContent(UserMessage message) {
String content = message.getText();
ProcessResult result = processText(content);
if (result.hasMatches && strategy == RedactionStrategy.BLOCK) {
throw new PIIDetectionException(piiType.name(), result.matches);
}
if (result.redactedText.equals(content)) {
return message;
}
return UserMessage.builder().text(result.redactedText).metadata(message.getMetadata()).build();
}
/**
* 处理助手消息:检测并脱敏PII,拦截策略直接抛异常
* @param message 助手消息
* @return 处理后的助手消息
*/
private AssistantMessage processContent(AssistantMessage message) {
String content = message.getText();
ProcessResult result = processText(content);
if (result.hasMatches && strategy == RedactionStrategy.BLOCK) {
throw new PIIDetectionException(piiType.name(), result.matches);
}
if (result.redactedText.equals(content)) {
return message;
}
return AssistantMessage.builder()
.content(result.redactedText)
.properties(message.getMetadata())
.toolCalls(message.getToolCalls())
.media(message.getMedia())
.build();
}
/**
* 处理工具响应消息:遍历所有响应结果,检测并脱敏PII
* @param message 工具响应消息
* @return 处理后的工具响应消息
*/
private ToolResponseMessage processToolResponse(ToolResponseMessage message) {
List<ToolResponseMessage.ToolResponse> responses = new ArrayList<>();
boolean hasChanges = false;
for (ToolResponseMessage.ToolResponse response : message.getResponses()) {
String content = response.responseData();
ProcessResult result = processText(content);
if (result.hasMatches && strategy == RedactionStrategy.BLOCK) {
throw new PIIDetectionException(piiType.name(), result.matches);
}
if (!result.redactedText.equals(content)) {
responses.add(new ToolResponseMessage.ToolResponse(
response.id(), response.name(), result.redactedText));
hasChanges = true;
}
else {
responses.add(response);
}
}
return hasChanges
? ToolResponseMessage.builder()
.responses(responses)
.metadata(message.getMetadata())
.build()
: message;
}
/**
* 文本处理核心:调用检测器识别PII,应用处理策略
* @param text 待检测文本
* @return 处理结果(脱敏文本、匹配标记、匹配列表)
*/
private ProcessResult processText(String text) {
List<PIIMatch> matches = detector.detect(text);
if (matches.isEmpty()) {
return new ProcessResult(text, false, matches);
}
String redacted = applyStrategy(text, matches);
return new ProcessResult(redacted, true, matches);
}
/**
* 应用PII处理策略:脱敏/掩码/哈希
* @param text 原始文本
* @param matches PII匹配结果
* @return 策略处理后的文本
*/
private String applyStrategy(String text, List<PIIMatch> matches) {
if (matches.isEmpty()) {
return text;
}
StringBuilder result = new StringBuilder();
int lastEnd = 0;
// 按起始位置排序匹配项
matches.sort(Comparator.comparingInt(m -> m.start));
for (PIIMatch match : matches) {
result.append(text, lastEnd, match.start);
switch (strategy) {
case REDACT:
result.append("[REDACTED_").append(piiType.name()).append("]");
break;
case MASK:
result.append(maskValue(match.value));
break;
case HASH:
result.append(hashValue(match.value));
break;
case BLOCK:
// 拦截逻辑已在上层处理
break;
}
lastEnd = match.end;
}
result.append(text.substring(lastEnd));
return result.toString();
}
/**
* 掩码处理:保留最后4位,其余替换为*
* @param value 原始敏感值
* @return 掩码后的值
*/
private String maskValue(String value) {
if (value.length() <= 4) {
return "****";
}
int visibleChars = 4;
String masked = "*".repeat(value.length() - visibleChars);
return masked + value.substring(value.length() - visibleChars);
}
/**
* 哈希处理:生成固定格式的哈希标识
* @param value 原始敏感值
* @return 哈希后的值
*/
private String hashValue(String value) {
int hash = value.hashCode();
return String.format("<%s_hash:%08x>", piiType.name().toLowerCase(), hash);
}
/**
* 根据PII类型获取默认检测器
* @param type PII类型
* @return 对应的检测器实例
*/
private PIIDetector getDefaultDetector(PIIType type) {
switch (type) {
case EMAIL:
return PIIDetectors.emailDetector();
case CREDIT_CARD:
return PIIDetectors.creditCardDetector();
case IP:
return PIIDetectors.ipDetector();
case MAC_ADDRESS:
return PIIDetectors.macAddressDetector();
case URL:
return PIIDetectors.urlDetector();
default:
throw new IllegalArgumentException("No default detector for PII type: " + type);
}
}
/**
* 获取钩子名称
* @return 带PII类型的钩子名称
*/
@Override
public String getName() {
return "PIIDetection[" + piiType.name() + "]";
}
/**
* 支持的流程跳转:无跳转能力
* @return 空列表
*/
@Override
public List<JumpTo> canJumpTo() {
return List.of();
}
/**
* PII处理结果内部类:封装脱敏文本、匹配状态、匹配列表
*/
private static class ProcessResult {
/** 脱敏后的文本 */
final String redactedText;
/** 是否匹配到PII */
final boolean hasMatches;
/** PII匹配详情列表 */
final List<PIIMatch> matches;
/**
* 构造方法
* @param redactedText 脱敏文本
* @param hasMatches 是否匹配
* @param matches 匹配列表
*/
ProcessResult(String redactedText, boolean hasMatches, List<PIIMatch> matches) {
this.redactedText = redactedText;
this.hasMatches = hasMatches;
this.matches = matches;
}
}
/**
* 建造者类:配置化创建PIIDetectionHook实例
*/
public static class Builder {
private PIIType piiType;
private RedactionStrategy strategy = RedactionStrategy.REDACT;
private PIIDetector detector;
private boolean applyToInput = true;
private boolean applyToOutput = false;
private boolean applyToToolResults = false;
/**
* 设置PII检测类型(必填)
* @param piiType 敏感信息类型
* @return Builder
*/
public Builder piiType(PIIType piiType) {
this.piiType = piiType;
return this;
}
/**
* 设置PII处理策略,默认脱敏
* @param strategy 处理策略
* @return Builder
*/
public Builder strategy(RedactionStrategy strategy) {
this.strategy = strategy;
return this;
}
/**
* 自定义PII检测器
* @param detector 检测器实例
* @return Builder
*/
public Builder detector(PIIDetector detector) {
this.detector = detector;
return this;
}
/**
* 是否处理用户输入,默认开启
* @param applyToInput 开关
* @return Builder
*/
public Builder applyToInput(boolean applyToInput) {
this.applyToInput = applyToInput;
return this;
}
/**
* 是否处理助手输出,默认关闭
* @param applyToOutput 开关
* @return Builder
*/
public Builder applyToOutput(boolean applyToOutput) {
this.applyToOutput = applyToOutput;
return this;
}
/**
* 是否处理工具响应结果,默认关闭
* @param applyToToolResults 开关
* @return Builder
*/
public Builder applyToToolResults(boolean applyToToolResults) {
this.applyToToolResults = applyToToolResults;
return this;
}
/**
* 构建PIIDetectionHook实例
* @return 钩子实例
* @throws IllegalArgumentException 未指定piiType时抛出
*/
public PIIDetectionHook build() {
if (piiType == null) {
throw new IllegalArgumentException("piiType must be specified");
}
return new PIIDetectionHook(this);
}
}
}