diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 069e2fb67..cb4b88c32 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -124,6 +124,10 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES +app.state.YOUTUBE_LOADER_LANGUAGE = ["en"] +app.state.YOUTUBE_LOADER_TRANSLATION = None + + def update_embedding_model( embedding_model: str, update_model: bool = False, @@ -314,6 +318,10 @@ async def get_rag_config(user=Depends(get_admin_user)): "chunk_overlap": app.state.CHUNK_OVERLAP, }, "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "youtube": { + "language": app.state.YOUTUBE_LOADER_LANGUAGE, + "translation": app.state.YOUTUBE_LOADER_TRANSLATION, + }, } @@ -322,10 +330,16 @@ class ChunkParamUpdateForm(BaseModel): chunk_overlap: int +class YoutubeLoaderConfig(BaseModel): + language: List[str] + translation: Optional[str] = None + + class ConfigUpdateForm(BaseModel): pdf_extract_images: Optional[bool] = None chunk: Optional[ChunkParamUpdateForm] = None web_loader_ssl_verification: Optional[bool] = None + youtube: Optional[YoutubeLoaderConfig] = None @app.post("/config/update") @@ -352,6 +366,18 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ) + app.state.YOUTUBE_LOADER_LANGUAGE = ( + form_data.youtube.language + if form_data.youtube != None + else app.state.YOUTUBE_LOADER_LANGUAGE + ) + + app.state.YOUTUBE_LOADER_TRANSLATION = ( + form_data.youtube.translation + if form_data.youtube != None + else app.state.YOUTUBE_LOADER_TRANSLATION + ) + return { "status": True, "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, @@ -360,6 +386,10 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "chunk_overlap": app.state.CHUNK_OVERLAP, }, "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "youtube": { + "language": app.state.YOUTUBE_LOADER_LANGUAGE, + "translation": app.state.YOUTUBE_LOADER_TRANSLATION, + }, } @@ -486,7 +516,12 @@ def query_collection_handler( @app.post("/youtube") def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): try: - loader = YoutubeLoader.from_youtube_url(form_data.url, add_video_info=False) + loader = YoutubeLoader.from_youtube_url( + form_data.url, + add_video_info=True, + language=app.state.YOUTUBE_LOADER_LANGUAGE, + translation=app.state.YOUTUBE_LOADER_TRANSLATION, + ) data = loader.load() collection_name = form_data.collection_name diff --git a/backend/requirements.txt b/backend/requirements.txt index 357fd1e86..c8b699447 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -57,3 +57,4 @@ PyJWT[crypto]==2.8.0 black==24.4.2 langfuse==2.27.3 youtube-transcript-api==0.6.2 +pytube \ No newline at end of file diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index ccf166dab..b792b7f00 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -32,10 +32,16 @@ type ChunkConfigForm = { chunk_overlap: number; }; +type YoutubeConfigForm = { + language: string[]; + translation?: string | null; +}; + type RAGConfigForm = { pdf_extract_images?: boolean; chunk?: ChunkConfigForm; web_loader_ssl_verification?: boolean; + youtube?: YoutubeConfigForm; }; export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => { diff --git a/src/lib/components/documents/Settings/WebParams.svelte b/src/lib/components/documents/Settings/WebParams.svelte index bb78686d1..2ca2f3ace 100644 --- a/src/lib/components/documents/Settings/WebParams.svelte +++ b/src/lib/components/documents/Settings/WebParams.svelte @@ -11,9 +11,16 @@ let webLoaderSSLVerification = true; + let youtubeLanguage = 'en'; + let youtubeTranslation = null; + const submitHandler = async () => { const res = await updateRAGConfig(localStorage.token, { - web_loader_ssl_verification: webLoaderSSLVerification + web_loader_ssl_verification: webLoaderSSLVerification, + youtube: { + language: youtubeLanguage.split(',').map((lang) => lang.trim()), + translation: youtubeTranslation + } }); }; @@ -22,6 +29,8 @@ if (res) { webLoaderSSLVerification = res.web_loader_ssl_verification; + youtubeLanguage = res.youtube.language.join(','); + youtubeTranslation = res.youtube.translation; } }); @@ -36,7 +45,7 @@