Skip to content

Commit f52ad04

Browse files
committed
Improve parsing of predicted classifications
1 parent 4b54e4f commit f52ad04

3 files changed

Lines changed: 163 additions & 32 deletions

File tree

extensions/ql-vscode/src/data-extensions-editor/auto-model.ts

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,61 @@ export function createAutoModelRequest(
6868
return request;
6969
}
7070

71+
export function parsePredictedClassifications(
72+
predicted: Method[],
73+
): Record<string, ModeledMethod> {
74+
const predictedBySignature: Record<string, Method[]> = {};
75+
for (const method of predicted) {
76+
if (!method.classification) {
77+
continue;
78+
}
79+
80+
const signature = toFullMethodSignature(method);
81+
82+
if (!(signature in predictedBySignature)) {
83+
predictedBySignature[signature] = [];
84+
}
85+
86+
predictedBySignature[signature].push(method);
87+
}
88+
89+
const modeledMethods: Record<string, ModeledMethod> = {};
90+
91+
for (const signature in predictedBySignature) {
92+
const predictedMethods = predictedBySignature[signature];
93+
94+
const sinks = predictedMethods.filter(
95+
(method) => method.classification?.type === ClassificationType.Sink,
96+
);
97+
if (sinks.length === 0) {
98+
// For now, model any method for which none of its arguments are modeled as sinks as neutral
99+
modeledMethods[signature] = {
100+
type: "neutral",
101+
kind: "",
102+
input: "",
103+
output: "",
104+
};
105+
continue;
106+
}
107+
108+
// Order the sinks by the input alphabetically. This will ensure that the first argument is always
109+
// first in the list of sinks, the second argument is always second, etc.
110+
// If we get back "Argument[1]" and "Argument[3]", "Argument[1]" should always be first
111+
sinks.sort((a, b) => (a.input ?? "").localeCompare(b.input ?? ""));
112+
113+
const sink = sinks[0];
114+
115+
modeledMethods[signature] = {
116+
type: "sink",
117+
kind: sink.classification?.kind ?? "",
118+
input: sink.input ?? "",
119+
output: sink.output ?? "",
120+
};
121+
}
122+
123+
return modeledMethods;
124+
}
125+
71126
function toMethodClassificationType(
72127
type: ModeledMethodType,
73128
): ClassificationType {
@@ -93,19 +148,6 @@ function toMethodClassification(modeledMethod: ModeledMethod): Classification {
93148
};
94149
}
95150

96-
export function classificationTypeToModeledMethodType(
97-
type: ClassificationType,
98-
): ModeledMethodType {
99-
switch (type) {
100-
case ClassificationType.Source:
101-
return "source";
102-
case ClassificationType.Sink:
103-
return "sink";
104-
case ClassificationType.Summary:
105-
return "summary";
106-
case ClassificationType.Neutral:
107-
return "neutral";
108-
default:
109-
return "none";
110-
}
151+
function toFullMethodSignature(method: Method): string {
152+
return `${method.package}.${method.type}.${method.name}${method.signature}`;
111153
}

extensions/ql-vscode/src/data-extensions-editor/data-extensions-editor-view.ts

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ import { ModeledMethod } from "./modeled-method";
4141
import { ExtensionPackModelFile } from "./shared/extension-pack";
4242
import { autoModel } from "./auto-model-api";
4343
import {
44-
classificationTypeToModeledMethodType,
4544
createAutoModelRequest,
45+
parsePredictedClassifications,
4646
} from "./auto-model";
4747
import { showLlmGeneration } from "../config";
4848

@@ -393,24 +393,13 @@ export class DataExtensionsEditorView extends AbstractWebview<
393393

394394
const response = await autoModel(this.app.credentials, request);
395395

396-
const modeledMethodsByName: Record<string, ModeledMethod> = {};
397-
398-
for (const method of response.predicted) {
399-
if (method.classification === undefined) {
400-
continue;
401-
}
402-
403-
modeledMethodsByName[method.signature] = {
404-
type: classificationTypeToModeledMethodType(method.classification.type),
405-
kind: method.classification.kind,
406-
input: method.input ?? "",
407-
output: method.output ?? "",
408-
};
409-
}
396+
const predictedModeledMethods = parsePredictedClassifications(
397+
response.predicted,
398+
);
410399

411400
await this.postMessage({
412401
t: "addModeledMethods",
413-
modeledMethods: modeledMethodsByName,
402+
modeledMethods: predictedModeledMethods,
414403
overrideNone: true,
415404
});
416405
}

extensions/ql-vscode/test/unit-tests/data-extensions-editor/auto-model.test.ts

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
import { createAutoModelRequest } from "../../../src/data-extensions-editor/auto-model";
1+
import {
2+
createAutoModelRequest,
3+
parsePredictedClassifications,
4+
} from "../../../src/data-extensions-editor/auto-model";
25
import { ExternalApiUsage } from "../../../src/data-extensions-editor/external-api-usage";
36
import { ModeledMethod } from "../../../src/data-extensions-editor/modeled-method";
7+
import {
8+
ClassificationType,
9+
Method,
10+
} from "../../../src/data-extensions-editor/auto-model-api";
411

512
describe("createAutoModelRequest", () => {
613
const externalApiUsages: ExternalApiUsage[] = [
@@ -281,3 +288,96 @@ describe("createAutoModelRequest", () => {
281288
});
282289
});
283290
});
291+
292+
describe("parsePredictedClassifications", () => {
293+
const predictions: Method[] = [
294+
{
295+
package: "org.sql2o",
296+
type: "Sql2o",
297+
name: "createQuery",
298+
signature: "(String)",
299+
usages: ["createQuery(...)", "createQuery(...)"],
300+
input: "Argument[0]",
301+
classification: {
302+
type: ClassificationType.Sink,
303+
kind: "sql injection sink",
304+
explanation: "",
305+
},
306+
},
307+
{
308+
package: "org.sql2o",
309+
type: "Sql2o",
310+
name: "executeScalar",
311+
signature: "(Class)",
312+
usages: ["executeScalar(...)", "executeScalar(...)"],
313+
input: "Argument[0]",
314+
classification: {
315+
type: ClassificationType.Neutral,
316+
kind: "",
317+
explanation: "not a sink",
318+
},
319+
},
320+
{
321+
package: "org.sql2o",
322+
type: "Sql2o",
323+
name: "Sql2o",
324+
signature: "(String,String,String)",
325+
usages: ["new Sql2o(...)"],
326+
input: "Argument[0]",
327+
classification: {
328+
type: ClassificationType.Neutral,
329+
kind: "",
330+
explanation: "not a sink",
331+
},
332+
},
333+
{
334+
package: "org.sql2o",
335+
type: "Sql2o",
336+
name: "Sql2o",
337+
signature: "(String,String,String)",
338+
usages: ["new Sql2o(...)"],
339+
input: "Argument[1]",
340+
classification: {
341+
type: ClassificationType.Sink,
342+
kind: "sql injection sink",
343+
explanation: "not a sink",
344+
},
345+
},
346+
{
347+
package: "org.sql2o",
348+
type: "Sql2o",
349+
name: "Sql2o",
350+
signature: "(String,String,String)",
351+
usages: ["new Sql2o(...)"],
352+
input: "Argument[2]",
353+
classification: {
354+
type: ClassificationType.Sink,
355+
kind: "sql injection sink",
356+
explanation: "not a sink",
357+
},
358+
},
359+
];
360+
361+
it("correctly parses the output", () => {
362+
expect(parsePredictedClassifications(predictions)).toEqual({
363+
"org.sql2o.Sql2o.createQuery(String)": {
364+
type: "sink",
365+
kind: "sql injection sink",
366+
input: "Argument[0]",
367+
output: "",
368+
},
369+
"org.sql2o.Sql2o.executeScalar(Class)": {
370+
type: "neutral",
371+
kind: "",
372+
input: "",
373+
output: "",
374+
},
375+
"org.sql2o.Sql2o.Sql2o(String,String,String)": {
376+
type: "sink",
377+
kind: "sql injection sink",
378+
input: "Argument[1]",
379+
output: "",
380+
},
381+
});
382+
});
383+
});

0 commit comments

Comments
 (0)