Skip to main content
通过实现运行在代理执行流程特定阶段的钩子来构建自定义中间件。

钩子

中间件提供两种风格的钩子来拦截代理执行:

节点式钩子

在特定的执行点按顺序运行。

包装式钩子

在每个模型或工具调用周围运行。

节点式钩子

在特定的执行点按顺序运行。用于日志记录、验证和状态更新。 选择您的中间件需要的钩子。您可以在节点式钩子和包装式钩子之间进行选择。 节点式钩子在特定的执行点运行:
钩子何时运行
before_agent代理启动前(每次调用一次)
before_model每次模型调用前
after_model每次模型响应后
after_agent代理完成后(每次调用一次)
包装式钩子在每个调用周围运行,让您控制执行:
钩子何时运行
wrap_model_call每个模型调用周围
wrap_tool_call每个工具调用周围
示例:
from langchain.agents.middleware import before_model, after_model, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any


@before_model(can_jump_to=["end"])
def check_message_limit(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    if len(state["messages"]) >= 50:
        return {
            "messages": [AIMessage("Conversation limit reached.")],
            "jump_to": "end"
        }
    return None

@after_model
def log_response(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    print(f"Model returned: {state['messages'][-1].content}")
    return None

包装式钩子

拦截执行并控制处理程序的调用时机。用于重试、缓存和转换。 您可以决定处理程序被调用零次(短路)、一次(正常流程)还是多次(重试逻辑)。 可用钩子:
  • wrap_model_call - 每个模型调用周围
  • wrap_tool_call - 每个工具调用周围
示例:
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from typing import Callable


@wrap_model_call
def retry_model(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    for attempt in range(3):
        try:
            return handler(request)
        except Exception as e:
            if attempt == 2:
                raise
            print(f"Retry {attempt + 1}/3 after error: {e}")

状态更新

节点式钩子和包装式钩子都可以更新代理状态。机制有所不同:
  • 节点式钩子 (before_agent, before_model, after_model, after_agent):直接返回一个字典。该字典使用图的归约器应用于代理状态。
  • 包装式钩子 (wrap_model_call, wrap_tool_call):对于模型调用,返回 ExtendedModelResponse 并附带 Command 以将状态更新与模型响应一起注入。对于工具调用,直接返回 Command。当您需要根据在模型或工具调用期间运行的逻辑来跟踪或更新状态时使用,例如摘要触发点、使用情况元数据,或从请求或响应计算出的自定义字段。

节点式钩子

从节点式钩子返回一个字典以将更新合并到代理状态中。字典键映射到状态字段。
from langchain.agents.middleware import after_model, AgentState
from langgraph.runtime import Runtime
from typing import Any
from typing_extensions import NotRequired


class TrackingState(AgentState):
    model_call_count: NotRequired[int]


@after_model(state_schema=TrackingState)
def increment_after_model(state: TrackingState, runtime: Runtime) -> dict[str, Any] | None:
    return {"model_call_count": state.get("model_call_count", 0) + 1}

包装式钩子

wrap_model_call 返回带有 CommandExtendedModelResponse 以从模型调用层注入状态更新:
from typing import Callable
from langchain.agents.middleware import (
    wrap_model_call,
    ModelRequest,
    ModelResponse,
    AgentState,
    ExtendedModelResponse
)
from langgraph.types import Command
from typing_extensions import NotRequired

class UsageTrackingState(AgentState):
    """Agent state with token usage tracking."""

    last_model_call_tokens: NotRequired[int]


@wrap_model_call(state_schema=UsageTrackingState)
def track_usage(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
    response = handler(request)
    return ExtendedModelResponse(
        model_response=response,
        command=Command(update={"last_model_call_tokens": 150}),
    )
Command 流经图的归约器,因此更新会正确应用,消息是累加的而不是替换现有状态。

与多个中间件的组合

当多个中间件层返回 ExtendedModelResponse 时,它们的命令会组合:
  • 命令通过归约器应用: 每个 Command 成为单独的状态更新。对于消息,这意味着它们是累加的。
  • 冲突时外层获胜: 对于非归约器状态字段,命令先应用内层,然后应用外层。最外层中间件的值在冲突键上具有优先级。
  • 重试安全: 如果外层中间件实现了可能导致再次多次调用 handler() 的逻辑(例如重试逻辑),则早期调用的命令会被丢弃。
from typing import Annotated, Callable

from langchain.agents.middleware import (
    AgentMiddleware,
    AgentState,
    ExtendedModelResponse,
    ModelRequest,
    ModelResponse,
)
from langchain.messages import SystemMessage
from langgraph.types import Command
from typing_extensions import NotRequired


def _last_wins(_a: str, b: str) -> str:
    """Reducer: last writer wins (outer overwrites inner)."""
    return b


class CustomMiddlewareState(AgentState):
    """Agent state: trace_layer uses last-wins (outer wins), messages use additive reducer."""

    # Non-reducer field with last-wins: both middleware write; outermost value wins
    trace_layer: NotRequired[Annotated[str, _last_wins]]


class OuterMiddleware(AgentMiddleware):
    def wrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], ModelResponse],
    ) -> ExtendedModelResponse:
        response = handler(request)
        return ExtendedModelResponse(
            model_response=response,
            command=Command(update={
                "trace_layer": "outer",
                "messages": [SystemMessage(content="[Outer ran]")],
            }),
        )


class InnerMiddleware(AgentMiddleware):
    """Adds trace_layer and message. Outer adds to same keys; trace_layer: outer wins, messages: additive."""

    def wrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], ModelResponse],
    ):
        response = handler(request)
        return ExtendedModelResponse(
            model_response=response,
            command=Command(update={
                "trace_layer": "inner",
                "messages": [SystemMessage(content="[Inner ran]")],
            }),
        )

