diff --git a/examples/pipelines/rag/llamaindex_ollama_pipeline.py b/examples/pipelines/rag/llamaindex_ollama_pipeline.py index efafe67..ef588e1 100644 --- a/examples/pipelines/rag/llamaindex_ollama_pipeline.py +++ b/examples/pipelines/rag/llamaindex_ollama_pipeline.py @@ -10,23 +10,43 @@ requirements: llama-index, llama-index-llms-ollama, llama-index-embeddings-ollam from typing import List, Union, Generator, Iterator from schemas import OpenAIChatMessage +import os + +from pydantic import BaseModel class Pipeline: + + class Valves(BaseModel): + LLAMAINDEX_OLLAMA_BASE_URL: str + LLAMAINDEX_MODEL_NAME: str + LLAMAINDEX_EMBEDDING_MODEL_NAME: str + def __init__(self): self.documents = None self.index = None + self.valves = self.Valves( + **{ + "LLAMAINDEX_OLLAMA_BASE_URL": os.getenv("LLAMAINDEX_OLLAMA_BASE_URL", "http://localhost:11434"), + "LLAMAINDEX_MODEL_NAME": os.getenv("LLAMAINDEX_MODEL_NAME", "llama3"), + "LLAMAINDEX_EMBEDDING_MODEL_NAME": os.getenv("LLAMAINDEX_EMBEDDING_MODEL_NAME", "nomic-embed-text"), + } + ) + async def on_startup(self): from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.llms.ollama import Ollama from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader Settings.embed_model = OllamaEmbedding( - model_name="nomic-embed-text", - base_url="http://localhost:11434", + model_name=self.valves.LLAMAINDEX_EMBEDDING_MODEL_NAME, + base_url=self.valves.LLAMAINDEX_OLLAMA_BASE_URL, + ) + Settings.llm = Ollama( + model=self.valves.LLAMAINDEX_MODEL_NAME, + base_url=self.valves.LLAMAINDEX_OLLAMA_BASE_URL, ) - Settings.llm = Ollama(model="llama3") # This function is called when the server is started. global documents, index