feat: ollama auth support

This commit is contained in:
Timothy Jaeryang Baek 2024-11-11 22:33:18 -08:00
parent 607a8b2109
commit 99446c4b76

View File

@ -209,10 +209,18 @@ async def post_streaming_url(
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = await session.post(
url,
data=payload,
headers={"Content-Type": "application/json"},
headers=headers,
)
r.raise_for_status()
@ -275,9 +283,10 @@ async def get_all_models():
else:
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable:
tasks.append(aiohttp_get(f"{url}/api/tags"))
tasks.append(aiohttp_get(f"{url}/api/tags", key))
else:
tasks.append(None)
@ -341,9 +350,16 @@ async def get_ollama_tags(
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {}
if key:
headers["Authorization"] = f"Bearer {key}"
r = None
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
r.raise_for_status()
return r.json()
@ -371,7 +387,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if url_idx is None:
# returns lowest version
tasks = [
aiohttp_get(f"{url}/api/version")
aiohttp_get(
f"{url}/api/version",
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
)
for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
@ -511,10 +530,18 @@ async def copy_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/copy",
headers={"Content-Type": "application/json"},
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@ -560,11 +587,18 @@ async def delete_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="DELETE",
url=f"{url}/api/delete",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(),
headers=headers,
)
try:
r.raise_for_status()
@ -601,10 +635,17 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/show",
headers={"Content-Type": "application/json"},
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
@ -686,10 +727,17 @@ def generate_ollama_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
headers={"Content-Type": "application/json"},
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
@ -743,10 +791,17 @@ def generate_ollama_batch_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request(
method="POST",
url=f"{url}/api/embed",
headers={"Content-Type": "application/json"},
headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try: