Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 45 additions & 54 deletions src/paperscout/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, cast

from .config import settings
from .db import init_db, init_pool
from .db import init_db, init_pool, pool_status
from .health import start_health_server
from .monitor import PollResult, Scheduler
from .protocols import DataSource, OpsAlertFn
Expand Down Expand Up @@ -171,6 +171,7 @@ async def _async_main() -> None:
bolt_thread = None
mq = None
app = None
pool = None

_register_shutdown_signals(asyncio.get_running_loop(), shutdown_event, shutdown_reason)

Expand Down Expand Up @@ -228,62 +229,51 @@ async def _async_main() -> None:

launch_time = datetime.now(timezone.utc)

pool = init_pool(settings.database_url)
init_db(pool)

state = ProbeState(pool)
user_watchlist = UserWatchlist(pool)
index = WG21Index(pool, cfg=settings)
prober = ISOProber(index, state, user_watchlist)
sources: list[DataSource] = [index, prober]
if settings.enable_open_std:
sources.append(OpenStdSource())
app = create_app()
mq = MessageQueue(app)
mq.start()

def paper_count_fn() -> int:
return len(index.papers)

def _on_poll_result(result: PollResult) -> None:
notify_channel(app, result, mq)
notify_users(app, result, mq)

def _ops_alert(msg: str) -> None:
if settings.ops_alert_channel:
mq.enqueue(
settings.ops_alert_channel,
f":rotating_light: PaperScout alert: {msg}",
)

def _pool_status(p: Any) -> dict[str, Any]:
"""Best-effort pool stats (psycopg2 ThreadedConnectionPool uses private attrs)."""
status: dict[str, Any] = {"max": getattr(p, "maxconn", None)}
try:
status["in_use"] = len(p._used)
status["available"] = len(p._pool)
except AttributeError:
status["in_use"] = None
status["available"] = None
return status

scheduler = Scheduler(
sources=sources,
user_watchlist=user_watchlist,
state=state,
cfg=settings,
notify_callback=_on_poll_result,
ops_alert_fn=cast(OpsAlertFn, _ops_alert),
)
try:
pool = init_pool(settings.database_url)
init_db(pool)

state = ProbeState(pool)
user_watchlist = UserWatchlist(pool)
index = WG21Index(pool, cfg=settings)
prober = ISOProber(index, state, user_watchlist)
sources: list[DataSource] = [index, prober]
if settings.enable_open_std:
sources.append(OpenStdSource())
app = create_app()
mq = MessageQueue(app)
mq.start()

def paper_count_fn() -> int:
return len(index.papers)

def _on_poll_result(result: PollResult) -> None:
notify_channel(app, result, mq)
notify_users(app, result, mq)

def _ops_alert(msg: str) -> None:
if settings.ops_alert_channel:
mq.enqueue(
settings.ops_alert_channel,
f":rotating_light: PaperScout alert: {msg}",
)

