Skip to content

Commit 9e0c70d

Browse files
skottmckayCopilotCopilot
authored
Use IModel in the public API. (#556)
Use IModel in the public API. Changes allow ICatalog and IModel to be stubbed for testing as you no longer need a concrete Model or ModelVariant class. - Make Model and ModelVariant implementation details - Add variant info and selection to IModel so it works with either Model or ModelVariant - Move GetLatestVersion to Catalog and take IModel as input - ModelVariant has insufficient info to implement this and intuitively the catalog should know this information. - Update tests - fix usage of test config file for shared test data path --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: skottmckay <979079+skottmckay@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent d1a9e3c commit 9e0c70d

16 files changed

Lines changed: 272 additions & 149 deletions

File tree

samples/cs/GettingStarted/src/ModelManagementExample/Program.cs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,39 +51,35 @@
5151
// Get a model using an alias from the catalog
5252
var model = await catalog.GetModelAsync("qwen2.5-0.5b") ?? throw new Exception("Model not found");
5353

54-
// `model.SelectedVariant` indicates which variant will be used by default.
55-
//
5654
// Models in Model.Variants are ordered by priority, with the highest priority first.
5755
// The first downloaded model is selected by default.
5856
// The highest priority is selected if no models have been downloaded.
5957
// If the selected variant is not the highest priority, it means that Foundry Local
6058
// has found a locally cached variant for you to improve performance (remove need to download).
6159
Console.WriteLine("\nThe default selected model variant is: " + model.Id);
62-
if (model.SelectedVariant != model.Variants.First())
60+
if (model.Id != model.Variants.First().Id)
6361
{
64-
Debug.Assert(await model.SelectedVariant.IsCachedAsync());
62+
Debug.Assert(await model.IsCachedAsync());
6563
Console.WriteLine("The model variant was selected due to being locally cached.");
6664
}
6765

6866

69-
// OPTIONAL: `model` can be used directly and `model.SelectedVariant` will be used as the default.
70-
// You can explicitly select or use a specific ModelVariant if you want more control
71-
// over the device and/or execution provider used.
72-
// Model and ModelVariant can be used interchangeably in methods such as
73-
// DownloadAsync, LoadAsync, UnloadAsync and GetChatClientAsync.
67+
// OPTIONAL: `model` can be used directly with its currently selected variant.
68+
// You can explicitly select (`model.SelectVariant`) or use a specific variant from `model.Variants`
69+
// if you want more control over the device and/or execution provider used.
7470
//
7571
// Choices:
76-
// - Use a ModelVariant directly from the catalog if you know the variant Id
72+
// - Use a model variant directly from the catalog if you know the variant Id
7773
// - `var modelVariant = await catalog.GetModelVariantAsync("qwen2.5-0.5b-instruct-generic-gpu:3")`
7874
//
79-
// - Get the ModelVariant from Model.Variants
75+
// - Get the model variant from IModel.Variants
8076
// - `var modelVariant = model.Variants.First(v => v.Id == "qwen2.5-0.5b-instruct-generic-cpu:4")`
8177
// - `var modelVariant = model.Variants.First(v => v.Info.Runtime?.DeviceType == DeviceType.GPU)`
8278
// - optional: update selected variant in `model` using `model.SelectVariant(modelVariant);` if you wish to use
8379
// `model` in your code.
8480

8581
// For this example we explicitly select the CPU variant, and call SelectVariant so all the following example code
86-
// uses the `model` instance.
82+
// uses the `model` instance. It would be equally valid to use `modelVariant` directly.
8783
Console.WriteLine("Selecting CPU variant of model");
8884
var modelVariant = model.Variants.First(v => v.Info.Runtime?.DeviceType == DeviceType.CPU);
8985
model.SelectVariant(modelVariant);

sdk/cs/src/Catalog.cs

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,51 +52,59 @@ internal static async Task<Catalog> CreateAsync(IModelLoadManager modelManager,
5252
return catalog;
5353
}
5454

55-
public async Task<List<Model>> ListModelsAsync(CancellationToken? ct = null)
55+
public async Task<List<IModel>> ListModelsAsync(CancellationToken? ct = null)
5656
{
5757
return await Utils.CallWithExceptionHandling(() => ListModelsImplAsync(ct),
5858
"Error listing models.", _logger).ConfigureAwait(false);
5959
}
6060

61-
public async Task<List<ModelVariant>> GetCachedModelsAsync(CancellationToken? ct = null)
61+
public async Task<List<IModel>> GetCachedModelsAsync(CancellationToken? ct = null)
6262
{
6363
return await Utils.CallWithExceptionHandling(() => GetCachedModelsImplAsync(ct),
6464
"Error getting cached models.", _logger).ConfigureAwait(false);
6565
}
6666

67-
public async Task<List<ModelVariant>> GetLoadedModelsAsync(CancellationToken? ct = null)
67+
public async Task<List<IModel>> GetLoadedModelsAsync(CancellationToken? ct = null)
6868
{
6969
return await Utils.CallWithExceptionHandling(() => GetLoadedModelsImplAsync(ct),
7070
"Error getting loaded models.", _logger).ConfigureAwait(false);
7171
}
7272

73-
public async Task<Model?> GetModelAsync(string modelAlias, CancellationToken? ct = null)
73+
public async Task<IModel?> GetModelAsync(string modelAlias, CancellationToken? ct = null)
7474
{
7575
return await Utils.CallWithExceptionHandling(() => GetModelImplAsync(modelAlias, ct),
7676
$"Error getting model with alias '{modelAlias}'.", _logger)
7777
.ConfigureAwait(false);
7878
}
7979

80-
public async Task<ModelVariant?> GetModelVariantAsync(string modelId, CancellationToken? ct = null)
80+
public async Task<IModel?> GetModelVariantAsync(string modelId, CancellationToken? ct = null)
8181
{
8282
return await Utils.CallWithExceptionHandling(() => GetModelVariantImplAsync(modelId, ct),
8383
$"Error getting model variant with ID '{modelId}'.", _logger)
8484
.ConfigureAwait(false);
8585
}
8686

87-
private async Task<List<Model>> ListModelsImplAsync(CancellationToken? ct = null)
87+
public async Task<IModel> GetLatestVersionAsync(IModel modelOrModelVariant, CancellationToken? ct = null)
88+
{
89+
return await Utils.CallWithExceptionHandling(
90+
() => GetLatestVersionImplAsync(modelOrModelVariant, ct),
91+
$"Error getting latest version for model with name '{modelOrModelVariant.Info.Name}'.",
92+
_logger).ConfigureAwait(false);
93+
}
94+
95+
private async Task<List<IModel>> ListModelsImplAsync(CancellationToken? ct = null)
8896
{
8997
await UpdateModels(ct).ConfigureAwait(false);
9098

9199
using var disposable = await _lock.LockAsync().ConfigureAwait(false);
92-
return _modelAliasToModel.Values.OrderBy(m => m.Alias).ToList();
100+
return _modelAliasToModel.Values.OrderBy(m => m.Alias).Cast<IModel>().ToList();
93101
}
94102

95-
private async Task<List<ModelVariant>> GetCachedModelsImplAsync(CancellationToken? ct = null)
103+
private async Task<List<IModel>> GetCachedModelsImplAsync(CancellationToken? ct = null)
96104
{
97105
var cachedModelIds = await Utils.GetCachedModelIdsAsync(_coreInterop, ct).ConfigureAwait(false);
98106

99-
List<ModelVariant> cachedModels = new();
107+
List<IModel> cachedModels = [];
100108
foreach (var modelId in cachedModelIds)
101109
{
102110
if (_modelIdToModelVariant.TryGetValue(modelId, out ModelVariant? modelVariant))
@@ -108,10 +116,10 @@ private async Task<List<ModelVariant>> GetCachedModelsImplAsync(CancellationToke
108116
return cachedModels;
109117
}
110118

111-
private async Task<List<ModelVariant>> GetLoadedModelsImplAsync(CancellationToken? ct = null)
119+
private async Task<List<IModel>> GetLoadedModelsImplAsync(CancellationToken? ct = null)
112120
{
113121
var loadedModelIds = await _modelLoadManager.ListLoadedModelsAsync(ct).ConfigureAwait(false);
114-
List<ModelVariant> loadedModels = new();
122+
List<IModel> loadedModels = [];
115123

116124
foreach (var modelId in loadedModelIds)
117125
{
@@ -143,6 +151,45 @@ private async Task<List<ModelVariant>> GetLoadedModelsImplAsync(CancellationToke
143151
return modelVariant;
144152
}
145153

154+
private async Task<IModel> GetLatestVersionImplAsync(IModel modelOrModelVariant, CancellationToken? ct)
155+
{
156+
Model? model;
157+
158+
if (modelOrModelVariant is ModelVariant)
159+
{
160+
// For ModelVariant, resolve the owning Model via alias.
161+
model = await GetModelImplAsync(modelOrModelVariant.Alias, ct);
162+
}
163+
else
164+
{
165+
// Try to use the concrete Model instance if this is our SDK type.
166+
model = modelOrModelVariant as Model;
167+
168+
// If this is a different IModel implementation (e.g., a test stub),
169+
// fall back to resolving the Model via alias.
170+
if (model == null)
171+
{
172+
model = await GetModelImplAsync(modelOrModelVariant.Alias, ct);
173+
}
174+
}
175+
176+
if (model == null)
177+
{
178+
throw new FoundryLocalException($"Model with alias '{modelOrModelVariant.Alias}' not found in catalog.",
179+
_logger);
180+
}
181+
182+
// variants are sorted by version, so the first one matching the name is the latest version for that variant.
183+
var latest = model!.Variants.FirstOrDefault(v => v.Info.Name == modelOrModelVariant.Info.Name) ??
184+
// should not be possible given we internally manage all the state involved
185+
throw new FoundryLocalException($"Internal error. Mismatch between model (alias:{model.Alias}) and " +
186+
$"model variant (alias:{modelOrModelVariant.Alias}).", _logger);
187+
188+
// if input was the latest return the input (could be model or model variant)
189+
// otherwise return the latest model variant
190+
return latest.Id == modelOrModelVariant.Id ? modelOrModelVariant : latest;
191+
}
192+
146193
private async Task UpdateModels(CancellationToken? ct)
147194
{
148195
// TODO: make this configurable
Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ public class Model : IModel
1212
{
1313
private readonly ILogger _logger;
1414

15-
public List<ModelVariant> Variants { get; internal set; }
16-
public ModelVariant SelectedVariant { get; internal set; } = default!;
15+
private readonly List<IModel> _variants;
16+
public IReadOnlyList<IModel> Variants => _variants;
17+
internal IModel SelectedVariant { get; set; } = default!;
1718

1819
public string Alias { get; init; }
1920
public string Id => SelectedVariant.Id;
21+
public ModelInfo Info => SelectedVariant.Info;
2022

2123
/// <summary>
2224
/// Is the currently selected variant cached locally?
@@ -33,7 +35,7 @@ internal Model(ModelVariant modelVariant, ILogger logger)
3335
_logger = logger;
3436

3537
Alias = modelVariant.Alias;
36-
Variants = new() { modelVariant };
38+
_variants = [modelVariant];
3739

3840
// variants are sorted by Core, so the first one added is the default
3941
SelectedVariant = modelVariant;
@@ -48,7 +50,7 @@ internal void AddVariant(ModelVariant variant)
4850
_logger);
4951
}
5052

51-
Variants.Add(variant);
53+
_variants.Add(variant);
5254

5355
// prefer the highest priority locally cached variant
5456
if (variant.Info.Cached && !SelectedVariant.Info.Cached)
@@ -62,31 +64,15 @@ internal void AddVariant(ModelVariant variant)
6264
/// </summary>
6365
/// <param name="variant">Model variant to select. Must be one of the variants in <see cref="Variants"/>.</param>
6466
/// <exception cref="FoundryLocalException">If variant is not valid for this model.</exception>
65-
public void SelectVariant(ModelVariant variant)
67+
public void SelectVariant(IModel variant)
6668
{
6769
_ = Variants.FirstOrDefault(v => v == variant) ??
68-
// user error so don't log
69-
throw new FoundryLocalException($"Model {Alias} does not have a {variant.Id} variant.");
70+
// user error so don't log.
71+
throw new FoundryLocalException($"Input variant was not found in Variants.");
7072

7173
SelectedVariant = variant;
7274
}
7375

74-
/// <summary>
75-
/// Get the latest version of the specified model variant.
76-
/// </summary>
77-
/// <param name="variant">Model variant.</param>
78-
/// <returns>ModelVariant for latest version. Same as `variant` if that is the latest version.</returns>
79-
/// <exception cref="FoundryLocalException">If variant is not valid for this model.</exception>
80-
public ModelVariant GetLatestVersion(ModelVariant variant)
81-
{
82-
// variants are sorted by version, so the first one matching the name is the latest version for that variant.
83-
var latest = Variants.FirstOrDefault(v => v.Info.Name == variant.Info.Name) ??
84-
// user error so don't log
85-
throw new FoundryLocalException($"Model {Alias} does not have a {variant.Id} variant.");
86-
87-
return latest;
88-
}
89-
9076
public async Task<string> GetPathAsync(CancellationToken? ct = null)
9177
{
9278
return await SelectedVariant.GetPathAsync(ct).ConfigureAwait(false);
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Microsoft.AI.Foundry.Local;
99
using Microsoft.AI.Foundry.Local.Detail;
1010
using Microsoft.Extensions.Logging;
1111

12-
public class ModelVariant : IModel
12+
internal class ModelVariant : IModel
1313
{
1414
private readonly IModelLoadManager _modelLoadManager;
1515
private readonly ICoreInterop _coreInterop;
@@ -22,6 +22,8 @@ public class ModelVariant : IModel
2222
public string Alias => Info.Alias;
2323
public int Version { get; init; } // parsed from Info.Version if possible, else 0
2424

25+
public IReadOnlyList<IModel> Variants => [this];
26+
2527
internal ModelVariant(ModelInfo modelInfo, IModelLoadManager modelLoadManager, ICoreInterop coreInterop,
2628
ILogger logger)
2729
{
@@ -190,4 +192,11 @@ private async Task<OpenAIAudioClient> GetAudioClientImplAsync(CancellationToken?
190192

191193
return new OpenAIAudioClient(Id);
192194
}
195+
196+
public void SelectVariant(IModel variant)
197+
{
198+
throw new FoundryLocalException(
199+
$"SelectVariant is not supported on a ModelVariant. " +
200+
$"Call Catalog.GetModelAsync(\"{Alias}\") to get an IModel with all variants available.");
201+
}
193202
}

sdk/cs/src/ICatalog.cs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,36 +18,46 @@ public interface ICatalog
1818
/// List the available models in the catalog.
1919
/// </summary>
2020
/// <param name="ct">Optional CancellationToken.</param>
21-
/// <returns>List of Model instances.</returns>
22-
Task<List<Model>> ListModelsAsync(CancellationToken? ct = null);
21+
/// <returns>List of IModel instances.</returns>
22+
Task<List<IModel>> ListModelsAsync(CancellationToken? ct = null);
2323

2424
/// <summary>
2525
/// Lookup a model by its alias.
2626
/// </summary>
2727
/// <param name="modelAlias">Model alias.</param>
2828
/// <param name="ct">Optional CancellationToken.</param>
29-
/// <returns>The matching Model, or null if no model with the given alias exists.</returns>
30-
Task<Model?> GetModelAsync(string modelAlias, CancellationToken? ct = null);
29+
/// <returns>The matching IModel, or null if no model with the given alias exists.</returns>
30+
Task<IModel?> GetModelAsync(string modelAlias, CancellationToken? ct = null);
3131

3232
/// <summary>
3333
/// Lookup a model variant by its unique model id.
34+
/// NOTE: This will return an IModel with a single variant. Use GetModelAsync to get an IModel with all avaialable
35+
/// variants.
3436
/// </summary>
3537
/// <param name="modelId">Model id.</param>
3638
/// <param name="ct">Optional CancellationToken.</param>
37-
/// <returns>The matching ModelVariant, or null if no variant with the given id exists.</returns>
38-
Task<ModelVariant?> GetModelVariantAsync(string modelId, CancellationToken? ct = null);
39+
/// <returns>The matching IModel, or null if no variant with the given id exists.</returns>
40+
Task<IModel?> GetModelVariantAsync(string modelId, CancellationToken? ct = null);
3941

4042
/// <summary>
4143
/// Get a list of currently downloaded models from the model cache.
4244
/// </summary>
4345
/// <param name="ct">Optional CancellationToken.</param>
44-
/// <returns>List of ModelVariant instances.</returns>
45-
Task<List<ModelVariant>> GetCachedModelsAsync(CancellationToken? ct = null);
46+
/// <returns>List of IModel instances.</returns>
47+
Task<List<IModel>> GetCachedModelsAsync(CancellationToken? ct = null);
4648

4749
/// <summary>
4850
/// Get a list of the currently loaded models.
4951
/// </summary>
5052
/// <param name="ct">Optional CancellationToken.</param>
51-
/// <returns>List of ModelVariant instances.</returns>
52-
Task<List<ModelVariant>> GetLoadedModelsAsync(CancellationToken? ct = null);
53+
/// <returns>List of IModel instances.</returns>
54+
Task<List<IModel>> GetLoadedModelsAsync(CancellationToken? ct = null);
55+
56+
/// <summary>
57+
/// Get the latest version of a model.
58+
/// This is used to check if a newer version of a model is available in the catalog for download.
59+
/// </summary>
60+
/// <param name="model">The model to check for the latest version.</param>
61+
/// <returns>The latest version of the model. Will match the input if it is the latest version.</returns>
62+
Task<IModel> GetLatestVersionAsync(IModel model, CancellationToken? ct = null);
5363
}

sdk/cs/src/IModel.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public interface IModel
1616
Justification = "Alias is a suitable name in this context.")]
1717
string Alias { get; }
1818

19+
ModelInfo Info { get; }
20+
1921
Task<bool> IsCachedAsync(CancellationToken? ct = null);
2022
Task<bool> IsLoadedAsync(CancellationToken? ct = null);
2123

@@ -67,4 +69,17 @@ Task DownloadAsync(Action<float>? downloadProgress = null,
6769
/// <param name="ct">Optional cancellation token.</param>
6870
/// <returns>OpenAI.AudioClient</returns>
6971
Task<OpenAIAudioClient> GetAudioClientAsync(CancellationToken? ct = null);
72+
73+
/// <summary>
74+
/// Variants of the model that are available. Variants of the model are optimized for different devices.
75+
/// </summary>
76+
IReadOnlyList<IModel> Variants { get; }
77+
78+
/// <summary>
79+
/// Select a model variant from <see cref="Variants"/> to use for <see cref="IModel"/> operations.
80+
/// An IModel from `Variants` can also be used directly.
81+
/// </summary>
82+
/// <param name="variant">Model variant to select. Must be one of the variants in <see cref="Variants"/>.</param>
83+
/// <exception cref="FoundryLocalException">If variant is not valid for this model.</exception>
84+
void SelectVariant(IModel variant);
7085
}

sdk/cs/test/FoundryLocal.Tests/AudioClientTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace Microsoft.AI.Foundry.Local.Tests;
1212

1313
internal sealed class AudioClientTests
1414
{
15-
private static Model? model;
15+
private static IModel? model;
1616

1717
[Before(Class)]
1818
public static async Task Setup()

0 commit comments

Comments
 (0)