mirror of
https://github.com/matatonic/openedai-speech
synced 2025-06-26 18:16:32 +00:00
182 lines
5.7 KiB
Python
182 lines
5.7 KiB
Python
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import PlainTextResponse, JSONResponse
|
|
from loguru import logger
|
|
|
|
class OpenAIError(Exception):
|
|
pass
|
|
|
|
class APIError(OpenAIError):
|
|
message: str
|
|
code: str = None
|
|
param: str = None
|
|
type: str = None
|
|
|
|
def __init__(self, message: str, code: int = 500, param: str = None, internal_message: str = ''):
|
|
super().__init__(message)
|
|
self.message = message
|
|
self.code = code
|
|
self.param = param
|
|
self.type = self.__class__.__name__,
|
|
self.internal_message = internal_message
|
|
|
|
def __repr__(self):
|
|
return "%s(message=%r, code=%d, param=%s)" % (
|
|
self.__class__.__name__,
|
|
self.message,
|
|
self.code,
|
|
self.param,
|
|
)
|
|
|
|
class InternalServerError(APIError):
|
|
pass
|
|
|
|
class ServiceUnavailableError(APIError):
|
|
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
|
|
super().__init__(message, code, internal_message)
|
|
|
|
class APIStatusError(APIError):
|
|
status_code: int = 400
|
|
|
|
def __init__(self, message: str, param: str = None, internal_message: str = ''):
|
|
super().__init__(message, self.status_code, param, internal_message)
|
|
|
|
class BadRequestError(APIStatusError):
|
|
status_code: int = 400
|
|
|
|
class AuthenticationError(APIStatusError):
|
|
status_code: int = 401
|
|
|
|
class PermissionDeniedError(APIStatusError):
|
|
status_code: int = 403
|
|
|
|
class NotFoundError(APIStatusError):
|
|
status_code: int = 404
|
|
|
|
class ConflictError(APIStatusError):
|
|
status_code: int = 409
|
|
|
|
class UnprocessableEntityError(APIStatusError):
|
|
status_code: int = 422
|
|
|
|
class RateLimitError(APIStatusError):
|
|
status_code: int = 429
|
|
|
|
class OpenAIStub(FastAPI):
|
|
def __init__(self, **kwargs) -> None:
|
|
super().__init__(**kwargs)
|
|
self.models = {}
|
|
|
|
self.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"]
|
|
)
|
|
|
|
@self.exception_handler(Exception)
|
|
def openai_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
|
# Generic server errors
|
|
#logger.opt(exception=exc).error("Logging exception traceback")
|
|
|
|
return JSONResponse(status_code=500, content={
|
|
'message': 'InternalServerError',
|
|
'code': 500,
|
|
})
|
|
|
|
@self.exception_handler(APIError)
|
|
def openai_apierror_handler(request: Request, exc: APIError) -> JSONResponse:
|
|
# Server error
|
|
logger.opt(exception=exc).error("Logging exception traceback")
|
|
|
|
if exc.internal_message:
|
|
logger.info(exc.internal_message)
|
|
|
|
return JSONResponse(status_code = exc.code, content={
|
|
'message': exc.message,
|
|
'code': exc.code,
|
|
'type': exc.__class__.__name__,
|
|
'param': exc.param,
|
|
})
|
|
|
|
@self.exception_handler(APIStatusError)
|
|
def openai_statuserror_handler(request: Request, exc: APIStatusError) -> JSONResponse:
|
|
# Client side error
|
|
logger.info(repr(exc))
|
|
|
|
if exc.internal_message:
|
|
logger.info(exc.internal_message)
|
|
|
|
return JSONResponse(status_code = exc.code, content={
|
|
'message': exc.message,
|
|
'code': exc.code,
|
|
'type': exc.__class__.__name__,
|
|
'param': exc.param,
|
|
})
|
|
|
|
@self.middleware("http")
|
|
async def log_requests(request: Request, call_next):
|
|
logger.debug(f"Request path: {request.url.path}")
|
|
logger.debug(f"Request method: {request.method}")
|
|
logger.debug(f"Request headers: {request.headers}")
|
|
logger.debug(f"Request query params: {request.query_params}")
|
|
logger.debug(f"Request body: {await request.body()}")
|
|
|
|
response = await call_next(request)
|
|
|
|
logger.debug(f"Response status code: {response.status_code}")
|
|
logger.debug(f"Response headers: {response.headers}")
|
|
|
|
return response
|
|
|
|
@self.get('/v1/billing/usage')
|
|
@self.get('/v1/dashboard/billing/usage')
|
|
async def handle_billing_usage():
|
|
return { 'total_usage': 0 }
|
|
|
|
@self.get("/", response_class=PlainTextResponse)
|
|
@self.head("/", response_class=PlainTextResponse)
|
|
@self.options("/", response_class=PlainTextResponse)
|
|
async def root():
|
|
return PlainTextResponse(content="", status_code=200 if self.models else 503)
|
|
|
|
@self.get("/health")
|
|
async def health():
|
|
return {"status": "ok" if self.models else "unk" }
|
|
|
|
@self.get("/v1/models")
|
|
async def get_model_list():
|
|
return self.model_list()
|
|
|
|
@self.get("/v1/models/{model}")
|
|
async def get_model_info(model_id: str):
|
|
return self.model_info(model_id)
|
|
|
|
def register_model(self, name: str, model: str = None) -> None:
|
|
self.models[name] = model if model else name
|
|
|
|
def deregister_model(self, name: str) -> None:
|
|
if name in self.models:
|
|
del self.models[name]
|
|
|
|
def model_info(self, model: str) -> dict:
|
|
result = {
|
|
"id": model,
|
|
"object": "model",
|
|
"created": 0,
|
|
"owned_by": "user"
|
|
}
|
|
return result
|
|
|
|
def model_list(self) -> dict:
|
|
if not self.models:
|
|
return {}
|
|
|
|
result = {
|
|
"object": "list",
|
|
"data": [ self.model_info(model) for model in list(set(self.models.keys() | self.models.values())) if model ]
|
|
}
|
|
|
|
return result
|