Fix: Add authorization header with bearer token for remote Ollama server endpoints

This commit is contained in:
bnodnarb 2024-11-24 20:29:54 -10:00
parent 840437e58f
commit 8dc73e8744

View File

@ -195,7 +195,10 @@ async def post_streaming_url(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) )
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -210,13 +213,13 @@ async def post_streaming_url(
r.raise_for_status() r.raise_for_status()
if stream: if stream:
headers = dict(r.headers) response_headers = dict(r.headers)
if content_type: if content_type:
headers["Content-Type"] = content_type response_headers["Content-Type"] = content_type
return StreamingResponse( return StreamingResponse(
r.content, r.content,
status_code=r.status, status_code=r.status,
headers=headers, headers=response_headers,
background=BackgroundTask( background=BackgroundTask(
cleanup_response, response=r, session=session cleanup_response, response=r, session=session
), ),
@ -324,7 +327,10 @@ async def get_ollama_tags(
else: else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {} headers = {}
@ -525,7 +531,10 @@ async def copy_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -584,7 +593,10 @@ async def delete_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -635,7 +647,10 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -730,7 +745,10 @@ async def generate_ollama_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -797,7 +815,10 @@ async def generate_ollama_batch_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -974,7 +995,10 @@ async def generate_chat_completion(
log.info(f"url: {url}") log.info(f"url: {url}")
log.debug(f"generate_chat_completion() - 2.payload = {payload}") log.debug(f"generate_chat_completion() - 2.payload = {payload}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") payload["model"] = payload["model"].replace(f"{prefix_id}.", "")