mirror of
https://github.com/open-webui/openapi-servers
synced 2025-06-26 18:17:04 +00:00
feat: external rag example
This commit is contained in:
parent
2b7844d634
commit
c6a791d1e1
1
servers/external-rag/.gitignore
vendored
Normal file
1
servers/external-rag/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
memory.json
|
67
servers/external-rag/main.py
Normal file
67
servers/external-rag/main.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import os
|
||||||
|
from fastapi import FastAPI, HTTPException, Query
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
# --- RAG Libraries ---
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="RAG Retriever API",
|
||||||
|
version="1.0.0",
|
||||||
|
description="Retrieval-Only API: Queries to vectorstore using LangChain, FAISS, and sentence-transformers.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalQueryInput(BaseModel):
|
||||||
|
queries: List[str] = Field(
|
||||||
|
..., description="List of queries to retrieve from the vectorstore"
|
||||||
|
)
|
||||||
|
k: int = Field(3, description="Number of results per query", example=3)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievedDoc(BaseModel):
|
||||||
|
query: str
|
||||||
|
results: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalResponse(BaseModel):
|
||||||
|
responses: List[RetrievedDoc]
|
||||||
|
|
||||||
|
|
||||||
|
# --------- Initialize Retriever (on app startup) --------
|
||||||
|
VECTORSTORE_PATH = "faiss_index" # Path to your FAISS vector store
|
||||||
|
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # Widely used, fast
|
||||||
|
|
||||||
|
|
||||||
|
def get_retriever():
|
||||||
|
embedder = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
||||||
|
vectorstore = FAISS.load_local(VECTORSTORE_PATH, embeddings=embedder)
|
||||||
|
retriever = vectorstore.as_retriever()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
retriever = get_retriever()
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/retrieve",
|
||||||
|
response_model=RetrievalResponse,
|
||||||
|
summary="Retrieve top-k docs for each query",
|
||||||
|
)
|
||||||
|
def retrieve_docs(input: RetrievalQueryInput):
|
||||||
|
"""
|
||||||
|
Given a list of user queries, returns top-k retrieved documents per query.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
out = []
|
||||||
|
for q in input.queries:
|
||||||
|
docs = retriever.get_relevant_documents(q, k=input.k)
|
||||||
|
results = [doc.page_content for doc in docs]
|
||||||
|
out.append(RetrievedDoc(query=q, results=results))
|
||||||
|
return RetrievalResponse(responses=out)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
8
servers/external-rag/requirements.txt
Normal file
8
servers/external-rag/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
pydantic
|
||||||
|
python-multipart
|
||||||
|
|
||||||
|
langchain
|
||||||
|
langchain_community
|
||||||
|
sentence_transformers
|
Loading…
Reference in New Issue
Block a user