Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ func (s *Service) issueJWT(userID, email, jfUserID string) (string, error) {
return tok.SignedString(s.signKey)
}

// IssueAccessTokenForTest is a test seam; production code calls issueJWT internally.
func (s *Service) IssueAccessTokenForTest(userID, email, jfUserID string) (string, error) {
return s.issueJWT(userID, email, jfUserID)
}

func (s *Service) VerifyJWT(token string) (*Claims, error) {
parsed, err := jwt.ParseWithClaims(token, &Claims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
Expand Down
18 changes: 18 additions & 0 deletions internal/http/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package http

import (
"encoding/json"
stdhttp "net/http"

"github.com/Stoganet/api-proxy/internal/gen"
)

// writeError is used by middleware that has access to r but cannot return
// a typed response object (e.g. jwtStrictMiddleware). Handlers use apiError
// + typed response objects instead.
func writeError(w stdhttp.ResponseWriter, r *stdhttp.Request, status int, code gen.ErrorErrorCode, message string) {
e := apiError(r.Context(), code, message)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(e)
}
110 changes: 110 additions & 0 deletions internal/http/handlers_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package http

import (
"context"
"errors"
"fmt"

"github.com/Stoganet/api-proxy/internal/auth"
"github.com/Stoganet/api-proxy/internal/gen"
)

func (s *Server) GetHealthz(_ context.Context, _ gen.GetHealthzRequestObject) (gen.GetHealthzResponseObject, error) {
return gen.GetHealthz200JSONResponse{Status: gen.Ok}, nil
}

func (s *Server) PostAuthLogin(ctx context.Context, request gen.PostAuthLoginRequestObject) (gen.PostAuthLoginResponseObject, error) {
body := request.Body
pair, err := s.auth.Login(ctx, body.Username, body.Password, body.DeviceLabel)
switch {
case errors.Is(err, auth.ErrAccountLocked):
return gen.PostAuthLogin423JSONResponse(apiError(ctx, gen.AccountLocked, "too many failed attempts")), nil
case errors.Is(err, auth.ErrInvalidCredentials):
return gen.PostAuthLogin401JSONResponse(apiError(ctx, gen.InvalidCredentials, "username or password incorrect")), nil
case errors.Is(err, auth.ErrJellyfinUnavailable):
return gen.PostAuthLogin503JSONResponse(apiError(ctx, gen.BackendUnavailable, "jellyfin unavailable")), nil
case err != nil:
return nil, err
}
return gen.PostAuthLogin200JSONResponse(toGenTokenPair(pair)), nil
}

func (s *Server) PostAuthRefresh(ctx context.Context, request gen.PostAuthRefreshRequestObject) (gen.PostAuthRefreshResponseObject, error) {
pair, err := s.auth.Refresh(ctx, request.Body.RefreshToken)
switch {
case errors.Is(err, auth.ErrTokenExpired):
return gen.PostAuthRefresh401JSONResponse(apiError(ctx, gen.TokenExpired, "refresh token expired")), nil
case errors.Is(err, auth.ErrTokenInvalid), errors.Is(err, auth.ErrTokenReused):
return gen.PostAuthRefresh401JSONResponse(apiError(ctx, gen.TokenInvalid, "refresh token invalid")), nil
case err != nil:
return nil, err
}
return gen.PostAuthRefresh200JSONResponse(toGenTokenPair(pair)), nil
}

func (s *Server) PostAuthLogout(ctx context.Context, request gen.PostAuthLogoutRequestObject) (gen.PostAuthLogoutResponseObject, error) {
err := s.auth.Logout(ctx, request.Body.RefreshToken)
if err != nil && !errors.Is(err, auth.ErrTokenInvalid) {
return nil, err
}
return gen.PostAuthLogout204Response{}, nil
}

func (s *Server) PostAuthLogoutAll(ctx context.Context, _ gen.PostAuthLogoutAllRequestObject) (gen.PostAuthLogoutAllResponseObject, error) {
uid, ok := ctx.Value(ctxUserID).(string)
if !ok || uid == "" {
return nil, fmt.Errorf("ctxUserID missing — JWT middleware must run before this handler")
}
if err := s.auth.LogoutAll(ctx, uid); err != nil {
return nil, err
}
return gen.PostAuthLogoutAll204Response{}, nil
}

func (s *Server) PostAuthQuickConnectStart(ctx context.Context, _ gen.PostAuthQuickConnectStartRequestObject) (gen.PostAuthQuickConnectStartResponseObject, error) {
out, err := s.auth.QuickConnectStart(ctx)
if errors.Is(err, auth.ErrJellyfinUnavailable) {
return gen.PostAuthQuickConnectStart503JSONResponse(apiError(ctx, gen.BackendUnavailable, "jellyfin unavailable")), nil
}
if err != nil {
return nil, err
}
return gen.PostAuthQuickConnectStart200JSONResponse{Code: out.Code, PollToken: out.PollToken}, nil
}

func (s *Server) PostAuthQuickConnectPoll(ctx context.Context, request gen.PostAuthQuickConnectPollRequestObject) (gen.PostAuthQuickConnectPollResponseObject, error) {
pair, err := s.auth.QuickConnectPoll(ctx, request.Body.PollToken)
switch {
case errors.Is(err, auth.ErrQuickConnectPending):
return gen.PostAuthQuickConnectPoll202Response{}, nil
case errors.Is(err, auth.ErrQuickConnectExpired), errors.Is(err, auth.ErrTokenInvalid):
return gen.PostAuthQuickConnectPoll410JSONResponse(apiError(ctx, gen.TokenExpired, "quick connect expired")), nil
case errors.Is(err, auth.ErrJellyfinUnavailable):
return gen.PostAuthQuickConnectPoll410JSONResponse(apiError(ctx, gen.BackendUnavailable, "jellyfin unavailable")), nil
case err != nil:
return nil, err
}
return gen.PostAuthQuickConnectPoll200JSONResponse(toGenTokenPair(pair)), nil
}

// apiError builds a gen.Error with the request ID from context.
func apiError(ctx context.Context, code gen.ErrorErrorCode, message string) gen.Error {
var e gen.Error
e.Error.Code = code
e.Error.Message = message
e.RequestId = requestIDFromCtx(ctx)
return e
}

// toGenTokenPair converts an auth.TokenPair to the generated API type.
func toGenTokenPair(p *auth.TokenPair) gen.TokenPair {
return gen.TokenPair{
AccessToken: p.AccessToken,
RefreshToken: p.RefreshToken,
User: gen.User{
Id: p.User.ID,
Email: p.User.Email,
DisplayName: p.User.DisplayName,
},
}
}
218 changes: 218 additions & 0 deletions internal/http/handlers_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package http

import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/Stoganet/api-proxy/internal/auth"
"github.com/Stoganet/api-proxy/internal/gen"
)

// fakeAuth is a configurable stub for authService used in handler tests.
type fakeAuth struct {
loginOut *auth.TokenPair
loginErr error
refreshOut *auth.TokenPair
refreshErr error
logoutErr error
logoutAllErr error
qcStartOut *auth.QuickConnectStartOut
qcStartErr error
qcPollOut *auth.TokenPair
qcPollErr error
}

func (f *fakeAuth) Login(_ context.Context, _, _ string, _ *string) (*auth.TokenPair, error) {
return f.loginOut, f.loginErr
}
func (f *fakeAuth) Refresh(_ context.Context, _ string) (*auth.TokenPair, error) {
return f.refreshOut, f.refreshErr
}
func (f *fakeAuth) Logout(_ context.Context, _ string) error { return f.logoutErr }
func (f *fakeAuth) LogoutAll(_ context.Context, _ string) error { return f.logoutAllErr }
func (f *fakeAuth) QuickConnectStart(_ context.Context) (*auth.QuickConnectStartOut, error) {
return f.qcStartOut, f.qcStartErr
}
func (f *fakeAuth) QuickConnectPoll(_ context.Context, _ string) (*auth.TokenPair, error) {
return f.qcPollOut, f.qcPollErr
}
func (f *fakeAuth) VerifyJWT(_ string) (*auth.Claims, error) { return nil, nil }

var testTokenPair = &auth.TokenPair{
AccessToken: "access",
RefreshToken: "refresh",
User: auth.User{ID: "u1", Email: "a@b.com", DisplayName: "A"},
}

func newTestServer(t *testing.T, fa *fakeAuth) http.Handler {
t.Helper()
s := &Server{auth: fa}
strict := gen.NewStrictHandlerWithOptions(s, nil, gen.StrictHTTPServerOptions{})
return gen.Handler(strict)
}

func do(t *testing.T, h http.Handler, method, path, body string) *httptest.ResponseRecorder {
t.Helper()
var buf *bytes.Buffer
if body != "" {
buf = bytes.NewBufferString(body)
} else {
buf = bytes.NewBuffer(nil)
}
req := httptest.NewRequest(method, path, buf)
if body != "" {
req.Header.Set("Content-Type", "application/json")
}
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
return w
}

func decodeError(t *testing.T, w *httptest.ResponseRecorder) gen.Error {
t.Helper()
var e gen.Error
if err := json.NewDecoder(w.Body).Decode(&e); err != nil {
t.Fatalf("decode error body: %v", err)
}
return e
}

func TestGetHealthz(t *testing.T) {
h := newTestServer(t, &fakeAuth{})
w := do(t, h, http.MethodGet, "/healthz", "")
if w.Code != http.StatusOK {
t.Fatalf("got %d, want 200", w.Code)
}
}

func TestPostAuthLogin(t *testing.T) {
body := `{"username":"u","password":"p"}`
tests := []struct {
name string
fa *fakeAuth
wantCode int
wantErr gen.ErrorErrorCode
}{
{"ok", &fakeAuth{loginOut: testTokenPair}, http.StatusOK, ""},
{"invalid creds", &fakeAuth{loginErr: auth.ErrInvalidCredentials}, http.StatusUnauthorized, gen.InvalidCredentials},
{"account locked", &fakeAuth{loginErr: auth.ErrAccountLocked}, http.StatusLocked, gen.AccountLocked},
{"backend unavailable", &fakeAuth{loginErr: auth.ErrJellyfinUnavailable}, http.StatusServiceUnavailable, gen.BackendUnavailable},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
w := do(t, newTestServer(t, tc.fa), http.MethodPost, "/auth/login", body)
if w.Code != tc.wantCode {
t.Fatalf("got %d, want %d", w.Code, tc.wantCode)
}
if tc.wantErr != "" {
e := decodeError(t, w)
if e.Error.Code != tc.wantErr {
t.Fatalf("got error code %q, want %q", e.Error.Code, tc.wantErr)
}
}
})
}
}

