Skip to content

Commit 5a66d6f

Browse files
authored
Merge pull request #2448 from github/koesie10/auto-model
Add initial implementation of auto-modeling
2 parents 3b4f236 + d27f3d2 commit 5a66d6f

File tree

9 files changed

+772
-1
lines changed

9 files changed

+772
-1
lines changed

extensions/ql-vscode/src/config.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,3 +711,10 @@ const QUERIES_PANEL = new Setting("queriesPanel", ROOT_SETTING);
711711
export function showQueriesPanel(): boolean {
712712
return !!QUERIES_PANEL.getValue<boolean>();
713713
}
714+
715+
const DATA_EXTENSIONS = new Setting("dataExtensions", ROOT_SETTING);
716+
const LLM_GENERATION = new Setting("llmGeneration", DATA_EXTENSIONS);
717+
718+
export function showLlmGeneration(): boolean {
719+
return !!LLM_GENERATION.getValue<boolean>();
720+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import { Credentials } from "../common/authentication";
2+
import { OctokitResponse } from "@octokit/types";
3+
4+
export enum ClassificationType {
5+
Unknown = "CLASSIFICATION_TYPE_UNKNOWN",
6+
Neutral = "CLASSIFICATION_TYPE_NEUTRAL",
7+
Source = "CLASSIFICATION_TYPE_SOURCE",
8+
Sink = "CLASSIFICATION_TYPE_SINK",
9+
Summary = "CLASSIFICATION_TYPE_SUMMARY",
10+
}
11+
12+
export interface Classification {
13+
type: ClassificationType;
14+
kind: string;
15+
explanation: string;
16+
}
17+
18+
export interface Method {
19+
package: string;
20+
type: string;
21+
name: string;
22+
signature: string;
23+
usages: string[];
24+
classification?: Classification;
25+
input?: string;
26+
output?: string;
27+
}
28+
29+
export interface ModelRequest {
30+
language: string;
31+
candidates: Method[];
32+
samples: Method[];
33+
}
34+
35+
export interface ModelResponse {
36+
language: string;
37+
predicted: Method[];
38+
}
39+
40+
export async function autoModel(
41+
credentials: Credentials,
42+
request: ModelRequest,
43+
): Promise<ModelResponse> {
44+
const octokit = await credentials.getOctokit();
45+
46+
const response: OctokitResponse<ModelResponse> = await octokit.request(
47+
"POST /repos/github/codeql/code-scanning/codeql/auto-model",
48+
{
49+
data: request,
50+
},
51+
);
52+
53+
return response.data;
54+
}
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import { ExternalApiUsage } from "./external-api-usage";
2+
import { ModeledMethod, ModeledMethodType } from "./modeled-method";
3+
import {
4+
Classification,
5+
ClassificationType,
6+
Method,
7+
ModelRequest,
8+
} from "./auto-model-api";
9+
10+
export function createAutoModelRequest(
11+
language: string,
12+
externalApiUsages: ExternalApiUsage[],
13+
modeledMethods: Record<string, ModeledMethod>,
14+
): ModelRequest {
15+
const request: ModelRequest = {
16+
language,
17+
samples: [],
18+
candidates: [],
19+
};
20+
21+
// Sort by number of usages so we always send the most used methods first
22+
externalApiUsages = [...externalApiUsages];
23+
externalApiUsages.sort((a, b) => b.usages.length - a.usages.length);
24+
25+
for (const externalApiUsage of externalApiUsages) {
26+
const modeledMethod: ModeledMethod = modeledMethods[
27+
externalApiUsage.signature
28+
] ?? {
29+
type: "none",
30+
};
31+
32+
const numberOfArguments =
33+
externalApiUsage.methodParameters === "()"
34+
? 0
35+
: externalApiUsage.methodParameters.split(",").length;
36+
37+
for (
38+
let argumentIndex = 0;
39+
argumentIndex < numberOfArguments;
40+
argumentIndex++
41+
) {
42+
const method: Method = {
43+
package: externalApiUsage.packageName,
44+
type: externalApiUsage.typeName,
45+
name: externalApiUsage.methodName,
46+
signature: externalApiUsage.methodParameters,
47+
classification:
48+
modeledMethod.type === "none"
49+
? undefined
50+
: toMethodClassification(modeledMethod),
51+
usages: externalApiUsage.usages
52+
.slice(0, 10)
53+
.map((usage) => usage.label),
54+
input: `Argument[${argumentIndex}]`,
55+
};
56+
57+
if (modeledMethod.type === "none") {
58+
request.candidates.push(method);
59+
} else {
60+
request.samples.push(method);
61+
}
62+
}
63+
}
64+
65+
request.candidates = request.candidates.slice(0, 20);
66+
request.samples = request.samples.slice(0, 100);
67+
68+
return request;
69+
}
70+
71+
/**
72+
* For now, we have a simplified model that only models methods as sinks. It does not model methods as neutral,
73+
* so we aren't actually able to correctly determine that a method is neutral; it could still be a source or summary.
74+
* However, to keep this method simple and give output to the user, we will model any method for which none of its
75+
* arguments are modeled as sinks as neutral.
76+
*
77+
* If there are multiple arguments which are modeled as sinks, we will only model the first one.
78+
*/
79+
export function parsePredictedClassifications(
80+
predicted: Method[],
81+
): Record<string, ModeledMethod> {
82+
const predictedBySignature: Record<string, Method[]> = {};
83+
for (const method of predicted) {
84+
if (!method.classification) {
85+
continue;
86+
}
87+
88+
const signature = toFullMethodSignature(method);
89+
90+
if (!(signature in predictedBySignature)) {
91+
predictedBySignature[signature] = [];
92+
}
93+
94+
predictedBySignature[signature].push(method);
95+
}
96+
97+
const modeledMethods: Record<string, ModeledMethod> = {};
98+
99+
for (const signature in predictedBySignature) {
100+
const predictedMethods = predictedBySignature[signature];
101+
102+
const sinks = predictedMethods.filter(
103+
(method) => method.classification?.type === ClassificationType.Sink,
104+
);
105+
if (sinks.length === 0) {
106+
// For now, model any method for which none of its arguments are modeled as sinks as neutral
107+
modeledMethods[signature] = {
108+
type: "neutral",
109+
kind: "",
110+
input: "",
111+
output: "",
112+
};
113+
continue;
114+
}
115+
116+
// Order the sinks by the input alphabetically. This will ensure that the first argument is always
117+
// first in the list of sinks, the second argument is always second, etc.
118+
// If we get back "Argument[1]" and "Argument[3]", "Argument[1]" should always be first
119+
sinks.sort((a, b) => compareInputOutput(a.input ?? "", b.input ?? ""));
120+
121+
const sink = sinks[0];
122+
123+
modeledMethods[signature] = {
124+
type: "sink",
125+
kind: sink.classification?.kind ?? "",
126+
input: sink.input ?? "",
127+
output: sink.output ?? "",
128+
};
129+
}
130+
131+
return modeledMethods;
132+
}
133+
134+
function toMethodClassificationType(
135+
type: ModeledMethodType,
136+
): ClassificationType {
137+
switch (type) {
138+
case "source":
139+
return ClassificationType.Source;
140+
case "sink":
141+
return ClassificationType.Sink;
142+
case "summary":
143+
return ClassificationType.Summary;
144+
case "neutral":
145+
return ClassificationType.Neutral;
146+
default:
147+
return ClassificationType.Unknown;
148+
}
149+
}
150+
151+
function toMethodClassification(modeledMethod: ModeledMethod): Classification {
152+
return {
153+
type: toMethodClassificationType(modeledMethod.type),
154+
kind: modeledMethod.kind,
155+
explanation: "",
156+
};
157+
}
158+
159+
function toFullMethodSignature(method: Method): string {
160+
return `${method.package}.${method.type}#${method.name}${method.signature}`;
161+
}
162+
163+
const argumentRegex = /^Argument\[(\d+)]$/;
164+
165+
// Argument[this] is before ReturnValue
166+
const nonNumericArgumentOrder = ["Argument[this]", "ReturnValue"];
167+
168+
/**
169+
* Compare two inputs or outputs matching `Argument[<number>]`, `Argument[this]`, or `ReturnValue`.
170+
* If they are the same, return 0. If a is less than b, returns a negative number.
171+
* If a is greater than b, returns a positive number.
172+
*/
173+
export function compareInputOutput(a: string, b: string): number {
174+
if (a === b) {
175+
return 0;
176+
}
177+
178+
const aMatch = a.match(argumentRegex);
179+
const bMatch = b.match(argumentRegex);
180+
181+
// Numeric arguments are always first
182+
if (aMatch && !bMatch) {
183+
return -1;
184+
}
185+
if (!aMatch && bMatch) {
186+
return 1;
187+
}
188+
189+
// Neither is an argument
190+
if (!aMatch && !bMatch) {
191+
const aIndex = nonNumericArgumentOrder.indexOf(a);
192+
const bIndex = nonNumericArgumentOrder.indexOf(b);
193+
194+
// If either one is unknown, it is sorted last
195+
if (aIndex === -1 && bIndex === -1) {
196+
return a.localeCompare(b);
197+
}
198+
if (aIndex === -1) {
199+
return 1;
200+
}
201+
if (bIndex === -1) {
202+
return -1;
203+
}
204+
205+
return aIndex - bIndex;
206+
}
207+
208+
// This case shouldn't happen, but makes TypeScript happy
209+
if (!aMatch || !bMatch) {
210+
return 0;
211+
}
212+
213+
// Both are arguments
214+
const aIndex = parseInt(aMatch[1]);
215+
const bIndex = parseInt(bMatch[1]);
216+
217+
return aIndex - bIndex;
218+
}

extensions/ql-vscode/src/data-extensions-editor/data-extensions-editor-view.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ import { createDataExtensionYaml, loadDataExtensionYaml } from "./yaml";
3939
import { ExternalApiUsage } from "./external-api-usage";
4040
import { ModeledMethod } from "./modeled-method";
4141
import { ExtensionPackModelFile } from "./shared/extension-pack";
42+
import { autoModel } from "./auto-model-api";
43+
import {
44+
createAutoModelRequest,
45+
parsePredictedClassifications,
46+
} from "./auto-model";
47+
import { showLlmGeneration } from "../config";
4248

4349
function getQlSubmoduleFolder(): WorkspaceFolder | undefined {
4450
const workspaceFolder = workspace.workspaceFolders?.find(
@@ -127,6 +133,13 @@ export class DataExtensionsEditorView extends AbstractWebview<
127133
case "generateExternalApi":
128134
await this.generateModeledMethods();
129135

136+
break;
137+
case "generateExternalApiFromLlm":
138+
await this.generateModeledMethodsFromLlm(
139+
msg.externalApiUsages,
140+
msg.modeledMethods,
141+
);
142+
130143
break;
131144
default:
132145
assertNever(msg);
@@ -149,6 +162,7 @@ export class DataExtensionsEditorView extends AbstractWebview<
149162
viewState: {
150163
extensionPackModelFile: this.modelFile,
151164
modelFileExists: await pathExists(this.modelFile.filename),
165+
showLlmButton: showLlmGeneration(),
152166
},
153167
});
154168
}
@@ -367,6 +381,29 @@ export class DataExtensionsEditorView extends AbstractWebview<
367381
await this.clearProgress();
368382
}
369383

