Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions agentrun/integration/langgraph/agent_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from typing import Any, Dict, Iterator, List, Optional, Union

from agentrun.server.model import AgentResult, EventType
from agentrun.utils.error_utils import (
build_error_event_data,
is_rate_limited_error,
)
from agentrun.utils.log import logger

# 需要从工具输入中过滤掉的内部字段(LangGraph/MCP 注入的运行时对象)
Expand Down Expand Up @@ -952,10 +956,15 @@ def _convert_astream_events_event(

yield AgentResult(
event=EventType.ERROR,
data={
"message": f"LLM error: {error_message}",
"code": "LLM_ERROR",
},
data=build_error_event_data(
error,
fallback_code="LLM_ERROR",
fallback_message=(
error_message
if is_rate_limited_error(error)
else f"LLM error: {error_message}"
),
),
)

# 8. Chain 错误
Expand Down
17 changes: 12 additions & 5 deletions agentrun/server/agui_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
# ============================================================================

DEFAULT_PREFIX = "/ag-ui/agent"
RUN_ERROR_EXTRA_FIELDS = ("retryable", "retryAfterMs", "traceId")


@dataclass
Expand Down Expand Up @@ -743,12 +744,18 @@ def _process_event_with_boundaries(

# ERROR 事件
if event.event == EventType.ERROR:
yield self._encoder.encode(
RunErrorEvent(
message=event.data.get("message", ""),
code=event.data.get("code"),
)
agui_event = RunErrorEvent(
message=event.data.get("message", ""),
code=event.data.get("code"),
)
extra_fields = {
key: event.data[key]
for key in RUN_ERROR_EXTRA_FIELDS
if key in event.data
}
if extra_fields:
agui_event = agui_event.model_copy(update=extra_fields)
yield self._encoder.encode(agui_event)
return

# STATE 事件
Expand Down
56 changes: 29 additions & 27 deletions agentrun/server/invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
)
import uuid

from agentrun.utils.error_utils import (
build_error_event_data,
is_rate_limited_error,
)
from .model import AgentEvent, AgentRequest, EventType
from .protocol import (
AsyncInvokeAgentHandler,
Expand Down Expand Up @@ -117,10 +121,7 @@ async def invoke_stream(
if isinstance(item, str):
if not item: # 跳过空字符串
continue
yield AgentEvent(
event=EventType.TEXT,
data={"delta": item},
)
yield self._wrap_text(item)

elif isinstance(item, AgentEvent):
# 处理用户返回的事件
Expand All @@ -142,7 +143,11 @@ async def invoke_stream(
logger.error(f"Agent 调用出错: {e}", exc_info=True)
yield AgentEvent(
event=EventType.ERROR,
data={"message": str(e), "code": type(e).__name__},
data=build_error_event_data(
e,
fallback_code=type(e).__name__,
fallback_message=str(e),
),
)

def _process_user_event(
Expand Down Expand Up @@ -227,12 +232,7 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]:
return results

if isinstance(result, str):
results.append(
AgentEvent(
event=EventType.TEXT,
data={"delta": result},
)
)
results.append(self._wrap_text(result))

elif isinstance(result, AgentEvent):
# 处理可能的 TOOL_CALL 展开
Expand All @@ -243,12 +243,7 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]:
if isinstance(item, AgentEvent):
results.extend(self._process_user_event(item))
elif isinstance(item, str) and item:
results.append(
AgentEvent(
event=EventType.TEXT,
data={"delta": item},
)
)
results.append(self._wrap_text(item))
else:
results.extend(self._wrap_model_chunk(item))

Expand All @@ -275,10 +270,7 @@ async def _wrap_stream(
if isinstance(item, str):
if not item:
continue
yield AgentEvent(
event=EventType.TEXT,
data={"delta": item},
)
yield self._wrap_text(item)

elif isinstance(item, AgentEvent):
for processed_event in self._process_user_event(item):
Expand Down Expand Up @@ -346,15 +338,25 @@ def _wrap_model_chunk(self, item: Any) -> List[AgentEvent]:

content = self._read_attr_or_key(item, "content")
if isinstance(content, str) and content:
events.append(
AgentEvent(
event=EventType.TEXT,
data={"delta": content},
)
)
events.append(self._wrap_text(content))

return events

def _wrap_text(self, text: str) -> AgentEvent:
if is_rate_limited_error(text):
return AgentEvent(
event=EventType.ERROR,
data=build_error_event_data(
text,
fallback_code=type(text).__name__,
fallback_message=text,
),
)
return AgentEvent(
event=EventType.TEXT,
data={"delta": text},
)

def _read_attr_or_key(self, obj: Any, key: str) -> Any:
if isinstance(obj, dict):
return obj.get(key)
Expand Down
85 changes: 85 additions & 0 deletions agentrun/utils/error_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Small helpers for model rate-limit errors."""

import re
from typing import Any, Dict, Optional

RATE_LIMITED_CODE = "RATE_LIMITED"
RATE_LIMITED_RETRY_AFTER_MS = 2000
_RATE_LIMIT_TEXT_RE = re.compile(
r"429|too[-_\s]*many[-_\s]*requests|rate[-_\s]*limit|throttl",
re.IGNORECASE,
)
_RATE_LIMIT_CODES = {
"ratelimitexceeded",
"ratelimited",
"throttling",
"toomanyrequests",
}


def build_error_event_data(
error: Any,
*,
fallback_code: str,
fallback_message: str,
) -> Dict[str, Any]:
"""Keep the original message; add rate-limit metadata only when matched."""
if not is_rate_limited_error(error):
return {"message": fallback_message, "code": fallback_code}

data: Dict[str, Any] = {
"message": fallback_message,
"code": RATE_LIMITED_CODE,
"retryable": True,
"retryAfterMs": RATE_LIMITED_RETRY_AFTER_MS,
}
trace_id = _get_value(error, "trace_id") or _get_value(error, "traceId")
if trace_id:
data["traceId"] = str(trace_id)
return data


def is_rate_limited_error(error: Any) -> bool:
if error is None:
return False
if _status_code(error) == 429 or _status_code(_get_value(error, "response")) == 429:
return True
if _rate_limit_code(error) or _rate_limit_code(_get_value(error, "response")):
return True
return bool(_RATE_LIMIT_TEXT_RE.search(str(error)))


def _status_code(obj: Any) -> Optional[int]:
for name in ("status_code", "status", "http_status", "statusCode"):
value = _get_value(obj, name)
if value is None:
continue
try:
return int(value)
except (TypeError, ValueError):
return None
return None


def _rate_limit_code(obj: Any) -> bool:
for name in ("code", "error_code", "errorCode"):
code = _get_value(obj, name)
if code and _normalize_code(code) in _RATE_LIMIT_CODES:
return True
return False


def _get_value(obj: Any, name: str) -> Optional[Any]:
if obj is None:
return None
if isinstance(obj, dict):
return obj.get(name)
return getattr(obj, name, None)


def _normalize_code(code: Any) -> str:
value = "".join(ch for ch in str(code).lower() if ch.isalnum())
for suffix in ("exception", "error"):
if value.endswith(suffix):
return value[: -len(suffix)]
return value
21 changes: 20 additions & 1 deletion tests/unittests/integration/test_langgraph_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def test_on_llm_error(self):
"event": "on_llm_error",
"run_id": "run_llm",
"data": {
"error": RuntimeError("API rate limit exceeded"),
"error": RuntimeError("Model backend failed"),
},
}

Expand All @@ -828,6 +828,25 @@ def test_on_llm_error(self):
assert "RuntimeError" in results[0].data["message"]
assert results[0].data["code"] == "LLM_ERROR"

def test_on_llm_error_rate_limited(self):
"""测试 on_llm_error 限流错误归一化且 message 保留原始错误"""
event = {
"event": "on_llm_error",
"run_id": "run_llm",
"data": {
"error": RuntimeError("Error code: 429 - rate limit exceeded"),
},
}

results = list(AgentRunConverter().to_agui_events(event))

assert len(results) == 1
assert results[0].event == EventType.ERROR
assert results[0].data["message"] == "RuntimeError: Error code: 429 - rate limit exceeded"
assert results[0].data["code"] == "RATE_LIMITED"
assert results[0].data["retryable"] is True
assert results[0].data["retryAfterMs"] == 2000

def test_on_chain_error(self):
"""测试 on_chain_error 事件

Expand Down
21 changes: 20 additions & 1 deletion tests/unittests/integration/test_langgraph_to_agent_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def test_on_llm_error(self):
"event": "on_llm_error",
"run_id": "run_llm",
"data": {
"error": RuntimeError("API rate limit exceeded"),
"error": RuntimeError("Model backend failed"),
},
}

Expand All @@ -826,6 +826,25 @@ def test_on_llm_error(self):
assert "RuntimeError" in results[0].data["message"]
assert results[0].data["code"] == "LLM_ERROR"

def test_on_llm_error_rate_limited(self):
"""测试 on_llm_error 限流错误归一化且 message 保留原始错误"""
event = {
"event": "on_llm_error",
"run_id": "run_llm",
"data": {
"error": RuntimeError("Error code: 429 - rate limit exceeded"),
},
}

results = list(AgentRunConverter().to_agui_events(event))

assert len(results) == 1
assert results[0].event == EventType.ERROR
assert results[0].data["message"] == "RuntimeError: Error code: 429 - rate limit exceeded"
assert results[0].data["code"] == "RATE_LIMITED"
assert results[0].data["retryable"] is True
assert results[0].data["retryAfterMs"] == 2000

def test_on_chain_error(self):
"""测试 on_chain_error 事件

Expand Down
25 changes: 25 additions & 0 deletions tests/unittests/server/test_agui_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,31 @@ def invoke_agent(request: AgentRequest):

assert "RUN_ERROR" in types

@pytest.mark.asyncio
async def test_text_rate_limit_error_stream_payload(self):
"""测试文本形式的 429 错误输出 RUN_ERROR 且无 RUN_FINISHED"""

def invoke_agent(request: AgentRequest):
return "Error code: 429 - rate limit exceeded"

client = self.get_client(invoke_agent)
response = client.post(
"/ag-ui/agent",
json={"messages": [{"role": "user", "content": "Hello"}]},
)

assert response.status_code == 200
events = _agui_sse_events(response)
types = [event.get("type") for event in events]
run_error = next(
event for event in events if event.get("type") == "RUN_ERROR"
)
assert "RUN_FINISHED" not in types
assert run_error["message"] == "Error code: 429 - rate limit exceeded"
assert run_error["code"] == "RATE_LIMITED"
assert run_error["retryable"] is True
assert run_error["retryAfterMs"] == 2000

@pytest.mark.asyncio
async def test_exception_in_parse_request(self):
"""测试 parse_request 中的异常处理(覆盖 155-156 行)
Expand Down
36 changes: 36 additions & 0 deletions tests/unittests/server/test_error_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Tests for model rate-limit helpers."""

from agentrun.utils.error_utils import (
build_error_event_data,
is_rate_limited_error,
)


def test_text_429_rate_limit_is_rate_limited():
assert is_rate_limited_error("Error code: 429 - rate limit exceeded")


def test_structured_status_429_is_rate_limited():
class RateLimitError(RuntimeError):
status_code = 429

assert is_rate_limited_error(RateLimitError("provider overloaded"))


def test_non_rate_limit_text_is_not_rate_limited():
assert not is_rate_limited_error("normal response")


def test_rate_limit_event_uses_original_message():
data = build_error_event_data(
"Error code: 429 - rate limit exceeded",
fallback_code="str",
fallback_message="Error code: 429 - rate limit exceeded",
)

assert data == {
"message": "Error code: 429 - rate limit exceeded",
"code": "RATE_LIMITED",
"retryable": True,
"retryAfterMs": 2000,
}
Loading
Loading