func TestPostAuthRefresh(t *testing.T) {
body := `{"refresh_token":"tok"}`
tests := []struct {
name string
fa *fakeAuth
wantCode int
wantErr gen.ErrorErrorCode
}{
{"ok", &fakeAuth{refreshOut: testTokenPair}, http.StatusOK, ""},
{"expired", &fakeAuth{refreshErr: auth.ErrTokenExpired}, http.StatusUnauthorized, gen.TokenExpired},
{"invalid", &fakeAuth{refreshErr: auth.ErrTokenInvalid}, http.StatusUnauthorized, gen.TokenInvalid},
{"reused", &fakeAuth{refreshErr: auth.ErrTokenReused}, http.StatusUnauthorized, gen.TokenInvalid},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
w := do(t, newTestServer(t, tc.fa), http.MethodPost, "/auth/refresh", body)
if w.Code != tc.wantCode {
t.Fatalf("got %d, want %d", w.Code, tc.wantCode)
}
if tc.wantErr != "" {
e := decodeError(t, w)
if e.Error.Code != tc.wantErr {
t.Fatalf("got error code %q, want %q", e.Error.Code, tc.wantErr)
}
}
})
}
}

func TestPostAuthLogout(t *testing.T) {
body := `{"refresh_token":"tok"}`
w := do(t, newTestServer(t, &fakeAuth{}), http.MethodPost, "/auth/logout", body)
if w.Code != http.StatusNoContent {
t.Fatalf("got %d, want 204", w.Code)
}
}