384+
private async generateModeledMethodsFromLlm(
385+
externalApiUsages: ExternalApiUsage[],
386+
modeledMethods: Record<string, ModeledMethod>,
387+
): Promise<void> {
388+
const request = createAutoModelRequest(
389+
this.databaseItem.language,
390+
externalApiUsages,
391+
modeledMethods,
392+
);
393+
394+
const response = await autoModel(this.app.credentials, request);
395+
396+
const predictedModeledMethods = parsePredictedClassifications(
397+
response.predicted,
398+
);
399+
400+
await this.postMessage({
401+
t: "addModeledMethods",
402+
modeledMethods: predictedModeledMethods,
403+
overrideNone: true,
404+
});
405+
}
406+
370407
/*
371408
* Progress in this class is a bit weird. Most of the progress is based on running the query.
372409
* Query progress is always between 0 and 1000. However, we still have some steps that need

extensions/ql-vscode/src/data-extensions-editor/shared/view-state.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ import { ExtensionPackModelFile } from "./extension-pack";
33
export interface DataExtensionEditorViewState {
44
extensionPackModelFile: ExtensionPackModelFile;
55
modelFileExists: boolean;
6+
showLlmButton: boolean;
67
}

extensions/ql-vscode/src/pure/interface-types.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,12 @@ export interface GenerateExternalApiMessage {
544544
t: "generateExternalApi";
545545
}
546546

547+
export interface GenerateExternalApiFromLlmMessage {
548+
t: "generateExternalApiFromLlm";
549+
externalApiUsages: ExternalApiUsage[];
550+
modeledMethods: Record<string, ModeledMethod>;
551+
}
552+
547553
export type ToDataExtensionsEditorMessage =
548554
| SetExtensionPackStateMessage
549555
| SetExternalApiUsagesMessage
@@ -556,4 +562,5 @@ export type FromDataExtensionsEditorMessage =
556562
| OpenExtensionPackMessage
557563
| JumpToUsageMessage
558564
| SaveModeledMethods
559-
| GenerateExternalApiMessage;
565+
| GenerateExternalApiMessage
566+
| GenerateExternalApiFromLlmMessage;

extensions/ql-vscode/src/stories/data-extensions-editor/DataExtensionsEditor.stories.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ DataExtensionsEditor.args = {
3030
"/home/user/vscode-codeql-starter/codeql-custom-queries-java/sql2o/models/sql2o.yml",
3131
},
3232
modelFileExists: true,
33+
showLlmButton: true,
3334
},
3435
initialExternalApiUsages: [
3536
{

0 commit comments

Comments
 (0)