enh: ollama /v1/completion endpoint support

This commit is contained in:
Timothy Jaeryang Baek 2024-12-07 13:46:46 -08:00
parent 976676a482
commit 1439f6862d

View File

@ -1032,6 +1032,82 @@ class OpenAIChatCompletionForm(BaseModel):
model_config = ConfigDict(extra="allow")
class OpenAICompletionForm(BaseModel):
model: str
prompt: str
model_config = ConfigDict(extra="allow")
@app.post("/v1/completions")
@app.post("/v1/completions/{url_idx}")
async def generate_openai_completion(
form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
):
try:
form_data = OpenAICompletionForm(**form_data)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=400,
detail=str(e),
)
payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.model
if ":" not in model_id:
model_id = f"{model_id}:latest"
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
params = model_info.params.model_dump()
if params:
payload = apply_model_params_to_body_openai(params, payload)
# Check if user has access to the model
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
else:
if user.role != "admin":
raise HTTPException(
status_code=403,
detail="Model not found",
)
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
return await post_streaming_url(
f"{url}/v1/completions",
json.dumps(payload),
stream=payload.get("stream", False),
)
@app.post("/v1/chat/completions")
@app.post("/v1/chat/completions/{url_idx}")
async def generate_openai_chat_completion(