Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"test:update-snapshots": "npm run build && node scripts/test.mjs --test-update-snapshots",
"prepare": "node --experimental-strip-types scripts/prepare.ts",
"verify-server-json-version": "node --experimental-strip-types scripts/verify-server-json-version.ts",
"eval": "npm run build && node --experimental-strip-types --test scripts/eval_gemini.ts"
"eval": "npm run build && node --experimental-strip-types scripts/eval_gemini.ts"
},
"files": [
"build/src",
Expand Down
127 changes: 87 additions & 40 deletions scripts/eval_gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import fs from 'node:fs';
import path from 'node:path';
import {describe, test, before, after, afterEach} from 'node:test';
import {parseArgs} from 'node:util';

import {
GoogleGenerativeAI,
Expand All @@ -16,9 +16,10 @@ import {
import {Client} from '@modelcontextprotocol/sdk/client/index.js';
import {StdioClientTransport} from '@modelcontextprotocol/sdk/client/stdio.js';

import {TestServer} from '../build/tests/server.js';

const ROOT_DIR = path.resolve(import.meta.dirname, '..');
const SCENARIOS_DIR = path.join(import.meta.dirname, 'eval_scenarios');
import {TestServer} from '../build/tests/server.js';

// Define schema for our test scenarios
export interface CapturedFunctionCall {
Expand All @@ -37,8 +38,6 @@ export interface TestScenario {
}

async function loadScenario(scenarioPath: string): Promise<TestScenario> {
// Dynamic import of the test scenario
// We assume the test file exports a 'scenario' object
const module = await import(scenarioPath);
if (!module.scenario) {
throw new Error(
Expand All @@ -48,7 +47,6 @@ async function loadScenario(scenarioPath: string): Promise<TestScenario> {
return module.scenario;
}

// Helper to sanitize schema for Gemini
function isRecord(v: unknown): v is Record<string, unknown> {
return typeof v === 'object' && v !== null && !Array.isArray(v);
}
Expand Down Expand Up @@ -84,9 +82,18 @@ async function runSingleScenario(
scenarioPath: string,
apiKey: string,
server: TestServer,
modelId: string,
debug: boolean,
): Promise<void> {
const debugLog = (...args: unknown[]) => {
if (debug) {
console.log(...args);
}
};
const absolutePath = path.resolve(scenarioPath);
console.log(`\n### Running Scenario: ${absolutePath} ###`);
debugLog(
`\n### Running Scenario: ${path.relative(ROOT_DIR, absolutePath)} ###`,
);

let client: Client | undefined;
let transport: StdioClientTransport | undefined;
Expand Down Expand Up @@ -121,10 +128,16 @@ async function runSingleScenario(
}
});

const args = [serverPath];
if (!debug) {
args.push('--headless');
}

transport = new StdioClientTransport({
command: 'node',
args: [serverPath],
args,
env,
stderr: debug ? 'inherit' : 'ignore',
});

client = new Client(
Expand Down Expand Up @@ -165,7 +178,7 @@ async function runSingleScenario(

const genAI = new GoogleGenerativeAI(apiKey);
const model = genAI.getGenerativeModel({
model: 'gemini-2.5-flash',
model: modelId,
tools: [{functionDeclarations}],
});

Expand All @@ -181,8 +194,8 @@ async function runSingleScenario(

// Execute turns
let turnCount = 0;
console.log(`\n--- Turn 1 (User) ---`);
console.log(scenario.prompt);
debugLog(`\n--- Turn 1 (User) ---`);
debugLog(scenario.prompt);

let result = await chat.sendMessage(scenario.prompt, {
timeout: 5000,
Expand All @@ -191,17 +204,15 @@ async function runSingleScenario(

while (turnCount < scenario.maxTurns) {
turnCount++;
console.log(`\n--- Turn ${turnCount} (Model) ---`);
debugLog(`\n--- Turn ${turnCount} (Model) ---`);
const text = response.text();
if (text) {
console.log(`Text: ${text}`);
debugLog(`Text: ${text}`);
}

const functionCalls = response.functionCalls();
if (functionCalls && functionCalls.length > 0) {
console.log(
`Function Calls: ${JSON.stringify(functionCalls, null, 2)}`,
);
debugLog(`Function Calls: ${JSON.stringify(functionCalls, null, 2)}`);

const functionResponses = [];
for (const call of functionCalls) {
Expand All @@ -219,7 +230,7 @@ async function runSingleScenario(

const safeArgs = isRecord(call.args) ? call.args : {};

console.log(
debugLog(
`Executing tool: ${originalName} with args: ${JSON.stringify(call.args)}`,
);

Expand Down Expand Up @@ -253,49 +264,85 @@ async function runSingleScenario(
}

// Send tool results back
console.log(`Sending ${functionResponses.length} tool outputs back...`);
debugLog(`Sending ${functionResponses.length} tool outputs back...`);
result = await chat.sendMessage(functionResponses);
response = result.response;
} else {
console.log('No tool calls. Interaction finished.');
debugLog('No tool calls. Interaction finished.');
break;
}
}

console.log('\nVerifying expectations...');
debugLog('\nVerifying expectations...');
expectations(allCalls);
} finally {
await client?.close();
await transport?.close();
}
}
const apiKey = process.env.GEMINI_API_KEY;
if (!apiKey) {
throw new Error('GEMINI_API_KEY environment variable is required.');
}

void describe('Gemini Eval Scenarios', () => {
const server = new TestServer(TestServer.randomPort());
async function main() {
const apiKey = process.env.GEMINI_API_KEY;
if (!apiKey) {
throw new Error('GEMINI_API_KEY environment variable is required.');
}

before(async () => {
await server.start();
const {values, positionals} = parseArgs({
options: {
model: {
type: 'string',
default: 'gemini-2.5-flash',
},
debug: {
type: 'boolean',
default: false,
},
},
allowPositionals: true,
});

after(async () => {
await server.stop();
});
const modelId = values.model;
const debug = values.debug;
const scenarioFiles =
positionals.length > 0
? positionals.map(p => path.resolve(p))
: fs
.readdirSync(SCENARIOS_DIR)
.filter(file => file.endsWith('.ts') || file.endsWith('.js'))
.map(file => path.join(SCENARIOS_DIR, file));

afterEach(() => {
server.restore();
});
const server = new TestServer(TestServer.randomPort());
await server.start();

const files = fs.readdirSync(SCENARIOS_DIR).filter(file => {
return file.endsWith('.ts') || file.endsWith('.js');
});
let successCount = 0;
let failureCount = 0;

for (const file of files) {
void test(file, {timeout: 60_000}, async () => {
await runSingleScenario(path.join(SCENARIOS_DIR, file), apiKey, server);
});
try {
for (const scenarioPath of scenarioFiles) {
try {
await runSingleScenario(scenarioPath, apiKey, server, modelId, debug);
console.log(`✔ ${path.relative(ROOT_DIR, scenarioPath)}`);
successCount++;
} catch (e) {
console.error(`✖ ${path.relative(ROOT_DIR, scenarioPath)}`);
console.error(e);
failureCount++;
} finally {
server.restore();
}
}
} finally {
await server.stop();
}

console.log(`\nSummary: ${successCount} passed, ${failureCount} failed`);

if (failureCount > 0) {
process.exit(1);
}
}

main().catch(error => {
console.error('Fatal error:', error);
process.exit(1);
});