|
1 | 1 | package oauth |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "context" |
5 | 4 | "crypto/rand" |
6 | 5 | "encoding/base64" |
7 | 6 | "fmt" |
8 | 7 | "io" |
9 | | - "log" |
10 | 8 | "net" |
11 | 9 | "net/http" |
12 | 10 | "os" |
13 | 11 | "os/exec" |
14 | 12 | "runtime" |
15 | 13 | "strings" |
16 | 14 | "time" |
17 | | - |
18 | | - "golang.org/x/oauth2" |
19 | 15 | ) |
20 | 16 |
|
21 | 17 | const ( |
@@ -72,178 +68,6 @@ func isRunningInDocker() bool { |
72 | 68 | return false |
73 | 69 | } |
74 | 70 |
|
75 | | -// StartDeviceFlow initiates an OAuth device authorization flow |
76 | | -// This is suitable for environments without callback capabilities (like Docker containers) |
77 | | -func StartDeviceFlow(ctx context.Context, cfg Config) (*Result, error) { |
78 | | - oauth2Cfg := &oauth2.Config{ |
79 | | - ClientID: cfg.ClientID, |
80 | | - ClientSecret: cfg.ClientSecret, |
81 | | - Scopes: cfg.Scopes, |
82 | | - Endpoint: oauth2.Endpoint{ |
83 | | - AuthURL: cfg.AuthURL, |
84 | | - TokenURL: cfg.TokenURL, |
85 | | - DeviceAuthURL: cfg.DeviceAuthURL, |
86 | | - }, |
87 | | - } |
88 | | - |
89 | | - // Request device authorization |
90 | | - deviceAuth, err := oauth2Cfg.DeviceAuth(ctx) |
91 | | - if err != nil { |
92 | | - return nil, fmt.Errorf("failed to get device authorization: %w", err) |
93 | | - } |
94 | | - |
95 | | - // Display verification instructions to user |
96 | | - fmt.Fprint(os.Stderr, "\n"+strings.Repeat("=", 80)+"\n") |
97 | | - fmt.Fprint(os.Stderr, "GitHub OAuth Device Authorization\n") |
98 | | - fmt.Fprint(os.Stderr, strings.Repeat("=", 80)+"\n\n") |
99 | | - fmt.Fprintf(os.Stderr, "Please visit: %s\n\n", deviceAuth.VerificationURI) |
100 | | - fmt.Fprintf(os.Stderr, "And enter code: %s\n\n", deviceAuth.UserCode) |
101 | | - fmt.Fprint(os.Stderr, strings.Repeat("=", 80)+"\n\n") |
102 | | - |
103 | | - // Poll for token |
104 | | - token, err := oauth2Cfg.DeviceAccessToken(ctx, deviceAuth) |
105 | | - if err != nil { |
106 | | - return nil, fmt.Errorf("failed to get device access token: %w", err) |
107 | | - } |
108 | | - |
109 | | - fmt.Fprint(os.Stderr, "\n✓ Authorization successful!\n\n") |
110 | | - |
111 | | - return &Result{ |
112 | | - AccessToken: token.AccessToken, |
113 | | - RefreshToken: token.RefreshToken, |
114 | | - TokenType: token.TokenType, |
115 | | - Expiry: token.Expiry, |
116 | | - }, nil |
117 | | -} |
118 | | - |
119 | | -// StartOAuthFlow automatically selects the appropriate OAuth flow based on environment |
120 | | -// - Device flow for Docker containers (no callback server possible) |
121 | | -// - Interactive PKCE flow for native binaries (best UX with browser) |
122 | | -func StartOAuthFlow(ctx context.Context, cfg Config) (*Result, error) { |
123 | | - // Check if we're in Docker |
124 | | - if isRunningInDocker() && cfg.CallbackPort == 0 { |
125 | | - // Docker without explicit callback port - use device flow |
126 | | - log.Printf("Detected Docker environment, using device flow") |
127 | | - return StartDeviceFlow(ctx, cfg) |
128 | | - } |
129 | | - |
130 | | - // Use interactive PKCE flow (browser-based) |
131 | | - return StartInteractiveFlow(ctx, cfg) |
132 | | -} |
133 | | - |
134 | | -// StartInteractiveFlow initiates an interactive OAuth flow with PKCE |
135 | | -// This is intended for stdio mode only and opens a browser for user consent |
136 | | -func StartInteractiveFlow(ctx context.Context, cfg Config) (*Result, error) { |
137 | | - // Generate PKCE verifier |
138 | | - verifier, err := generatePKCEVerifier() |
139 | | - if err != nil { |
140 | | - return nil, fmt.Errorf("failed to generate PKCE verifier: %w", err) |
141 | | - } |
142 | | - |
143 | | - // Create OAuth2 config |
144 | | - oauth2Cfg := &oauth2.Config{ |
145 | | - ClientID: cfg.ClientID, |
146 | | - ClientSecret: cfg.ClientSecret, |
147 | | - RedirectURL: cfg.RedirectURL, |
148 | | - Scopes: cfg.Scopes, |
149 | | - Endpoint: oauth2.Endpoint{ |
150 | | - AuthURL: cfg.AuthURL, |
151 | | - TokenURL: cfg.TokenURL, |
152 | | - }, |
153 | | - } |
154 | | - |
155 | | - // Generate state for CSRF protection |
156 | | - stateBytes := make([]byte, 16) |
157 | | - if _, err := rand.Read(stateBytes); err != nil { |
158 | | - return nil, fmt.Errorf("failed to generate state: %w", err) |
159 | | - } |
160 | | - state := base64.RawURLEncoding.EncodeToString(stateBytes) |
161 | | - |
162 | | - // Start local HTTP server for callback |
163 | | - listener, port, err := startLocalServer(cfg.CallbackPort) |
164 | | - if err != nil { |
165 | | - return nil, fmt.Errorf("failed to start local server: %w", err) |
166 | | - } |
167 | | - defer listener.Close() |
168 | | - |
169 | | - // Update redirect URL with actual port |
170 | | - oauth2Cfg.RedirectURL = fmt.Sprintf("http://localhost:%d/callback", port) |
171 | | - |
172 | | - // Channel to receive the authorization code |
173 | | - codeChan := make(chan string, 1) |
174 | | - errChan := make(chan error, 1) |
175 | | - |
176 | | - // Setup HTTP handler for callback |
177 | | - server := &http.Server{ |
178 | | - Handler: createCallbackHandler(state, codeChan, errChan), |
179 | | - ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks |
180 | | - } |
181 | | - |
182 | | - // Start server in background |
183 | | - go func() { |
184 | | - if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { |
185 | | - errChan <- fmt.Errorf("server error: %w", err) |
186 | | - } |
187 | | - }() |
188 | | - |
189 | | - // Build authorization URL with PKCE |
190 | | - authURL := oauth2Cfg.AuthCodeURL( |
191 | | - state, |
192 | | - oauth2.S256ChallengeOption(verifier), |
193 | | - ) |
194 | | - |
195 | | - // Display URL to user and try to open browser |
196 | | - fmt.Fprint(os.Stderr, "\n"+strings.Repeat("=", 80)+"\n") |
197 | | - fmt.Fprint(os.Stderr, "GitHub OAuth Authorization Required\n") |
198 | | - fmt.Fprint(os.Stderr, strings.Repeat("=", 80)+"\n\n") |
199 | | - fmt.Fprint(os.Stderr, "Opening your browser to complete authorization...\n\n") |
200 | | - fmt.Fprint(os.Stderr, "If your browser doesn't open automatically, please visit this URL:\n\n") |
201 | | - fmt.Fprintf(os.Stderr, " %s\n\n", authURL) |
202 | | - fmt.Fprint(os.Stderr, strings.Repeat("=", 80)+"\n\n") |
203 | | - |
204 | | - // Try to open browser |
205 | | - if err := openBrowser(authURL); err != nil { |
206 | | - log.Printf("Warning: Could not open browser automatically: %v", err) |
207 | | - } |
208 | | - |
209 | | - // Wait for callback with timeout |
210 | | - var code string |
211 | | - select { |
212 | | - case code = <-codeChan: |
213 | | - // Success |
214 | | - case err := <-errChan: |
215 | | - return nil, fmt.Errorf("callback error: %w", err) |
216 | | - case <-ctx.Done(): |
217 | | - return nil, fmt.Errorf("context cancelled: %w", ctx.Err()) |
218 | | - case <-time.After(DefaultAuthTimeout): |
219 | | - return nil, fmt.Errorf("authorization timeout after %v", DefaultAuthTimeout) |
220 | | - } |
221 | | - |
222 | | - // Shutdown server gracefully |
223 | | - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
224 | | - defer cancel() |
225 | | - _ = server.Shutdown(shutdownCtx) |
226 | | - |
227 | | - // Exchange authorization code for token with PKCE verifier |
228 | | - token, err := oauth2Cfg.Exchange( |
229 | | - ctx, |
230 | | - code, |
231 | | - oauth2.VerifierOption(verifier), |
232 | | - ) |
233 | | - if err != nil { |
234 | | - return nil, fmt.Errorf("failed to exchange code for token: %w", err) |
235 | | - } |
236 | | - |
237 | | - fmt.Fprint(os.Stderr, "\n✓ Authorization successful!\n\n") |
238 | | - |
239 | | - return &Result{ |
240 | | - AccessToken: token.AccessToken, |
241 | | - RefreshToken: token.RefreshToken, |
242 | | - TokenType: token.TokenType, |
243 | | - Expiry: token.Expiry, |
244 | | - }, nil |
245 | | -} |
246 | | - |
247 | 71 | // startLocalServer starts a local HTTP server on the specified port |
248 | 72 | // If port is 0, uses a random available port |
249 | 73 | func startLocalServer(port int) (net.Listener, int, error) { |
|
0 commit comments