diff --git a/main.py b/main.py index 63e48fb..50744d4 100644 --- a/main.py +++ b/main.py @@ -191,7 +191,16 @@ async def filter(form_data: FilterForm): ) pipeline = PIPELINE_MODULES[form_data.model] - return await pipeline.filter(form_data.body, form_data.user) + + try: + body = await pipeline.filter(form_data.body, form_data.user) + return body + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error in filter {form_data.model}: {str(e)}", + ) @app.post("/chat/completions") diff --git a/pipelines/examples/rate_limit_filter_pipeline.py b/pipelines/examples/rate_limit_filter_pipeline.py new file mode 100644 index 0000000..a3bea03 --- /dev/null +++ b/pipelines/examples/rate_limit_filter_pipeline.py @@ -0,0 +1,94 @@ +from typing import List, Optional +from schemas import OpenAIChatMessage +import time + + +class Pipeline: + def __init__(self): + # Pipeline filters are only compatible with Open WebUI + # You can think of filter pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API. + self.type = "filter" + + self.id = "rate_limit_filter_pipeline" + self.name = "Rate Limit Filter" + + # Assign a priority level to the filter pipeline. + # The priority level determines the order in which the filter pipelines are executed. + # The lower the number, the higher the priority. + self.priority = 0 + + # List target pipelines (models) that this filter will be connected to. + self.pipelines = ["*"] + + pass + + # Initialize rate limits + self.requests_per_minute: Optional[int] = 60 + self.requests_per_hour: Optional[int] = 1000 + self.sliding_window_limit: Optional[int] = 100 + self.sliding_window_minutes: Optional[int] = 15 + + # Tracking data - user_id -> (timestamps of requests) + self.user_requests = {} + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + pass + + def prune_requests(self, user_id: str): + """Prune old requests that are outside of the sliding window period.""" + now = time.time() + if user_id in self.user_requests: + self.user_requests[user_id] = [ + req + for req in self.user_requests[user_id] + if now - req < self.sliding_window_minutes * 60 + ] + + def log_request(self, user_id: str): + """Log a new request for a user.""" + now = time.time() + if user_id not in self.user_requests: + self.user_requests[user_id] = [] + self.user_requests[user_id].append(now) + + def rate_limited(self, user_id: str) -> bool: + """Check if a user is rate limited.""" + self.prune_requests(user_id) + + user_reqs = self.user_requests.get(user_id, []) + + if self.requests_per_minute is not None: + requests_last_minute = sum(1 for req in user_reqs if time.time() - req < 60) + if requests_last_minute >= self.requests_per_minute: + return True + + if self.requests_per_hour is not None: + requests_last_hour = sum(1 for req in user_reqs if time.time() - req < 3600) + if requests_last_hour >= self.requests_per_hour: + return True + + if self.sliding_window_limit is not None: + requests_in_window = len(user_reqs) + if requests_in_window >= self.sliding_window_limit: + return True + + return False + + async def filter(self, body: dict, user: Optional[dict] = None) -> dict: + print(f"pipe:{__name__}") + print(body) + print(user) + + user_id = user["id"] if user and "id" in user else "default_user" + if self.rate_limited(user_id): + raise Exception("Rate limit exceeded. Please try again later.") + + self.log_request(user_id) + return body diff --git a/pipelines/filter_pipeline.py b/pipelines/filter_pipeline.py deleted file mode 100644 index 3590e4a..0000000 --- a/pipelines/filter_pipeline.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import List, Optional -from schemas import OpenAIChatMessage - - -class Pipeline: - def __init__(self): - # Pipeline filters are only compatible with Open WebUI - # You can think of filter pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API. - self.type = "filter" - - self.id = "filter_pipeline" - self.name = "Filter" - - # Assign a priority level to the filter pipeline. - # The priority level determines the order in which the filter pipelines are executed. - # The lower the number, the higher the priority. - self.priority = 0 - - # List target pipelines (models) that this filter will be connected to. - self.pipelines = [ - {"id": "llama3:latest"}, - ] - pass - - async def on_startup(self): - # This function is called when the server is started. - print(f"on_startup:{__name__}") - pass - - async def on_shutdown(self): - # This function is called when the server is stopped. - print(f"on_shutdown:{__name__}") - pass - - async def filter(self, body: dict, user: Optional[dict] = None) -> dict: - print(f"pipe:{__name__}") - - print(body) - print(user) - - return body diff --git a/pipelines/rate_limit_filter_pipeline.py b/pipelines/rate_limit_filter_pipeline.py new file mode 100644 index 0000000..a3bea03 --- /dev/null +++ b/pipelines/rate_limit_filter_pipeline.py @@ -0,0 +1,94 @@ +from typing import List, Optional +from schemas import OpenAIChatMessage +import time + + +class Pipeline: + def __init__(self): + # Pipeline filters are only compatible with Open WebUI + # You can think of filter pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API. + self.type = "filter" + + self.id = "rate_limit_filter_pipeline" + self.name = "Rate Limit Filter" + + # Assign a priority level to the filter pipeline. + # The priority level determines the order in which the filter pipelines are executed. + # The lower the number, the higher the priority. + self.priority = 0 + + # List target pipelines (models) that this filter will be connected to. + self.pipelines = ["*"] + + pass + + # Initialize rate limits + self.requests_per_minute: Optional[int] = 60 + self.requests_per_hour: Optional[int] = 1000 + self.sliding_window_limit: Optional[int] = 100 + self.sliding_window_minutes: Optional[int] = 15 + + # Tracking data - user_id -> (timestamps of requests) + self.user_requests = {} + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + pass + + def prune_requests(self, user_id: str): + """Prune old requests that are outside of the sliding window period.""" + now = time.time() + if user_id in self.user_requests: + self.user_requests[user_id] = [ + req + for req in self.user_requests[user_id] + if now - req < self.sliding_window_minutes * 60 + ] + + def log_request(self, user_id: str): + """Log a new request for a user.""" + now = time.time() + if user_id not in self.user_requests: + self.user_requests[user_id] = [] + self.user_requests[user_id].append(now) + + def rate_limited(self, user_id: str) -> bool: + """Check if a user is rate limited.""" + self.prune_requests(user_id) + + user_reqs = self.user_requests.get(user_id, []) + + if self.requests_per_minute is not None: + requests_last_minute = sum(1 for req in user_reqs if time.time() - req < 60) + if requests_last_minute >= self.requests_per_minute: + return True + + if self.requests_per_hour is not None: + requests_last_hour = sum(1 for req in user_reqs if time.time() - req < 3600) + if requests_last_hour >= self.requests_per_hour: + return True + + if self.sliding_window_limit is not None: + requests_in_window = len(user_reqs) + if requests_in_window >= self.sliding_window_limit: + return True + + return False + + async def filter(self, body: dict, user: Optional[dict] = None) -> dict: + print(f"pipe:{__name__}") + print(body) + print(user) + + user_id = user["id"] if user and "id" in user else "default_user" + if self.rate_limited(user_id): + raise Exception("Rate limit exceeded. Please try again later.") + + self.log_request(user_id) + return body