diff --git a/main.py b/main.py index ae2c30b..8204540 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import inspect from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool @@ -666,149 +667,59 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): status_code=status.HTTP_404_NOT_FOUND, detail=f"Pipeline {form_data.model} not found", ) - - pipeline = app.state.PIPELINES[form_data.model] - pipeline_id = form_data.model - - if pipeline["type"] == "manifold": - manifold_id, pipeline_id = pipeline_id.split(".", 1) - pipe = PIPELINE_MODULES[manifold_id].pipe - else: - pipe = PIPELINE_MODULES[pipeline_id].pipe - - is_async = inspect.iscoroutinefunction(pipe) - is_async_gen = inspect.isasyncgenfunction(pipe) - # Helper function to ensure line is a string - def ensure_string(line): - if isinstance(line, bytes): - return line.decode("utf-8") - return str(line) - - if form_data.stream: - async def stream_content(): - if is_async_gen: - pipe_gen = pipe( + async def execute_pipe(pipe, *args, **kwargs): + if inspect.isasyncgenfunction(pipe): + async for res in pipe(*args, **kwargs): + yield res + elif inspect.iscoroutinefunction(pipe): + for item in await pipe(*args, **kwargs): + yield item + else: + for item in await run_in_threadpool(pipe, *args, **kwargs): + yield item + + async def job(): + print(form_data.model) + + pipeline = app.state.PIPELINES[form_data.model] + pipeline_id = form_data.model + + print(pipeline_id) + + if pipeline["type"] == "manifold": + manifold_id, pipeline_id = pipeline_id.split(".", 1) + pipe = PIPELINE_MODULES[manifold_id].pipe + else: + pipe = PIPELINE_MODULES[pipeline_id].pipe + + if form_data.stream: + + async def stream_content(): + res = execute_pipe(pipe, user_message=user_message, model_id=pipeline_id, messages=messages, body=form_data.model_dump(), ) - - async for line in pipe_gen: + async for line in res: if isinstance(line, BaseModel): line = line.model_dump_json() line = f"data: {line}" - - line = ensure_string(line) - logging.info(f"stream_content:AsyncGeneratorFunction:{line}") - + + try: + line = line.decode("utf-8") + except: + pass + + logging.info(f"stream_content:Generator:{line}") + if line.startswith("data:"): yield f"{line}\n\n" else: line = stream_message_template(form_data.model, line) yield f"data: {json.dumps(line)}\n\n" - - finish_message = { - "id": f"{form_data.model}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": form_data.model, - "choices": [ - { - "index": 0, - "delta": {}, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - - yield f"data: {json.dumps(finish_message)}\n\n" - yield f"data: [DONE]" - - elif is_async: - res = await pipe( - user_message=user_message, - model_id=pipeline_id, - messages=messages, - body=form_data.model_dump(), - ) - - logging.info(f"stream:true:async:{res}") - - if isinstance(res, str): - message = stream_message_template(form_data.model, res) - logging.info(f"stream_content:str:async:{message}") - yield f"data: {json.dumps(message)}\n\n" - - elif inspect.isasyncgen(res): - async for line in res: - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - - line = ensure_string(line) - logging.info(f"stream_content:AsyncGenerator:{line}") - - if line.startswith("data:"): - yield f"{line}\n\n" - else: - line = stream_message_template(form_data.model, line) - yield f"data: {json.dumps(line)}\n\n" - - if isinstance(res, str) or inspect.isasyncgen(res): - finish_message = { - "id": f"{form_data.model}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": form_data.model, - "choices": [ - { - "index": 0, - "delta": {}, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - - yield f"data: {json.dumps(finish_message)}\n\n" - yield f"data: [DONE]" - - else: - def sync_job(): - res = pipe( - user_message=user_message, - model_id=pipeline_id, - messages=messages, - body=form_data.model_dump(), - ) - return res - - res = await run_in_threadpool(sync_job) - logging.info(f"stream:true:sync:{res}") - - if isinstance(res, str): - message = stream_message_template(form_data.model, res) - logging.info(f"stream_content:str:{message}") - yield f"data: {json.dumps(message)}\n\n" - - if isinstance(res, Iterator): - for line in res: - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - - line = ensure_string(line) - logging.info(f"stream_content:Generator:{line}") - - if line.startswith("data:"): - yield f"{line}\n\n" - else: - line = stream_message_template(form_data.model, line) - yield f"data: {json.dumps(line)}\n\n" - + if isinstance(res, str) or isinstance(res, Generator): finish_message = { "id": f"{form_data.model}-{str(uuid.uuid4())}", @@ -827,23 +738,21 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: [DONE]" - - return StreamingResponse(stream_content(), media_type="text/event-stream") - else: - if is_async_gen: - pipe_gen = pipe( + + return StreamingResponse(stream_content(), media_type="text/event-stream") + else: + res = execute_pipe(pipe, user_message=user_message, model_id=pipeline_id, messages=messages, body=form_data.model_dump(), ) - + message = "" - async for stream in pipe_gen: - stream = ensure_string(stream) + async for stream in res: message = f"{message}{stream}" - - logging.info(f"stream:false:async_gen_function:{message}") + + logging.info(f"stream:false:{message}") return { "id": f"{form_data.model}-{str(uuid.uuid4())}", "object": "chat.completion", @@ -861,90 +770,5 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): } ], } - elif is_async: - res = await pipe( - user_message=user_message, - model_id=pipeline_id, - messages=messages, - body=form_data.model_dump(), - ) - logging.info(f"stream:false:async:{res}") - - if isinstance(res, dict): - return res - elif isinstance(res, BaseModel): - return res.model_dump() - else: - message = "" - - if isinstance(res, str): - message = res - - elif inspect.isasyncgen(res): - async for stream in res: - stream = ensure_string(stream) - message = f"{message}{stream}" - - logging.info(f"stream:false:async:{message}") - return { - "id": f"{form_data.model}-{str(uuid.uuid4())}", - "object": "chat.completion", - "created": int(time.time()), - "model": form_data.model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": message, - }, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - else: - def job(): - res = pipe( - user_message=user_message, - model_id=pipeline_id, - messages=messages, - body=form_data.model_dump(), - ) - logging.info(f"stream:false:sync:{res}") - - if isinstance(res, dict): - return res - elif isinstance(res, BaseModel): - return res.model_dump() - else: - message = "" - - if isinstance(res, str): - message = res - - if isinstance(res, Generator): - for stream in res: - stream = ensure_string(stream) - message = f"{message}{stream}" - - logging.info(f"stream:false:sync:{message}") - return { - "id": f"{form_data.model}-{str(uuid.uuid4())}", - "object": "chat.completion", - "created": int(time.time()), - "model": form_data.model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": message, - }, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - - return await run_in_threadpool(job) + + return await job()