创建中间件

您可以通过两种方式创建中间件:

基于装饰器的中间件

简单快捷,适用于单钩子中间件。使用装饰器包装单个函数。

基于类的中间件

功能更强大,适用于具有多个钩子或配置的复杂中间件。

基于装饰器的中间件

简单快捷,适用于单钩子中间件。使用装饰器包装单个函数。 可用装饰器: 节点式: 包装式: 便捷功能: 示例:
from langchain.agents.middleware import (
    before_model,
    wrap_model_call,
    AgentState,
    ModelRequest,
    ModelResponse,
)
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from typing import Any, Callable


@before_model
def log_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    print(f"About to call model with {len(state['messages'])} messages")
    return None

@wrap_model_call
def retry_model(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    for attempt in range(3):
        try:
            return handler(request)
        except Exception as e:
            if attempt == 2:
                raise
            print(f"Retry {attempt + 1}/3 after error: {e}")

agent = create_agent(
    model="gpt-4.1",
    middleware=[log_before_model, retry_model],
    tools=[...],
)
何时使用装饰器:
  • 需要单个钩子
  • 无需复杂配置
  • 快速原型设计

基于类的中间件

功能更强大,适用于具有多个钩子或配置的复杂中间件。当您需要为同一个钩子定义同步和异步实现,或者希望在一个中间件中组合多个钩子时使用类。 示例:
from langchain.agents.middleware import (
    AgentMiddleware,
    AgentState,
    ModelRequest,
    ModelResponse,
)
from langgraph.runtime import Runtime
from typing import Any, Callable

class LoggingMiddleware(AgentMiddleware):
    def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        print(f"About to call model with {len(state['messages'])} messages")
        return None

    def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        print(f"Model returned: {state['messages'][-1].content}")
        return None

    async def abefore_model(
        self, state: AgentState, runtime: Runtime
    ) -> dict[str, Any] | None:
        # Async version of before_model
        return None

    async def aafter_model(
        self, state: AgentState, runtime: Runtime
    ) -> dict[str, Any] | None:
        # Async version of after_model
        print(f"Model returned: {state['messages'][-1].content}")
        return None


agent = create_agent(
    model="gpt-4.1",
    middleware=[LoggingMiddleware()],
    tools=[...],
)
何时使用类:
  • 为同一个钩子定义同步和异步实现
  • 需要在单个中间件中定义多个钩子
  • 需要复杂配置(例如可配置阈值、自定义模型)
  • 跨项目重用,带初始化时配置

自定义状态模式

如果您的中间件需要在钩子之间跟踪状态,中间件可以使用自定义属性扩展代理状态。这使得中间件能够:
  • 在执行过程中跟踪状态:维护在整个代理执行生命周期中持续存在的计数器、标志或其他值
  • 在钩子之间共享数据:从 before_modelafter_model 或不同中间件实例之间传递信息
  • 实现横切关注点:添加功能,如速率限制、使用情况跟踪、用户上下文或审计日志,而无需修改核心代理逻辑
  • 进行条件决策:使用累积状态来确定是否继续执行、跳转到不同节点或动态修改行为
from langchain.agents import create_agent
from langchain.messages import HumanMessage
from langchain.agents.middleware import AgentState, before_model, after_model
from typing_extensions import NotRequired
from typing import Any
from langgraph.runtime import Runtime


class CustomState(AgentState):
    model_call_count: NotRequired[int]
    user_id: NotRequired[str]


@before_model(state_schema=CustomState, can_jump_to=["end"])
def check_call_limit(state: CustomState, runtime: Runtime) -> dict[str, Any] | None:
    count = state.get("model_call_count", 0)
    if count > 10:
        return {"jump_to": "end"}
    return None


@after_model(state_schema=CustomState)
def increment_counter(state: CustomState, runtime: Runtime) -> dict[str, Any] | None:
    return {"model_call_count": state.get("model_call_count", 0) + 1}


agent = create_agent(
    model="gpt-4.1",
    middleware=[check_call_limit, increment_counter],
    tools=[],
)

# Invoke with custom state
result = agent.invoke({
    "messages": [HumanMessage("Hello")],
    "model_call_count": 0,
    "user_id": "user-123",
})

执行顺序

使用多个中间件时,了解它们的执行方式:
agent = create_agent(
    model="gpt-4.1",
    middleware=[middleware1, middleware2, middleware3],
    tools=[...],
)
Before 钩子按顺序运行:
  1. middleware1.before_agent()
  2. middleware2.before_agent()
  3. middleware3.before_agent()
代理循环开始
  1. middleware1.before_model()
  2. middleware2.before_model()
  3. middleware3.before_model()
包装钩子像函数调用一样嵌套:
  1. middleware1.wrap_model_call()middleware2.wrap_model_call()middleware3.wrap_model_call() → 模型
After 钩子按相反顺序运行:
  1. middleware3.after_model()
  2. middleware2.after_model()
  3. middleware1.after_model()
代理循环结束
  1. middleware3.after_agent()
  2. middleware2.after_agent()
  3. middleware1.after_agent()
关键规则:
  • before_* 钩子:第一个到最后
  • after_* 钩子:最后到第一个(反向)
  • wrap_* 钩子:嵌套(第一个中间件包装所有其他中间件)

代理跳转

要从中间件提前退出,返回包含 jump_to 的字典: 可用跳转目标:
  • 'end':跳转到代理执行的末尾(或第一个 after_agent 钩子)
  • 'tools':跳转到工具节点
  • 'model':跳转到模型节点(或第一个 before_model 钩子)
from langchain.agents.middleware import after_model, hook_config, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any


@after_model
@hook_config(can_jump_to=["end"])
def check_for_blocked(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    last_message = state["messages"][-1]
    if "BLOCKED" in last_message.content:
        return {
            "messages": [AIMessage("I cannot respond to that request.")],
            "jump_to": "end"
        }
    return None

最佳实践

  1. 保持中间件专注 - 每个应做好一件事
  2. 优雅地处理错误 - 不要让中间件错误导致代理崩溃
  3. 使用适当的钩子类型
    • 节点式用于顺序逻辑(日志记录、验证)
    • 包装式用于控制流(重试、回退、缓存)
  4. 清楚记录任何自定义状态属性
  5. 集成前独立单元测试中间件
  6. 考虑执行顺序 - 将关键中间件放在列表前面
  7. 尽可能使用内置中间件

示例

动态提示词

在运行时动态修改系统提示词,以便在每次模型调用之前注入上下文、用户特定指令或其他信息。这是最常见的中间件用例之一。 使用 ModelRequest 上的 system_message 字段读取和修改系统提示词。它包含一个 SystemMessage 对象(即使代理是使用字符串 system_prompt 创建的)。
from collections.abc import Callable

from langchain.agents.middleware import ModelRequest, ModelResponse, wrap_model_call
from langchain.messages import SystemMessage


@wrap_model_call
def add_context(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    new_content = list(request.system_message.content_blocks) + [
        {"type": "text", "text": "Additional context."}
    ]
    new_system_message = SystemMessage(content=new_content)
    return handler(request.override(system_message=new_system_message))
  • ModelRequest.system_message 始终是一个 SystemMessage 对象,即使代理是使用 system_prompt="string" 创建的
  • 使用 SystemMessage.content_blocks 将内容作为块列表访问,无论原始内容是字符串还是列表
  • 修改系统消息时,使用 content_blocks 并追加新块以保留现有结构
  • 您可以直接将 SystemMessage 对象传递给 create_agentsystem_prompt 参数,用于高级用例,如缓存控制

动态模型选择

from collections.abc import Callable

from langchain.agents.middleware import ModelRequest, ModelResponse, wrap_model_call
from langchain.chat_models import init_chat_model

complex_model = init_chat_model("claude-sonnet-4-6")
simple_model = init_chat_model("claude-haiku-4-5-20251001")


@wrap_model_call
def dynamic_model(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    if len(request.messages) > 10:
        model = complex_model
    else:
        model = simple_model
    return handler(request.override(model=model))

动态选择工具

在运行时选择相关工具以提高性能和准确性。本节介绍过滤预注册的工具。有关注册在运行时发现的工具(例如来自 MCP 服务器),请参阅 Runtime tool registration 好处:
  • 更短的提示词 - 通过仅暴露相关工具来降低复杂性
  • 更好的准确性 - 模型从较少的选项中正确选择
  • 权限控制 - 根据用户访问权限动态过滤工具
from langchain.agents import create_agent
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from typing import Callable


@wrap_model_call
def select_tools(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    """Middleware to select relevant tools based on state/context."""
    # Select a small, relevant subset of tools based on state/context
    relevant_tools = select_relevant_tools(request.state, request.runtime)
    return handler(request.override(tools=relevant_tools))

agent = create_agent(
    model="gpt-4.1",
    tools=all_tools,  # All available tools need to be registered upfront
    middleware=[select_tools],
)

工具调用监控

from collections.abc import Callable

from langchain.agents.middleware import wrap_tool_call
from langchain.messages import ToolMessage
from langchain.tools.tool_node import ToolCallRequest
from langgraph.types import Command


@wrap_tool_call
def monitor_tool(
    request: ToolCallRequest,
    handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
    print(f"Executing tool: {request.tool_call['name']}")
    print(f"Arguments: {request.tool_call['args']}")
    try:
        result = handler(request)
        print("Tool completed successfully")
        return result
    except Exception as e:
        print(f"Tool failed: {e}")
        raise

提示词缓存(Anthropic)

在使用 Anthropic 模型时,使用带有缓存控制指令的结构化内容块来缓存大型系统提示词:
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.messages import SystemMessage
from typing import Callable


@wrap_model_call
def add_cached_context(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    # Always work with content blocks
    new_content = list(request.system_message.content_blocks) + [
        {
            "type": "text",
            "text": "Here is a large document to analyze:\n\n<document>...</document>",
            # content up until this point is cached
            "cache_control": {"type": "ephemeral"}
        }
    ]

    new_system_message = SystemMessage(content=new_content)
    return handler(request.override(system_message=new_system_message))
注意:
  • ModelRequest.system_message 始终是一个 SystemMessage 对象,即使代理是使用 system_prompt="string" 创建的
  • 使用 SystemMessage.content_blocks 将内容作为块列表访问,无论原始内容是字符串还是列表
  • 修改系统消息时,使用 content_blocks 并追加新块以保留现有结构
  • 您可以直接将 SystemMessage 对象传递给 create_agentsystem_prompt 参数,用于高级用例,如缓存控制
:::

其他资源