Merge pull request #14370 from daw/feat/add-azure-openai-embeddings-option

feat:Add Azure OpenAI embedding support
This commit is contained in:
Tim Jaeryang Baek
2025-05-30 00:18:55 +04:00
committed by GitHub
6 changed files with 315 additions and 51 deletions

View File

@@ -239,6 +239,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
"azure_openai_config": {
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
"deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
"version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
},
}
@@ -252,9 +258,17 @@ class OllamaConfigForm(BaseModel):
key: str
class AzureOpenAIConfigForm(BaseModel):
url: str
key: str
deployment: str
version: str
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
ollama_config: Optional[OllamaConfigForm] = None
azure_openai_config: Optional[AzureOpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
@@ -271,7 +285,7 @@ async def update_embedding_config(
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai", "azure_openai"]:
if form_data.openai_config is not None:
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
form_data.openai_config.url
@@ -288,6 +302,20 @@ async def update_embedding_config(
form_data.ollama_config.key
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = (
form_data.azure_openai_config.deployment
)
request.app.state.config.RAG_AZURE_OPENAI_VERSION = (
form_data.azure_openai_config.version
)
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
form_data.embedding_batch_size
)
@@ -304,14 +332,32 @@ async def update_embedding_config(
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_BASE_URL
else (
request.app.state.config.RAG_OLLAMA_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
else (
request.app.state.config.RAG_OLLAMA_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
)
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
(
request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
(
request.app.state.config.RAG_AZURE_OPENAI_VERSION
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
)
return {
@@ -327,6 +373,12 @@ async def update_embedding_config(
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
"azure_openai_config": {
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
"deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
"version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
},
}
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
@@ -1129,14 +1181,32 @@ def save_docs_to_vector_db(
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_BASE_URL
else (
request.app.state.config.RAG_OLLAMA_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
else (
request.app.state.config.RAG_OLLAMA_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
)
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
(
request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
(
request.app.state.config.RAG_AZURE_OPENAI_VERSION
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
)
embeddings = embedding_function(