Update main.py

This commit is contained in:
mindspawn 2024-06-07 21:18:04 -07:00 committed by GitHub
parent a8d80f936a
commit 4ecc1c06d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,6 +9,7 @@ from fastapi import (
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import os, shutil, logging, re import os, shutil, logging, re
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Union, Sequence from typing import List, Union, Sequence
@ -30,6 +31,7 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader, UnstructuredExcelLoader,
UnstructuredPowerPointLoader, UnstructuredPowerPointLoader,
YoutubeLoader, YoutubeLoader,
OutlookMessageLoader,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
@ -879,6 +881,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
# ChromaDB does not like datetime formats
# for meta-data so convert them to string.
for metadata in metadatas:
for key, value in metadata.items():
if isinstance(value, datetime):
metadata[key] = str(value)
try: try:
if overwrite: if overwrite:
for collection in CHROMA_CLIENT.list_collections(): for collection in CHROMA_CLIENT.list_collections():
@ -965,6 +974,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"swift", "swift",
"vue", "vue",
"svelte", "svelte",
"msg"
] ]
if file_ext == "pdf": if file_ext == "pdf":
@ -999,6 +1009,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"application/vnd.openxmlformats-officedocument.presentationml.presentation", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]: ] or file_ext in ["ppt", "pptx"]:
loader = UnstructuredPowerPointLoader(file_path) loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
elif file_ext in known_source_ext or ( elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0 file_content_type and file_content_type.find("text/") >= 0
): ):