feat: external rag example

This commit is contained in:
Timothy Jaeryang Baek 2025-06-05 11:03:35 +04:00
parent 2b7844d634
commit c6a791d1e1
3 changed files with 76 additions and 0 deletions

1
servers/external-rag/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
memory.json

View File

@ -0,0 +1,67 @@
import os
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel, Field
from typing import List
# --- RAG Libraries ---
from langchain_community.vectorstores import FAISS
from sentence_transformers import SentenceTransformer
from langchain.embeddings import HuggingFaceEmbeddings
app = FastAPI(
title="RAG Retriever API",
version="1.0.0",
description="Retrieval-Only API: Queries to vectorstore using LangChain, FAISS, and sentence-transformers.",
)
class RetrievalQueryInput(BaseModel):
queries: List[str] = Field(
..., description="List of queries to retrieve from the vectorstore"
)
k: int = Field(3, description="Number of results per query", example=3)
class RetrievedDoc(BaseModel):
query: str
results: List[str]
class RetrievalResponse(BaseModel):
responses: List[RetrievedDoc]
# --------- Initialize Retriever (on app startup) --------
VECTORSTORE_PATH = "faiss_index" # Path to your FAISS vector store
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # Widely used, fast
def get_retriever():
embedder = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
vectorstore = FAISS.load_local(VECTORSTORE_PATH, embeddings=embedder)
retriever = vectorstore.as_retriever()
return retriever
retriever = get_retriever()
# --------------------------------------------------------
@app.post(
"/retrieve",
response_model=RetrievalResponse,
summary="Retrieve top-k docs for each query",
)
def retrieve_docs(input: RetrievalQueryInput):
"""
Given a list of user queries, returns top-k retrieved documents per query.
"""
try:
out = []
for q in input.queries:
docs = retriever.get_relevant_documents(q, k=input.k)
results = [doc.page_content for doc in docs]
out.append(RetrievedDoc(query=q, results=results))
return RetrievalResponse(responses=out)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@ -0,0 +1,8 @@
fastapi
uvicorn[standard]
pydantic
python-multipart
langchain
langchain_community
sentence_transformers