This commit is contained in:
Timothy J. Baek
2024-07-01 17:11:09 -07:00
parent 3c1ea24374
commit a392865615
4 changed files with 40 additions and 37 deletions

View File

@@ -91,7 +91,7 @@ from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR,
DOCS_DIR,
TEXT_EXTRACTION_ENGINE,
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
RAG_TOP_K,
RAG_RELEVANCE_THRESHOLD,
@@ -148,7 +148,7 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.config.TEXT_EXTRACTION_ENGINE = TEXT_EXTRACTION_ENGINE
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.CHUNK_SIZE = CHUNK_SIZE
@@ -395,8 +395,8 @@ async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"text_extraction": {
"engine": app.state.config.TEXT_EXTRACTION_ENGINE,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
@@ -428,7 +428,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
}
class TextExtractionConfig(BaseModel):
class ContentExtractionConfig(BaseModel):
engine: str = ""
tika_server_url: Optional[str] = None
@@ -466,7 +466,7 @@ class WebConfig(BaseModel):
class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None
text_extraction: Optional[TextExtractionConfig] = None
content_extraction: Optional[ContentExtractionConfig] = None
chunk: Optional[ChunkParamUpdateForm] = None
youtube: Optional[YoutubeLoaderConfig] = None
web: Optional[WebConfig] = None
@@ -480,10 +480,10 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
else app.state.config.PDF_EXTRACT_IMAGES
)
if form_data.text_extraction is not None:
log.info(f"Updating text settings: {form_data.text_extraction}")
app.state.config.TEXT_EXTRACTION_ENGINE = form_data.text_extraction.engine
app.state.config.TIKA_SERVER_URL = form_data.text_extraction.tika_server_url
if form_data.content_extraction is not None:
log.info(f"Updating text settings: {form_data.content_extraction}")
app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
if form_data.chunk is not None:
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
@@ -521,8 +521,8 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"text_extraction": {
"engine": app.state.config.TEXT_EXTRACTION_ENGINE,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
@@ -1017,7 +1017,7 @@ class TikaLoader:
self.mime_type = mime_type
def load(self) -> List[Document]:
with (open(self.file_path, "rb") as f):
with open(self.file_path, "rb") as f:
data = f.read()
if self.mime_type is not None:
@@ -1096,9 +1096,12 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"msg",
]
if app.state.config.TEXT_EXTRACTION_ENGINE == "tika" and app.state.config.TIKA_SERVER_URL:
if (
app.state.config.CONTENT_EXTRACTION_ENGINE == "tika"
and app.state.config.TIKA_SERVER_URL
):
if file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:

View File

@@ -886,13 +886,13 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG document text extraction
# RAG document content extraction
####################################
TEXT_EXTRACTION_ENGINE = PersistentConfig(
"TEXT_EXTRACTION_ENGINE",
"rag.text_extraction_engine",
os.environ.get("TEXT_EXTRACTION_ENGINE", "").lower()
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
"rag.CONTENT_EXTRACTION_ENGINE",
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
TIKA_SERVER_URL = PersistentConfig(