diff --git a/package.json b/package.json index b7a655a9a..de5b2081b 100644 --- a/package.json +++ b/package.json @@ -22,7 +22,7 @@ "test:update-snapshots": "npm run build && node scripts/test.mjs --test-update-snapshots", "prepare": "node --experimental-strip-types scripts/prepare.ts", "verify-server-json-version": "node --experimental-strip-types scripts/verify-server-json-version.ts", - "eval": "npm run build && node --experimental-strip-types --test scripts/eval_gemini.ts" + "eval": "npm run build && node --experimental-strip-types scripts/eval_gemini.ts" }, "files": [ "build/src", diff --git a/scripts/eval_gemini.ts b/scripts/eval_gemini.ts index c099399ca..462e6af7d 100644 --- a/scripts/eval_gemini.ts +++ b/scripts/eval_gemini.ts @@ -6,7 +6,7 @@ import fs from 'node:fs'; import path from 'node:path'; -import {describe, test, before, after, afterEach} from 'node:test'; +import {parseArgs} from 'node:util'; import { GoogleGenerativeAI, @@ -16,9 +16,10 @@ import { import {Client} from '@modelcontextprotocol/sdk/client/index.js'; import {StdioClientTransport} from '@modelcontextprotocol/sdk/client/stdio.js'; +import {TestServer} from '../build/tests/server.js'; + const ROOT_DIR = path.resolve(import.meta.dirname, '..'); const SCENARIOS_DIR = path.join(import.meta.dirname, 'eval_scenarios'); -import {TestServer} from '../build/tests/server.js'; // Define schema for our test scenarios export interface CapturedFunctionCall { @@ -37,8 +38,6 @@ export interface TestScenario { } async function loadScenario(scenarioPath: string): Promise { - // Dynamic import of the test scenario - // We assume the test file exports a 'scenario' object const module = await import(scenarioPath); if (!module.scenario) { throw new Error( @@ -48,7 +47,6 @@ async function loadScenario(scenarioPath: string): Promise { return module.scenario; } -// Helper to sanitize schema for Gemini function isRecord(v: unknown): v is Record { return typeof v === 'object' && v !== null && !Array.isArray(v); } @@ -84,9 +82,18 @@ async function runSingleScenario( scenarioPath: string, apiKey: string, server: TestServer, + modelId: string, + debug: boolean, ): Promise { + const debugLog = (...args: unknown[]) => { + if (debug) { + console.log(...args); + } + }; const absolutePath = path.resolve(scenarioPath); - console.log(`\n### Running Scenario: ${absolutePath} ###`); + debugLog( + `\n### Running Scenario: ${path.relative(ROOT_DIR, absolutePath)} ###`, + ); let client: Client | undefined; let transport: StdioClientTransport | undefined; @@ -121,10 +128,16 @@ async function runSingleScenario( } }); + const args = [serverPath]; + if (!debug) { + args.push('--headless'); + } + transport = new StdioClientTransport({ command: 'node', - args: [serverPath], + args, env, + stderr: debug ? 'inherit' : 'ignore', }); client = new Client( @@ -165,7 +178,7 @@ async function runSingleScenario( const genAI = new GoogleGenerativeAI(apiKey); const model = genAI.getGenerativeModel({ - model: 'gemini-2.5-flash', + model: modelId, tools: [{functionDeclarations}], }); @@ -181,8 +194,8 @@ async function runSingleScenario( // Execute turns let turnCount = 0; - console.log(`\n--- Turn 1 (User) ---`); - console.log(scenario.prompt); + debugLog(`\n--- Turn 1 (User) ---`); + debugLog(scenario.prompt); let result = await chat.sendMessage(scenario.prompt, { timeout: 5000, @@ -191,17 +204,15 @@ async function runSingleScenario( while (turnCount < scenario.maxTurns) { turnCount++; - console.log(`\n--- Turn ${turnCount} (Model) ---`); + debugLog(`\n--- Turn ${turnCount} (Model) ---`); const text = response.text(); if (text) { - console.log(`Text: ${text}`); + debugLog(`Text: ${text}`); } const functionCalls = response.functionCalls(); if (functionCalls && functionCalls.length > 0) { - console.log( - `Function Calls: ${JSON.stringify(functionCalls, null, 2)}`, - ); + debugLog(`Function Calls: ${JSON.stringify(functionCalls, null, 2)}`); const functionResponses = []; for (const call of functionCalls) { @@ -219,7 +230,7 @@ async function runSingleScenario( const safeArgs = isRecord(call.args) ? call.args : {}; - console.log( + debugLog( `Executing tool: ${originalName} with args: ${JSON.stringify(call.args)}`, ); @@ -253,49 +264,85 @@ async function runSingleScenario( } // Send tool results back - console.log(`Sending ${functionResponses.length} tool outputs back...`); + debugLog(`Sending ${functionResponses.length} tool outputs back...`); result = await chat.sendMessage(functionResponses); response = result.response; } else { - console.log('No tool calls. Interaction finished.'); + debugLog('No tool calls. Interaction finished.'); break; } } - console.log('\nVerifying expectations...'); + debugLog('\nVerifying expectations...'); expectations(allCalls); } finally { await client?.close(); await transport?.close(); } } -const apiKey = process.env.GEMINI_API_KEY; -if (!apiKey) { - throw new Error('GEMINI_API_KEY environment variable is required.'); -} -void describe('Gemini Eval Scenarios', () => { - const server = new TestServer(TestServer.randomPort()); +async function main() { + const apiKey = process.env.GEMINI_API_KEY; + if (!apiKey) { + throw new Error('GEMINI_API_KEY environment variable is required.'); + } - before(async () => { - await server.start(); + const {values, positionals} = parseArgs({ + options: { + model: { + type: 'string', + default: 'gemini-2.5-flash', + }, + debug: { + type: 'boolean', + default: false, + }, + }, + allowPositionals: true, }); - after(async () => { - await server.stop(); - }); + const modelId = values.model; + const debug = values.debug; + const scenarioFiles = + positionals.length > 0 + ? positionals.map(p => path.resolve(p)) + : fs + .readdirSync(SCENARIOS_DIR) + .filter(file => file.endsWith('.ts') || file.endsWith('.js')) + .map(file => path.join(SCENARIOS_DIR, file)); - afterEach(() => { - server.restore(); - }); + const server = new TestServer(TestServer.randomPort()); + await server.start(); - const files = fs.readdirSync(SCENARIOS_DIR).filter(file => { - return file.endsWith('.ts') || file.endsWith('.js'); - }); + let successCount = 0; + let failureCount = 0; - for (const file of files) { - void test(file, {timeout: 60_000}, async () => { - await runSingleScenario(path.join(SCENARIOS_DIR, file), apiKey, server); - }); + try { + for (const scenarioPath of scenarioFiles) { + try { + await runSingleScenario(scenarioPath, apiKey, server, modelId, debug); + console.log(`✔ ${path.relative(ROOT_DIR, scenarioPath)}`); + successCount++; + } catch (e) { + console.error(`✖ ${path.relative(ROOT_DIR, scenarioPath)}`); + console.error(e); + failureCount++; + } finally { + server.restore(); + } + } + } finally { + await server.stop(); } + + console.log(`\nSummary: ${successCount} passed, ${failureCount} failed`); + + if (failureCount > 0) { + process.exit(1); + } +} + +main().catch(error => { + console.error('Fatal error:', error); + process.exit(1); });