refac: tools

This commit is contained in:
Timothy J. Baek 2024-06-11 11:31:14 -07:00
parent 9d16dd997a
commit 0bb26ae504
2 changed files with 12 additions and 20 deletions
backend
apps/socket
main.py

View File

@ -19,8 +19,6 @@ TIMEOUT_DURATION = 3
@sio.event @sio.event
async def connect(sid, environ, auth): async def connect(sid, environ, auth):
print("connect ", sid)
user = None user = None
if auth and "token" in auth: if auth and "token" in auth:
data = decode_token(auth["token"]) data = decode_token(auth["token"])
@ -37,7 +35,6 @@ async def connect(sid, environ, auth):
print(f"user {user.name}({user.id}) connected with session ID {sid}") print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()}) await sio.emit("usage", {"models": get_models_in_use()})
@ -64,13 +61,11 @@ async def user_join(sid, data):
print(f"user {user.name}({user.id}) connected with session ID {sid}") print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count") @sio.on("user-count")
async def user_count(sid): async def user_count(sid):
print("user-count", sid)
await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("user-count", {"count": len(set(USER_POOL))})
@ -79,14 +74,12 @@ def get_models_in_use():
models_in_use = [] models_in_use = []
for model_id, data in USAGE_POOL.items(): for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id) models_in_use.append(model_id)
print(f"Models in use: {models_in_use}")
return models_in_use return models_in_use
@sio.on("usage") @sio.on("usage")
async def usage(sid, data): async def usage(sid, data):
print(f'Received "usage" event from {sid}: {data}')
model_id = data["model"] model_id = data["model"]
@ -114,7 +107,6 @@ async def usage(sid, data):
async def remove_after_timeout(sid, model_id): async def remove_after_timeout(sid, model_id):
try: try:
print("remove_after_timeout", sid, model_id)
await asyncio.sleep(TIMEOUT_DURATION) await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL: if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"]) print(USAGE_POOL[model_id]["sids"])
@ -124,7 +116,6 @@ async def remove_after_timeout(sid, model_id):
if len(USAGE_POOL[model_id]["sids"]) == 0: if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id] del USAGE_POOL[model_id]
print(f"Removed usage data for {model_id} due to timeout")
# Broadcast the usage data to all clients # Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()}) await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError: except asyncio.CancelledError:
@ -143,9 +134,6 @@ async def disconnect(sid):
if len(USER_POOL[user_id]) == 0: if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id] del USER_POOL[user_id]
print(f"user {user_id} disconnected with session ID {sid}")
print(USER_POOL)
await sio.emit("user-count", {"count": len(USER_POOL)}) await sio.emit("user-count", {"count": len(USER_POOL)})
else: else:
print(f"Unknown session ID {sid} disconnected") print(f"Unknown session ID {sid} disconnected")

View File

@ -178,7 +178,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
"History:\n" "History:\n"
+ "\n".join( + "\n".join(
[ [
f"{message['role']}: {message['content']}" f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4] for message in messages[::-1][:4]
] ]
) )
@ -209,17 +209,21 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
response = await generate_openai_chat_completion(payload, user=user) response = await generate_openai_chat_completion(payload, user=user)
content = None content = None
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary if hasattr(response, "body_iterator"):
if response.background is not None: async for chunk in response.body_iterator:
await response.background() data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
# Parse the function response # Parse the function response
if content is not None: if content is not None:
print(content) print(f"content: {content}")
result = json.loads(content) result = json.loads(content)
print(result) print(result)