From c0055afdb376661f0a164194bd5dcf7dcd157e19 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 20 Nov 2024 10:01:58 -0800 Subject: [PATCH] refac: youtube loader --- .../apps/retrieval/loaders/youtube.py | 98 +++++++++++++++++++ backend/requirements.txt | 2 +- pyproject.toml | 2 +- 3 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 backend/open_webui/apps/retrieval/loaders/youtube.py diff --git a/backend/open_webui/apps/retrieval/loaders/youtube.py b/backend/open_webui/apps/retrieval/loaders/youtube.py new file mode 100644 index 000000000..ad1088be0 --- /dev/null +++ b/backend/open_webui/apps/retrieval/loaders/youtube.py @@ -0,0 +1,98 @@ +from typing import Any, Dict, Generator, List, Optional, Sequence, Union +from urllib.parse import parse_qs, urlparse +from langchain_core.documents import Document + + +ALLOWED_SCHEMES = {"http", "https"} +ALLOWED_NETLOCS = { + "youtu.be", + "m.youtube.com", + "youtube.com", + "www.youtube.com", + "www.youtube-nocookie.com", + "vid.plus", +} + + +def _parse_video_id(url: str) -> Optional[str]: + """Parse a YouTube URL and return the video ID if valid, otherwise None.""" + parsed_url = urlparse(url) + + if parsed_url.scheme not in ALLOWED_SCHEMES: + return None + + if parsed_url.netloc not in ALLOWED_NETLOCS: + return None + + path = parsed_url.path + + if path.endswith("/watch"): + query = parsed_url.query + parsed_query = parse_qs(query) + if "v" in parsed_query: + ids = parsed_query["v"] + video_id = ids if isinstance(ids, str) else ids[0] + else: + return None + else: + path = parsed_url.path.lstrip("/") + video_id = path.split("/")[-1] + + if len(video_id) != 11: # Video IDs are 11 characters long + return None + + return video_id + + +class YoutubeLoader: + """Load `YouTube` video transcripts.""" + + def __init__( + self, + video_id: str, + language: Union[str, Sequence[str]] = "en", + ): + """Initialize with YouTube video ID.""" + _video_id = _parse_video_id(video_id) + self.video_id = _video_id if _video_id is not None else video_id + self._metadata = {"source": video_id} + self.language = language + if isinstance(language, str): + self.language = [language] + else: + self.language = language + + def load(self) -> List[Document]: + """Load YouTube transcripts into `Document` objects.""" + try: + from youtube_transcript_api import ( + NoTranscriptFound, + TranscriptsDisabled, + YouTubeTranscriptApi, + ) + except ImportError: + raise ImportError( + 'Could not import "youtube_transcript_api" Python package. ' + "Please install it with `pip install youtube-transcript-api`." + ) + + try: + transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_id) + except Exception as e: + print(e) + return [] + + try: + transcript = transcript_list.find_transcript(self.language) + except NoTranscriptFound: + transcript = transcript_list.find_transcript(["en"]) + + transcript_pieces: List[Dict[str, Any]] = transcript.fetch() + + transcript = " ".join( + map( + lambda transcript_piece: transcript_piece["text"].strip(" "), + transcript_pieces, + ) + ) + return [Document(page_content=transcript, metadata=self._metadata)] diff --git a/backend/requirements.txt b/backend/requirements.txt index 368613b22..258f69e25 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -82,7 +82,7 @@ authlib==1.3.2 black==24.8.0 langfuse==2.44.0 -youtube-transcript-api==0.6.2 +youtube-transcript-api==0.6.3 pytube==15.0.0 extract_msg diff --git a/pyproject.toml b/pyproject.toml index e425a70f4..9a1c2bb03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dependencies = [ "black==24.8.0", "langfuse==2.44.0", - "youtube-transcript-api==0.6.2", + "youtube-transcript-api==0.6.3", "pytube==15.0.0", "extract_msg",