diff --git a/src/engine/anthropic.ts b/src/engine/anthropic.ts index 708468ef..7dbc5936 100644 --- a/src/engine/anthropic.ts +++ b/src/engine/anthropic.ts @@ -8,6 +8,7 @@ import axios from 'axios'; import chalk from 'chalk'; import { OpenAI } from 'openai'; import { GenerateCommitMessageErrorEnum } from '../generateCommitMessageFromGitDiff'; +import { removeContentTags } from '../utils/removeContentTags'; import { tokenCount } from '../utils/tokenCount'; import { AiEngine, AiEngineConfig } from './Engine'; @@ -54,8 +55,8 @@ export class AnthropicEngine implements AiEngine { const data = await this.client.messages.create(params); const message = data?.content[0].text; - - return message; + let content = message; + return removeContentTags(content, 'think'); } catch (error) { const err = error as Error; outro(`${chalk.red('✖')} ${err?.message || err}`); diff --git a/src/engine/azure.ts b/src/engine/azure.ts index db1b6562..acbf0d39 100644 --- a/src/engine/azure.ts +++ b/src/engine/azure.ts @@ -7,6 +7,7 @@ import axios from 'axios'; import chalk from 'chalk'; import { OpenAI } from 'openai'; import { GenerateCommitMessageErrorEnum } from '../generateCommitMessageFromGitDiff'; +import { removeContentTags } from '../utils/removeContentTags'; import { tokenCount } from '../utils/tokenCount'; import { AiEngine, AiEngineConfig } from './Engine'; @@ -52,7 +53,9 @@ export class AzureEngine implements AiEngine { if (message?.content === null) { return undefined; } - return message?.content; + + let content = message?.content; + return removeContentTags(content, 'think'); } catch (error) { outro(`${chalk.red('✖')} ${this.config.model}`); diff --git a/src/engine/deepseek.ts b/src/engine/deepseek.ts index 9424a9aa..61b2cebc 100644 --- a/src/engine/deepseek.ts +++ b/src/engine/deepseek.ts @@ -1,8 +1,9 @@ import axios from 'axios'; import { OpenAI } from 'openai'; import { GenerateCommitMessageErrorEnum } from '../generateCommitMessageFromGitDiff'; +import { removeContentTags } from '../utils/removeContentTags'; import { tokenCount } from '../utils/tokenCount'; -import { OpenAiEngine, OpenAiConfig } from './openAI'; +import { OpenAiEngine, OpenAiConfig } from './openAi'; export interface DeepseekConfig extends OpenAiConfig {} @@ -41,8 +42,8 @@ export class DeepseekEngine extends OpenAiEngine { const completion = await this.client.chat.completions.create(params); const message = completion.choices[0].message; - - return message?.content; + let content = message?.content; + return removeContentTags(content, 'think'); } catch (error) { const err = error as Error; if ( diff --git a/src/engine/flowise.ts b/src/engine/flowise.ts index 00890f5c..aecc2772 100644 --- a/src/engine/flowise.ts +++ b/src/engine/flowise.ts @@ -1,5 +1,6 @@ import axios, { AxiosInstance } from 'axios'; import { OpenAI } from 'openai'; +import { removeContentTags } from '../utils/removeContentTags'; import { AiEngine, AiEngineConfig } from './Engine'; interface FlowiseAiConfig extends AiEngineConfig {} @@ -36,7 +37,8 @@ export class FlowiseEngine implements AiEngine { try { const response = await this.client.post('', payload); const message = response.data; - return message?.text; + let content = message?.text; + return removeContentTags(content, 'think'); } catch (err: any) { const message = err.response?.data?.error ?? err.message; throw new Error('local model issues. details: ' + message); diff --git a/src/engine/gemini.ts b/src/engine/gemini.ts index e2e63acc..d56cf074 100644 --- a/src/engine/gemini.ts +++ b/src/engine/gemini.ts @@ -7,6 +7,7 @@ import { } from '@google/generative-ai'; import axios from 'axios'; import { OpenAI } from 'openai'; +import { removeContentTags } from '../utils/removeContentTags'; import { AiEngine, AiEngineConfig } from './Engine'; interface GeminiConfig extends AiEngineConfig {} @@ -71,7 +72,8 @@ export class GeminiEngine implements AiEngine { } }); - return result.response.text(); + const content = result.response.text(); + return removeContentTags(content, 'think'); } catch (error) { const err = error as Error; if ( diff --git a/src/engine/mistral.ts b/src/engine/mistral.ts index ce480f2e..8a7e932e 100644 --- a/src/engine/mistral.ts +++ b/src/engine/mistral.ts @@ -1,27 +1,21 @@ import axios from 'axios'; -import { Mistral } from '@mistralai/mistralai'; import { OpenAI } from 'openai'; import { GenerateCommitMessageErrorEnum } from '../generateCommitMessageFromGitDiff'; +import { removeContentTags } from '../utils/removeContentTags'; import { tokenCount } from '../utils/tokenCount'; import { AiEngine, AiEngineConfig } from './Engine'; -import { - AssistantMessage as MistralAssistantMessage, - SystemMessage as MistralSystemMessage, - ToolMessage as MistralToolMessage, - UserMessage as MistralUserMessage -} from '@mistralai/mistralai/models/components'; +// Using any for Mistral types to avoid type declaration issues export interface MistralAiConfig extends AiEngineConfig {} -export type MistralCompletionMessageParam = Array< -| (MistralSystemMessage & { role: "system" }) -| (MistralUserMessage & { role: "user" }) -| (MistralAssistantMessage & { role: "assistant" }) -| (MistralToolMessage & { role: "tool" }) -> +export type MistralCompletionMessageParam = Array; + +// Import Mistral dynamically to avoid TS errors +// eslint-disable-next-line @typescript-eslint/no-var-requires +const Mistral = require('@mistralai/mistralai').Mistral; export class MistralAiEngine implements AiEngine { config: MistralAiConfig; - client: Mistral; + client: any; // Using any type for Mistral client to avoid TS errors constructor(config: MistralAiConfig) { this.config = config; @@ -64,7 +58,8 @@ export class MistralAiEngine implements AiEngine { if (!message || !message.content) throw Error('No completion choice available.') - return message.content as string; + let content = message.content as string; + return removeContentTags(content, 'think'); } catch (error) { const err = error as Error; if ( diff --git a/src/engine/mlx.ts b/src/engine/mlx.ts index 4e324e5d..f83a19a8 100644 --- a/src/engine/mlx.ts +++ b/src/engine/mlx.ts @@ -1,7 +1,7 @@ import axios, { AxiosInstance } from 'axios'; import { OpenAI } from 'openai'; +import { removeContentTags } from '../utils/removeContentTags'; import { AiEngine, AiEngineConfig } from './Engine'; -import { chown } from 'fs'; interface MLXConfig extends AiEngineConfig {} @@ -37,11 +37,11 @@ export class MLXEngine implements AiEngine { const choices = response.data.choices; const message = choices[0].message; - - return message?.content; + let content = message?.content; + return removeContentTags(content, 'think'); } catch (err: any) { const message = err.response?.data?.error ?? err.message; throw new Error(`MLX provider error: ${message}`); } } -} \ No newline at end of file +} diff --git a/src/engine/ollama.ts b/src/engine/ollama.ts index af3b927e..2d21d637 100644 --- a/src/engine/ollama.ts +++ b/src/engine/ollama.ts @@ -1,5 +1,6 @@ import axios, { AxiosInstance } from 'axios'; import { OpenAI } from 'openai'; +import { removeContentTags } from '../utils/removeContentTags'; import { AiEngine, AiEngineConfig } from './Engine'; interface OllamaConfig extends AiEngineConfig {} @@ -35,12 +36,7 @@ export class OllamaEngine implements AiEngine { const { message } = response.data; let content = message?.content; - - if (content && content.includes('')) { - return content.replace(/[\s\S]*?<\/think>/g, '').trim(); - } - - return content; + return removeContentTags(content, 'think'); } catch (err: any) { const message = err.response?.data?.error ?? err.message; throw new Error(`Ollama provider error: ${message}`); diff --git a/src/engine/openAi.ts b/src/engine/openAi.ts index ea5d9e9a..4e1c6a99 100644 --- a/src/engine/openAi.ts +++ b/src/engine/openAi.ts @@ -1,6 +1,7 @@ import axios from 'axios'; import { OpenAI } from 'openai'; import { GenerateCommitMessageErrorEnum } from '../generateCommitMessageFromGitDiff'; +import { removeContentTags } from '../utils/removeContentTags'; import { tokenCount } from '../utils/tokenCount'; import { AiEngine, AiEngineConfig } from './Engine'; @@ -45,8 +46,8 @@ export class OpenAiEngine implements AiEngine { const completion = await this.client.chat.completions.create(params); const message = completion.choices[0].message; - - return message?.content; + let content = message?.content; + return removeContentTags(content, 'think'); } catch (error) { const err = error as Error; if ( diff --git a/src/utils/removeContentTags.ts b/src/utils/removeContentTags.ts new file mode 100644 index 00000000..d478434c --- /dev/null +++ b/src/utils/removeContentTags.ts @@ -0,0 +1,51 @@ +/** + * Removes content wrapped in specified tags from a string + * @param content The content string to process + * @param tag The tag name without angle brackets (e.g., 'think' for '') + * @returns The content with the specified tags and their contents removed, and trimmed + */ +export function removeContentTags(content: T, tag: string): T { + if (!content || typeof content !== 'string') { + return content; + } + + // Dynamic implementation for other cases + const openTag = `<${tag}>`; + const closeTag = ``; + + // Parse the content and remove tags + let result = ''; + let skipUntil: number | null = null; + let depth = 0; + + for (let i = 0; i < content.length; i++) { + // Check for opening tag + if (content.substring(i, i + openTag.length) === openTag) { + depth++; + if (depth === 1) { + skipUntil = content.indexOf(closeTag, i + openTag.length); + i = i + openTag.length - 1; // Skip the opening tag + continue; + } + } + // Check for closing tag + else if (content.substring(i, i + closeTag.length) === closeTag && depth > 0) { + depth--; + if (depth === 0) { + i = i + closeTag.length - 1; // Skip the closing tag + skipUntil = null; + continue; + } + } + + // Only add character if not inside a tag + if (skipUntil === null) { + result += content[i]; + } + } + + // Normalize spaces (replace multiple spaces with a single space) + result = result.replace(/\s+/g, ' ').trim(); + + return result as unknown as T; +} diff --git a/test/unit/removeContentTags.test.ts b/test/unit/removeContentTags.test.ts new file mode 100644 index 00000000..96f28fa1 --- /dev/null +++ b/test/unit/removeContentTags.test.ts @@ -0,0 +1,57 @@ +import { removeContentTags } from '../../src/utils/removeContentTags'; + +describe('removeContentTags', () => { + it('should remove content wrapped in specified tags', () => { + const content = 'This is something to hide visible content'; + const result = removeContentTags(content, 'think'); + expect(result).toBe('This is visible content'); + }); + + it('should handle multiple tag occurrences', () => { + const content = 'hidden visible also hidden text'; + const result = removeContentTags(content, 'think'); + expect(result).toBe('visible text'); + }); + + it('should handle multiline content within tags', () => { + const content = 'Start hidden\nover multiple\nlines End'; + const result = removeContentTags(content, 'think'); + expect(result).toBe('Start End'); + }); + + it('should return content as is when tag is not found', () => { + const content = 'Content without any tags'; + const result = removeContentTags(content, 'think'); + expect(result).toBe('Content without any tags'); + }); + + it('should work with different tag names', () => { + const content = 'This is something to hide visible content'; + const result = removeContentTags(content, 'custom'); + expect(result).toBe('This is visible content'); + }); + + it('should handle null content', () => { + const content = null; + const result = removeContentTags(content, 'think'); + expect(result).toBe(null); + }); + + it('should handle undefined content', () => { + const content = undefined; + const result = removeContentTags(content, 'think'); + expect(result).toBe(undefined); + }); + + it('should trim the result', () => { + const content = ' hidden visible '; + const result = removeContentTags(content, 'think'); + expect(result).toBe('visible'); + }); + + it('should handle nested tags correctly', () => { + const content = 'Outside Inside Nested End'; + const result = removeContentTags(content, 'think'); + expect(result).toBe('Outside End'); + }); +});