refac: connections handling

This commit is contained in:
Timothy Jaeryang Baek 2025-01-18 17:10:15 -08:00
parent 430854e223
commit ca0285fc91
3 changed files with 113 additions and 56 deletions

View File

@ -152,10 +152,12 @@ async def send_post_request(
) )
def get_api_key(url, configs): def get_api_key(idx, url, configs):
parsed_url = urlparse(url) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
return configs.get(base_url, {}).get("key", None) return configs.get(idx, configs.get(base_url, {})).get(
"key", None
) # Legacy support
########################################## ##########################################
@ -238,11 +240,13 @@ async def update_config(
request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
# Remove any extra configs # Remove the API configs that are not in the API URLS
config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys() keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS))))
for url in list(request.app.state.config.OLLAMA_BASE_URLS): request.app.state.config.OLLAMA_API_CONFIGS = {
if url not in config_urls: key: value
request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None) for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items()
if key in keys
}
return { return {
"ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
@ -258,10 +262,18 @@ async def get_all_models(request: Request):
request_tasks = [] request_tasks = []
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
if url not in request.app.state.config.OLLAMA_API_CONFIGS: if (idx not in request.app.state.config.OLLAMA_API_CONFIGS) or (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/tags")) request_tasks.append(send_get_request(f"{url}/api/tags"))
else: else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
idx,
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
enable = api_config.get("enable", True) enable = api_config.get("enable", True)
key = api_config.get("key", None) key = api_config.get("key", None)
@ -275,7 +287,12 @@ async def get_all_models(request: Request):
for idx, response in enumerate(responses): for idx, response in enumerate(responses):
if response: if response:
url = request.app.state.config.OLLAMA_BASE_URLS[idx] url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
idx,
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
model_ids = api_config.get("model_ids", []) model_ids = api_config.get("model_ids", [])
@ -349,7 +366,7 @@ async def get_ollama_tags(
models = await get_all_models(request) models = await get_all_models(request)
else: else:
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
r = None r = None
try: try:
@ -393,11 +410,14 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
request_tasks = [ request_tasks = [
send_get_request( send_get_request(
f"{url}/api/version", f"{url}/api/version",
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( request.app.state.config.OLLAMA_API_CONFIGS.get(
"key", None idx,
), request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
).get("key", None),
) )
for url in request.app.state.config.OLLAMA_BASE_URLS for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
] ]
responses = await asyncio.gather(*request_tasks) responses = await asyncio.gather(*request_tasks)
responses = list(filter(lambda x: x is not None, responses)) responses = list(filter(lambda x: x is not None, responses))
@ -454,11 +474,14 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
request_tasks = [ request_tasks = [
send_get_request( send_get_request(
f"{url}/api/ps", f"{url}/api/ps",
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( request.app.state.config.OLLAMA_API_CONFIGS.get(
"key", None idx,
), request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
).get("key", None),
) )
for url in request.app.state.config.OLLAMA_BASE_URLS for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
] ]
responses = await asyncio.gather(*request_tasks) responses = await asyncio.gather(*request_tasks)
@ -488,7 +511,7 @@ async def pull_model(
return await send_post_request( return await send_post_request(
url=f"{url}/api/pull", url=f"{url}/api/pull",
payload=json.dumps(payload), payload=json.dumps(payload),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@ -524,7 +547,7 @@ async def push_model(
return await send_post_request( return await send_post_request(
url=f"{url}/api/push", url=f"{url}/api/push",
payload=form_data.model_dump_json(exclude_none=True).encode(), payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@ -549,7 +572,7 @@ async def create_model(
return await send_post_request( return await send_post_request(
url=f"{url}/api/create", url=f"{url}/api/create",
payload=form_data.model_dump_json(exclude_none=True).encode(), payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@ -579,7 +602,7 @@ async def copy_model(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
try: try:
r = requests.request( r = requests.request(
@ -634,7 +657,7 @@ async def delete_model(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
try: try:
r = requests.request( r = requests.request(
@ -684,7 +707,7 @@ async def show_model_info(
url_idx = random.choice(models[form_data.name]["urls"]) url_idx = random.choice(models[form_data.name]["urls"])
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
try: try:
r = requests.request( r = requests.request(
@ -753,7 +776,7 @@ async def embed(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
try: try:
r = requests.request( r = requests.request(
@ -822,7 +845,7 @@ async def embeddings(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
try: try:
r = requests.request( r = requests.request(
@ -897,7 +920,10 @@ async def generate_completion(
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
url_idx,
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
)
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
@ -906,7 +932,7 @@ async def generate_completion(
return await send_post_request( return await send_post_request(
url=f"{url}/api/generate", url=f"{url}/api/generate",
payload=form_data.model_dump_json(exclude_none=True).encode(), payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@ -1005,7 +1031,10 @@ async def generate_chat_completion(
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = await get_ollama_url(request, payload["model"], url_idx) url = await get_ollama_url(request, payload["model"], url_idx)
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
url_idx,
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
)
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
@ -1015,7 +1044,7 @@ async def generate_chat_completion(
url=f"{url}/api/chat", url=f"{url}/api/chat",
payload=json.dumps(payload), payload=json.dumps(payload),
stream=form_data.stream, stream=form_data.stream,
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
content_type="application/x-ndjson", content_type="application/x-ndjson",
) )
@ -1104,7 +1133,10 @@ async def generate_openai_completion(
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = await get_ollama_url(request, payload["model"], url_idx) url = await get_ollama_url(request, payload["model"], url_idx)
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
url_idx,
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
)
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
@ -1115,7 +1147,7 @@ async def generate_openai_completion(
url=f"{url}/v1/completions", url=f"{url}/v1/completions",
payload=json.dumps(payload), payload=json.dumps(payload),
stream=payload.get("stream", False), stream=payload.get("stream", False),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@ -1178,7 +1210,10 @@ async def generate_openai_chat_completion(
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = await get_ollama_url(request, payload["model"], url_idx) url = await get_ollama_url(request, payload["model"], url_idx)
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
url_idx,
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
)
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
@ -1188,7 +1223,7 @@ async def generate_openai_chat_completion(
url=f"{url}/v1/chat/completions", url=f"{url}/v1/chat/completions",
payload=json.dumps(payload), payload=json.dumps(payload),
stream=payload.get("stream", False), stream=payload.get("stream", False),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
) )

View File

@ -145,11 +145,13 @@ async def update_config(
request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
# Remove any extra configs # Remove the API configs that are not in the API URLS
config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys() keys = list(map(str, range(len(request.app.state.config.OPENAI_API_BASE_URLS))))
for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): request.app.state.config.OPENAI_API_CONFIGS = {
if url not in config_urls: key: value
request.app.state.config.OPENAI_API_CONFIGS.pop(url, None) for key, value in request.app.state.config.OPENAI_API_CONFIGS.items()
if key in keys
}
return { return {
"ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
@ -264,14 +266,21 @@ async def get_all_models_responses(request: Request) -> list:
request_tasks = [] request_tasks = []
for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
if url not in request.app.state.config.OPENAI_API_CONFIGS: if (idx not in request.app.state.config.OPENAI_API_CONFIGS) or (
url not in request.app.state.config.OPENAI_API_CONFIGS # Legacy support
):
request_tasks.append( request_tasks.append(
send_get_request( send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
) )
) )
else: else:
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
idx,
request.app.state.config.OPENAI_API_CONFIGS.get(
url, {}
), # Legacy support
)
enable = api_config.get("enable", True) enable = api_config.get("enable", True)
model_ids = api_config.get("model_ids", []) model_ids = api_config.get("model_ids", [])
@ -310,7 +319,12 @@ async def get_all_models_responses(request: Request) -> list:
for idx, response in enumerate(responses): for idx, response in enumerate(responses):
if response: if response:
url = request.app.state.config.OPENAI_API_BASE_URLS[idx] url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
idx,
request.app.state.config.OPENAI_API_CONFIGS.get(
url, {}
), # Legacy support
)
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
@ -585,7 +599,10 @@ async def generate_chat_completion(
# Get the API config for the model # Get the API config for the model
api_config = request.app.state.config.OPENAI_API_CONFIGS.get( api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
idx,
request.app.state.config.OPENAI_API_CONFIGS.get(
request.app.state.config.OPENAI_API_BASE_URLS[idx], {} request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
), # Legacy support
) )
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)

View File

@ -104,14 +104,14 @@
const addOpenAIConnectionHandler = async (connection) => { const addOpenAIConnectionHandler = async (connection) => {
OPENAI_API_BASE_URLS = [...OPENAI_API_BASE_URLS, connection.url]; OPENAI_API_BASE_URLS = [...OPENAI_API_BASE_URLS, connection.url];
OPENAI_API_KEYS = [...OPENAI_API_KEYS, connection.key]; OPENAI_API_KEYS = [...OPENAI_API_KEYS, connection.key];
OPENAI_API_CONFIGS[connection.url] = connection.config; OPENAI_API_CONFIGS[OPENAI_API_BASE_URLS.length] = connection.config;
await updateOpenAIHandler(); await updateOpenAIHandler();
}; };
const addOllamaConnectionHandler = async (connection) => { const addOllamaConnectionHandler = async (connection) => {
OLLAMA_BASE_URLS = [...OLLAMA_BASE_URLS, connection.url]; OLLAMA_BASE_URLS = [...OLLAMA_BASE_URLS, connection.url];
OLLAMA_API_CONFIGS[connection.url] = connection.config; OLLAMA_API_CONFIGS[OLLAMA_BASE_URLS.length] = connection.config;
await updateOllamaHandler(); await updateOllamaHandler();
}; };
@ -141,15 +141,17 @@
OLLAMA_API_CONFIGS = ollamaConfig.OLLAMA_API_CONFIGS; OLLAMA_API_CONFIGS = ollamaConfig.OLLAMA_API_CONFIGS;
if (ENABLE_OPENAI_API) { if (ENABLE_OPENAI_API) {
for (const url of OPENAI_API_BASE_URLS) { // get url and idx
if (!OPENAI_API_CONFIGS[url]) { for (const [idx, url] of OPENAI_API_BASE_URLS.entries()) {
OPENAI_API_CONFIGS[url] = {}; if (!OPENAI_API_CONFIGS[idx]) {
// Legacy support, url as key
OPENAI_API_CONFIGS[idx] = OPENAI_API_CONFIGS[url] || {};
} }
} }
OPENAI_API_BASE_URLS.forEach(async (url, idx) => { OPENAI_API_BASE_URLS.forEach(async (url, idx) => {
OPENAI_API_CONFIGS[url] = OPENAI_API_CONFIGS[url] || {}; OPENAI_API_CONFIGS[idx] = OPENAI_API_CONFIGS[idx] || {};
if (!(OPENAI_API_CONFIGS[url]?.enable ?? true)) { if (!(OPENAI_API_CONFIGS[idx]?.enable ?? true)) {
return; return;
} }
const res = await getOpenAIModels(localStorage.token, idx); const res = await getOpenAIModels(localStorage.token, idx);
@ -160,9 +162,9 @@
} }
if (ENABLE_OLLAMA_API) { if (ENABLE_OLLAMA_API) {
for (const url of OLLAMA_BASE_URLS) { for (const [idx, url] of OLLAMA_BASE_URLS.entries()) {
if (!OLLAMA_API_CONFIGS[url]) { if (!OLLAMA_API_CONFIGS[idx]) {
OLLAMA_API_CONFIGS[url] = {}; OLLAMA_API_CONFIGS[idx] = OLLAMA_API_CONFIGS[url] || {};
} }
} }
} }
@ -235,7 +237,7 @@
pipeline={pipelineUrls[url] ? true : false} pipeline={pipelineUrls[url] ? true : false}
bind:url bind:url
bind:key={OPENAI_API_KEYS[idx]} bind:key={OPENAI_API_KEYS[idx]}
bind:config={OPENAI_API_CONFIGS[url]} bind:config={OPENAI_API_CONFIGS[idx]}
onSubmit={() => { onSubmit={() => {
updateOpenAIHandler(); updateOpenAIHandler();
}} }}
@ -244,6 +246,8 @@
(url, urlIdx) => idx !== urlIdx (url, urlIdx) => idx !== urlIdx
); );
OPENAI_API_KEYS = OPENAI_API_KEYS.filter((key, keyIdx) => idx !== keyIdx); OPENAI_API_KEYS = OPENAI_API_KEYS.filter((key, keyIdx) => idx !== keyIdx);
delete OPENAI_API_CONFIGS[idx];
}} }}
/> />
{/each} {/each}
@ -294,13 +298,14 @@
{#each OLLAMA_BASE_URLS as url, idx} {#each OLLAMA_BASE_URLS as url, idx}
<OllamaConnection <OllamaConnection
bind:url bind:url
bind:config={OLLAMA_API_CONFIGS[url]} bind:config={OLLAMA_API_CONFIGS[idx]}
{idx} {idx}
onSubmit={() => { onSubmit={() => {
updateOllamaHandler(); updateOllamaHandler();
}} }}
onDelete={() => { onDelete={() => {
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url, urlIdx) => idx !== urlIdx); OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url, urlIdx) => idx !== urlIdx);
delete OLLAMA_API_CONFIGS[idx];
}} }}
/> />
{/each} {/each}