diff --git a/src/paperscout/__main__.py b/src/paperscout/__main__.py index c0a42a6..fc6e7bd 100644 --- a/src/paperscout/__main__.py +++ b/src/paperscout/__main__.py @@ -14,7 +14,7 @@ from typing import Any, cast from .config import settings -from .db import init_db, init_pool +from .db import init_db, init_pool, pool_status from .health import start_health_server from .monitor import PollResult, Scheduler from .protocols import DataSource, OpsAlertFn @@ -171,6 +171,7 @@ async def _async_main() -> None: bolt_thread = None mq = None app = None + pool = None _register_shutdown_signals(asyncio.get_running_loop(), shutdown_event, shutdown_reason) @@ -228,62 +229,51 @@ async def _async_main() -> None: launch_time = datetime.now(timezone.utc) - pool = init_pool(settings.database_url) - init_db(pool) - - state = ProbeState(pool) - user_watchlist = UserWatchlist(pool) - index = WG21Index(pool, cfg=settings) - prober = ISOProber(index, state, user_watchlist) - sources: list[DataSource] = [index, prober] - if settings.enable_open_std: - sources.append(OpenStdSource()) - app = create_app() - mq = MessageQueue(app) - mq.start() - - def paper_count_fn() -> int: - return len(index.papers) - - def _on_poll_result(result: PollResult) -> None: - notify_channel(app, result, mq) - notify_users(app, result, mq) - - def _ops_alert(msg: str) -> None: - if settings.ops_alert_channel: - mq.enqueue( - settings.ops_alert_channel, - f":rotating_light: PaperScout alert: {msg}", - ) - - def _pool_status(p: Any) -> dict[str, Any]: - """Best-effort pool stats (psycopg2 ThreadedConnectionPool uses private attrs).""" - status: dict[str, Any] = {"max": getattr(p, "maxconn", None)} - try: - status["in_use"] = len(p._used) - status["available"] = len(p._pool) - except AttributeError: - status["in_use"] = None - status["available"] = None - return status - - scheduler = Scheduler( - sources=sources, - user_watchlist=user_watchlist, - state=state, - cfg=settings, - notify_callback=_on_poll_result, - ops_alert_fn=cast(OpsAlertFn, _ops_alert), - ) + try: + pool = init_pool(settings.database_url) + init_db(pool) + + state = ProbeState(pool) + user_watchlist = UserWatchlist(pool) + index = WG21Index(pool, cfg=settings) + prober = ISOProber(index, state, user_watchlist) + sources: list[DataSource] = [index, prober] + if settings.enable_open_std: + sources.append(OpenStdSource()) + app = create_app() + mq = MessageQueue(app) + mq.start() + + def paper_count_fn() -> int: + return len(index.papers) + + def _on_poll_result(result: PollResult) -> None: + notify_channel(app, result, mq) + notify_users(app, result, mq) + + def _ops_alert(msg: str) -> None: + if settings.ops_alert_channel: + mq.enqueue( + settings.ops_alert_channel, + f":rotating_light: PaperScout alert: {msg}", + ) - def _extra_health_fields() -> dict[str, Any]: - return _merge_extra_health_fields( - scheduler.health_snapshot(), - _mq_health_fields(mq), - _pool_status(pool), + scheduler = Scheduler( + sources=sources, + user_watchlist=user_watchlist, + state=state, + cfg=settings, + notify_callback=_on_poll_result, + ops_alert_fn=cast(OpsAlertFn, _ops_alert), ) - try: + def _extra_health_fields() -> dict[str, Any]: + return _merge_extra_health_fields( + scheduler.health_snapshot(), + _mq_health_fields(mq), + pool_status(pool), + ) + register_handlers(app, user_watchlist, state, paper_count_fn, launch_time) health_server = start_health_server( @@ -318,6 +308,7 @@ def _extra_health_fields() -> dict[str, Any]: bolt_thread=bolt_thread, mq_drain_timeout=settings.shutdown_mq_drain_timeout_seconds, thread_join_timeout=settings.shutdown_thread_join_timeout_seconds, + pool=pool, ) diff --git a/src/paperscout/db.py b/src/paperscout/db.py index 02a627b..91b9f91 100644 --- a/src/paperscout/db.py +++ b/src/paperscout/db.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from typing import Any from psycopg2 import pool as pg_pool @@ -48,6 +49,29 @@ def init_pool(dsn: str, minconn: int = 1, maxconn: int = 10) -> pg_pool.Threaded return p +def pool_status(p: pg_pool.ThreadedConnectionPool) -> dict[str, Any]: + """Report pool reachability via documented ``getconn``/``putconn`` only. + + Borrows and immediately returns one connection as a liveness probe. + Does not read undocumented pool attributes or private internals. + """ + try: + conn = p.getconn() + except AttributeError: + raise + except Exception: + return {"reachable": False} + try: + p.putconn(conn) + except Exception: + try: + conn.close() + except Exception: + pass + return {"reachable": False} + return {"reachable": True} + + def init_db(p: pg_pool.ThreadedConnectionPool) -> None: """Create all tables (idempotent).""" conn = p.getconn() diff --git a/src/paperscout/shutdown.py b/src/paperscout/shutdown.py index b3646b3..7a7560f 100644 --- a/src/paperscout/shutdown.py +++ b/src/paperscout/shutdown.py @@ -1,4 +1,4 @@ -"""Graceful process shutdown: drain MQ, stop HTTP servers, join worker threads.""" +"""Graceful process shutdown: drain MQ, stop HTTP servers, join worker threads, close DB pool.""" from __future__ import annotations @@ -6,6 +6,7 @@ import threading from http.server import HTTPServer +from psycopg2 import pool as pg_pool from slack_bolt import App from .scout import MessageQueue @@ -49,6 +50,7 @@ def shutdown_services( bolt_thread: threading.Thread | None, mq_drain_timeout: float, thread_join_timeout: float, + pool: pg_pool.ThreadedConnectionPool | None = None, ) -> int: """Ordered teardown. Returns the number of messages drained from the queue.""" drained = 0 @@ -83,4 +85,12 @@ def shutdown_services( reason, drained, ) + + if pool is not None: + try: + pool.closeall() + log.info("shutdown: DB pool closed") + except Exception: + log.exception("shutdown: DB pool close failed") + return drained diff --git a/tests/conftest.py b/tests/conftest.py index 13249fb..39fe9e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -194,6 +194,7 @@ def __init__(self): self._store = _FakeStore() self.fail_on_commit = False self.rollback_count = 0 + self.closeall_called = False self.call_log: list[tuple[str, Sequence]] = [] def calls_matching(self, sql_fragment: str) -> list[tuple[str, Sequence]]: @@ -221,6 +222,9 @@ def getconn(self): def putconn(self, conn): pass + def closeall(self) -> None: + self.closeall_called = True + def seed_watchlist_raw(self, rows: list[tuple[str, str, str]]) -> None: """Directly populate ``user_watchlist`` rows for edge-case tests.""" for uid, entry, etype in rows: diff --git a/tests/test_db.py b/tests/test_db.py index 3df8e63..1b19675 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -6,7 +6,7 @@ import pytest -from paperscout.db import init_db, init_pool +from paperscout.db import init_db, init_pool, pool_status @patch("paperscout.db.pg_pool.ThreadedConnectionPool") @@ -63,3 +63,32 @@ def test_init_db_putconn_even_when_execute_fails(): init_db(pool) pool.putconn.assert_called_once_with(conn) + + +def test_pool_status_uses_public_api_only(): + pool = MagicMock() + conn = MagicMock() + pool.getconn.return_value = conn + assert pool_status(pool) == {"reachable": True} + pool.getconn.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_pool_status_returns_false_when_getconn_fails(): + pool = MagicMock() + pool.getconn.side_effect = RuntimeError("pool exhausted") + assert pool_status(pool) == {"reachable": False} + + +def test_pool_status_closes_conn_when_putconn_fails(): + pool = MagicMock() + conn = MagicMock() + pool.getconn.return_value = conn + pool.putconn.side_effect = RuntimeError("putconn failed") + assert pool_status(pool) == {"reachable": False} + conn.close.assert_called_once() + + +def test_pool_status_raises_for_incompatible_pool(): + with pytest.raises(AttributeError): + pool_status(object()) # type: ignore[arg-type] diff --git a/tests/test_health.py b/tests/test_health.py index c258d97..8d981f4 100644 --- a/tests/test_health.py +++ b/tests/test_health.py @@ -51,7 +51,7 @@ def health_url_with_extras(): "probe_stats": {}, "probe_success_rate": 0.5, "mq_depth": 3, - "db_pool": {"max": 10, "in_use": 1, "available": 9}, + "db_pool": {"reachable": True}, }, ) yield f"http://127.0.0.1:{port}" @@ -121,7 +121,7 @@ def test_health_extra_fields_merged(self, health_url_with_extras): assert data["last_successful_poll"] == "2026-03-16T12:00:00+00:00" assert data["probe_success_rate"] == 0.5 assert data["mq_depth"] == 3 - assert data["db_pool"] == {"max": 10, "in_use": 1, "available": 9} + assert data["db_pool"] == {"reachable": True} @dataclass(frozen=True, slots=True) diff --git a/tests/test_main_health_merge.py b/tests/test_main_health_merge.py index ab376bc..4a3d69d 100644 --- a/tests/test_main_health_merge.py +++ b/tests/test_main_health_merge.py @@ -21,11 +21,11 @@ def test_merge_scheduler_wins_on_key_conflict(caplog): "poll_count": 99, } with caplog.at_level(logging.DEBUG, logger="paperscout"): - out = _merge_extra_health_fields(scheduler, mq_extra, {"max": 10}) + out = _merge_extra_health_fields(scheduler, mq_extra, {"reachable": True}) assert out["last_updated"] == "2026-01-01T00:00:00+00:00" assert out["poll_count"] == 1 assert out["mq_depth"] == 5 - assert out["db_pool"] == {"max": 10} + assert out["db_pool"] == {"reachable": True} assert any("not allow-listed" in r.message for r in caplog.records) diff --git a/tests/test_shutdown.py b/tests/test_shutdown.py index 1acb092..838184a 100644 --- a/tests/test_shutdown.py +++ b/tests/test_shutdown.py @@ -79,3 +79,53 @@ def test_shutdown_services_continues_after_mq_drain_failure(self, caplog): assert drained == 0 health_server.shutdown.assert_called_once() assert any("drained 0" in r.message for r in caplog.records) + + def test_shutdown_services_closes_pool(self, caplog): + pool = MagicMock() + with caplog.at_level(logging.INFO, logger="paperscout"): + shutdown_services( + reason="SIGTERM", + mq=None, + health_server=None, + health_thread=None, + app=None, + bolt_thread=None, + mq_drain_timeout=30.0, + thread_join_timeout=5.0, + pool=pool, + ) + pool.closeall.assert_called_once() + assert any("DB pool closed" in r.message for r in caplog.records) + + def test_shutdown_services_pool_close_raises_continues(self, caplog): + mq = MagicMock() + mq.drain.return_value = 1 + pool = MagicMock() + pool.closeall.side_effect = RuntimeError("pool close boom") + with caplog.at_level(logging.ERROR, logger="paperscout"): + drained = shutdown_services( + reason="SIGTERM", + mq=mq, + health_server=None, + health_thread=None, + app=None, + bolt_thread=None, + mq_drain_timeout=30.0, + thread_join_timeout=5.0, + pool=pool, + ) + assert drained == 1 + pool.closeall.assert_called_once() + + def test_shutdown_services_skips_pool_when_none(self): + shutdown_services( + reason="unknown", + mq=None, + health_server=None, + health_thread=None, + app=None, + bolt_thread=None, + mq_drain_timeout=30.0, + thread_join_timeout=5.0, + pool=None, + )