Merge branch 'refs/heads/dev' into feat/sqlalchemy-instead-of-peewee

# Conflicts:
#	backend/requirements.txt
This commit is contained in:
Jonathan Rohde
2024-07-01 10:37:56 +02:00
23 changed files with 1089 additions and 342 deletions

View File

@@ -14,7 +14,6 @@ from fastapi import (
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel
import uuid
@@ -277,6 +276,8 @@ def transcribe(
f.close()
if app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,

View File

@@ -12,7 +12,6 @@ from fastapi import (
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (

View File

@@ -153,7 +153,7 @@ async def cleanup_response(
await session.close()
async def post_streaming_url(url: str, payload: str):
async def post_streaming_url(url: str, payload: str, stream: bool = True):
r = None
try:
session = aiohttp.ClientSession(
@@ -162,12 +162,20 @@ async def post_streaming_url(url: str, payload: str):
r = await session.post(url, data=payload)
r.raise_for_status()
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(cleanup_response, response=r, session=session),
)
if stream:
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
res = await r.json()
await cleanup_response(r, session)
return res
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
@@ -963,7 +971,11 @@ async def generate_openai_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
return await post_streaming_url(
f"{url}/v1/chat/completions",
json.dumps(payload),
stream=payload.get("stream", False),
)
@app.get("/v1/models")

View File

@@ -48,8 +48,6 @@ import mimetypes
import uuid
import json
import sentence_transformers
from apps.webui.models.documents import (
Documents,
DocumentForm,
@@ -190,6 +188,8 @@ def update_embedding_model(
update_model: bool = False,
):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
@@ -204,6 +204,8 @@ def update_reranking_model(
update_model: bool = False,
):
if reranking_model:
import sentence_transformers
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
device=DEVICE_TYPE,

View File

@@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any
@@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor):
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]

View File

@@ -259,6 +259,9 @@ async def generate_function_chat_completion(form_data, user):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except: