diff --git a/examples/pipelines/providers/azure_openai_manifold_pipeline.py b/examples/pipelines/providers/azure_openai_manifold_pipeline.py index 6f77a44..65834a6 100644 --- a/examples/pipelines/providers/azure_openai_manifold_pipeline.py +++ b/examples/pipelines/providers/azure_openai_manifold_pipeline.py @@ -26,7 +26,6 @@ class Pipeline: } ) self.set_pipelines() - pass def set_pipelines(self): models = self.valves.AZURE_OPENAI_MODELS.split(";") @@ -35,27 +34,26 @@ class Pipeline: {"id": model, "name": name} for model, name in zip(models, model_names) ] print(f"azure_openai_manifold_pipeline - models: {self.pipelines}") - pass async def on_valves_updated(self): - self.set_pipelines() + self.set_pipelines() async def on_startup(self): # This function is called when the server is started. print(f"on_startup:{__name__}") - pass async def on_shutdown(self): # This function is called when the server is stopped. print(f"on_shutdown:{__name__}") - pass def pipe( - self, user_message: str, model_id: str, messages: List[dict], body: dict - ) -> Union[str, Generator, Iterator]: - # This is where you can add your custom pipelines like RAG. + self, + user_message: str, + model_id: str, + messages: List[dict], + body: dict + ) -> Union[str, Generator[str, None, None], Iterator[str]]: print(f"pipe:{__name__}") - print(messages) print(user_message) @@ -64,36 +62,152 @@ class Pipeline: "Content-Type": "application/json", } - url = f"{self.valves.AZURE_OPENAI_ENDPOINT}/openai/deployments/{model_id}/chat/completions?api-version={self.valves.AZURE_OPENAI_API_VERSION}" + # URL for Chat Completions in Azure OpenAI + url = ( + f"{self.valves.AZURE_OPENAI_ENDPOINT}/openai/deployments/" + f"{model_id}/chat/completions?api-version={self.valves.AZURE_OPENAI_API_VERSION}" + ) - allowed_params = {'messages', 'temperature', 'role', 'content', 'contentPart', 'contentPartImage', - 'enhancements', 'dataSources', 'n', 'stream', 'stop', 'max_tokens', 'presence_penalty', - 'frequency_penalty', 'logit_bias', 'user', 'function_call', 'funcions', 'tools', - 'tool_choice', 'top_p', 'log_probs', 'top_logprobs', 'response_format', 'seed'} - # remap user field + # --- Define the allowed parameter sets --- + # (1) Default allowed params (non-o1) + allowed_params_default = { + "messages", + "temperature", + "role", + "content", + "contentPart", + "contentPartImage", + "enhancements", + "dataSources", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "function_call", + "funcions", + "tools", + "tool_choice", + "top_p", + "log_probs", + "top_logprobs", + "response_format", + "seed", + } + + # (2) o1 models allowed params + allowed_params_o1 = { + "model", + "messages", + "top_p", + "n", + "max_completion_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + } + + # Simple helper to detect if it's an o1 model + def is_o1_model(m: str) -> bool: + # Adjust this check to your naming pattern for o1 models + return "o1" in m or m.startswith("o") + + # Ensure user is a string if "user" in body and not isinstance(body["user"], str): - body["user"] = body["user"]["id"] if "id" in body["user"] else str(body["user"]) - filtered_body = {k: v for k, v in body.items() if k in allowed_params} - # log fields that were filtered out as a single line - if len(body) != len(filtered_body): - print(f"Dropped params: {', '.join(set(body.keys()) - set(filtered_body.keys()))}") + body["user"] = body["user"].get("id", str(body["user"])) - try: - r = requests.post( - url=url, - json=filtered_body, - headers=headers, - stream=True, - ) + # If it's an o1 model, do a "fake streaming" approach + if is_o1_model(model_id): + # We'll remove "stream" from the body if present (since we'll do manual streaming), + # then filter to the allowed params for o1 models. + body.pop("stream", None) + filtered_body = {k: v for k, v in body.items() if k in allowed_params_o1} - r.raise_for_status() - if body["stream"]: - return r.iter_lines() - else: - return r.json() - except Exception as e: - if r: - text = r.text - return f"Error: {e} ({text})" - else: - return f"Error: {e}" + # Log which fields were dropped + if len(body) != len(filtered_body): + dropped_keys = set(body.keys()) - set(filtered_body.keys()) + print(f"Dropped params: {', '.join(dropped_keys)}") + + try: + # We make a normal request (non-streaming) + r = requests.post( + url=url, + json=filtered_body, + headers=headers, + stream=False, + ) + r.raise_for_status() + + # Parse the full JSON response + data = r.json() + + # Typically, the text content is in data["choices"][0]["message"]["content"] + # This may vary depending on your actual response shape. + # For safety, let's do a little fallback: + content = "" + if ( + isinstance(data, dict) + and "choices" in data + and isinstance(data["choices"], list) + and len(data["choices"]) > 0 + and "message" in data["choices"][0] + and "content" in data["choices"][0]["message"] + ): + content = data["choices"][0]["message"]["content"] + else: + # fallback to something, or just return the raw data + # but let's handle the "fun" streaming of partial content + content = str(data) + + # We will chunk the text to simulate streaming + def chunk_text(text: str, chunk_size: int = 30) -> Generator[str, None, None]: + """Yield text in fixed-size chunks.""" + for i in range(0, len(text), chunk_size): + yield text[i : i + chunk_size] + + # Return a generator that yields chunks + def fake_stream() -> Generator[str, None, None]: + for chunk in chunk_text(content): + yield chunk + + return fake_stream() + + except Exception as e: + # If the request object exists, return its text + if "r" in locals() and r is not None: + return f"Error: {e} ({r.text})" + else: + return f"Error: {e}" + + else: + # Normal pipeline for non-o1 models: + filtered_body = {k: v for k, v in body.items() if k in allowed_params_default} + if len(body) != len(filtered_body): + dropped_keys = set(body.keys()) - set(filtered_body.keys()) + print(f"Dropped params: {', '.join(dropped_keys)}") + + try: + r = requests.post( + url=url, + json=filtered_body, + headers=headers, + stream=True, + ) + r.raise_for_status() + + if filtered_body.get("stream"): + # Real streaming + return r.iter_lines() + else: + # Just return the JSON + return r.json() + + except Exception as e: + if "r" in locals() and r is not None: + return f"Error: {e} ({r.text})" + else: + return f"Error: {e}" \ No newline at end of file