|
| 1 | +package middleware |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + "net/http" |
| 7 | + "regexp" |
| 8 | + "strings" |
| 9 | + |
| 10 | + httpheaders "github.com/github/github-mcp-server/pkg/http/headers" |
| 11 | +) |
| 12 | + |
| 13 | +type authType int |
| 14 | + |
| 15 | +const ( |
| 16 | + authTypeUnknown authType = iota |
| 17 | + authTypeIDE |
| 18 | + authTypeGhToken |
| 19 | +) |
| 20 | + |
| 21 | +var ( |
| 22 | + errMissingAuthorizationHeader = errors.New("missing Authorization header") |
| 23 | + errBadAuthorizationHeader = errors.New("bad Authorization header format") |
| 24 | + errUnsupportedAuthorizationHeader = errors.New("unsupported Authorization header format") |
| 25 | +) |
| 26 | + |
| 27 | +var supportedThirdPartyTokenPrefixes = []string{ |
| 28 | + "ghp_", // Personal access token (classic) |
| 29 | + "github_pat_", // Fine-grained personal access token |
| 30 | + "gho_", // OAuth access token |
| 31 | + "ghu_", // User access token for a GitHub App |
| 32 | + "ghs_", // Installation access token for a GitHub App (a.k.a. server-to-server token) |
| 33 | +} |
| 34 | + |
| 35 | +// oldPatternRegexp is the regular expression for the old pattern of the token. |
| 36 | +// Until 2021, GitHub API tokens did not have an identifiable prefix. They |
| 37 | +// were 40 characters long and only contained the characters a-f and 0-9. |
| 38 | +var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) |
| 39 | + |
| 40 | +type tokenCtxKey string |
| 41 | + |
| 42 | +var tokenContextKey tokenCtxKey = "tokenctx" |
| 43 | + |
| 44 | +type TokenData struct { |
| 45 | + Token string |
| 46 | +} |
| 47 | + |
| 48 | +// AddToken adds the given token data to the context. |
| 49 | +func AddToken(ctx context.Context, data *TokenData) context.Context { |
| 50 | + return context.WithValue(ctx, tokenContextKey, data) |
| 51 | +} |
| 52 | + |
| 53 | +// ReqData returns the request data from the context. It will panic if there is |
| 54 | +// no data in the context (which should never happen in production). |
| 55 | +func Token(ctx context.Context) *TokenData { |
| 56 | + d, ok := ctx.Value(tokenContextKey).(*TokenData) |
| 57 | + if !ok || d == nil { |
| 58 | + // This should never happen in production, so making it a panic saves us a lot of unnecessary error handling. |
| 59 | + panic(errors.New("context does not contain request context token data")) |
| 60 | + } |
| 61 | + return d |
| 62 | +} |
| 63 | + |
| 64 | +// ExtractUserToken is a middleware that extracts the user token from the request |
| 65 | +// and adds it to the request context. It also validates the token format. |
| 66 | +func ExtractUserToken() func(next http.Handler) http.Handler { |
| 67 | + return func(next http.Handler) http.Handler { |
| 68 | + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 69 | + _, token, err := parseAuthorizationHeader(r) |
| 70 | + if err != nil { |
| 71 | + // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec |
| 72 | + if errors.Is(err, errMissingAuthorizationHeader) { |
| 73 | + // sendAuthChallenge(w, r, cfg, obsv) |
| 74 | + return |
| 75 | + } |
| 76 | + // For other auth errors (bad format, unsupported), return 400 |
| 77 | + http.Error(w, err.Error(), http.StatusBadRequest) |
| 78 | + return |
| 79 | + } |
| 80 | + |
| 81 | + // Add token info to context |
| 82 | + ctx := r.Context() |
| 83 | + ctx = AddToken(ctx, &TokenData{Token: token}) |
| 84 | + |
| 85 | + next.ServeHTTP(w, r.WithContext(ctx)) |
| 86 | + }) |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | +func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) { |
| 91 | + authHeader := req.Header.Get(httpheaders.AuthorizationHeader) |
| 92 | + if authHeader == "" { |
| 93 | + return 0, "", errMissingAuthorizationHeader |
| 94 | + } |
| 95 | + |
| 96 | + switch { |
| 97 | + // decrypt dotcom token and set it as token |
| 98 | + case strings.HasPrefix(authHeader, "GitHub-Bearer "): |
| 99 | + return 0, "", errUnsupportedAuthorizationHeader |
| 100 | + default: |
| 101 | + // support both "Bearer" and "bearer" to conform to api.github.com |
| 102 | + if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { |
| 103 | + token = authHeader[7:] |
| 104 | + } else { |
| 105 | + token = authHeader |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + // Do a naïve check for a colon in the token - currently, only the IDE token has a colon in it. |
| 110 | + // ex: tid=1;exp=25145314523;chat=1:<hmac> |
| 111 | + if strings.Contains(token, ":") { |
| 112 | + return authTypeIDE, token, nil |
| 113 | + } |
| 114 | + |
| 115 | + for _, prefix := range supportedThirdPartyTokenPrefixes { |
| 116 | + if strings.HasPrefix(token, prefix) { |
| 117 | + return authTypeGhToken, token, nil |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + matchesOldTokenPattern := oldPatternRegexp.MatchString(token) |
| 122 | + if matchesOldTokenPattern { |
| 123 | + return authTypeGhToken, token, nil |
| 124 | + } |
| 125 | + |
| 126 | + return authTypeUnknown, "", errBadAuthorizationHeader |
| 127 | +} |
0 commit comments