diff --git a/main.py b/main.py index ca9c6c5..c3ff1bd 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import inspect from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool @@ -27,7 +28,7 @@ import json import uuid import sys import subprocess - +import inspect from config import API_KEY, PIPELINES_DIR, LOG_LEVELS @@ -670,8 +671,19 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): status_code=status.HTTP_404_NOT_FOUND, detail=f"Pipeline {form_data.model} not found", ) + + async def execute_pipe(pipe, *args, **kwargs): + if inspect.isasyncgenfunction(pipe): + async for res in pipe(*args, **kwargs): + yield res + elif inspect.iscoroutinefunction(pipe): + for item in await pipe(*args, **kwargs): + yield item + else: + for item in await run_in_threadpool(pipe, *args, **kwargs): + yield item - def job(): + async def job(): print(form_data.model) pipeline = app.state.PIPELINES[form_data.model] @@ -687,39 +699,30 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): if form_data.stream: - def stream_content(): - res = pipe( + async def stream_content(): + res = execute_pipe(pipe, user_message=user_message, model_id=pipeline_id, messages=messages, body=form_data.model_dump(), ) + async for line in res: + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" - logging.info(f"stream:true:{res}") + try: + line = line.decode("utf-8") + except: + pass - if isinstance(res, str): - message = stream_message_template(form_data.model, res) - logging.info(f"stream_content:str:{message}") - yield f"data: {json.dumps(message)}\n\n" + logging.info(f"stream_content:Generator:{line}") - if isinstance(res, Iterator): - for line in res: - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - - try: - line = line.decode("utf-8") - except: - pass - - logging.info(f"stream_content:Generator:{line}") - - if line.startswith("data:"): - yield f"{line}\n\n" - else: - line = stream_message_template(form_data.model, line) - yield f"data: {json.dumps(line)}\n\n" + if line.startswith("data:"): + yield f"{line}\n\n" + else: + line = stream_message_template(form_data.model, line) + yield f"data: {json.dumps(line)}\n\n" if isinstance(res, str) or isinstance(res, Generator): finish_message = { @@ -736,52 +739,40 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): } ], } - + yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: [DONE]" return StreamingResponse(stream_content(), media_type="text/event-stream") else: - res = pipe( + res = execute_pipe(pipe, user_message=user_message, model_id=pipeline_id, messages=messages, body=form_data.model_dump(), ) - logging.info(f"stream:false:{res}") - if isinstance(res, dict): - return res - elif isinstance(res, BaseModel): - return res.model_dump() - else: + message = "" + async for stream in res: + message = f"{message}{stream}" - message = "" + logging.info(f"stream:false:{message}") + return { + "id": f"{form_data.model}-{str(uuid.uuid4())}", + "object": "chat.completion", + "created": int(time.time()), + "model": form_data.model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + } - if isinstance(res, str): - message = res - - if isinstance(res, Generator): - for stream in res: - message = f"{message}{stream}" - - logging.info(f"stream:false:{message}") - return { - "id": f"{form_data.model}-{str(uuid.uuid4())}", - "object": "chat.completion", - "created": int(time.time()), - "model": form_data.model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": message, - }, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - - return await run_in_threadpool(job) + return await job()