Merge pull request #7326 from bnodnarb/fix/ollama-authentication

fix: Include Authorization header in /api/pull and /api/chat requests
This commit is contained in:
Timothy Jaeryang Baek 2024-11-25 15:38:43 -08:00 committed by GitHub
commit da4676de2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -195,7 +195,10 @@ async def post_streaming_url(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -210,13 +213,13 @@ async def post_streaming_url(
r.raise_for_status()
if stream:
headers = dict(r.headers)
response_headers = dict(r.headers)
if content_type:
headers["Content-Type"] = content_type
response_headers["Content-Type"] = content_type
return StreamingResponse(
r.content,
status_code=r.status,
headers=headers,
headers=response_headers,
background=BackgroundTask(
cleanup_response, response=r, session=session
),
@ -324,7 +327,10 @@ async def get_ollama_tags(
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {}
@ -525,7 +531,10 @@ async def copy_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -584,7 +593,10 @@ async def delete_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -635,7 +647,10 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -730,7 +745,10 @@ async def generate_ollama_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -797,7 +815,10 @@ async def generate_ollama_batch_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -974,7 +995,10 @@ async def generate_chat_completion(
log.info(f"url: {url}")
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")