diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..1e447e9 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,96 @@ +"""Session persistence + auto-reload. + +The long-running background server loads the session once at startup. These tests +pin down the behavior that lets a fresh ``og-veil login`` reach that running +server: the in-memory session re-reads ``session.json`` when another process +rewrites it, instead of forever using its stale (expired/revoked) token. +""" + +from __future__ import annotations + +import json +import time + +import pytest + +from veil.config import session_path +from veil.session import AuthError, Session + + +@pytest.fixture +def home(tmp_path, monkeypatch): + monkeypatch.setenv("OG_VEIL_HOME", str(tmp_path)) + return tmp_path + + +def _bundle(access_token: str, *, expires_at: float) -> dict: + return { + "type": "opengradient-cli-auth", + "access_token": access_token, + "refresh_token": "refresh-xyz", + "expires_at": expires_at, + "user": {"email": "me@example.com"}, + "config": { + "supabase_url": "https://supabase.example", + "supabase_anon_key": "anon", + "chat_api_base_url": "https://chat.example", + }, + } + + +def _write_session(data: dict) -> None: + session_path().write_text(json.dumps(data)) + + +def test_fresh_login_on_disk_is_picked_up_without_restart(home): + # The server loaded a session whose token is already expired and whose refresh + # token would be rejected upstream — the exact state behind the reported 401. + _write_session(_bundle("stale-token", expires_at=time.time() - 3600)) + session = Session.load() + + # The user runs `og-veil login`, which writes a brand-new, valid session to + # disk from another process. The running server must adopt it. + _write_session(_bundle("fresh-token", expires_at=time.time() + 3600)) + + # No refresh network call should be needed: the reloaded token is still valid. + assert session.access_token() == "fresh-token" + assert session.user_email == "me@example.com" + + +def test_unchanged_session_is_not_reloaded_and_no_refresh(home): + _write_session(_bundle("good-token", expires_at=time.time() + 3600)) + session = Session.load() + + # Make any refresh attempt explode so the test fails loudly if one happens. + def _boom(): + raise AssertionError("should not refresh a still-valid token") + + session._refresh = _boom # type: ignore[assignment] + assert session.access_token() == "good-token" + + +def test_refresh_writes_back_without_triggering_self_reload(home): + # A token that's expired in memory but has no fresh login on disk must still go + # through refresh; the write-back must not look like an external change. + _write_session(_bundle("expired-token", expires_at=time.time() - 3600)) + session = Session.load() + + calls = {"n": 0} + + def _fake_refresh(): + calls["n"] += 1 + session._data["access_token"] = "refreshed-token" + session._data["expires_at"] = time.time() + 3600 + session.save() + + session._refresh = _fake_refresh # type: ignore[assignment] + assert session.access_token() == "refreshed-token" + # A second call sees its own write (unchanged mtime), so it neither reloads nor + # refreshes again. + assert session.access_token() == "refreshed-token" + assert calls["n"] == 1 + + +def test_load_without_session_file_raises(home): + with pytest.raises(AuthError): + Session.load() diff --git a/veil/session.py b/veil/session.py index a8ed95f..66e36c2 100644 --- a/veil/session.py +++ b/veil/session.py @@ -16,6 +16,7 @@ from __future__ import annotations import json +import os import threading import time import webbrowser @@ -66,9 +67,14 @@ def from_dict(cls, d: dict) -> "NetworkConfig": class Session: """A persisted Chat session: tokens + network config, with auto-refresh.""" - def __init__(self, data: dict): + def __init__(self, data: dict, mtime: Optional[float] = None): self._data = data self.config = NetworkConfig.from_dict(data.get("config", {})) + # Serializes refresh/reload across the threaded server's request handlers. + self._lock = threading.Lock() + # mtime of session.json as we last read/wrote it, so a long-running server + # can notice when another process (e.g. `og-veil login`) rewrites it. + self._mtime = mtime # --- persistence ------------------------------------------------------- @classmethod @@ -77,7 +83,7 @@ def load(cls) -> "Session": if not path.exists(): raise AuthError("not logged in — run `og-veil login` first") try: - return cls(json.loads(path.read_text())) + return cls(json.loads(path.read_text()), mtime=_file_mtime(path)) except (OSError, json.JSONDecodeError) as exc: raise AuthError(f"could not read saved session: {exc}") from exc @@ -88,6 +94,28 @@ def save(self) -> None: path.chmod(0o600) # the file holds a live session token except OSError: pass + # Record our own write so reload-on-change doesn't re-read it needlessly. + self._mtime = _file_mtime(path) + + def _reload_if_changed(self) -> None: + """Re-read session.json if another process rewrote it (e.g. a fresh login). + + The background server loads the session once at startup and keeps it in + memory. Without this, a successful ``og-veil login`` — which writes a new + token to disk — would never reach the running server, so requests would + keep failing with the stale (expired/revoked) token until a restart. + """ + path = session_path() + disk_mtime = _file_mtime(path) + if disk_mtime is None or disk_mtime == self._mtime: + return + try: + data = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return # mid-write or unreadable; keep what we have and try again later + self._data = data + self.config = NetworkConfig.from_dict(data.get("config", {})) + self._mtime = disk_mtime # --- accessors --------------------------------------------------------- @property @@ -99,9 +127,13 @@ def auth_headers(self) -> dict: return {"Authorization": f"Bearer {self.access_token()}"} def access_token(self) -> str: - if self._is_expired(): - self._refresh() - token = self._data.get("access_token") + with self._lock: + # Pick up a fresh login written by another process before deciding + # whether the in-memory token needs refreshing. + self._reload_if_changed() + if self._is_expired(): + self._refresh() + token = self._data.get("access_token") if not token: raise AuthError("session has no access token — run `og-veil login` to sign in again") return token @@ -257,3 +289,11 @@ def _require(d: dict, key: str) -> str: if not value: raise AuthError(f"CLI-auth config is missing '{key}'") return value + + +def _file_mtime(path) -> Optional[float]: + """Last-modified time of ``path`` in nanoseconds, or None if it's missing.""" + try: + return os.stat(path).st_mtime_ns + except OSError: + return None