Skip to content

Commit ac72310

Browse files
committed
Handle SAML errors when fetching repository context
1 parent dd3a4a1 commit ac72310

3 files changed

Lines changed: 60 additions & 33 deletions

File tree

src/api/handleSamlError.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import {Octokit} from "@octokit/rest";
2+
import {getSession} from "../auth/auth";
3+
import {logDebug} from "../log";
4+
import {getClient} from "./api";
5+
6+
export async function handleSamlError<T>(request: (client: Octokit) => Promise<T>): Promise<T> {
7+
try {
8+
const session = await getSession();
9+
const client = getClient(session.accessToken);
10+
return await request(client);
11+
} catch (error) {
12+
if ((error as Error).message.includes("Resource protected by organization SAML enforcement.")) {
13+
logDebug("SAML error, re-authenticating");
14+
const session = await getSession(
15+
"Your organization is protected by SAML enforcement. Please sign-in again to continue."
16+
);
17+
const client = getClient(session.accessToken);
18+
return await request(client);
19+
} else {
20+
throw error;
21+
}
22+
}
23+
}

src/auth/auth.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@ import * as vscode from "vscode";
33
const AUTH_PROVIDER_ID = "github";
44
const DEFAULT_SCOPES = ["repo", "workflow"];
55

6-
export async function getSession(): Promise<vscode.AuthenticationSession> {
6+
/**
7+
* Retrieves a session from the GitHub authentication provider
8+
* @param forceMessage Force a new session with a prompt to the user
9+
* @returns A {@link vscode.AuthenticationSession}
10+
*/
11+
export async function getSession(forceMessage?: string): Promise<vscode.AuthenticationSession> {
712
const existingSession = await vscode.authentication.getSession(AUTH_PROVIDER_ID, getScopes(), {
8-
createIfNone: true
13+
createIfNone: true,
14+
forceNewSession: forceMessage ? {detail: forceMessage} : false
915
});
1016

1117
if (!existingSession) {

src/git/repository.ts

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import * as vscode from "vscode";
2-
3-
import {logDebug, logError} from "../log";
4-
import {API, GitExtension, RefType, RepositoryState} from "../typings/git";
5-
62
import {Octokit} from "@octokit/rest";
7-
import {getClient} from "../api/api";
8-
import {getSession} from "../auth/auth";
3+
4+
import {handleSamlError} from "../api/handleSamlError";
95
import {getRemoteName} from "../configuration/configuration";
106
import {Protocol} from "../external/protocol";
7+
import {logDebug, logError} from "../log";
8+
import {API, GitExtension, RefType, RepositoryState} from "../typings/git";
119

1210
interface GitHubUrls {
1311
workspaceUri: vscode.Uri;
@@ -146,8 +144,6 @@ export async function getGitHubContext(): Promise<GitHubContext | undefined> {
146144

147145
try {
148146
const git = await getGitExtension();
149-
const session = await getSession();
150-
const client = getClient(session.accessToken);
151147

152148
const protocolInfos = await getGitHubUrls();
153149
if (!protocolInfos) {
@@ -157,29 +153,31 @@ export async function getGitHubContext(): Promise<GitHubContext | undefined> {
157153

158154
logDebug("Found protocol infos", protocolInfos.length.toString());
159155

160-
const repos = await Promise.all(
161-
protocolInfos.map(async (protocolInfo): Promise<GitHubRepoContext> => {
162-
logDebug("Getting infos for repository", protocolInfo.url);
163-
164-
const repoInfo = await client.repos.get({
165-
repo: protocolInfo.protocol.repositoryName,
166-
owner: protocolInfo.protocol.owner
167-
});
168-
169-
const repo = git && git.getRepository(protocolInfo.workspaceUri);
170-
171-
return {
172-
workspaceUri: protocolInfo.workspaceUri,
173-
client,
174-
repositoryState: repo?.state,
175-
name: protocolInfo.protocol.repositoryName,
176-
owner: protocolInfo.protocol.owner,
177-
id: repoInfo.data.id,
178-
defaultBranch: `refs/heads/${repoInfo.data.default_branch}`,
179-
organizationOwned: repoInfo.data.owner.type === "Organization"
180-
};
181-
})
182-
);
156+
const repos = await handleSamlError(async (client: Octokit) => {
157+
return await Promise.all(
158+
protocolInfos.map(async (protocolInfo): Promise<GitHubRepoContext> => {
159+
logDebug("Getting infos for repository", protocolInfo.url);
160+
161+
const repoInfo = await client.repos.get({
162+
repo: protocolInfo.protocol.repositoryName,
163+
owner: protocolInfo.protocol.owner
164+
});
165+
166+
const repo = git && git.getRepository(protocolInfo.workspaceUri);
167+
168+
return {
169+
workspaceUri: protocolInfo.workspaceUri,
170+
client,
171+
repositoryState: repo?.state,
172+
name: protocolInfo.protocol.repositoryName,
173+
owner: protocolInfo.protocol.owner,
174+
id: repoInfo.data.id,
175+
defaultBranch: `refs/heads/${repoInfo.data.default_branch}`,
176+
organizationOwned: repoInfo.data.owner.type === "Organization"
177+
};
178+
})
179+
);
180+
});
183181

184182
gitHubContext = Promise.resolve({
185183
repos,

0 commit comments

Comments
 (0)