This commit is contained in:
Timothy J. Baek 2024-09-28 02:29:08 +02:00
parent af57a2c153
commit 2428878f42

View File

@ -246,10 +246,10 @@ app.add_middleware(
class CollectionNameForm(BaseModel): class CollectionNameForm(BaseModel):
collection_name: Optional[str] = "test" collection_name: Optional[str] = None
class UrlForm(CollectionNameForm): class ProcessUrlForm(CollectionNameForm):
url: str url: str
@ -636,7 +636,6 @@ def store_data_in_vector_db(
chunk_overlap=app.state.config.CHUNK_OVERLAP, chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
docs = text_splitter.split_documents(data) docs = text_splitter.split_documents(data)
if len(docs) > 0: if len(docs) > 0:
@ -715,66 +714,6 @@ def store_docs_in_vector_db(
return False return False
@app.post("/doc")
def store_doc(
collection_name: Optional[str] = Form(None),
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
log.info(f"file.content_type: {file.content_type}")
try:
unsanitized_filename = file.filename
filename = os.path.basename(unsanitized_filename)
file_path = f"{UPLOAD_DIR}/{filename}"
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
f.close()
f = open(file_path, "rb")
if collection_name is None:
collection_name = calculate_sha256(f)[:63]
f.close()
loader = Loader(
engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
)
data = loader.load(filename, file.content_type, file_path)
try:
result = store_data_in_vector_db(data, collection_name)
if result:
return {
"status": True,
"collection_name": collection_name,
"filename": filename,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e,
)
except Exception as e:
log.exception(e)
if "No pandoc was found" in str(e):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
class ProcessFileForm(BaseModel): class ProcessFileForm(BaseModel):
file_id: str file_id: str
collection_name: Optional[str] = None collection_name: Optional[str] = None
@ -796,11 +735,10 @@ def process_file(
) )
data = loader.load(file.filename, file.meta.get("content_type"), file_path) data = loader.load(file.filename, file.meta.get("content_type"), file_path)
f = open(file_path, "rb")
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name is None: if collection_name is None:
collection_name = calculate_sha256(f)[:63] with open(file_path, "rb") as f:
f.close() collection_name = calculate_sha256(f)[:63]
try: try:
result = store_data_in_vector_db( result = store_data_in_vector_db(
@ -813,11 +751,9 @@ def process_file(
) )
if result: if result:
return { return {
"status": True, "status": True,
"collection_name": collection_name, "collection_name": collection_name,
"known_type": known_type,
"filename": file.meta.get("name", file.filename), "filename": file.meta.get("name", file.filename),
} }
except Exception as e: except Exception as e:
@ -839,15 +775,15 @@ def process_file(
) )
class TextRAGForm(BaseModel): class ProcessTextForm(BaseModel):
name: str name: str
content: str content: str
collection_name: Optional[str] = None collection_name: Optional[str] = None
@app.post("/text") @app.post("/process/text")
def store_text( def process_text(
form_data: TextRAGForm, form_data: ProcessTextForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
collection_name = form_data.collection_name collection_name = form_data.collection_name
@ -878,9 +814,8 @@ def process_docs_dir(user=Depends(get_admin_user)):
filename = path.name filename = path.name
file_content_type = mimetypes.guess_type(path) file_content_type = mimetypes.guess_type(path)
f = open(path, "rb") with open(path, "rb") as f:
collection_name = calculate_sha256(f)[:63] collection_name = calculate_sha256(f)[:63]
f.close()
loader = Loader( loader = Loader(
engine=app.state.config.CONTENT_EXTRACTION_ENGINE, engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
@ -933,7 +868,7 @@ def process_docs_dir(user=Depends(get_admin_user)):
@app.post("/process/youtube") @app.post("/process/youtube")
def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
try: try:
loader = YoutubeLoader.from_youtube_url( loader = YoutubeLoader.from_youtube_url(
form_data.url, form_data.url,
@ -944,10 +879,11 @@ def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
data = loader.load() data = loader.load()
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name == "": if not collection_name:
collection_name = calculate_sha256_string(form_data.url)[:63] collection_name = calculate_sha256_string(form_data.url)[:63]
store_data_in_vector_db(data, collection_name, overwrite=True) store_data_in_vector_db(data, collection_name, overwrite=True)
return { return {
"status": True, "status": True,
"collection_name": collection_name, "collection_name": collection_name,
@ -962,8 +898,7 @@ def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
@app.post("/process/web") @app.post("/process/web")
def process_web(form_data: UrlForm, user=Depends(get_verified_user)): def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = get_web_loader( loader = get_web_loader(
form_data.url, form_data.url,
@ -973,10 +908,11 @@ def process_web(form_data: UrlForm, user=Depends(get_verified_user)):
data = loader.load() data = loader.load()
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name == "": if not collection_name:
collection_name = calculate_sha256_string(form_data.url)[:63] collection_name = calculate_sha256_string(form_data.url)[:63]
store_data_in_vector_db(data, collection_name, overwrite=True) store_data_in_vector_db(data, collection_name, overwrite=True)
return { return {
"status": True, "status": True,
"collection_name": collection_name, "collection_name": collection_name,