Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Session persistence + auto-reload.

The long-running background server loads the session once at startup. These tests
pin down the behavior that lets a fresh ``og-veil login`` reach that running
server: the in-memory session re-reads ``session.json`` when another process
rewrites it, instead of forever using its stale (expired/revoked) token.
"""

from __future__ import annotations

import json
import time

import pytest

from veil.config import session_path
from veil.session import AuthError, Session


@pytest.fixture
def home(tmp_path, monkeypatch):
monkeypatch.setenv("OG_VEIL_HOME", str(tmp_path))
return tmp_path


def _bundle(access_token: str, *, expires_at: float) -> dict:
return {
"type": "opengradient-cli-auth",
"access_token": access_token,
"refresh_token": "refresh-xyz",
"expires_at": expires_at,
"user": {"email": "me@example.com"},
"config": {
"supabase_url": "https://supabase.example",
"supabase_anon_key": "anon",
"chat_api_base_url": "https://chat.example",
},
}


def _write_session(data: dict) -> None:
session_path().write_text(json.dumps(data))


def test_fresh_login_on_disk_is_picked_up_without_restart(home):
# The server loaded a session whose token is already expired and whose refresh
# token would be rejected upstream — the exact state behind the reported 401.
_write_session(_bundle("stale-token", expires_at=time.time() - 3600))
session = Session.load()

# The user runs `og-veil login`, which writes a brand-new, valid session to
# disk from another process. The running server must adopt it.
_write_session(_bundle("fresh-token", expires_at=time.time() + 3600))

# No refresh network call should be needed: the reloaded token is still valid.
assert session.access_token() == "fresh-token"
assert session.user_email == "me@example.com"


def test_unchanged_session_is_not_reloaded_and_no_refresh(home):
_write_session(_bundle("good-token", expires_at=time.time() + 3600))
session = Session.load()

# Make any refresh attempt explode so the test fails loudly if one happens.
def _boom():
raise AssertionError("should not refresh a still-valid token")

session._refresh = _boom # type: ignore[assignment]
assert session.access_token() == "good-token"


def test_refresh_writes_back_without_triggering_self_reload(home):
# A token that's expired in memory but has no fresh login on disk must still go
# through refresh; the write-back must not look like an external change.
_write_session(_bundle("expired-token", expires_at=time.time() - 3600))
session = Session.load()

calls = {"n": 0}

def _fake_refresh():
calls["n"] += 1
session._data["access_token"] = "refreshed-token"
session._data["expires_at"] = time.time() + 3600
session.save()

session._refresh = _fake_refresh # type: ignore[assignment]
assert session.access_token() == "refreshed-token"
# A second call sees its own write (unchanged mtime), so it neither reloads nor
# refreshes again.
assert session.access_token() == "refreshed-token"
assert calls["n"] == 1


def test_load_without_session_file_raises(home):
with pytest.raises(AuthError):
Session.load()
50 changes: 45 additions & 5 deletions veil/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import json
import os
import threading
import time
import webbrowser
Expand Down Expand Up @@ -66,9 +67,14 @@ def from_dict(cls, d: dict) -> "NetworkConfig":
class Session:
"""A persisted Chat session: tokens + network config, with auto-refresh."""

def __init__(self, data: dict):
def __init__(self, data: dict, mtime: Optional[float] = None):
self._data = data
self.config = NetworkConfig.from_dict(data.get("config", {}))
# Serializes refresh/reload across the threaded server's request handlers.
self._lock = threading.Lock()
# mtime of session.json as we last read/wrote it, so a long-running server
# can notice when another process (e.g. `og-veil login`) rewrites it.
self._mtime = mtime

# --- persistence -------------------------------------------------------
@classmethod
Expand All @@ -77,7 +83,7 @@ def load(cls) -> "Session":
if not path.exists():
raise AuthError("not logged in — run `og-veil login` first")
try:
return cls(json.loads(path.read_text()))
return cls(json.loads(path.read_text()), mtime=_file_mtime(path))
except (OSError, json.JSONDecodeError) as exc:
raise AuthError(f"could not read saved session: {exc}") from exc

Expand All @@ -88,6 +94,28 @@ def save(self) -> None:
path.chmod(0o600) # the file holds a live session token
except OSError:
pass
# Record our own write so reload-on-change doesn't re-read it needlessly.
self._mtime = _file_mtime(path)

def _reload_if_changed(self) -> None:
"""Re-read session.json if another process rewrote it (e.g. a fresh login).

The background server loads the session once at startup and keeps it in
memory. Without this, a successful ``og-veil login`` — which writes a new
token to disk — would never reach the running server, so requests would
keep failing with the stale (expired/revoked) token until a restart.
"""
path = session_path()
disk_mtime = _file_mtime(path)
if disk_mtime is None or disk_mtime == self._mtime:
return
try:
data = json.loads(path.read_text())
except (OSError, json.JSONDecodeError):
return # mid-write or unreadable; keep what we have and try again later
self._data = data
self.config = NetworkConfig.from_dict(data.get("config", {}))
self._mtime = disk_mtime

# --- accessors ---------------------------------------------------------
@property
Expand All @@ -99,9 +127,13 @@ def auth_headers(self) -> dict:
return {"Authorization": f"Bearer {self.access_token()}"}

def access_token(self) -> str:
if self._is_expired():
self._refresh()
token = self._data.get("access_token")
with self._lock:
# Pick up a fresh login written by another process before deciding
# whether the in-memory token needs refreshing.
self._reload_if_changed()
if self._is_expired():
self._refresh()
token = self._data.get("access_token")
if not token:
raise AuthError("session has no access token — run `og-veil login` to sign in again")
return token
Expand Down Expand Up @@ -257,3 +289,11 @@ def _require(d: dict, key: str) -> str:
if not value:
raise AuthError(f"CLI-auth config is missing '{key}'")
return value


def _file_mtime(path) -> Optional[float]:
"""Last-modified time of ``path`` in nanoseconds, or None if it's missing."""
try:
return os.stat(path).st_mtime_ns
except OSError:
return None
Loading