feat: ollama auth support

This commit is contained in:
Timothy Jaeryang Baek 2024-11-11 22:33:18 -08:00
parent 607a8b2109
commit 99446c4b76

View File

@ -209,10 +209,18 @@ async def post_streaming_url(
session = aiohttp.ClientSession( session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) )
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = await session.post( r = await session.post(
url, url,
data=payload, data=payload,
headers={"Content-Type": "application/json"}, headers=headers,
) )
r.raise_for_status() r.raise_for_status()
@ -275,9 +283,10 @@ async def get_all_models():
else: else:
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
enable = api_config.get("enable", True) enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable: if enable:
tasks.append(aiohttp_get(f"{url}/api/tags")) tasks.append(aiohttp_get(f"{url}/api/tags", key))
else: else:
tasks.append(None) tasks.append(None)
@ -341,9 +350,16 @@ async def get_ollama_tags(
else: else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {}
if key:
headers["Authorization"] = f"Bearer {key}"
r = None r = None
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
r.raise_for_status() r.raise_for_status()
return r.json() return r.json()
@ -371,7 +387,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if url_idx is None: if url_idx is None:
# returns lowest version # returns lowest version
tasks = [ tasks = [
aiohttp_get(f"{url}/api/version") aiohttp_get(
f"{url}/api/version",
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
)
for url in app.state.config.OLLAMA_BASE_URLS for url in app.state.config.OLLAMA_BASE_URLS
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
@ -511,10 +530,18 @@ async def copy_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/copy", url=f"{url}/api/copy",
headers={"Content-Type": "application/json"}, headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
@ -560,11 +587,18 @@ async def delete_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request( r = requests.request(
method="DELETE", method="DELETE",
url=f"{url}/api/delete", url=f"{url}/api/delete",
headers={"Content-Type": "application/json"},
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
headers=headers,
) )
try: try:
r.raise_for_status() r.raise_for_status()
@ -601,10 +635,17 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/show", url=f"{url}/api/show",
headers={"Content-Type": "application/json"}, headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
try: try:
@ -686,10 +727,17 @@ def generate_ollama_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/embeddings", url=f"{url}/api/embeddings",
headers={"Content-Type": "application/json"}, headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
try: try:
@ -743,10 +791,17 @@ def generate_ollama_batch_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/embed", url=f"{url}/api/embed",
headers={"Content-Type": "application/json"}, headers=headers,
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
try: try: