From 3acd4d620c128f3b306e82ad11d4d502b6073049 Mon Sep 17 00:00:00 2001 From: dominik Date: Mon, 17 Mar 2025 08:45:05 +0100 Subject: [PATCH] fixes https://github.com/open-webui/pipelines/issues/411 and https://github.com/open-webui/pipelines/issues/359 --- main.py | 265 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 216 insertions(+), 49 deletions(-) diff --git a/main.py b/main.py index cff3335..ae2c30b 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,7 @@ import json import uuid import sys import subprocess - +import inspect from config import API_KEY, PIPELINES_DIR @@ -667,56 +667,148 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): detail=f"Pipeline {form_data.model} not found", ) - def job(): - print(form_data.model) + pipeline = app.state.PIPELINES[form_data.model] + pipeline_id = form_data.model - pipeline = app.state.PIPELINES[form_data.model] - pipeline_id = form_data.model + if pipeline["type"] == "manifold": + manifold_id, pipeline_id = pipeline_id.split(".", 1) + pipe = PIPELINE_MODULES[manifold_id].pipe + else: + pipe = PIPELINE_MODULES[pipeline_id].pipe - print(pipeline_id) - - if pipeline["type"] == "manifold": - manifold_id, pipeline_id = pipeline_id.split(".", 1) - pipe = PIPELINE_MODULES[manifold_id].pipe - else: - pipe = PIPELINE_MODULES[pipeline_id].pipe - - if form_data.stream: - - def stream_content(): - res = pipe( + is_async = inspect.iscoroutinefunction(pipe) + is_async_gen = inspect.isasyncgenfunction(pipe) + + # Helper function to ensure line is a string + def ensure_string(line): + if isinstance(line, bytes): + return line.decode("utf-8") + return str(line) + + if form_data.stream: + async def stream_content(): + if is_async_gen: + pipe_gen = pipe( user_message=user_message, model_id=pipeline_id, messages=messages, body=form_data.model_dump(), ) - - logging.info(f"stream:true:{res}") - + + async for line in pipe_gen: + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + + line = ensure_string(line) + logging.info(f"stream_content:AsyncGeneratorFunction:{line}") + + if line.startswith("data:"): + yield f"{line}\n\n" + else: + line = stream_message_template(form_data.model, line) + yield f"data: {json.dumps(line)}\n\n" + + finish_message = { + "id": f"{form_data.model}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": form_data.model, + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + yield f"data: {json.dumps(finish_message)}\n\n" + yield f"data: [DONE]" + + elif is_async: + res = await pipe( + user_message=user_message, + model_id=pipeline_id, + messages=messages, + body=form_data.model_dump(), + ) + + logging.info(f"stream:true:async:{res}") + if isinstance(res, str): message = stream_message_template(form_data.model, res) - logging.info(f"stream_content:str:{message}") + logging.info(f"stream_content:str:async:{message}") yield f"data: {json.dumps(message)}\n\n" - - if isinstance(res, Iterator): - for line in res: + + elif inspect.isasyncgen(res): + async for line in res: if isinstance(line, BaseModel): line = line.model_dump_json() line = f"data: {line}" - - try: - line = line.decode("utf-8") - except: - pass - - logging.info(f"stream_content:Generator:{line}") - + + line = ensure_string(line) + logging.info(f"stream_content:AsyncGenerator:{line}") + if line.startswith("data:"): yield f"{line}\n\n" else: line = stream_message_template(form_data.model, line) yield f"data: {json.dumps(line)}\n\n" - + + if isinstance(res, str) or inspect.isasyncgen(res): + finish_message = { + "id": f"{form_data.model}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": form_data.model, + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + yield f"data: {json.dumps(finish_message)}\n\n" + yield f"data: [DONE]" + + else: + def sync_job(): + res = pipe( + user_message=user_message, + model_id=pipeline_id, + messages=messages, + body=form_data.model_dump(), + ) + return res + + res = await run_in_threadpool(sync_job) + logging.info(f"stream:true:sync:{res}") + + if isinstance(res, str): + message = stream_message_template(form_data.model, res) + logging.info(f"stream_content:str:{message}") + yield f"data: {json.dumps(message)}\n\n" + + if isinstance(res, Iterator): + for line in res: + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + + line = ensure_string(line) + logging.info(f"stream_content:Generator:{line}") + + if line.startswith("data:"): + yield f"{line}\n\n" + else: + line = stream_message_template(form_data.model, line) + yield f"data: {json.dumps(line)}\n\n" + if isinstance(res, str) or isinstance(res, Generator): finish_message = { "id": f"{form_data.model}-{str(uuid.uuid4())}", @@ -732,36 +824,68 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): } ], } - + yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - else: - res = pipe( + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + if is_async_gen: + pipe_gen = pipe( user_message=user_message, model_id=pipeline_id, messages=messages, body=form_data.model_dump(), ) - logging.info(f"stream:false:{res}") - + + message = "" + async for stream in pipe_gen: + stream = ensure_string(stream) + message = f"{message}{stream}" + + logging.info(f"stream:false:async_gen_function:{message}") + return { + "id": f"{form_data.model}-{str(uuid.uuid4())}", + "object": "chat.completion", + "created": int(time.time()), + "model": form_data.model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + elif is_async: + res = await pipe( + user_message=user_message, + model_id=pipeline_id, + messages=messages, + body=form_data.model_dump(), + ) + logging.info(f"stream:false:async:{res}") + if isinstance(res, dict): return res elif isinstance(res, BaseModel): return res.model_dump() else: - message = "" - + if isinstance(res, str): message = res - - if isinstance(res, Generator): - for stream in res: + + elif inspect.isasyncgen(res): + async for stream in res: + stream = ensure_string(stream) message = f"{message}{stream}" - - logging.info(f"stream:false:{message}") + + logging.info(f"stream:false:async:{message}") return { "id": f"{form_data.model}-{str(uuid.uuid4())}", "object": "chat.completion", @@ -779,5 +903,48 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): } ], } - - return await run_in_threadpool(job) + else: + def job(): + res = pipe( + user_message=user_message, + model_id=pipeline_id, + messages=messages, + body=form_data.model_dump(), + ) + logging.info(f"stream:false:sync:{res}") + + if isinstance(res, dict): + return res + elif isinstance(res, BaseModel): + return res.model_dump() + else: + message = "" + + if isinstance(res, str): + message = res + + if isinstance(res, Generator): + for stream in res: + stream = ensure_string(stream) + message = f"{message}{stream}" + + logging.info(f"stream:false:sync:{message}") + return { + "id": f"{form_data.model}-{str(uuid.uuid4())}", + "object": "chat.completion", + "created": int(time.time()), + "model": form_data.model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + return await run_in_threadpool(job)