diff --git a/.changeset/fix-litellm-model-desync.md b/.changeset/fix-litellm-model-desync.md new file mode 100644 index 0000000000..3598784912 --- /dev/null +++ b/.changeset/fix-litellm-model-desync.md @@ -0,0 +1,28 @@ +--- +"zoo-code": patch +--- + +Fix LiteLLM provider cache key collision, credential priority, and model-selection fallback to non-existent default. + +Two bugs are addressed: + +1. **Cache key collision**: All URL-scoped providers (LiteLLM, Ollama, LM Studio, Poe, DeepSeek, + Requesty) previously shared one cache entry keyed only on the provider name. Switching between + profiles backed by different servers silently served the wrong model list and the stale list + persisted across VS Code restarts via the disk cache. Fixed with a compound cache key: + URL-scoped providers use `provider:baseUrl`; key-scoped providers (LiteLLM, Poe, Requesty) + additionally include a short, irreversible discriminator derived from the API key + (`provider:baseUrl:`) so that two different API keys on the same server never share + a cache entry (relevant when the server enforces per-key model allowlists). Both the discriminator + and the on-disk filename digest are derived via truncated PBKDF2 so neither can be reversed to + identify the API key written to the cache filename. The `RouterProvider.getModel()` cold-start + fallback is also corrected to pass the full options so it resolves the same compound key. + +2. **Silent fallback to hardcoded default**: When the LiteLLM model list was empty (due to the + collision above, a failed sync, or a transient error), `useSelectedModel` reset the configured + model ID to `claude-3-7-sonnet-20250219` -- a model that typically does not exist on user + LiteLLM servers. Four sub-fixes: preserve the configured model ID when the list is empty; + invalidate the React Query router-models cache after a successful "Sync Models" click; pass the + current LiteLLM credentials in the debounced `requestRouterModels` message; and correct the + credential priority in `webviewMessageHandler.ts` so that message values (current unsaved field + state) take precedence over stale saved config, matching the pattern already used for DeepSeek. diff --git a/src/api/providers/fetchers/__tests__/modelCache.spec.ts b/src/api/providers/fetchers/__tests__/modelCache.spec.ts index 11395485a9..5b771f2bda 100644 --- a/src/api/providers/fetchers/__tests__/modelCache.spec.ts +++ b/src/api/providers/fetchers/__tests__/modelCache.spec.ts @@ -432,5 +432,187 @@ describe("empty cache protection", () => { expect(result1).toEqual(mockModels) expect(result2).toEqual(mockModels) }) + + it("scopes in-flight dedup by API key for key-scoped providers", async () => { + // In-flight dedup is keyed on the compound cache key, so concurrent refreshes for a + // key-scoped provider must dedup only when the API key matches. Two different keys + // (different compound keys) each trigger their own fetch; the same key shares one. + const mockModels = { + "requesty/model": { + maxTokens: 4096, + contextWindow: 200000, + supportsPromptCache: false, + description: "Requesty model", + }, + } + mockGetRequestyModels.mockResolvedValue(mockModels) + + const { refreshModels } = await import("../modelCache") + + // Different keys -> separate compound keys -> two distinct fetches. + const [a, b] = await Promise.all([ + refreshModels({ provider: "requesty", apiKey: "key-one" }), + refreshModels({ provider: "requesty", apiKey: "key-two" }), + ]) + expect(mockGetRequestyModels).toHaveBeenCalledTimes(2) + expect(a).toEqual(mockModels) + expect(b).toEqual(mockModels) + + mockGetRequestyModels.mockClear() + + // Same key -> same compound key -> a single shared in-flight fetch. + let resolveShared: (value: typeof mockModels) => void + mockGetRequestyModels.mockReturnValue( + new Promise((resolve) => { + resolveShared = resolve + }), + ) + + const shared1 = refreshModels({ provider: "requesty", apiKey: "same-key" }) + const shared2 = refreshModels({ provider: "requesty", apiKey: "same-key" }) + + expect(mockGetRequestyModels).toHaveBeenCalledTimes(1) + + resolveShared!(mockModels) + const [s1, s2] = await Promise.all([shared1, shared2]) + expect(s1).toEqual(mockModels) + expect(s2).toEqual(mockModels) + }) + }) +}) + +describe("key-scoped cache key derivation", () => { + // Exercises the per-API-key cache discriminator that all KEY_SCOPED_PROVIDERS share. + // Requesty is used only because it is a key-scoped provider with a mocked fetcher; the + // behavior under test is provider-agnostic. + const keyScopedProvider = "requesty" as const + + let mockCache: any + let mockSet: Mock + + const mockModels = { + "key-scoped/model": { + maxTokens: 4096, + contextWindow: 200000, + supportsPromptCache: false, + description: "Key-scoped provider model", + }, + } + + beforeEach(() => { + vi.clearAllMocks() + const MockedNodeCache = vi.mocked(NodeCache) + mockCache = new MockedNodeCache() + mockCache.get.mockReturnValue(undefined) + mockSet = mockCache.set + mockGetRequestyModels.mockResolvedValue(mockModels) + }) + + // Returns the cache key the result was written under (first arg of the matching set call). + const writtenCacheKey = (): string => { + const call = mockSet.mock.calls.find((c) => c[1] === mockModels) + return call?.[0] as string + } + + it("writes different cache keys for different API keys", async () => { + await getModels({ provider: keyScopedProvider, apiKey: "key-one" }) + const firstKey = writtenCacheKey() + + mockSet.mockClear() + await getModels({ provider: keyScopedProvider, apiKey: "key-two" }) + const secondKey = writtenCacheKey() + + expect(firstKey).toBeDefined() + expect(secondKey).toBeDefined() + expect(firstKey).not.toEqual(secondKey) + }) + + it("writes the same cache key for repeated calls with the same API key", async () => { + await getModels({ provider: keyScopedProvider, apiKey: "stable-key" }) + const firstKey = writtenCacheKey() + + mockSet.mockClear() + await getModels({ provider: keyScopedProvider, apiKey: "stable-key" }) + const secondKey = writtenCacheKey() + + expect(firstKey).toEqual(secondKey) + }) + + it("does not embed the raw API key in the cache key and truncates the discriminator", async () => { + const apiKey = "super-secret-api-key-value" + await getModels({ provider: keyScopedProvider, apiKey }) + const cacheKey = writtenCacheKey() + + // The raw secret must never appear in the on-disk-bound cache key. + expect(cacheKey).not.toContain(apiKey) + // The discriminator is the trailing key-component: an 8-char (32-bit) hex string. + const discriminator = cacheKey.split(":").pop() as string + expect(discriminator).toMatch(/^[0-9a-f]{8}$/) + }) +}) + +describe("compound cache key derivation across scoping dimensions", () => { + // Exercises every branch of getCacheKey via the public getModels() entry point. + // litellm is url-scoped AND key-scoped; openrouter is neither, so it hits the bare + // provider fallback. The fetcher mocks let us observe the cache key the result is + // written under (first arg of the matching memoryCache.set call). + const mockModels = { + "compound/model": { + maxTokens: 4096, + contextWindow: 200000, + supportsPromptCache: false, + description: "Compound cache key model", + }, + } + + let mockSet: Mock + + beforeEach(() => { + vi.clearAllMocks() + const MockedNodeCache = vi.mocked(NodeCache) + const mockCache = new MockedNodeCache() + ;(mockCache.get as Mock).mockReturnValue(undefined) + mockSet = mockCache.set as unknown as Mock + mockGetLiteLLMModels.mockResolvedValue(mockModels) + mockGetOpenRouterModels.mockResolvedValue(mockModels) + }) + + const writtenCacheKey = (): string => { + const call = mockSet.mock.calls.find((c) => c[1] === mockModels) + return call?.[0] as string + } + + it("includes both the server URL and the key discriminator for url+key-scoped providers", async () => { + await getModels({ provider: "litellm", apiKey: "compound-key", baseUrl: "http://host:4000" }) + const cacheKey = writtenCacheKey() + + // Expected shape: provider:url:keyDiscriminator + expect(cacheKey).toMatch(/^litellm:http:\/\/host:4000:[0-9a-f]{8}$/) + }) + + it("normalizes trailing slashes in the server URL so equivalent URLs share a cache key", async () => { + await getModels({ provider: "litellm", apiKey: "compound-key", baseUrl: "http://host:4000/" }) + const withSlash = writtenCacheKey() + + mockSet.mockClear() + await getModels({ provider: "litellm", apiKey: "compound-key", baseUrl: "http://host:4000" }) + const withoutSlash = writtenCacheKey() + + expect(withSlash).toEqual(withoutSlash) + }) + + it("includes only the server URL when a url-scoped provider has no API key", async () => { + await getModels({ provider: "litellm", baseUrl: "http://host:4000" }) + const cacheKey = writtenCacheKey() + + // No trailing key discriminator when apiKey is absent. + expect(cacheKey).toBe("litellm:http://host:4000") + }) + + it("falls back to the bare provider name for providers that are neither url- nor key-scoped", async () => { + await getModels({ provider: "openrouter", apiKey: "ignored-key", baseUrl: "http://ignored:4000" }) + const cacheKey = writtenCacheKey() + + expect(cacheKey).toBe("openrouter") }) }) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 404a60cd85..312f4f9382 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -1,6 +1,7 @@ import * as path from "path" import fs from "fs/promises" import * as fsSync from "fs" +import { pbkdf2Sync } from "crypto" import NodeCache from "node-cache" import { z } from "zod" @@ -34,9 +35,10 @@ const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) // Zod schema for validating ModelRecord structure from disk cache const modelRecordSchema = z.record(z.string(), modelInfoSchema) -// Track in-flight refresh requests to prevent concurrent API calls for the same provider -// This prevents race conditions where multiple calls might overwrite each other's results -const inFlightRefresh = new Map>() +// Track in-flight refresh requests to prevent concurrent API calls for the same provider+url. +// Keyed on the compound cache key (see getCacheKey) so that two different URL-scoped servers never +// deduplicate each other's in-flight refreshes. +const inFlightRefresh = new Map>() // Providers whose model lists are scoped to the signed-in user (e.g. per-account // allowlists or org policies). For these we MUST NOT cache results on disk or @@ -44,18 +46,127 @@ const inFlightRefresh = new Map>() // list to the next user, and stale data could mask backend allowlist updates. const AUTH_SCOPED_PROVIDERS: ReadonlySet = new Set(["zoo-gateway"]) +// Providers whose model list is determined by the server URL, not just by the provider name. +// Each unique baseUrl must be cached independently so that switching endpoints never serves +// stale results from a previously-cached server. +const URL_SCOPED_PROVIDERS: ReadonlySet = new Set([ + "litellm", + "poe", + "deepseek", + "ollama", + "lmstudio", + "requesty", +]) + +// Providers where the API key itself determines which models are visible (e.g. per-key +// allowlists). For these the cache key also includes a short hash of +// the API key so that two different keys on the same server never share a cache entry. +const KEY_SCOPED_PROVIDERS: ReadonlySet = new Set([ + "litellm", // Per-key model allowlists are a first-class LiteLLM proxy feature + "poe", // Per-account model availability + "requesty", // Per-account custom model policies +]) + function isAuthScopedProvider(provider: RouterName): boolean { return AUTH_SCOPED_PROVIDERS.has(provider) } -async function writeModels(router: RouterName, data: ModelRecord) { - const filename = `${router}_models.json` +// Memoize derived digests so the deliberately-structureless KDF runs at most once per +// distinct input per session (getCacheKey / cacheKeyToFilename run on every cache lookup). +const cacheDigestCache = new Map() + +// Fixed, non-secret application salt. This is NOT credential storage: it derives short, +// stable cache-key components from the API key and the compound cache key so that distinct +// inputs map to distinct cache entries / filenames. PBKDF2 is used (over a plain hash) only +// to obtain a uniform, structureless mapping with no exploitable internal structure; the +// iteration count is intentionally modest because security here rests on truncation, not on +// KDF slowness. Using a KDF rather than a plain digest also keeps API-key-derived values off +// CodeQL's js/insufficient-password-hash sink, which flags any password-tainted value flowing +// into a non-password hashing operation -- and that taint propagates to anything derived from +// the key, including the compound cache key hashed for the on-disk filename. +const CACHE_DIGEST_SALT = "zoo-model-cache-key-v1" +const CACHE_DIGEST_ITERATIONS = 10_000 + +/** + * Derive a short, irreversible, truncated digest of a cache input. + * + * The output is deliberately far smaller than the entropy of a real API key: collisions + * across the handful of keys/servers a single user configures are negligible (birthday bound + * ~ n^2 / 2^(8*bytes)), while the truncated output is small enough that any preimage search + * yields an astronomically large set of candidate inputs -- so a value written to an on-disk + * cache filename cannot be reversed to identify the API key it was derived from. + */ +function deriveCacheDigest(value: string, bytes: number): string { + const memoKey = `${bytes}:${value}` + const cached = cacheDigestCache.get(memoKey) + if (cached) return cached + const digest = pbkdf2Sync(value, CACHE_DIGEST_SALT, CACHE_DIGEST_ITERATIONS, bytes, "sha256").toString("hex") + cacheDigestCache.set(memoKey, digest) + return digest +} + +// 4 bytes (8 hex chars) = 32 bits for the per-API-key discriminator embedded in the cache key. +const API_KEY_DISCRIMINATOR_BYTES = 4 +// 8 bytes (16 hex chars) = 64 bits for the filename digest, preserving the prior filename width. +const FILENAME_DIGEST_BYTES = 8 + +/** + * Derive a short, irreversible, non-identifying cache-key discriminator from an API key. + */ +function deriveApiKeyDiscriminator(apiKey: string): string { + return deriveCacheDigest(apiKey, API_KEY_DISCRIMINATOR_BYTES) +} + +/** + * Build a cache key that is unique per provider+server+key combination. + * + * - URL-scoped providers include the normalized baseUrl so that two different servers + * of the same provider type never share a cache entry. + * - Key-scoped providers additionally fold in a short, irreversible discriminator derived + * from the API key so that two different API keys on the same server never share a cache + * entry (relevant when the server enforces per-key model allowlists, e.g. LiteLLM, Poe, + * Requesty). See deriveApiKeyDiscriminator for why the value cannot be reversed to the key. + */ +function getCacheKey(options: GetModelsOptions): string { + const { provider } = options + const isUrlScoped = URL_SCOPED_PROVIDERS.has(provider as RouterName) + const isKeyScoped = KEY_SCOPED_PROVIDERS.has(provider as RouterName) + + // Build URL and key components independently so that key-scoped providers + // without a custom baseUrl still get a per-key cache entry (otherwise two + // different keys on the default server would collapse to the same entry). + // Strip trailing slashes so "http://host:4000/" and "http://host:4000" map to the same key. + const urlPart = isUrlScoped && options.baseUrl ? options.baseUrl.replace(/\/+$/, "") : undefined + const keyPart = isKeyScoped && options.apiKey ? deriveApiKeyDiscriminator(options.apiKey) : undefined + + if (urlPart && keyPart) return `${provider}:${urlPart}:${keyPart}` + if (urlPart) return `${provider}:${urlPart}` + if (keyPart) return `${provider}:${keyPart}` + return provider +} + +/** + * Convert a cache key to a filesystem-safe filename component. + * Hashes the full key to guarantee uniqueness while preserving a readable + * provider prefix at the start of the filename. + */ +function cacheKeyToFilename(cacheKey: string): string { + const prefix = cacheKey.split(":")[0] // provider name -- always filesystem-safe + // The compound cache key embeds the API-key discriminator, so it is treated as + // password-tainted by static analysis; deriveCacheDigest keeps the filename derivation + // off the weak-hash sink while still producing a collision-free, irreversible component. + const hash = deriveCacheDigest(cacheKey, FILENAME_DIGEST_BYTES) + return `${prefix}_${hash}` +} + +async function writeModels(cacheKey: string, data: ModelRecord) { + const filename = `${cacheKeyToFilename(cacheKey)}_models.json` const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) await safeWriteJson(path.join(cacheDir, filename), data) } -async function readModels(router: RouterName): Promise { - const filename = `${router}_models.json` +async function readModels(cacheKey: string): Promise { + const filename = `${cacheKeyToFilename(cacheKey)}_models.json` const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) const filePath = path.join(cacheDir, filename) const exists = await fileExistsAtPath(filePath) @@ -86,8 +197,7 @@ async function fetchModelsFromProvider(options: GetModelsOptions): Promise => { const { provider } = options + const cacheKey = getCacheKey(options) const shouldSkipCache = isAuthScopedProvider(provider) - let models = shouldSkipCache ? undefined : getModelsFromCache(provider) + let models = shouldSkipCache ? undefined : getModelsFromCache(options) if (models) { return models @@ -149,10 +260,10 @@ export const getModels = async (options: GetModelsOptions): Promise // Only cache non-empty results so a failed API response doesn't get persisted // as if the provider had no models. Auth-scoped providers skip caching entirely. if (modelCount > 0 && !shouldSkipCache) { - memoryCache.set(provider, models) + memoryCache.set(cacheKey, models) - await writeModels(provider, models).catch((err) => - console.error(`[MODEL_CACHE] Error writing ${provider} models to file cache:`, err), + await writeModels(cacheKey, models).catch((err) => + console.error(`[MODEL_CACHE] Error writing ${cacheKey} models to file cache:`, err), ) } else if (modelCount === 0) { TelemetryService.instance.captureEvent(TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, { @@ -182,23 +293,30 @@ export const getModels = async (options: GetModelsOptions): Promise */ export const refreshModels = async (options: GetModelsOptions): Promise => { const { provider } = options + const cacheKey = getCacheKey(options) const shouldSkipCache = isAuthScopedProvider(provider) - // Check if there's already an in-flight refresh for this provider. + // Check if there's already an in-flight refresh for this provider+url combination. // This prevents race conditions where multiple concurrent refreshes might // overwrite each other's results. Skip de-duplication for auth-scoped // providers because two concurrent calls may carry different tokens // (e.g., after a sign-out/sign-in within the same session) and we must // not return the first caller's results to the second caller. if (!shouldSkipCache) { - const existingRequest = inFlightRefresh.get(provider) + const existingRequest = inFlightRefresh.get(cacheKey) if (existingRequest) { return existingRequest } } - // Create the refresh promise and track it + // Create the refresh promise and track it. + // + // The `finally` cleanup below runs only after the first `await` inside this async + // function yields, which cannot happen until the current synchronous run -- including + // the `inFlightRefresh.set(cacheKey, ...)` registration below -- has completed. So the + // entry is always present in the map before `finally` can delete it; the registration + // can never be lost to a microtask race even if the fetch resolves immediately. const refreshPromise = (async (): Promise => { try { // Force fresh API fetch - skip getModelsFromCache() check @@ -206,7 +324,7 @@ export const refreshModels = async (options: GetModelsOptions): Promise - console.error(`[refreshModels] Error writing ${provider} models to disk:`, err), + await writeModels(cacheKey, models).catch((err) => + console.error(`[refreshModels] Error writing ${cacheKey} models to disk:`, err), ) } @@ -235,23 +353,23 @@ export const refreshModels = async (options: GetModelsOptions): Promise { * @param refresh - If true, immediately fetch fresh data from API */ export const flushModels = async (options: GetModelsOptions, refresh: boolean = false): Promise => { - const { provider } = options if (refresh) { // Don't delete memory cache - let refreshModels atomically replace it // This prevents a race condition where getModels() might be called @@ -298,8 +415,10 @@ export const flushModels = async (options: GetModelsOptions, refresh: boolean = // Await the refresh to ensure the cache is updated before returning await refreshModels(options) } else { - // Only delete memory cache when not refreshing - memoryCache.del(provider) + // Only delete memory cache when not refreshing. Use the compound cache key so that + // URL-scoped providers (litellm, poe, etc.) actually evict the per-server entry rather + // than a bare provider-name entry that was never written. + memoryCache.del(getCacheKey(options)) } } @@ -311,9 +430,20 @@ export const flushModels = async (options: GetModelsOptions, refresh: boolean = * @param provider - The provider to get models for. * @returns Models from memory cache, disk cache, or undefined if not cached. */ -export function getModelsFromCache(provider: ProviderName): ModelRecord | undefined { +export function getModelsFromCache( + options: GetModelsOptions | ProviderName, +): ModelRecord | undefined { + // Auth-scoped providers (e.g. zoo-gateway) must never be served from cache -- + // their model lists are user-specific and a stale file left over from a previous + // session could leak another user's list. Mirror the guards in getModels/refreshModels. + const providerName = typeof options === "string" ? options : options.provider + if (isAuthScopedProvider(providerName as RouterName)) { + return undefined + } + + const cacheKey = typeof options === "string" ? options : getCacheKey(options) // Check memory cache first (fast) - const memoryModels = memoryCache.get(provider) + const memoryModels = memoryCache.get(cacheKey) if (memoryModels) { return memoryModels } @@ -321,7 +451,7 @@ export function getModelsFromCache(provider: ProviderName): ModelRecord | undefi // Memory cache miss - try to load from disk synchronously // This is acceptable because it only happens on cold start or after cache expiry try { - const filename = `${provider}_models.json` + const filename = `${cacheKeyToFilename(cacheKey)}_models.json` const cacheDir = getCacheDirectoryPathSync() if (!cacheDir) { return undefined @@ -339,19 +469,19 @@ export function getModelsFromCache(provider: ProviderName): ModelRecord | undefi const validation = modelRecordSchema.safeParse(models) if (!validation.success) { console.error( - `[MODEL_CACHE] Invalid disk cache data structure for ${provider}:`, + `[MODEL_CACHE] Invalid disk cache data structure for ${cacheKey}:`, validation.error.format(), ) return undefined } // Populate memory cache for future fast access - memoryCache.set(provider, validation.data) + memoryCache.set(cacheKey, validation.data) return validation.data } } catch (error) { - console.error(`[MODEL_CACHE] Error loading ${provider} models from disk:`, error) + console.error(`[MODEL_CACHE] Error loading ${cacheKey} models from disk:`, error) } return undefined diff --git a/src/api/providers/lm-studio.ts b/src/api/providers/lm-studio.ts index d04bd157c7..0567fa0aae 100644 --- a/src/api/providers/lm-studio.ts +++ b/src/api/providers/lm-studio.ts @@ -170,7 +170,10 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan } override getModel(): { id: string; info: ModelInfo } { - const models = getModelsFromCache("lmstudio") + const models = getModelsFromCache({ + provider: "lmstudio", + baseUrl: this.options.lmStudioBaseUrl, + }) if (models && this.options.lmStudioModelId && models[this.options.lmStudioModelId]) { return { id: this.options.lmStudioModelId, diff --git a/src/api/providers/poe.ts b/src/api/providers/poe.ts index 536d222acd..9513a1445d 100644 --- a/src/api/providers/poe.ts +++ b/src/api/providers/poe.ts @@ -38,7 +38,11 @@ export class PoeHandler extends BaseProvider implements SingleCompletionHandler override getModel() { const id = this.options.apiModelId ?? poeDefaultModelId - const cached = getModelsFromCache("poe") + const cached = getModelsFromCache({ + provider: "poe", + apiKey: this.options.poeApiKey, + baseUrl: this.options.poeBaseUrl, + }) const info: ModelInfo = cached?.[id] ?? getPoeDefaultModelInfo() return { id, info } } diff --git a/src/api/providers/router-provider.ts b/src/api/providers/router-provider.ts index 20419bb45a..7a983ed12e 100644 --- a/src/api/providers/router-provider.ts +++ b/src/api/providers/router-provider.ts @@ -62,24 +62,38 @@ export abstract class RouterProvider extends BaseProvider { } override getModel(): { id: string; info: ModelInfo } { - const id = this.modelId ?? this.defaultModelId + // Use `||` (not `??`) so an empty-string modelId also falls back to the default, + // guaranteeing a non-empty id rather than forwarding "" to the API as an invalid + // request. Note this guarantees non-empty, not viable: defaultModelId is provider- + // supplied and may not be a model that actually exists on the user's server (e.g. + // OpenAI-compatible have no inherent default), so a configured-but-empty selection + // can still resolve to a model the server rejects. + const id = this.modelId || this.defaultModelId // First check instance models (populated by fetchModel) if (this.models[id]) { return { id, info: this.models[id] } } - // Fall back to global cache (synchronous disk/memory cache) - // This ensures models are available before fetchModel() is called - const cachedModels = getModelsFromCache(this.name) + // Fall back to global cache (synchronous disk/memory cache). + // Pass the full options so URL-scoped providers (litellm, ollama, etc.) + // resolve the same compound cache key that fetchModel() wrote under. + const cachedModels = getModelsFromCache({ + provider: this.name, + baseUrl: this.client.baseURL, + apiKey: this.client.apiKey, + }) if (cachedModels?.[id]) { // Also populate instance models for future calls this.models = cachedModels return { id, info: cachedModels[id] } } - // Last resort: return default model - return { id: this.defaultModelId, info: this.defaultModelInfo } + // Last resort: preserve the configured model ID (falling back to the default + // only when none is configured) so an as-yet-unfetched model isn't silently + // swapped for the hardcoded default. info still comes from defaults since we + // have no fetched or cached metadata for the configured model at this point. + return { id, info: this.defaultModelInfo } } protected supportsTemperature(modelId: string): boolean { diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 9704a7229d..f1dcb4598d 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -629,7 +629,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { }) }) - it("prefers config values over message values for LiteLLM", async () => { + it("prefers message values over config values for LiteLLM", async () => { const mockModels: ModelRecord = {} mockGetModels.mockResolvedValue(mockModels) @@ -641,11 +641,11 @@ describe("webviewMessageHandler - requestRouterModels", () => { }, }) - // Verify config values are used over message values + // Verify message values take precedence over saved config (current unsaved field state wins) expect(mockGetModels).toHaveBeenCalledWith({ provider: "litellm", - apiKey: "litellm-key", // From config - baseUrl: "http://localhost:4000", // From config + apiKey: "message-key", // From message.values + baseUrl: "http://message-url", // From message.values }) }) }) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 5a827d5126..79a3ced3a2 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -1006,7 +1006,7 @@ export const webviewMessageHandler = async ( // For providers that need credentials, use their specific handlers await flushModels({ provider: routerNameFlush } as GetModelsOptions, true) break - case "requestRouterModels": + case "requestRouterModels": { const { apiConfiguration } = await provider.getState() // Optional single provider filter from webview @@ -1074,9 +1074,11 @@ export const webviewMessageHandler = async ( }, ] - // LiteLLM is conditional on baseUrl+apiKey - const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey - const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl + // LiteLLM is conditional on baseUrl+apiKey. + // Prefer explicit values from message (current unsaved field state) over saved config, + // matching the pattern used for DeepSeek and other credential-carrying providers. + const litellmApiKey = message?.values?.litellmApiKey ?? apiConfiguration.litellmApiKey + const litellmBaseUrl = message?.values?.litellmBaseUrl ?? apiConfiguration.litellmBaseUrl if (litellmApiKey && litellmBaseUrl) { // If explicit credentials are provided in message.values (from Refresh Models button), @@ -1185,6 +1187,7 @@ export const webviewMessageHandler = async ( values: providerFilter ? { provider: requestedProvider } : undefined, }) break + } case "requestOllamaModels": { // Specific handler for Ollama models only. const { apiConfiguration: ollamaApiConfig } = await provider.getState() diff --git a/src/shared/api.ts b/src/shared/api.ts index c0db55f661..fe2042bd5f 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -179,7 +179,7 @@ const dynamicProviderExtras = { openrouter: {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type "vercel-ai-gateway": {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type "zoo-gateway": {} as { apiKey?: string; baseUrl?: string }, - litellm: {} as { apiKey: string; baseUrl: string }, + litellm: {} as { apiKey?: string; baseUrl: string }, requesty: {} as { apiKey?: string; baseUrl?: string }, unbound: {} as { apiKey?: string }, ollama: {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 70617a1ee6..d54e0b634e 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -222,7 +222,15 @@ const ApiOptions = ({ requestLmStudioModels(apiConfiguration?.lmStudioBaseUrl) } else if (selectedProvider === "vscode-lm") { vscode.postMessage({ type: "requestVsCodeLmModels" }) - } else if (selectedProvider === "litellm" || selectedProvider === "poe") { + } else if (selectedProvider === "litellm") { + vscode.postMessage({ + type: "requestRouterModels", + values: { + litellmApiKey: apiConfiguration?.litellmApiKey, + litellmBaseUrl: apiConfiguration?.litellmBaseUrl, + }, + }) + } else if (selectedProvider === "poe") { vscode.postMessage({ type: "requestRouterModels" }) } }, diff --git a/webview-ui/src/components/settings/providers/LiteLLM.tsx b/webview-ui/src/components/settings/providers/LiteLLM.tsx index 38ae1f3a96..2a8dcb8d67 100644 --- a/webview-ui/src/components/settings/providers/LiteLLM.tsx +++ b/webview-ui/src/components/settings/providers/LiteLLM.tsx @@ -1,5 +1,6 @@ import { useCallback, useState, useEffect, useRef } from "react" import { VSCodeTextField, VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react" +import { useQueryClient } from "@tanstack/react-query" import { type ProviderSettings, @@ -34,6 +35,7 @@ export const LiteLLM = ({ simplifySettings, }: LiteLLMProps) => { const { t } = useAppTranslation() + const queryClient = useQueryClient() const { routerModels } = useExtensionState() const [refreshStatus, setRefreshStatus] = useState<"idle" | "loading" | "success" | "error">("idle") const [refreshError, setRefreshError] = useState() @@ -55,6 +57,12 @@ export const LiteLLM = ({ if (refreshStatus === "loading") { if (!litellmErrorJustReceived.current) { setRefreshStatus("success") + // Invalidate only the LiteLLM router-models query so useSelectedModel + // picks up the refreshed list. useSelectedModel reads LiteLLM under the + // compound key ["routerModels", "litellm"] (see useRouterModels), so we + // target that exact key rather than the bare ["routerModels"] prefix, + // which would needlessly invalidate every other provider's query too. + queryClient.invalidateQueries({ queryKey: ["routerModels", "litellm"] }) } // If litellmErrorJustReceived.current is true, status is already (or will be) "error". } @@ -65,7 +73,7 @@ export const LiteLLM = ({ return () => { window.removeEventListener("message", handleMessage) } - }, [refreshStatus, refreshError, setRefreshStatus, setRefreshError]) + }, [refreshStatus, queryClient]) const handleInputChange = useCallback( ( diff --git a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts index 0dc42129c0..9e84ec364b 100644 --- a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts +++ b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts @@ -571,12 +571,85 @@ describe("useSelectedModel", () => { const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) expect(result.current.provider).toBe("litellm") - // Should fall back to default model ID since "some-model" doesn't exist in empty litellm models - expect(result.current.id).toBe("claude-3-7-sonnet-20250219") + // Should preserve configured model ID since "some-model" doesn't exist in empty litellm models + expect(result.current.id).toBe("some-model") // Should use litellmDefaultModelInfo as fallback expect(result.current.info).toEqual(litellmDefaultModelInfo) }) + it("should return an empty model ID when the list is empty and no model is configured", () => { + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + requesty: {}, + litellm: {}, + }, + isLoading: false, + isError: false, + } as any) + + const apiConfiguration: ProviderSettings = { + apiProvider: "litellm", + // litellmModelId intentionally omitted + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.provider).toBe("litellm") + // LiteLLM has no inherent default; with nothing configured the ID is empty rather than a phantom model + expect(result.current.id).toBe("") + expect(result.current.info).toEqual(litellmDefaultModelInfo) + }) + + it("preserves the selected model when the list transitions from populated to empty", () => { + // Primary user-visible scenario: a "Sync Models" click momentarily empties the + // router-models list before the refreshed list arrives. The selection must be held + // across that transition rather than reset. + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + requesty: {}, + litellm: { + "my-custom-model": { + maxTokens: 4096, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + }, + }, + }, + isLoading: false, + isError: false, + } as any) + + const apiConfiguration: ProviderSettings = { + apiProvider: "litellm", + litellmModelId: "my-custom-model", + } + + const wrapper = createWrapper() + const { result, rerender } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + // Initially the configured model resolves from the populated list. + expect(result.current.id).toBe("my-custom-model") + + // Simulate the list emptying mid-sync. + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + requesty: {}, + litellm: {}, + }, + isLoading: false, + isError: false, + } as any) + rerender() + + // Selection is preserved through the empty window. + expect(result.current.id).toBe("my-custom-model") + }) + it("should use litellmDefaultModelInfo when selected model not found in routerModels", () => { mockUseRouterModels.mockReturnValue({ data: { diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index d3ebb6c0dd..7f60f5f0a0 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -167,7 +167,15 @@ function getSelectedModel({ return { id, info: routerInfo } } case "litellm": { - const id = getValidatedModelId(apiConfiguration.litellmModelId, routerModels.litellm, defaultModelId) + // When the model list is empty (not yet loaded or still loading), + // preserve the configured model ID. LiteLLM is a proxy with no inherent + // default model, so we never substitute a hardcoded default here -- when + // nothing is configured we return an empty ID so the picker shows "no + // selection" rather than a phantom model that does not exist on the server. + const hasModels = routerModels.litellm && Object.keys(routerModels.litellm).length > 0 + const id = hasModels + ? getValidatedModelId(apiConfiguration.litellmModelId, routerModels.litellm, defaultModelId) + : (apiConfiguration.litellmModelId ?? "") const routerInfo = routerModels.litellm?.[id] return { id, info: routerInfo ?? litellmDefaultModelInfo } }