dominik 2025-03-17 08:45:05 +01:00 committed by GitHub
parent f89ab37f53
commit 3acd4d620c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

265
main.py
View File

@ -27,7 +27,7 @@ import json
import uuid import uuid
import sys import sys
import subprocess import subprocess
import inspect
from config import API_KEY, PIPELINES_DIR from config import API_KEY, PIPELINES_DIR
@ -667,56 +667,148 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
detail=f"Pipeline {form_data.model} not found", detail=f"Pipeline {form_data.model} not found",
) )
def job(): pipeline = app.state.PIPELINES[form_data.model]
print(form_data.model) pipeline_id = form_data.model
pipeline = app.state.PIPELINES[form_data.model] if pipeline["type"] == "manifold":
pipeline_id = form_data.model manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipe = PIPELINE_MODULES[manifold_id].pipe
else:
pipe = PIPELINE_MODULES[pipeline_id].pipe
print(pipeline_id) is_async = inspect.iscoroutinefunction(pipe)
is_async_gen = inspect.isasyncgenfunction(pipe)
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1) # Helper function to ensure line is a string
pipe = PIPELINE_MODULES[manifold_id].pipe def ensure_string(line):
else: if isinstance(line, bytes):
pipe = PIPELINE_MODULES[pipeline_id].pipe return line.decode("utf-8")
return str(line)
if form_data.stream:
if form_data.stream:
def stream_content(): async def stream_content():
res = pipe( if is_async_gen:
pipe_gen = pipe(
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
logging.info(f"stream:true:{res}") async for line in pipe_gen:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
line = ensure_string(line)
logging.info(f"stream_content:AsyncGeneratorFunction:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
elif is_async:
res = await pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:true:async:{res}")
if isinstance(res, str): if isinstance(res, str):
message = stream_message_template(form_data.model, res) message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}") logging.info(f"stream_content:str:async:{message}")
yield f"data: {json.dumps(message)}\n\n" yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator): elif inspect.isasyncgen(res):
for line in res: async for line in res:
if isinstance(line, BaseModel): if isinstance(line, BaseModel):
line = line.model_dump_json() line = line.model_dump_json()
line = f"data: {line}" line = f"data: {line}"
try: line = ensure_string(line)
line = line.decode("utf-8") logging.info(f"stream_content:AsyncGenerator:{line}")
except:
pass
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"): if line.startswith("data:"):
yield f"{line}\n\n" yield f"{line}\n\n"
else: else:
line = stream_message_template(form_data.model, line) line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n" yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or inspect.isasyncgen(res):
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
else:
def sync_job():
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
return res
res = await run_in_threadpool(sync_job)
logging.info(f"stream:true:sync:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
line = ensure_string(line)
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator): if isinstance(res, str) or isinstance(res, Generator):
finish_message = { finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}", "id": f"{form_data.model}-{str(uuid.uuid4())}",
@ -732,36 +824,68 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
} }
], ],
} }
yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]" yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
res = pipe( if is_async_gen:
pipe_gen = pipe(
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
logging.info(f"stream:false:{res}")
message = ""
async for stream in pipe_gen:
stream = ensure_string(stream)
message = f"{message}{stream}"
logging.info(f"stream:false:async_gen_function:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
elif is_async:
res = await pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:async:{res}")
if isinstance(res, dict): if isinstance(res, dict):
return res return res
elif isinstance(res, BaseModel): elif isinstance(res, BaseModel):
return res.model_dump() return res.model_dump()
else: else:
message = "" message = ""
if isinstance(res, str): if isinstance(res, str):
message = res message = res
if isinstance(res, Generator): elif inspect.isasyncgen(res):
for stream in res: async for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}" message = f"{message}{stream}"
logging.info(f"stream:false:{message}") logging.info(f"stream:false:async:{message}")
return { return {
"id": f"{form_data.model}-{str(uuid.uuid4())}", "id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion", "object": "chat.completion",
@ -779,5 +903,48 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
} }
], ],
} }
else:
return await run_in_threadpool(job) def job():
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:sync:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}"
logging.info(f"stream:false:sync:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)