func TestPostAuthLogoutAll_MissingUID(t *testing.T) {
// Without JWT middleware the ctxUserID key is absent and handler must return an error.
s := &Server{auth: &fakeAuth{}}
_, err := s.PostAuthLogoutAll(context.Background(), gen.PostAuthLogoutAllRequestObject{})
if err == nil || !strings.Contains(err.Error(), "ctxUserID missing") {
t.Fatalf("expected ctxUserID error, got %v", err)
}
}

func TestPostAuthLogoutAll_WithUID(t *testing.T) {
s := &Server{auth: &fakeAuth{}}
ctx := context.WithValue(context.Background(), ctxUserID, "user-1")
resp, err := s.PostAuthLogoutAll(ctx, gen.PostAuthLogoutAllRequestObject{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, ok := resp.(gen.PostAuthLogoutAll204Response); !ok {
t.Fatalf("got %T, want PostAuthLogoutAll204Response", resp)
}
}

func TestPostAuthQuickConnectStart(t *testing.T) {
tests := []struct {
name string
fa *fakeAuth
wantCode int
}{
{"ok", &fakeAuth{qcStartOut: &auth.QuickConnectStartOut{Code: "ABC", PollToken: "tok"}}, http.StatusOK},
{"backend unavailable", &fakeAuth{qcStartErr: auth.ErrJellyfinUnavailable}, http.StatusServiceUnavailable},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
w := do(t, newTestServer(t, tc.fa), http.MethodPost, "/auth/quick-connect/start", "")
if w.Code != tc.wantCode {
t.Fatalf("got %d, want %d", w.Code, tc.wantCode)
}
})
}
}

func TestPostAuthQuickConnectPoll(t *testing.T) {
body := `{"poll_token":"tok"}`
tests := []struct {
name string
fa *fakeAuth
wantCode int
}{
{"ok", &fakeAuth{qcPollOut: testTokenPair}, http.StatusOK},
{"pending", &fakeAuth{qcPollErr: auth.ErrQuickConnectPending}, http.StatusAccepted},
{"expired", &fakeAuth{qcPollErr: auth.ErrQuickConnectExpired}, http.StatusGone},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
w := do(t, newTestServer(t, tc.fa), http.MethodPost, "/auth/quick-connect/poll", body)
if w.Code != tc.wantCode {
t.Fatalf("got %d, want %d", w.Code, tc.wantCode)
}
})
}
}
Loading