From 2a953d6f10a3a20cf2e08f2f83c75d8eced1210c Mon Sep 17 00:00:00 2001 From: Aaro Koinsaari <89689072+koinsaari@users.noreply.github.com> Date: Sat, 30 May 2026 00:46:14 +0300 Subject: [PATCH 1/3] feat: HTTP layer with strict-server handlers and JWT middleware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire up the generated StrictServerInterface with typed request/response objects. JWT auth runs as a StrictMiddlewareFunc applied selectively to protected operations (postAuthLogout, postAuthLogoutAll) rather than a blanket route group. - server.go: NewServer wires strict handler + middleware chain (RequestID → Logging → gen.Handler) - middleware.go: RequestID and Logging middleware; context keys for request ID, logger, userID, jfUserID - handlers_auth.go: full StrictServerInterface implementation mapping auth.Service errors to typed JSON responses - errors.go: writeError helper using gen.ErrorErrorCode constants - middleware_test.go: JWT middleware unit tests + compile-time interface check (var _ gen.StrictServerInterface = (*Server)(nil)) - auth/jwt.go: IssueAccessTokenForTest seam for middleware tests Co-Authored-By: Claude Sonnet 4.6 --- internal/auth/jwt.go | 5 ++ internal/http/errors.go | 18 ++++++ internal/http/handlers_auth.go | 106 +++++++++++++++++++++++++++++++ internal/http/middleware.go | 67 +++++++++++++++++++ internal/http/middleware_test.go | 91 ++++++++++++++++++++++++++ internal/http/server.go | 69 ++++++++++++++++++++ 6 files changed, 356 insertions(+) create mode 100644 internal/http/errors.go create mode 100644 internal/http/handlers_auth.go create mode 100644 internal/http/middleware.go create mode 100644 internal/http/middleware_test.go create mode 100644 internal/http/server.go diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 73a8757..c2ecbea 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -32,6 +32,11 @@ func (s *Service) issueJWT(userID, email, jfUserID string) (string, error) { return tok.SignedString(s.signKey) } +// IssueAccessTokenForTest is a test seam; production code calls issueJWT internally. +func (s *Service) IssueAccessTokenForTest(userID, email, jfUserID string) (string, error) { + return s.issueJWT(userID, email, jfUserID) +} + func (s *Service) VerifyJWT(token string) (*Claims, error) { parsed, err := jwt.ParseWithClaims(token, &Claims{}, func(t *jwt.Token) (any, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { diff --git a/internal/http/errors.go b/internal/http/errors.go new file mode 100644 index 0000000..816c7fb --- /dev/null +++ b/internal/http/errors.go @@ -0,0 +1,18 @@ +package http + +import ( + "encoding/json" + stdhttp "net/http" + + "github.com/Stoganet/api-proxy/internal/gen" +) + +// writeError is used by middleware that has access to r but cannot return +// a typed response object (e.g. jwtStrictMiddleware). Handlers use apiError +// + typed response objects instead. +func writeError(w stdhttp.ResponseWriter, r *stdhttp.Request, status int, code gen.ErrorErrorCode, message string) { + e := apiError(r.Context(), code, message) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(e) +} diff --git a/internal/http/handlers_auth.go b/internal/http/handlers_auth.go new file mode 100644 index 0000000..2a3741a --- /dev/null +++ b/internal/http/handlers_auth.go @@ -0,0 +1,106 @@ +package http + +import ( + "context" + "errors" + + "github.com/Stoganet/api-proxy/internal/auth" + "github.com/Stoganet/api-proxy/internal/gen" +) + +func (s *Server) GetHealthz(ctx context.Context, _ gen.GetHealthzRequestObject) (gen.GetHealthzResponseObject, error) { + return gen.GetHealthz200JSONResponse{Status: gen.Ok}, nil +} + +func (s *Server) PostAuthLogin(ctx context.Context, request gen.PostAuthLoginRequestObject) (gen.PostAuthLoginResponseObject, error) { + body := request.Body + pair, err := s.auth.Login(ctx, body.Username, body.Password, body.DeviceLabel) + switch { + case errors.Is(err, auth.ErrAccountLocked): + return gen.PostAuthLogin423JSONResponse(apiError(ctx, gen.AccountLocked, "too many failed attempts")), nil + case errors.Is(err, auth.ErrInvalidCredentials): + return gen.PostAuthLogin401JSONResponse(apiError(ctx, gen.InvalidCredentials, "username or password incorrect")), nil + case errors.Is(err, auth.ErrJellyfinUnavailable): + return gen.PostAuthLogin503JSONResponse(apiError(ctx, gen.BackendUnavailable, "jellyfin unavailable")), nil + case err != nil: + return nil, err + } + return gen.PostAuthLogin200JSONResponse(toGenTokenPair(pair)), nil +} + +func (s *Server) PostAuthRefresh(ctx context.Context, request gen.PostAuthRefreshRequestObject) (gen.PostAuthRefreshResponseObject, error) { + pair, err := s.auth.Refresh(ctx, request.Body.RefreshToken) + switch { + case errors.Is(err, auth.ErrTokenExpired): + return gen.PostAuthRefresh401JSONResponse(apiError(ctx, gen.TokenExpired, "refresh token expired")), nil + case errors.Is(err, auth.ErrTokenInvalid), errors.Is(err, auth.ErrTokenReused): + return gen.PostAuthRefresh401JSONResponse(apiError(ctx, gen.TokenInvalid, "refresh token invalid")), nil + case err != nil: + return nil, err + } + return gen.PostAuthRefresh200JSONResponse(toGenTokenPair(pair)), nil +} + +func (s *Server) PostAuthLogout(ctx context.Context, request gen.PostAuthLogoutRequestObject) (gen.PostAuthLogoutResponseObject, error) { + err := s.auth.Logout(ctx, request.Body.RefreshToken) + if err != nil && !errors.Is(err, auth.ErrTokenInvalid) { + return nil, err + } + return gen.PostAuthLogout204Response{}, nil +} + +func (s *Server) PostAuthLogoutAll(ctx context.Context, _ gen.PostAuthLogoutAllRequestObject) (gen.PostAuthLogoutAllResponseObject, error) { + uid, _ := ctx.Value(ctxUserID).(string) + if err := s.auth.LogoutAll(ctx, uid); err != nil { + return nil, err + } + return gen.PostAuthLogoutAll204Response{}, nil +} + +func (s *Server) PostAuthQuickConnectStart(ctx context.Context, _ gen.PostAuthQuickConnectStartRequestObject) (gen.PostAuthQuickConnectStartResponseObject, error) { + out, err := s.auth.QuickConnectStart(ctx) + if errors.Is(err, auth.ErrJellyfinUnavailable) { + return gen.PostAuthQuickConnectStart503JSONResponse(apiError(ctx, gen.BackendUnavailable, "jellyfin unavailable")), nil + } + if err != nil { + return nil, err + } + return gen.PostAuthQuickConnectStart200JSONResponse{Code: out.Code, PollToken: out.PollToken}, nil +} + +func (s *Server) PostAuthQuickConnectPoll(ctx context.Context, request gen.PostAuthQuickConnectPollRequestObject) (gen.PostAuthQuickConnectPollResponseObject, error) { + pair, err := s.auth.QuickConnectPoll(ctx, request.Body.PollToken) + switch { + case errors.Is(err, auth.ErrQuickConnectPending): + return gen.PostAuthQuickConnectPoll202Response{}, nil + case errors.Is(err, auth.ErrQuickConnectExpired), errors.Is(err, auth.ErrTokenInvalid): + return gen.PostAuthQuickConnectPoll410JSONResponse(apiError(ctx, gen.TokenExpired, "quick connect expired")), nil + case errors.Is(err, auth.ErrJellyfinUnavailable): + return gen.PostAuthQuickConnectPoll410JSONResponse(apiError(ctx, gen.BackendUnavailable, "jellyfin unavailable")), nil + case err != nil: + return nil, err + } + return gen.PostAuthQuickConnectPoll200JSONResponse(toGenTokenPair(pair)), nil +} + +// apiError builds a gen.Error with the request ID from context. +func apiError(ctx context.Context, code gen.ErrorErrorCode, message string) gen.Error { + var e gen.Error + e.Error.Code = code + e.Error.Message = message + e.RequestId = requestIDFromCtx(ctx) + return e +} + +// toGenTokenPair converts an auth.TokenPair to the generated API type. +func toGenTokenPair(p *auth.TokenPair) gen.TokenPair { + return gen.TokenPair{ + AccessToken: p.AccessToken, + RefreshToken: p.RefreshToken, + User: gen.User{ + Id: p.User.ID, + Email: p.User.Email, + DisplayName: p.User.DisplayName, + }, + } +} diff --git a/internal/http/middleware.go b/internal/http/middleware.go new file mode 100644 index 0000000..5997d1c --- /dev/null +++ b/internal/http/middleware.go @@ -0,0 +1,67 @@ +package http + +import ( + "context" + "log/slog" + stdhttp "net/http" + "time" + + "github.com/google/uuid" +) + +type ctxKey int + +const ( + ctxRequestID ctxKey = iota + ctxLogger + ctxUserID + ctxJFUserID +) + +func RequestID(next stdhttp.Handler) stdhttp.Handler { + return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + id := r.Header.Get("X-Request-Id") + if id == "" { + id = uuid.NewString() + } + w.Header().Set("X-Request-Id", id) + ctx := context.WithValue(r.Context(), ctxRequestID, id) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func Logging(base *slog.Logger) func(stdhttp.Handler) stdhttp.Handler { + return func(next stdhttp.Handler) stdhttp.Handler { + return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + start := time.Now() + rid, _ := r.Context().Value(ctxRequestID).(string) + logger := base.With("request_id", rid) + ctx := context.WithValue(r.Context(), ctxLogger, logger) + + rec := &statusRecorder{ResponseWriter: w, status: 200} + next.ServeHTTP(rec, r.WithContext(ctx)) + + logger.Info("http", + "method", r.Method, + "path", r.URL.Path, + "status", rec.status, + "latency_ms", time.Since(start).Milliseconds(), + ) + }) + } +} + +type statusRecorder struct { + stdhttp.ResponseWriter + status int +} + +func (s *statusRecorder) WriteHeader(code int) { + s.status = code + s.ResponseWriter.WriteHeader(code) +} + +func requestIDFromCtx(ctx context.Context) string { + v, _ := ctx.Value(ctxRequestID).(string) + return v +} diff --git a/internal/http/middleware_test.go b/internal/http/middleware_test.go new file mode 100644 index 0000000..0445bb9 --- /dev/null +++ b/internal/http/middleware_test.go @@ -0,0 +1,91 @@ +package http + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Stoganet/api-proxy/internal/auth" + "github.com/Stoganet/api-proxy/internal/gen" +) + +func TestRequestIDMiddleware_SetsHeader(t *testing.T) { + h := RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + if w.Header().Get("X-Request-Id") == "" { + t.Fatal("missing X-Request-Id") + } +} + +func TestJWTMiddleware_RejectsMissingHeader(t *testing.T) { + svc := newTestAuthSvc(t) + mw := jwtStrictMiddleware(svc) + inner := mw(func(_ context.Context, w http.ResponseWriter, _ *http.Request, _ any) (any, error) { + w.WriteHeader(http.StatusOK) + return nil, nil + }, "postAuthLogout") + + req := httptest.NewRequest(http.MethodPost, "/auth/logout", nil) + w := httptest.NewRecorder() + _, _ = inner(req.Context(), w, req, nil) + if w.Code != http.StatusUnauthorized { + t.Fatalf("got %d, want 401", w.Code) + } +} + +func TestJWTMiddleware_AcceptsValidToken(t *testing.T) { + svc := newTestAuthSvc(t) + tok, err := svc.IssueAccessTokenForTest("user-1", "a@b", "jf-1") + if err != nil { + t.Fatalf("issue: %v", err) + } + mw := jwtStrictMiddleware(svc) + inner := mw(func(_ context.Context, w http.ResponseWriter, _ *http.Request, _ any) (any, error) { + w.WriteHeader(http.StatusOK) + return nil, nil + }, "postAuthLogout") + + req := httptest.NewRequest(http.MethodPost, "/auth/logout", nil) + req.Header.Set("Authorization", "Bearer "+tok) + w := httptest.NewRecorder() + _, _ = inner(req.Context(), w, req, nil) + if w.Code != http.StatusOK { + t.Fatalf("got %d, want 200", w.Code) + } +} + +func TestJWTMiddleware_PassesThroughNonProtectedRoutes(t *testing.T) { + svc := newTestAuthSvc(t) + mw := jwtStrictMiddleware(svc) + called := false + inner := mw(func(_ context.Context, w http.ResponseWriter, _ *http.Request, _ any) (any, error) { + called = true + w.WriteHeader(http.StatusOK) + return nil, nil + }, "postAuthLogin") + + req := httptest.NewRequest(http.MethodPost, "/auth/login", nil) + w := httptest.NewRecorder() + _, _ = inner(req.Context(), w, req, nil) + if !called { + t.Fatal("handler should be called for non-protected route") + } +} + +func newTestAuthSvc(t *testing.T) *auth.Service { + t.Helper() + return auth.NewService(auth.Options{ + SignKey: []byte("01234567890123456789012345678901"), + Clock: func() time.Time { return time.Unix(1_700_000_000, 0) }, + AccessTTL: time.Hour, + }) +} + +// Ensure Server satisfies the generated StrictServerInterface at compile time. +var _ gen.StrictServerInterface = (*Server)(nil) diff --git a/internal/http/server.go b/internal/http/server.go new file mode 100644 index 0000000..7c6d056 --- /dev/null +++ b/internal/http/server.go @@ -0,0 +1,69 @@ +package http + +import ( + "context" + "encoding/json" + "log/slog" + stdhttp "net/http" + "strings" + + "github.com/Stoganet/api-proxy/internal/auth" + "github.com/Stoganet/api-proxy/internal/gen" +) + +type Server struct { + auth *auth.Service + logger *slog.Logger +} + +func NewServer(authSvc *auth.Service, logger *slog.Logger) stdhttp.Handler { + s := &Server{auth: authSvc, logger: logger} + + strict := gen.NewStrictHandlerWithOptions(s, []gen.StrictMiddlewareFunc{ + jwtStrictMiddleware(authSvc), + }, gen.StrictHTTPServerOptions{ + ResponseErrorHandlerFunc: func(w stdhttp.ResponseWriter, r *stdhttp.Request, err error) { + var e gen.Error + e.Error.Code = gen.Internal + e.Error.Message = "internal error" + e.RequestId = requestIDFromCtx(r.Context()) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(stdhttp.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(e) + }, + }) + + return RequestID(Logging(logger)(gen.Handler(strict))) +} + +// jwtStrictMiddleware enforces Bearer JWT auth on logout operations only. +// It runs before the handler and writes a 401 directly if the token is +// missing or invalid, then returns nil to stop the strict handler from +// writing a second response. +func jwtStrictMiddleware(authSvc *auth.Service) gen.StrictMiddlewareFunc { + protected := map[string]bool{ + "postAuthLogout": true, + "postAuthLogoutAll": true, + } + return func(f gen.StrictHandlerFunc, operationID string) gen.StrictHandlerFunc { + if !protected[operationID] { + return f + } + return func(ctx context.Context, w stdhttp.ResponseWriter, r *stdhttp.Request, req any) (any, error) { + h := r.Header.Get("Authorization") + if !strings.HasPrefix(h, "Bearer ") { + writeError(w, r, stdhttp.StatusUnauthorized, gen.TokenExpired, "missing bearer token") + return nil, nil + } + tok := strings.TrimPrefix(h, "Bearer ") + claims, err := authSvc.VerifyJWT(tok) + if err != nil { + writeError(w, r, stdhttp.StatusUnauthorized, gen.TokenExpired, "invalid or expired token") + return nil, nil + } + ctx = context.WithValue(ctx, ctxUserID, claims.UserID) + ctx = context.WithValue(ctx, ctxJFUserID, claims.JFUserID) + return f(ctx, w, r, req) + } + } +} From e352b01f70c8e5cf97ebc8373fba4029fa902072 Mon Sep 17 00:00:00 2001 From: Aaro Koinsaari <89689072+koinsaari@users.noreply.github.com> Date: Sat, 30 May 2026 10:56:43 +0300 Subject: [PATCH 2/3] fix: HTTP layer corrections and handler tests - gen.TokenInvalid for missing bearer header (was gen.TokenExpired) - handler tests for all routes via fakeAuth stub - authService interface on Server for testability without a real DB - guard PostAuthLogoutAll against missing ctxUserID - log errors in ResponseErrorHandlerFunc instead of discarding them Co-Authored-By: Claude Sonnet 4.6 --- internal/http/handlers_auth.go | 8 +- internal/http/handlers_auth_test.go | 218 ++++++++++++++++++++++++++++ internal/http/middleware_test.go | 2 +- internal/http/server.go | 24 ++- 4 files changed, 242 insertions(+), 10 deletions(-) create mode 100644 internal/http/handlers_auth_test.go diff --git a/internal/http/handlers_auth.go b/internal/http/handlers_auth.go index 2a3741a..b4d1ed0 100644 --- a/internal/http/handlers_auth.go +++ b/internal/http/handlers_auth.go @@ -3,12 +3,13 @@ package http import ( "context" "errors" + "fmt" "github.com/Stoganet/api-proxy/internal/auth" "github.com/Stoganet/api-proxy/internal/gen" ) -func (s *Server) GetHealthz(ctx context.Context, _ gen.GetHealthzRequestObject) (gen.GetHealthzResponseObject, error) { +func (s *Server) GetHealthz(_ context.Context, _ gen.GetHealthzRequestObject) (gen.GetHealthzResponseObject, error) { return gen.GetHealthz200JSONResponse{Status: gen.Ok}, nil } @@ -50,7 +51,10 @@ func (s *Server) PostAuthLogout(ctx context.Context, request gen.PostAuthLogoutR } func (s *Server) PostAuthLogoutAll(ctx context.Context, _ gen.PostAuthLogoutAllRequestObject) (gen.PostAuthLogoutAllResponseObject, error) { - uid, _ := ctx.Value(ctxUserID).(string) + uid, ok := ctx.Value(ctxUserID).(string) + if !ok || uid == "" { + return nil, fmt.Errorf("ctxUserID missing — JWT middleware must run before this handler") + } if err := s.auth.LogoutAll(ctx, uid); err != nil { return nil, err } diff --git a/internal/http/handlers_auth_test.go b/internal/http/handlers_auth_test.go new file mode 100644 index 0000000..bdd203e --- /dev/null +++ b/internal/http/handlers_auth_test.go @@ -0,0 +1,218 @@ +package http + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Stoganet/api-proxy/internal/auth" + "github.com/Stoganet/api-proxy/internal/gen" +) + +// fakeAuth is a configurable stub for authService used in handler tests. +type fakeAuth struct { + loginOut *auth.TokenPair + loginErr error + refreshOut *auth.TokenPair + refreshErr error + logoutErr error + logoutAllErr error + qcStartOut *auth.QuickConnectStartOut + qcStartErr error + qcPollOut *auth.TokenPair + qcPollErr error +} + +func (f *fakeAuth) Login(_ context.Context, _, _ string, _ *string) (*auth.TokenPair, error) { + return f.loginOut, f.loginErr +} +func (f *fakeAuth) Refresh(_ context.Context, _ string) (*auth.TokenPair, error) { + return f.refreshOut, f.refreshErr +} +func (f *fakeAuth) Logout(_ context.Context, _ string) error { return f.logoutErr } +func (f *fakeAuth) LogoutAll(_ context.Context, _ string) error { return f.logoutAllErr } +func (f *fakeAuth) QuickConnectStart(_ context.Context) (*auth.QuickConnectStartOut, error) { + return f.qcStartOut, f.qcStartErr +} +func (f *fakeAuth) QuickConnectPoll(_ context.Context, _ string) (*auth.TokenPair, error) { + return f.qcPollOut, f.qcPollErr +} +func (f *fakeAuth) VerifyJWT(_ string) (*auth.Claims, error) { return nil, nil } + +var testTokenPair = &auth.TokenPair{ + AccessToken: "access", + RefreshToken: "refresh", + User: auth.User{ID: "u1", Email: "a@b.com", DisplayName: "A"}, +} + +func newTestServer(t *testing.T, fa *fakeAuth) http.Handler { + t.Helper() + s := &Server{auth: fa} + strict := gen.NewStrictHandlerWithOptions(s, nil, gen.StrictHTTPServerOptions{}) + return gen.Handler(strict) +} + +func do(t *testing.T, h http.Handler, method, path, body string) *httptest.ResponseRecorder { + t.Helper() + var buf *bytes.Buffer + if body != "" { + buf = bytes.NewBufferString(body) + } else { + buf = bytes.NewBuffer(nil) + } + req := httptest.NewRequest(method, path, buf) + if body != "" { + req.Header.Set("Content-Type", "application/json") + } + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + return w +} + +func decodeError(t *testing.T, w *httptest.ResponseRecorder) gen.Error { + t.Helper() + var e gen.Error + if err := json.NewDecoder(w.Body).Decode(&e); err != nil { + t.Fatalf("decode error body: %v", err) + } + return e +} + +func TestGetHealthz(t *testing.T) { + h := newTestServer(t, &fakeAuth{}) + w := do(t, h, http.MethodGet, "/healthz", "") + if w.Code != http.StatusOK { + t.Fatalf("got %d, want 200", w.Code) + } +} + +func TestPostAuthLogin(t *testing.T) { + body := `{"username":"u","password":"p"}` + tests := []struct { + name string + fa *fakeAuth + wantCode int + wantErr gen.ErrorErrorCode + }{ + {"ok", &fakeAuth{loginOut: testTokenPair}, http.StatusOK, ""}, + {"invalid creds", &fakeAuth{loginErr: auth.ErrInvalidCredentials}, http.StatusUnauthorized, gen.InvalidCredentials}, + {"account locked", &fakeAuth{loginErr: auth.ErrAccountLocked}, http.StatusLocked, gen.AccountLocked}, + {"backend unavailable", &fakeAuth{loginErr: auth.ErrJellyfinUnavailable}, http.StatusServiceUnavailable, gen.BackendUnavailable}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + w := do(t, newTestServer(t, tc.fa), http.MethodPost, "/auth/login", body) + if w.Code != tc.wantCode { + t.Fatalf("got %d, want %d", w.Code, tc.wantCode) + } + if tc.wantErr != "" { + e := decodeError(t, w) + if e.Error.Code != tc.wantErr { + t.Fatalf("got error code %q, want %q", e.Error.Code, tc.wantErr) + } + } + }) + } +} + +func TestPostAuthRefresh(t *testing.T) { + body := `{"refresh_token":"tok"}` + tests := []struct { + name string + fa *fakeAuth + wantCode int + wantErr gen.ErrorErrorCode + }{ + {"ok", &fakeAuth{refreshOut: testTokenPair}, http.StatusOK, ""}, + {"expired", &fakeAuth{refreshErr: auth.ErrTokenExpired}, http.StatusUnauthorized, gen.TokenExpired}, + {"invalid", &fakeAuth{refreshErr: auth.ErrTokenInvalid}, http.StatusUnauthorized, gen.TokenInvalid}, + {"reused", &fakeAuth{refreshErr: auth.ErrTokenReused}, http.StatusUnauthorized, gen.TokenInvalid}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + w := do(t, newTestServer(t, tc.fa), http.MethodPost, "/auth/refresh", body) + if w.Code != tc.wantCode { + t.Fatalf("got %d, want %d", w.Code, tc.wantCode) + } + if tc.wantErr != "" { + e := decodeError(t, w) + if e.Error.Code != tc.wantErr { + t.Fatalf("got error code %q, want %q", e.Error.Code, tc.wantErr) + } + } + }) + } +} + +func TestPostAuthLogout(t *testing.T) { + body := `{"refresh_token":"tok"}` + w := do(t, newTestServer(t, &fakeAuth{}), http.MethodPost, "/auth/logout", body) + if w.Code != http.StatusNoContent { + t.Fatalf("got %d, want 204", w.Code) + } +} + +func TestPostAuthLogoutAll_MissingUID(t *testing.T) { + // Without JWT middleware the ctxUserID key is absent and handler must return an error. + s := &Server{auth: &fakeAuth{}} + _, err := s.PostAuthLogoutAll(context.Background(), gen.PostAuthLogoutAllRequestObject{}) + if err == nil || !strings.Contains(err.Error(), "ctxUserID missing") { + t.Fatalf("expected ctxUserID error, got %v", err) + } +} + +func TestPostAuthLogoutAll_WithUID(t *testing.T) { + s := &Server{auth: &fakeAuth{}} + ctx := context.WithValue(context.Background(), ctxUserID, "user-1") + resp, err := s.PostAuthLogoutAll(ctx, gen.PostAuthLogoutAllRequestObject{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := resp.(gen.PostAuthLogoutAll204Response); !ok { + t.Fatalf("got %T, want PostAuthLogoutAll204Response", resp) + } +} + +func TestPostAuthQuickConnectStart(t *testing.T) { + tests := []struct { + name string + fa *fakeAuth + wantCode int + }{ + {"ok", &fakeAuth{qcStartOut: &auth.QuickConnectStartOut{Code: "ABC", PollToken: "tok"}}, http.StatusOK}, + {"backend unavailable", &fakeAuth{qcStartErr: auth.ErrJellyfinUnavailable}, http.StatusServiceUnavailable}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + w := do(t, newTestServer(t, tc.fa), http.MethodPost, "/auth/quick-connect/start", "") + if w.Code != tc.wantCode { + t.Fatalf("got %d, want %d", w.Code, tc.wantCode) + } + }) + } +} + +func TestPostAuthQuickConnectPoll(t *testing.T) { + body := `{"poll_token":"tok"}` + tests := []struct { + name string + fa *fakeAuth + wantCode int + }{ + {"ok", &fakeAuth{qcPollOut: testTokenPair}, http.StatusOK}, + {"pending", &fakeAuth{qcPollErr: auth.ErrQuickConnectPending}, http.StatusAccepted}, + {"expired", &fakeAuth{qcPollErr: auth.ErrQuickConnectExpired}, http.StatusGone}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + w := do(t, newTestServer(t, tc.fa), http.MethodPost, "/auth/quick-connect/poll", body) + if w.Code != tc.wantCode { + t.Fatalf("got %d, want %d", w.Code, tc.wantCode) + } + }) + } +} diff --git a/internal/http/middleware_test.go b/internal/http/middleware_test.go index 0445bb9..6d48645 100644 --- a/internal/http/middleware_test.go +++ b/internal/http/middleware_test.go @@ -12,7 +12,7 @@ import ( ) func TestRequestIDMiddleware_SetsHeader(t *testing.T) { - h := RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := RequestID(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) diff --git a/internal/http/server.go b/internal/http/server.go index 7c6d056..3ed1615 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -11,8 +11,18 @@ import ( "github.com/Stoganet/api-proxy/internal/gen" ) +type authService interface { + Login(ctx context.Context, username, password string, deviceLabel *string) (*auth.TokenPair, error) + Refresh(ctx context.Context, plaintext string) (*auth.TokenPair, error) + Logout(ctx context.Context, plaintext string) error + LogoutAll(ctx context.Context, userID string) error + QuickConnectStart(ctx context.Context) (*auth.QuickConnectStartOut, error) + QuickConnectPoll(ctx context.Context, pollToken string) (*auth.TokenPair, error) + VerifyJWT(token string) (*auth.Claims, error) +} + type Server struct { - auth *auth.Service + auth authService logger *slog.Logger } @@ -23,6 +33,7 @@ func NewServer(authSvc *auth.Service, logger *slog.Logger) stdhttp.Handler { jwtStrictMiddleware(authSvc), }, gen.StrictHTTPServerOptions{ ResponseErrorHandlerFunc: func(w stdhttp.ResponseWriter, r *stdhttp.Request, err error) { + s.logger.ErrorContext(r.Context(), "handler error", "err", err, "request_id", requestIDFromCtx(r.Context())) var e gen.Error e.Error.Code = gen.Internal e.Error.Message = "internal error" @@ -37,10 +48,9 @@ func NewServer(authSvc *auth.Service, logger *slog.Logger) stdhttp.Handler { } // jwtStrictMiddleware enforces Bearer JWT auth on logout operations only. -// It runs before the handler and writes a 401 directly if the token is -// missing or invalid, then returns nil to stop the strict handler from +// Writes 401 directly and returns nil to prevent the strict handler from // writing a second response. -func jwtStrictMiddleware(authSvc *auth.Service) gen.StrictMiddlewareFunc { +func jwtStrictMiddleware(svc authService) gen.StrictMiddlewareFunc { protected := map[string]bool{ "postAuthLogout": true, "postAuthLogoutAll": true, @@ -52,14 +62,14 @@ func jwtStrictMiddleware(authSvc *auth.Service) gen.StrictMiddlewareFunc { return func(ctx context.Context, w stdhttp.ResponseWriter, r *stdhttp.Request, req any) (any, error) { h := r.Header.Get("Authorization") if !strings.HasPrefix(h, "Bearer ") { - writeError(w, r, stdhttp.StatusUnauthorized, gen.TokenExpired, "missing bearer token") + writeError(w, r, stdhttp.StatusUnauthorized, gen.TokenInvalid, "missing bearer token") return nil, nil } tok := strings.TrimPrefix(h, "Bearer ") - claims, err := authSvc.VerifyJWT(tok) + claims, err := svc.VerifyJWT(tok) if err != nil { writeError(w, r, stdhttp.StatusUnauthorized, gen.TokenExpired, "invalid or expired token") - return nil, nil + return nil, nil //nolint:nilerr // error handled by writing 401 directly } ctx = context.WithValue(ctx, ctxUserID, claims.UserID) ctx = context.WithValue(ctx, ctxJFUserID, claims.JFUserID) From 7cc820b0813f95fb89d85069a78280008dc0f1dc Mon Sep 17 00:00:00 2001 From: Aaro Koinsaari <89689072+koinsaari@users.noreply.github.com> Date: Sat, 30 May 2026 11:06:38 +0300 Subject: [PATCH 3/3] fix: distinguish expired vs malformed JWT in middleware Use errors.Is(err, jwt.ErrTokenExpired) so expired tokens get TokenExpired and malformed/bad-signature tokens get TokenInvalid. Add tests for both cases. Co-Authored-By: Claude Sonnet 4.6 --- internal/http/middleware_test.go | 61 ++++++++++++++++++++++++++++++++ internal/http/server.go | 8 ++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/internal/http/middleware_test.go b/internal/http/middleware_test.go index 6d48645..734432f 100644 --- a/internal/http/middleware_test.go +++ b/internal/http/middleware_test.go @@ -2,6 +2,7 @@ package http import ( "context" + "encoding/json" "net/http" "net/http/httptest" "testing" @@ -60,6 +61,66 @@ func TestJWTMiddleware_AcceptsValidToken(t *testing.T) { } } +func TestJWTMiddleware_RejectsMalformedToken(t *testing.T) { + svc := newTestAuthSvc(t) + mw := jwtStrictMiddleware(svc) + inner := mw(func(_ context.Context, w http.ResponseWriter, _ *http.Request, _ any) (any, error) { + w.WriteHeader(http.StatusOK) + return nil, nil + }, "postAuthLogout") + + req := httptest.NewRequest(http.MethodPost, "/auth/logout", nil) + req.Header.Set("Authorization", "Bearer not.a.jwt") + w := httptest.NewRecorder() + _, _ = inner(req.Context(), w, req, nil) + if w.Code != http.StatusUnauthorized { + t.Fatalf("got %d, want 401", w.Code) + } + var e gen.Error + if err := json.NewDecoder(w.Body).Decode(&e); err != nil { + t.Fatalf("decode: %v", err) + } + if e.Error.Code != gen.TokenInvalid { + t.Fatalf("got error code %q, want %q", e.Error.Code, gen.TokenInvalid) + } +} + +func TestJWTMiddleware_RejectsExpiredToken(t *testing.T) { + // Clock at issue time is in the past relative to the verification clock. + pastSvc := auth.NewService(auth.Options{ + SignKey: []byte("01234567890123456789012345678901"), + Clock: func() time.Time { return time.Unix(1_000_000_000, 0) }, + AccessTTL: time.Second, + }) + tok, err := pastSvc.IssueAccessTokenForTest("user-1", "a@b", "jf-1") + if err != nil { + t.Fatalf("issue: %v", err) + } + + // Verify with a service whose clock is well past the token's expiry. + verifySvc := newTestAuthSvc(t) // clock at 1_700_000_000, far beyond 1_000_000_001 + mw := jwtStrictMiddleware(verifySvc) + inner := mw(func(_ context.Context, w http.ResponseWriter, _ *http.Request, _ any) (any, error) { + w.WriteHeader(http.StatusOK) + return nil, nil + }, "postAuthLogout") + + req := httptest.NewRequest(http.MethodPost, "/auth/logout", nil) + req.Header.Set("Authorization", "Bearer "+tok) + w := httptest.NewRecorder() + _, _ = inner(req.Context(), w, req, nil) + if w.Code != http.StatusUnauthorized { + t.Fatalf("got %d, want 401", w.Code) + } + var e gen.Error + if err := json.NewDecoder(w.Body).Decode(&e); err != nil { + t.Fatalf("decode: %v", err) + } + if e.Error.Code != gen.TokenExpired { + t.Fatalf("got error code %q, want %q", e.Error.Code, gen.TokenExpired) + } +} + func TestJWTMiddleware_PassesThroughNonProtectedRoutes(t *testing.T) { svc := newTestAuthSvc(t) mw := jwtStrictMiddleware(svc) diff --git a/internal/http/server.go b/internal/http/server.go index 3ed1615..b9fc7e7 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -3,12 +3,14 @@ package http import ( "context" "encoding/json" + "errors" "log/slog" stdhttp "net/http" "strings" "github.com/Stoganet/api-proxy/internal/auth" "github.com/Stoganet/api-proxy/internal/gen" + "github.com/golang-jwt/jwt/v5" ) type authService interface { @@ -68,7 +70,11 @@ func jwtStrictMiddleware(svc authService) gen.StrictMiddlewareFunc { tok := strings.TrimPrefix(h, "Bearer ") claims, err := svc.VerifyJWT(tok) if err != nil { - writeError(w, r, stdhttp.StatusUnauthorized, gen.TokenExpired, "invalid or expired token") + code := gen.TokenInvalid + if errors.Is(err, jwt.ErrTokenExpired) { + code = gen.TokenExpired + } + writeError(w, r, stdhttp.StatusUnauthorized, code, "invalid or expired token") return nil, nil //nolint:nilerr // error handled by writing 401 directly } ctx = context.WithValue(ctx, ctxUserID, claims.UserID)