def _extra_health_fields() -> dict[str, Any]:
return _merge_extra_health_fields(
scheduler.health_snapshot(),
_mq_health_fields(mq),
_pool_status(pool),
scheduler = Scheduler(
sources=sources,
user_watchlist=user_watchlist,
state=state,
cfg=settings,
notify_callback=_on_poll_result,
ops_alert_fn=cast(OpsAlertFn, _ops_alert),
)

try:
def _extra_health_fields() -> dict[str, Any]:
return _merge_extra_health_fields(
scheduler.health_snapshot(),
_mq_health_fields(mq),
pool_status(pool),
)

register_handlers(app, user_watchlist, state, paper_count_fn, launch_time)

health_server = start_health_server(
Expand Down Expand Up @@ -318,6 +308,7 @@ def _extra_health_fields() -> dict[str, Any]:
bolt_thread=bolt_thread,
mq_drain_timeout=settings.shutdown_mq_drain_timeout_seconds,
thread_join_timeout=settings.shutdown_thread_join_timeout_seconds,
pool=pool,
)


Expand Down
24 changes: 24 additions & 0 deletions src/paperscout/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
from typing import Any

from psycopg2 import pool as pg_pool

Expand Down Expand Up @@ -48,6 +49,29 @@ def init_pool(dsn: str, minconn: int = 1, maxconn: int = 10) -> pg_pool.Threaded
return p


def pool_status(p: pg_pool.ThreadedConnectionPool) -> dict[str, Any]:
"""Report pool reachability via documented ``getconn``/``putconn`` only.

Borrows and immediately returns one connection as a liveness probe.
Does not read undocumented pool attributes or private internals.
"""
try:
conn = p.getconn()
except AttributeError:
raise
except Exception:
return {"reachable": False}
try:
p.putconn(conn)
except Exception:
try:
conn.close()
except Exception:
pass
return {"reachable": False}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return {"reachable": True}


def init_db(p: pg_pool.ThreadedConnectionPool) -> None:
"""Create all tables (idempotent)."""
conn = p.getconn()
Expand Down
12 changes: 11 additions & 1 deletion src/paperscout/shutdown.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Graceful process shutdown: drain MQ, stop HTTP servers, join worker threads."""
"""Graceful process shutdown: drain MQ, stop HTTP servers, join worker threads, close DB pool."""

from __future__ import annotations

import logging
import threading
from http.server import HTTPServer

from psycopg2 import pool as pg_pool
from slack_bolt import App

from .scout import MessageQueue
Expand Down Expand Up @@ -49,6 +50,7 @@ def shutdown_services(
bolt_thread: threading.Thread | None,
mq_drain_timeout: float,
thread_join_timeout: float,
pool: pg_pool.ThreadedConnectionPool | None = None,
) -> int:
"""Ordered teardown. Returns the number of messages drained from the queue."""
drained = 0
Expand Down Expand Up @@ -83,4 +85,12 @@ def shutdown_services(
reason,
drained,
)

if pool is not None:
try:
pool.closeall()
log.info("shutdown: DB pool closed")
except Exception:
log.exception("shutdown: DB pool close failed")

return drained
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(self):
self._store = _FakeStore()
self.fail_on_commit = False
self.rollback_count = 0
self.closeall_called = False
self.call_log: list[tuple[str, Sequence]] = []

def calls_matching(self, sql_fragment: str) -> list[tuple[str, Sequence]]:
Expand Down Expand Up @@ -221,6 +222,9 @@ def getconn(self):
def putconn(self, conn):
pass

def closeall(self) -> None:
self.closeall_called = True

def seed_watchlist_raw(self, rows: list[tuple[str, str, str]]) -> None:
"""Directly populate ``user_watchlist`` rows for edge-case tests."""
for uid, entry, etype in rows:
Expand Down
31 changes: 30 additions & 1 deletion tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from paperscout.db import init_db, init_pool
from paperscout.db import init_db, init_pool, pool_status


@patch("paperscout.db.pg_pool.ThreadedConnectionPool")
Expand Down Expand Up @@ -63,3 +63,32 @@ def test_init_db_putconn_even_when_execute_fails():
init_db(pool)

pool.putconn.assert_called_once_with(conn)


def test_pool_status_uses_public_api_only():
pool = MagicMock()
conn = MagicMock()
pool.getconn.return_value = conn
assert pool_status(pool) == {"reachable": True}
pool.getconn.assert_called_once()
pool.putconn.assert_called_once_with(conn)


def test_pool_status_returns_false_when_getconn_fails():
pool = MagicMock()
pool.getconn.side_effect = RuntimeError("pool exhausted")
assert pool_status(pool) == {"reachable": False}


def test_pool_status_closes_conn_when_putconn_fails():
pool = MagicMock()
conn = MagicMock()
pool.getconn.return_value = conn
pool.putconn.side_effect = RuntimeError("putconn failed")
assert pool_status(pool) == {"reachable": False}
conn.close.assert_called_once()


def test_pool_status_raises_for_incompatible_pool():
with pytest.raises(AttributeError):
pool_status(object()) # type: ignore[arg-type]
4 changes: 2 additions & 2 deletions tests/test_health.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def health_url_with_extras():
"probe_stats": {},
"probe_success_rate": 0.5,
"mq_depth": 3,
"db_pool": {"max": 10, "in_use": 1, "available": 9},
"db_pool": {"reachable": True},
},
)
yield f"http://127.0.0.1:{port}"
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_health_extra_fields_merged(self, health_url_with_extras):
assert data["last_successful_poll"] == "2026-03-16T12:00:00+00:00"
assert data["probe_success_rate"] == 0.5
assert data["mq_depth"] == 3
assert data["db_pool"] == {"max": 10, "in_use": 1, "available": 9}
assert data["db_pool"] == {"reachable": True}


@dataclass(frozen=True, slots=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_main_health_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def test_merge_scheduler_wins_on_key_conflict(caplog):
"poll_count": 99,
}
with caplog.at_level(logging.DEBUG, logger="paperscout"):
out = _merge_extra_health_fields(scheduler, mq_extra, {"max": 10})
out = _merge_extra_health_fields(scheduler, mq_extra, {"reachable": True})
assert out["last_updated"] == "2026-01-01T00:00:00+00:00"
assert out["poll_count"] == 1
assert out["mq_depth"] == 5
assert out["db_pool"] == {"max": 10}
assert out["db_pool"] == {"reachable": True}
assert any("not allow-listed" in r.message for r in caplog.records)


Expand Down
50 changes: 50 additions & 0 deletions tests/test_shutdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,53 @@ def test_shutdown_services_continues_after_mq_drain_failure(self, caplog):
assert drained == 0
health_server.shutdown.assert_called_once()
assert any("drained 0" in r.message for r in caplog.records)

def test_shutdown_services_closes_pool(self, caplog):
pool = MagicMock()
with caplog.at_level(logging.INFO, logger="paperscout"):
shutdown_services(
reason="SIGTERM",
mq=None,
health_server=None,
health_thread=None,
app=None,
bolt_thread=None,
mq_drain_timeout=30.0,
thread_join_timeout=5.0,
pool=pool,
)
pool.closeall.assert_called_once()
assert any("DB pool closed" in r.message for r in caplog.records)

def test_shutdown_services_pool_close_raises_continues(self, caplog):
mq = MagicMock()
mq.drain.return_value = 1
pool = MagicMock()
pool.closeall.side_effect = RuntimeError("pool close boom")
with caplog.at_level(logging.ERROR, logger="paperscout"):
drained = shutdown_services(
reason="SIGTERM",
mq=mq,
health_server=None,
health_thread=None,
app=None,
bolt_thread=None,
mq_drain_timeout=30.0,
thread_join_timeout=5.0,
pool=pool,
)
assert drained == 1
pool.closeall.assert_called_once()

def test_shutdown_services_skips_pool_when_none(self):
shutdown_services(
reason="unknown",
mq=None,
health_server=None,
health_thread=None,
app=None,
bolt_thread=None,
mq_drain_timeout=30.0,
thread_join_timeout=5.0,
pool=None,
)