diff --git a/agentrun/integration/langgraph/agent_converter.py b/agentrun/integration/langgraph/agent_converter.py index f00a46b..8f1d069 100644 --- a/agentrun/integration/langgraph/agent_converter.py +++ b/agentrun/integration/langgraph/agent_converter.py @@ -30,6 +30,10 @@ from typing import Any, Dict, Iterator, List, Optional, Union from agentrun.server.model import AgentResult, EventType +from agentrun.utils.error_utils import ( + build_error_event_data, + is_rate_limited_error, +) from agentrun.utils.log import logger # 需要从工具输入中过滤掉的内部字段(LangGraph/MCP 注入的运行时对象) @@ -952,10 +956,15 @@ def _convert_astream_events_event( yield AgentResult( event=EventType.ERROR, - data={ - "message": f"LLM error: {error_message}", - "code": "LLM_ERROR", - }, + data=build_error_event_data( + error, + fallback_code="LLM_ERROR", + fallback_message=( + error_message + if is_rate_limited_error(error) + else f"LLM error: {error_message}" + ), + ), ) # 8. Chain 错误 diff --git a/agentrun/server/agui_protocol.py b/agentrun/server/agui_protocol.py index 5e8ccb4..a9d6d43 100644 --- a/agentrun/server/agui_protocol.py +++ b/agentrun/server/agui_protocol.py @@ -53,6 +53,7 @@ # ============================================================================ DEFAULT_PREFIX = "/ag-ui/agent" +RUN_ERROR_EXTRA_FIELDS = ("retryable", "retryAfterMs", "traceId") @dataclass @@ -743,12 +744,18 @@ def _process_event_with_boundaries( # ERROR 事件 if event.event == EventType.ERROR: - yield self._encoder.encode( - RunErrorEvent( - message=event.data.get("message", ""), - code=event.data.get("code"), - ) + agui_event = RunErrorEvent( + message=event.data.get("message", ""), + code=event.data.get("code"), ) + extra_fields = { + key: event.data[key] + for key in RUN_ERROR_EXTRA_FIELDS + if key in event.data + } + if extra_fields: + agui_event = agui_event.model_copy(update=extra_fields) + yield self._encoder.encode(agui_event) return # STATE 事件 diff --git a/agentrun/server/invoker.py b/agentrun/server/invoker.py index 763e6a0..26f8956 100644 --- a/agentrun/server/invoker.py +++ b/agentrun/server/invoker.py @@ -24,6 +24,10 @@ ) import uuid +from agentrun.utils.error_utils import ( + build_error_event_data, + is_rate_limited_error, +) from .model import AgentEvent, AgentRequest, EventType from .protocol import ( AsyncInvokeAgentHandler, @@ -117,10 +121,7 @@ async def invoke_stream( if isinstance(item, str): if not item: # 跳过空字符串 continue - yield AgentEvent( - event=EventType.TEXT, - data={"delta": item}, - ) + yield self._wrap_text(item) elif isinstance(item, AgentEvent): # 处理用户返回的事件 @@ -142,7 +143,11 @@ async def invoke_stream( logger.error(f"Agent 调用出错: {e}", exc_info=True) yield AgentEvent( event=EventType.ERROR, - data={"message": str(e), "code": type(e).__name__}, + data=build_error_event_data( + e, + fallback_code=type(e).__name__, + fallback_message=str(e), + ), ) def _process_user_event( @@ -227,12 +232,7 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]: return results if isinstance(result, str): - results.append( - AgentEvent( - event=EventType.TEXT, - data={"delta": result}, - ) - ) + results.append(self._wrap_text(result)) elif isinstance(result, AgentEvent): # 处理可能的 TOOL_CALL 展开 @@ -243,12 +243,7 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]: if isinstance(item, AgentEvent): results.extend(self._process_user_event(item)) elif isinstance(item, str) and item: - results.append( - AgentEvent( - event=EventType.TEXT, - data={"delta": item}, - ) - ) + results.append(self._wrap_text(item)) else: results.extend(self._wrap_model_chunk(item)) @@ -275,10 +270,7 @@ async def _wrap_stream( if isinstance(item, str): if not item: continue - yield AgentEvent( - event=EventType.TEXT, - data={"delta": item}, - ) + yield self._wrap_text(item) elif isinstance(item, AgentEvent): for processed_event in self._process_user_event(item): @@ -346,15 +338,25 @@ def _wrap_model_chunk(self, item: Any) -> List[AgentEvent]: content = self._read_attr_or_key(item, "content") if isinstance(content, str) and content: - events.append( - AgentEvent( - event=EventType.TEXT, - data={"delta": content}, - ) - ) + events.append(self._wrap_text(content)) return events + def _wrap_text(self, text: str) -> AgentEvent: + if is_rate_limited_error(text): + return AgentEvent( + event=EventType.ERROR, + data=build_error_event_data( + text, + fallback_code=type(text).__name__, + fallback_message=text, + ), + ) + return AgentEvent( + event=EventType.TEXT, + data={"delta": text}, + ) + def _read_attr_or_key(self, obj: Any, key: str) -> Any: if isinstance(obj, dict): return obj.get(key) diff --git a/agentrun/utils/error_utils.py b/agentrun/utils/error_utils.py new file mode 100644 index 0000000..f7ebee6 --- /dev/null +++ b/agentrun/utils/error_utils.py @@ -0,0 +1,85 @@ +"""Small helpers for model rate-limit errors.""" + +import re +from typing import Any, Dict, Optional + +RATE_LIMITED_CODE = "RATE_LIMITED" +RATE_LIMITED_RETRY_AFTER_MS = 2000 +_RATE_LIMIT_TEXT_RE = re.compile( + r"429|too[-_\s]*many[-_\s]*requests|rate[-_\s]*limit|throttl", + re.IGNORECASE, +) +_RATE_LIMIT_CODES = { + "ratelimitexceeded", + "ratelimited", + "throttling", + "toomanyrequests", +} + + +def build_error_event_data( + error: Any, + *, + fallback_code: str, + fallback_message: str, +) -> Dict[str, Any]: + """Keep the original message; add rate-limit metadata only when matched.""" + if not is_rate_limited_error(error): + return {"message": fallback_message, "code": fallback_code} + + data: Dict[str, Any] = { + "message": fallback_message, + "code": RATE_LIMITED_CODE, + "retryable": True, + "retryAfterMs": RATE_LIMITED_RETRY_AFTER_MS, + } + trace_id = _get_value(error, "trace_id") or _get_value(error, "traceId") + if trace_id: + data["traceId"] = str(trace_id) + return data + + +def is_rate_limited_error(error: Any) -> bool: + if error is None: + return False + if _status_code(error) == 429 or _status_code(_get_value(error, "response")) == 429: + return True + if _rate_limit_code(error) or _rate_limit_code(_get_value(error, "response")): + return True + return bool(_RATE_LIMIT_TEXT_RE.search(str(error))) + + +def _status_code(obj: Any) -> Optional[int]: + for name in ("status_code", "status", "http_status", "statusCode"): + value = _get_value(obj, name) + if value is None: + continue + try: + return int(value) + except (TypeError, ValueError): + return None + return None + + +def _rate_limit_code(obj: Any) -> bool: + for name in ("code", "error_code", "errorCode"): + code = _get_value(obj, name) + if code and _normalize_code(code) in _RATE_LIMIT_CODES: + return True + return False + + +def _get_value(obj: Any, name: str) -> Optional[Any]: + if obj is None: + return None + if isinstance(obj, dict): + return obj.get(name) + return getattr(obj, name, None) + + +def _normalize_code(code: Any) -> str: + value = "".join(ch for ch in str(code).lower() if ch.isalnum()) + for suffix in ("exception", "error"): + if value.endswith(suffix): + return value[: -len(suffix)] + return value diff --git a/tests/unittests/integration/test_langgraph_events.py b/tests/unittests/integration/test_langgraph_events.py index 0e51714..7d88111 100644 --- a/tests/unittests/integration/test_langgraph_events.py +++ b/tests/unittests/integration/test_langgraph_events.py @@ -816,7 +816,7 @@ def test_on_llm_error(self): "event": "on_llm_error", "run_id": "run_llm", "data": { - "error": RuntimeError("API rate limit exceeded"), + "error": RuntimeError("Model backend failed"), }, } @@ -828,6 +828,25 @@ def test_on_llm_error(self): assert "RuntimeError" in results[0].data["message"] assert results[0].data["code"] == "LLM_ERROR" + def test_on_llm_error_rate_limited(self): + """测试 on_llm_error 限流错误归一化且 message 保留原始错误""" + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError("Error code: 429 - rate limit exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["message"] == "RuntimeError: Error code: 429 - rate limit exceeded" + assert results[0].data["code"] == "RATE_LIMITED" + assert results[0].data["retryable"] is True + assert results[0].data["retryAfterMs"] == 2000 + def test_on_chain_error(self): """测试 on_chain_error 事件 diff --git a/tests/unittests/integration/test_langgraph_to_agent_event.py b/tests/unittests/integration/test_langgraph_to_agent_event.py index 74933d1..19bbdba 100644 --- a/tests/unittests/integration/test_langgraph_to_agent_event.py +++ b/tests/unittests/integration/test_langgraph_to_agent_event.py @@ -814,7 +814,7 @@ def test_on_llm_error(self): "event": "on_llm_error", "run_id": "run_llm", "data": { - "error": RuntimeError("API rate limit exceeded"), + "error": RuntimeError("Model backend failed"), }, } @@ -826,6 +826,25 @@ def test_on_llm_error(self): assert "RuntimeError" in results[0].data["message"] assert results[0].data["code"] == "LLM_ERROR" + def test_on_llm_error_rate_limited(self): + """测试 on_llm_error 限流错误归一化且 message 保留原始错误""" + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError("Error code: 429 - rate limit exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["message"] == "RuntimeError: Error code: 429 - rate limit exceeded" + assert results[0].data["code"] == "RATE_LIMITED" + assert results[0].data["retryable"] is True + assert results[0].data["retryAfterMs"] == 2000 + def test_on_chain_error(self): """测试 on_chain_error 事件 diff --git a/tests/unittests/server/test_agui_protocol.py b/tests/unittests/server/test_agui_protocol.py index 0896a08..3188aab 100644 --- a/tests/unittests/server/test_agui_protocol.py +++ b/tests/unittests/server/test_agui_protocol.py @@ -101,6 +101,31 @@ def invoke_agent(request: AgentRequest): assert "RUN_ERROR" in types + @pytest.mark.asyncio + async def test_text_rate_limit_error_stream_payload(self): + """测试文本形式的 429 错误输出 RUN_ERROR 且无 RUN_FINISHED""" + + def invoke_agent(request: AgentRequest): + return "Error code: 429 - rate limit exceeded" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + events = _agui_sse_events(response) + types = [event.get("type") for event in events] + run_error = next( + event for event in events if event.get("type") == "RUN_ERROR" + ) + assert "RUN_FINISHED" not in types + assert run_error["message"] == "Error code: 429 - rate limit exceeded" + assert run_error["code"] == "RATE_LIMITED" + assert run_error["retryable"] is True + assert run_error["retryAfterMs"] == 2000 + @pytest.mark.asyncio async def test_exception_in_parse_request(self): """测试 parse_request 中的异常处理(覆盖 155-156 行) diff --git a/tests/unittests/server/test_error_utils.py b/tests/unittests/server/test_error_utils.py new file mode 100644 index 0000000..0299fc1 --- /dev/null +++ b/tests/unittests/server/test_error_utils.py @@ -0,0 +1,36 @@ +"""Tests for model rate-limit helpers.""" + +from agentrun.utils.error_utils import ( + build_error_event_data, + is_rate_limited_error, +) + + +def test_text_429_rate_limit_is_rate_limited(): + assert is_rate_limited_error("Error code: 429 - rate limit exceeded") + + +def test_structured_status_429_is_rate_limited(): + class RateLimitError(RuntimeError): + status_code = 429 + + assert is_rate_limited_error(RateLimitError("provider overloaded")) + + +def test_non_rate_limit_text_is_not_rate_limited(): + assert not is_rate_limited_error("normal response") + + +def test_rate_limit_event_uses_original_message(): + data = build_error_event_data( + "Error code: 429 - rate limit exceeded", + fallback_code="str", + fallback_message="Error code: 429 - rate limit exceeded", + ) + + assert data == { + "message": "Error code: 429 - rate limit exceeded", + "code": "RATE_LIMITED", + "retryable": True, + "retryAfterMs": 2000, + } diff --git a/tests/unittests/server/test_invoker.py b/tests/unittests/server/test_invoker.py index 76d8c65..8337db3 100644 --- a/tests/unittests/server/test_invoker.py +++ b/tests/unittests/server/test_invoker.py @@ -188,6 +188,26 @@ async def invoke_agent(req: AgentRequest) -> str: assert "Test error" in error_event.data["message"] assert error_event.data["code"] == "ValueError" + @pytest.mark.asyncio + async def test_invoke_stream_text_rate_limit_error(self, req): + """测试字符串形式的模型限流错误被转成 ERROR""" + + async def invoke_agent(req: AgentRequest) -> str: + return "Error code: 429 - rate limit exceeded" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 1 + assert items[0].event == EventType.ERROR + assert items[0].data["message"] == "Error code: 429 - rate limit exceeded" + assert items[0].data["code"] == "RATE_LIMITED" + assert items[0].data["retryable"] is True + assert items[0].data["retryAfterMs"] == 2000 + class TestInvokerSync: """同步调用测试"""