Skip to content

Commit f0f1a90

Browse files
Merge PR #475: feat(flutter): config validation for LLM/STT/TTS (revived from #456, Fixes #450)
Adds Dart-layer validation for LLM/STT/TTS configuration so invalid values throw SDKError.validationFailed before crossing the FFI boundary, and threads the real model contextLength from the registry so maxTokens is bounded against the actual model window. Original author: @DevDesai-444 (via sanchitmonga22 revival) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2 parents 0e72273 + f7c8ee2 commit f0f1a90

11 files changed

Lines changed: 255 additions & 34 deletions
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import 'package:runanywhere/core/protocols/component/component_configuration.dart';
2+
import 'package:runanywhere/core/types/model_types.dart';
3+
import 'package:runanywhere/foundation/error_types/sdk_error.dart';
4+
5+
/// Configuration for the LLM component.
6+
///
7+
/// Mirrors the validation contract used by the Swift and Kotlin SDKs so
8+
/// invalid parameters fail in Dart before crossing the FFI boundary.
9+
class LLMConfiguration implements ComponentConfiguration {
10+
final String? modelId;
11+
final InferenceFramework? preferredFramework;
12+
final int contextLength;
13+
final double temperature;
14+
final int maxTokens;
15+
final String? systemPrompt;
16+
final bool streamingEnabled;
17+
18+
const LLMConfiguration({
19+
this.modelId,
20+
this.preferredFramework,
21+
this.contextLength = 2048,
22+
this.temperature = 0.7,
23+
this.maxTokens = 100,
24+
this.systemPrompt,
25+
this.streamingEnabled = true,
26+
});
27+
28+
@override
29+
void validate() {
30+
if (contextLength <= 0) {
31+
throw SDKError.validationFailed(
32+
'Context length must be greater than 0',
33+
);
34+
}
35+
36+
if (!temperature.isFinite || temperature < 0 || temperature > 2.0) {
37+
throw SDKError.validationFailed(
38+
'Temperature must be between 0 and 2.0',
39+
);
40+
}
41+
42+
if (maxTokens <= 0 || maxTokens > contextLength) {
43+
throw SDKError.validationFailed(
44+
'Max tokens must be between 1 and context length',
45+
);
46+
}
47+
48+
// Guard against clearly oversized prompts (chars) — a system prompt larger
49+
// than the model's context window (in chars) is clearly invalid.
50+
// Uses ~4 chars per token as a generous char-level bound.
51+
final prompt = systemPrompt;
52+
if (prompt != null && prompt.length > contextLength * 4) {
53+
throw SDKError.validationFailed(
54+
"systemPrompt length (${prompt.length} chars) exceeds the model's context window",
55+
);
56+
}
57+
}
58+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import 'package:runanywhere/core/protocols/component/component_configuration.dart';
2+
import 'package:runanywhere/core/types/model_types.dart';
3+
import 'package:runanywhere/foundation/error_types/sdk_error.dart';
4+
5+
/// Configuration for the STT component.
6+
///
7+
/// Mirrors the validation contract used by the Swift and Kotlin SDKs so
8+
/// invalid parameters fail in Dart before crossing the FFI boundary.
9+
class STTConfiguration implements ComponentConfiguration {
10+
final String? modelId;
11+
final InferenceFramework? preferredFramework;
12+
final String language;
13+
final int sampleRate;
14+
final bool enablePunctuation;
15+
final bool enableDiarization;
16+
final List<String> vocabularyList;
17+
final int maxAlternatives;
18+
final bool enableTimestamps;
19+
20+
const STTConfiguration({
21+
this.modelId,
22+
this.preferredFramework,
23+
this.language = 'en-US',
24+
this.sampleRate = 16000,
25+
this.enablePunctuation = true,
26+
this.enableDiarization = false,
27+
this.vocabularyList = const <String>[],
28+
this.maxAlternatives = 1,
29+
this.enableTimestamps = true,
30+
});
31+
32+
@override
33+
void validate() {
34+
if (sampleRate <= 0 || sampleRate > 48000) {
35+
throw SDKError.validationFailed(
36+
'Sample rate must be between 1 and 48000 Hz',
37+
);
38+
}
39+
40+
if (maxAlternatives <= 0 || maxAlternatives > 10) {
41+
throw SDKError.validationFailed(
42+
'Max alternatives must be between 1 and 10',
43+
);
44+
}
45+
}
46+
}

sdk/runanywhere-flutter/packages/runanywhere/lib/features/tts/system_tts_service.dart

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,7 @@ import 'dart:async';
88
import 'dart:typed_data';
99

1010
import 'package:flutter_tts/flutter_tts.dart';
11-
12-
/// Configuration for TTS synthesis
13-
class TTSConfiguration {
14-
final String voice;
15-
final String language;
16-
final double speakingRate;
17-
final double pitch;
18-
final double volume;
19-
final String audioFormat;
20-
21-
const TTSConfiguration({
22-
this.voice = 'system',
23-
this.language = 'en-US',
24-
this.speakingRate = 0.5,
25-
this.pitch = 1.0,
26-
this.volume = 1.0,
27-
this.audioFormat = 'pcm',
28-
});
29-
}
11+
import 'package:runanywhere/features/tts/tts_configuration.dart';
3012

3113
/// Input for TTS synthesis
3214
class TTSSynthesisInput {
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import 'package:runanywhere/core/protocols/component/component_configuration.dart';
2+
import 'package:runanywhere/foundation/error_types/sdk_error.dart';
3+
4+
/// Configuration for TTS synthesis.
5+
class TTSConfiguration implements ComponentConfiguration {
6+
final String voice;
7+
final String language;
8+
final double speakingRate;
9+
final double pitch;
10+
final double volume;
11+
final String audioFormat;
12+
13+
const TTSConfiguration({
14+
this.voice = 'system',
15+
this.language = 'en-US',
16+
this.speakingRate = 1.0,
17+
this.pitch = 1.0,
18+
this.volume = 1.0,
19+
this.audioFormat = 'pcm',
20+
});
21+
22+
@override
23+
void validate() {
24+
if (!speakingRate.isFinite || speakingRate < 0.5 || speakingRate > 2.0) {
25+
throw SDKError.validationFailed(
26+
'Speaking rate must be between 0.5 and 2.0',
27+
);
28+
}
29+
30+
if (!pitch.isFinite || pitch < 0.5 || pitch > 2.0) {
31+
throw SDKError.validationFailed(
32+
'Pitch must be between 0.5 and 2.0',
33+
);
34+
}
35+
36+
if (!volume.isFinite || volume < 0.0 || volume > 1.0) {
37+
throw SDKError.validationFailed(
38+
'Volume must be between 0.0 and 1.0',
39+
);
40+
}
41+
}
42+
}

sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_llm.dart

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import 'dart:ffi';
1818
import 'dart:isolate'; // Keep for non-streaming generation
1919

2020
import 'package:ffi/ffi.dart';
21-
21+
import 'package:runanywhere/features/llm/llm_configuration.dart';
2222
import 'package:runanywhere/foundation/logging/sdk_logger.dart';
2323
import 'package:runanywhere/native/ffi_types.dart';
2424
import 'package:runanywhere/native/native_functions.dart';
@@ -49,6 +49,7 @@ class DartBridgeLLM {
4949

5050
RacHandle? _handle;
5151
String? _loadedModelId;
52+
int? _loadedContextLength;
5253
final _logger = SDKLogger('DartBridge.LLM');
5354

5455
/// Active stream subscription for cancellation
@@ -142,6 +143,7 @@ class DartBridgeLLM {
142143
String modelPath,
143144
String modelId,
144145
String modelName,
146+
int? contextLength,
145147
) async {
146148
final handle = getHandle();
147149

@@ -162,6 +164,7 @@ class DartBridgeLLM {
162164
}
163165

164166
_loadedModelId = modelId;
167+
_loadedContextLength = contextLength;
165168
_logger.info('LLM model loaded: $modelId');
166169
} finally {
167170
calloc.free(pathPtr);
@@ -177,6 +180,7 @@ class DartBridgeLLM {
177180
try {
178181
NativeFunctions.llmCleanup(_handle!);
179182
_loadedModelId = null;
183+
_loadedContextLength = null;
180184
_logger.info('LLM model unloaded');
181185
} catch (e) {
182186
_logger.error('Failed to unload LLM model: $e');
@@ -220,6 +224,13 @@ class DartBridgeLLM {
220224
throw StateError('No LLM model loaded. Call loadModel() first.');
221225
}
222226

227+
_validateGenerationParameters(
228+
contextLength: _requireLoadedContextLength(),
229+
maxTokens: maxTokens,
230+
temperature: temperature,
231+
systemPrompt: systemPrompt,
232+
);
233+
223234
// Run FFI call in a separate isolate to avoid heap corruption
224235
// from C++ background threads (Metal GPU operations)
225236
final handleAddress = handle.address;
@@ -263,6 +274,14 @@ class DartBridgeLLM {
263274
throw StateError('No LLM model loaded. Call loadModel() first.');
264275
}
265276

277+
_validateGenerationParameters(
278+
contextLength: _requireLoadedContextLength(),
279+
maxTokens: maxTokens,
280+
temperature: temperature,
281+
systemPrompt: systemPrompt,
282+
streamingEnabled: true,
283+
);
284+
266285
// Create stream controller for emitting tokens to the caller
267286
final controller = StreamController<String>();
268287

@@ -308,11 +327,11 @@ class DartBridgeLLM {
308327
controller.add(message);
309328
} else if (message is _StreamingMessage) {
310329
if (message.isComplete) {
311-
controller.close();
330+
unawaited(controller.close());
312331
receivePort.close();
313332
} else if (message.error != null) {
314333
controller.addError(StateError(message.error!));
315-
controller.close();
334+
unawaited(controller.close());
316335
receivePort.close();
317336
}
318337
}
@@ -340,6 +359,29 @@ class DartBridgeLLM {
340359
}
341360
}
342361

362+
int _requireLoadedContextLength() {
363+
final contextLength = _loadedContextLength;
364+
// Fall back to a generous ceiling when registry metadata is absent,
365+
// so generation is not blocked for models without explicit contextLength.
366+
return (contextLength != null && contextLength > 0) ? contextLength : 32768;
367+
}
368+
369+
void _validateGenerationParameters({
370+
required int contextLength,
371+
required int maxTokens,
372+
required double temperature,
373+
String? systemPrompt,
374+
bool streamingEnabled = false,
375+
}) {
376+
LLMConfiguration(
377+
contextLength: contextLength,
378+
maxTokens: maxTokens,
379+
temperature: temperature,
380+
systemPrompt: systemPrompt,
381+
streamingEnabled: streamingEnabled,
382+
).validate();
383+
}
384+
343385
// MARK: - Cleanup
344386

345387
/// Destroy the component and release resources.
@@ -349,6 +391,7 @@ class DartBridgeLLM {
349391
NativeFunctions.llmDestroy(_handle!);
350392
_handle = null;
351393
_loadedModelId = null;
394+
_loadedContextLength = null;
352395
_logger.debug('LLM component destroyed');
353396
} catch (e) {
354397
_logger.error('Failed to destroy LLM component: $e');

sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_assignment.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class DartBridgeModelAssignment {
253253
framework: struct.ref.framework,
254254
source: struct.ref.source,
255255
sizeBytes: struct.ref.sizeBytes,
256+
contextLength: struct.ref.contextLength,
256257
downloadURL: struct.ref.downloadURL != nullptr
257258
? struct.ref.downloadURL.toDartString()
258259
: null,

sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_registry.dart

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class DartBridgeModelRegistry {
151151
modelPtr.ref.localPath =
152152
pathDart != null ? strdupFn(pathDart) : nullptr;
153153
modelPtr.ref.downloadSize = model.sizeBytes;
154+
modelPtr.ref.contextLength = model.contextLength;
154155
modelPtr.ref.source = model.source;
155156

156157
final result = saveFn(_registryHandle!, modelPtr);
@@ -187,7 +188,15 @@ class DartBridgeModelRegistry {
187188
}
188189

189190
try {
190-
// Convert public ModelInfo to FFI ModelInfo
191+
// Convert public ModelInfo to FFI ModelInfo.
192+
//
193+
// Nullable -> non-nullable at the adapter boundary:
194+
// public_types.ModelInfo.downloadSize and .contextLength are `int?`
195+
// (null means "unknown"), while the internal FFI ModelInfo uses
196+
// non-nullable `int` to mirror the C struct (which uses 0 as the
197+
// sentinel for "unset"). The `?? 0` here encodes that null -> 0
198+
// convention; the reverse conversion in `_ffiModelToPublic` maps
199+
// `> 0 ? value : null` back to public types.
191200
final ffiModel = ModelInfo(
192201
id: model.id,
193202
name: model.name,
@@ -196,6 +205,7 @@ class DartBridgeModelRegistry {
196205
framework: _frameworkToFfi(model.framework),
197206
source: _sourceToFfi(model.source),
198207
sizeBytes: model.downloadSize ?? 0,
208+
contextLength: model.contextLength ?? 0,
199209
downloadURL: model.downloadURL?.toString(),
200210
localPath: model.localPath?.toFilePath(),
201211
version: null,
@@ -386,6 +396,7 @@ class DartBridgeModelRegistry {
386396
? Uri.file(ffiModel.localPath!)
387397
: null,
388398
downloadSize: ffiModel.sizeBytes > 0 ? ffiModel.sizeBytes : null,
399+
contextLength: ffiModel.contextLength > 0 ? ffiModel.contextLength : null,
389400
source: _sourceFromFfi(ffiModel.source),
390401
);
391402
}
@@ -806,6 +817,7 @@ class DartBridgeModelRegistry {
806817
framework: struct.ref.framework,
807818
source: struct.ref.source,
808819
sizeBytes: struct.ref.downloadSize,
820+
contextLength: struct.ref.contextLength,
809821
downloadURL: downloadURL,
810822
localPath: localPath,
811823
version: null,
@@ -1034,6 +1046,9 @@ base class RacModelInfoStruct extends Struct {
10341046
@Int64()
10351047
external int sizeBytes;
10361048

1049+
@Int32()
1050+
external int contextLength;
1051+
10371052
external Pointer<Utf8> downloadURL;
10381053
external Pointer<Utf8> localPath;
10391054
external Pointer<Utf8> version;
@@ -1094,6 +1109,7 @@ class ModelInfo {
10941109
final int framework;
10951110
final int source;
10961111
final int sizeBytes;
1112+
final int contextLength;
10971113
final String? downloadURL;
10981114
final String? localPath;
10991115
final String? version;
@@ -1106,6 +1122,7 @@ class ModelInfo {
11061122
required this.framework,
11071123
required this.source,
11081124
required this.sizeBytes,
1125+
required this.contextLength,
11091126
this.downloadURL,
11101127
this.localPath,
11111128
this.version,
@@ -1121,6 +1138,7 @@ class ModelInfo {
11211138
'framework': framework,
11221139
'source': source,
11231140
'sizeBytes': sizeBytes,
1141+
'contextLength': contextLength,
11241142
if (downloadURL != null) 'downloadURL': downloadURL,
11251143
if (localPath != null) 'localPath': localPath,
11261144
if (version != null) 'version': version,

0 commit comments

Comments
 (0)