mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	feat: preset backend logic
This commit is contained in:
		
							parent
							
								
									7d2ab168f1
								
							
						
					
					
						commit
						88d053833d
					
				@ -875,15 +875,88 @@ async def generate_chat_completion(
 | 
			
		||||
    url_idx: Optional[int] = None,
 | 
			
		||||
    user=Depends(get_verified_user),
 | 
			
		||||
):
 | 
			
		||||
    model_id = get_model_id_from_custom_model_id(form_data.model)
 | 
			
		||||
    model = model_id
 | 
			
		||||
 | 
			
		||||
    log.debug(
 | 
			
		||||
        "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
 | 
			
		||||
            form_data.model_dump_json(exclude_none=True).encode()
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    payload = {
 | 
			
		||||
        **form_data.model_dump(exclude_none=True),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    model_id = form_data.model
 | 
			
		||||
    model_info = Models.get_model_by_id(model_id)
 | 
			
		||||
 | 
			
		||||
    if model_info:
 | 
			
		||||
        print(model_info)
 | 
			
		||||
        if model_info.base_model_id:
 | 
			
		||||
            payload["model"] = model_info.base_model_id
 | 
			
		||||
 | 
			
		||||
        model_info.params = model_info.params.model_dump()
 | 
			
		||||
 | 
			
		||||
        if model_info.params:
 | 
			
		||||
            payload["options"] = {}
 | 
			
		||||
 | 
			
		||||
            payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
 | 
			
		||||
            payload["options"]["mirostat_eta"] = model_info.params.get(
 | 
			
		||||
                "mirostat_eta", None
 | 
			
		||||
            )
 | 
			
		||||
            payload["options"]["mirostat_tau"] = model_info.params.get(
 | 
			
		||||
                "mirostat_tau", None
 | 
			
		||||
            )
 | 
			
		||||
            payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
 | 
			
		||||
 | 
			
		||||
            payload["options"]["repeat_last_n"] = model_info.params.get(
 | 
			
		||||
                "repeat_last_n", None
 | 
			
		||||
            )
 | 
			
		||||
            payload["options"]["repeat_penalty"] = model_info.params.get(
 | 
			
		||||
                "frequency_penalty", None
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            payload["options"]["temperature"] = model_info.params.get(
 | 
			
		||||
                "temperature", None
 | 
			
		||||
            )
 | 
			
		||||
            payload["options"]["seed"] = model_info.params.get("seed", None)
 | 
			
		||||
 | 
			
		||||
            # TODO: add "stop" back in
 | 
			
		||||
            # payload["stop"] = model_info.params.get("stop", None)
 | 
			
		||||
 | 
			
		||||
            payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
 | 
			
		||||
 | 
			
		||||
            payload["options"]["num_predict"] = model_info.params.get(
 | 
			
		||||
                "max_tokens", None
 | 
			
		||||
            )
 | 
			
		||||
            payload["options"]["top_k"] = model_info.params.get("top_k", None)
 | 
			
		||||
 | 
			
		||||
            payload["options"]["top_p"] = model_info.params.get("top_p", None)
 | 
			
		||||
 | 
			
		||||
        if model_info.params.get("system", None):
 | 
			
		||||
            # Check if the payload already has a system message
 | 
			
		||||
            # If not, add a system message to the payload
 | 
			
		||||
            if payload.get("messages"):
 | 
			
		||||
                for message in payload["messages"]:
 | 
			
		||||
                    if message.get("role") == "system":
 | 
			
		||||
                        message["content"] = (
 | 
			
		||||
                            model_info.params.get("system", None) + message["content"]
 | 
			
		||||
                        )
 | 
			
		||||
                        break
 | 
			
		||||
                else:
 | 
			
		||||
                    payload["messages"].insert(
 | 
			
		||||
                        0,
 | 
			
		||||
                        {
 | 
			
		||||
                            "role": "system",
 | 
			
		||||
                            "content": model_info.params.get("system", None),
 | 
			
		||||
                        },
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    if url_idx == None:
 | 
			
		||||
        if ":" not in model:
 | 
			
		||||
            model = f"{model}:latest"
 | 
			
		||||
        if ":" not in payload["model"]:
 | 
			
		||||
            payload["model"] = f"{payload['model']}:latest"
 | 
			
		||||
 | 
			
		||||
        if model in app.state.MODELS:
 | 
			
		||||
            url_idx = random.choice(app.state.MODELS[model]["urls"])
 | 
			
		||||
        if payload["model"] in app.state.MODELS:
 | 
			
		||||
            url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
 | 
			
		||||
        else:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=400,
 | 
			
		||||
@ -893,23 +966,12 @@ async def generate_chat_completion(
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    print(payload)
 | 
			
		||||
 | 
			
		||||
    r = None
 | 
			
		||||
 | 
			
		||||
    # payload = {
 | 
			
		||||
    #     **form_data.model_dump_json(exclude_none=True).encode(),
 | 
			
		||||
    #     "model": model,
 | 
			
		||||
    #     "messages": form_data.messages,
 | 
			
		||||
 | 
			
		||||
    # }
 | 
			
		||||
 | 
			
		||||
    log.debug(
 | 
			
		||||
        "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
 | 
			
		||||
            form_data.model_dump_json(exclude_none=True).encode()
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def get_request():
 | 
			
		||||
        nonlocal form_data
 | 
			
		||||
        nonlocal payload
 | 
			
		||||
        nonlocal r
 | 
			
		||||
 | 
			
		||||
        request_id = str(uuid.uuid4())
 | 
			
		||||
@ -918,7 +980,7 @@ async def generate_chat_completion(
 | 
			
		||||
 | 
			
		||||
            def stream_content():
 | 
			
		||||
                try:
 | 
			
		||||
                    if form_data.stream:
 | 
			
		||||
                    if payload.get("stream", None):
 | 
			
		||||
                        yield json.dumps({"id": request_id, "done": False}) + "\n"
 | 
			
		||||
 | 
			
		||||
                    for chunk in r.iter_content(chunk_size=8192):
 | 
			
		||||
@ -936,7 +998,7 @@ async def generate_chat_completion(
 | 
			
		||||
            r = requests.request(
 | 
			
		||||
                method="POST",
 | 
			
		||||
                url=f"{url}/api/chat",
 | 
			
		||||
                data=form_data.model_dump_json(exclude_none=True).encode(),
 | 
			
		||||
                data=json.dumps(payload),
 | 
			
		||||
                stream=True,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -992,14 +1054,56 @@ async def generate_openai_chat_completion(
 | 
			
		||||
    user=Depends(get_verified_user),
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    payload = {
 | 
			
		||||
        **form_data.model_dump(exclude_none=True),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    model_id = form_data.model
 | 
			
		||||
    model_info = Models.get_model_by_id(model_id)
 | 
			
		||||
 | 
			
		||||
    if model_info:
 | 
			
		||||
        print(model_info)
 | 
			
		||||
        if model_info.base_model_id:
 | 
			
		||||
            payload["model"] = model_info.base_model_id
 | 
			
		||||
 | 
			
		||||
        model_info.params = model_info.params.model_dump()
 | 
			
		||||
 | 
			
		||||
        if model_info.params:
 | 
			
		||||
            payload["temperature"] = model_info.params.get("temperature", None)
 | 
			
		||||
            payload["top_p"] = model_info.params.get("top_p", None)
 | 
			
		||||
            payload["max_tokens"] = model_info.params.get("max_tokens", None)
 | 
			
		||||
            payload["frequency_penalty"] = model_info.params.get(
 | 
			
		||||
                "frequency_penalty", None
 | 
			
		||||
            )
 | 
			
		||||
            payload["seed"] = model_info.params.get("seed", None)
 | 
			
		||||
            # TODO: add "stop" back in
 | 
			
		||||
            # payload["stop"] = model_info.params.get("stop", None)
 | 
			
		||||
 | 
			
		||||
        if model_info.params.get("system", None):
 | 
			
		||||
            # Check if the payload already has a system message
 | 
			
		||||
            # If not, add a system message to the payload
 | 
			
		||||
            if payload.get("messages"):
 | 
			
		||||
                for message in payload["messages"]:
 | 
			
		||||
                    if message.get("role") == "system":
 | 
			
		||||
                        message["content"] = (
 | 
			
		||||
                            model_info.params.get("system", None) + message["content"]
 | 
			
		||||
                        )
 | 
			
		||||
                        break
 | 
			
		||||
                else:
 | 
			
		||||
                    payload["messages"].insert(
 | 
			
		||||
                        0,
 | 
			
		||||
                        {
 | 
			
		||||
                            "role": "system",
 | 
			
		||||
                            "content": model_info.params.get("system", None),
 | 
			
		||||
                        },
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    if url_idx == None:
 | 
			
		||||
        model = form_data.model
 | 
			
		||||
        if ":" not in payload["model"]:
 | 
			
		||||
            payload["model"] = f"{payload['model']}:latest"
 | 
			
		||||
 | 
			
		||||
        if ":" not in model:
 | 
			
		||||
            model = f"{model}:latest"
 | 
			
		||||
 | 
			
		||||
        if model in app.state.MODELS:
 | 
			
		||||
            url_idx = random.choice(app.state.MODELS[model]["urls"])
 | 
			
		||||
        if payload["model"] in app.state.MODELS:
 | 
			
		||||
            url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
 | 
			
		||||
        else:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=400,
 | 
			
		||||
@ -1012,7 +1116,7 @@ async def generate_openai_chat_completion(
 | 
			
		||||
    r = None
 | 
			
		||||
 | 
			
		||||
    def get_request():
 | 
			
		||||
        nonlocal form_data
 | 
			
		||||
        nonlocal payload
 | 
			
		||||
        nonlocal r
 | 
			
		||||
 | 
			
		||||
        request_id = str(uuid.uuid4())
 | 
			
		||||
@ -1021,7 +1125,7 @@ async def generate_openai_chat_completion(
 | 
			
		||||
 | 
			
		||||
            def stream_content():
 | 
			
		||||
                try:
 | 
			
		||||
                    if form_data.stream:
 | 
			
		||||
                    if payload.get("stream"):
 | 
			
		||||
                        yield json.dumps(
 | 
			
		||||
                            {"request_id": request_id, "done": False}
 | 
			
		||||
                        ) + "\n"
 | 
			
		||||
@ -1041,7 +1145,7 @@ async def generate_openai_chat_completion(
 | 
			
		||||
            r = requests.request(
 | 
			
		||||
                method="POST",
 | 
			
		||||
                url=f"{url}/v1/chat/completions",
 | 
			
		||||
                data=form_data.model_dump_json(exclude_none=True).encode(),
 | 
			
		||||
                data=json.dumps(payload),
 | 
			
		||||
                stream=True,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -315,41 +315,87 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
    body = await request.body()
 | 
			
		||||
    # TODO: Remove below after gpt-4-vision fix from Open AI
 | 
			
		||||
    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
 | 
			
		||||
 | 
			
		||||
    payload = None
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        body = body.decode("utf-8")
 | 
			
		||||
        body = json.loads(body)
 | 
			
		||||
        if "chat/completions" in path:
 | 
			
		||||
            body = body.decode("utf-8")
 | 
			
		||||
            body = json.loads(body)
 | 
			
		||||
 | 
			
		||||
        print(app.state.MODELS)
 | 
			
		||||
            payload = {**body}
 | 
			
		||||
 | 
			
		||||
        model = app.state.MODELS[body.get("model")]
 | 
			
		||||
            model_id = body.get("model")
 | 
			
		||||
            model_info = Models.get_model_by_id(model_id)
 | 
			
		||||
 | 
			
		||||
        idx = model["urlIdx"]
 | 
			
		||||
            if model_info:
 | 
			
		||||
                print(model_info)
 | 
			
		||||
                if model_info.base_model_id:
 | 
			
		||||
                    payload["model"] = model_info.base_model_id
 | 
			
		||||
 | 
			
		||||
        if "pipeline" in model and model.get("pipeline"):
 | 
			
		||||
            body["user"] = {"name": user.name, "id": user.id}
 | 
			
		||||
            body["title"] = (
 | 
			
		||||
                True if body["stream"] == False and body["max_tokens"] == 50 else False
 | 
			
		||||
            )
 | 
			
		||||
                model_info.params = model_info.params.model_dump()
 | 
			
		||||
 | 
			
		||||
        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
 | 
			
		||||
        # This is a workaround until OpenAI fixes the issue with this model
 | 
			
		||||
        if body.get("model") == "gpt-4-vision-preview":
 | 
			
		||||
            if "max_tokens" not in body:
 | 
			
		||||
                body["max_tokens"] = 4000
 | 
			
		||||
            log.debug("Modified body_dict:", body)
 | 
			
		||||
                if model_info.params:
 | 
			
		||||
                    payload["temperature"] = model_info.params.get("temperature", None)
 | 
			
		||||
                    payload["top_p"] = model_info.params.get("top_p", None)
 | 
			
		||||
                    payload["max_tokens"] = model_info.params.get("max_tokens", None)
 | 
			
		||||
                    payload["frequency_penalty"] = model_info.params.get(
 | 
			
		||||
                        "frequency_penalty", None
 | 
			
		||||
                    )
 | 
			
		||||
                    payload["seed"] = model_info.params.get("seed", None)
 | 
			
		||||
                    # TODO: add "stop" back in
 | 
			
		||||
                    # payload["stop"] = model_info.params.get("stop", None)
 | 
			
		||||
 | 
			
		||||
        # Fix for ChatGPT calls failing because the num_ctx key is in body
 | 
			
		||||
        if "num_ctx" in body:
 | 
			
		||||
            # If 'num_ctx' is in the dictionary, delete it
 | 
			
		||||
            # Leaving it there generates an error with the
 | 
			
		||||
            # OpenAI API (Feb 2024)
 | 
			
		||||
            del body["num_ctx"]
 | 
			
		||||
                if model_info.params.get("system", None):
 | 
			
		||||
                    # Check if the payload already has a system message
 | 
			
		||||
                    # If not, add a system message to the payload
 | 
			
		||||
                    if payload.get("messages"):
 | 
			
		||||
                        for message in payload["messages"]:
 | 
			
		||||
                            if message.get("role") == "system":
 | 
			
		||||
                                message["content"] = (
 | 
			
		||||
                                    model_info.params.get("system", None)
 | 
			
		||||
                                    + message["content"]
 | 
			
		||||
                                )
 | 
			
		||||
                                break
 | 
			
		||||
                        else:
 | 
			
		||||
                            payload["messages"].insert(
 | 
			
		||||
                                0,
 | 
			
		||||
                                {
 | 
			
		||||
                                    "role": "system",
 | 
			
		||||
                                    "content": model_info.params.get("system", None),
 | 
			
		||||
                                },
 | 
			
		||||
                            )
 | 
			
		||||
            else:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
            print(app.state.MODELS)
 | 
			
		||||
            model = app.state.MODELS[payload.get("model")]
 | 
			
		||||
 | 
			
		||||
            idx = model["urlIdx"]
 | 
			
		||||
 | 
			
		||||
            if "pipeline" in model and model.get("pipeline"):
 | 
			
		||||
                payload["user"] = {"name": user.name, "id": user.id}
 | 
			
		||||
                payload["title"] = (
 | 
			
		||||
                    True
 | 
			
		||||
                    if payload["stream"] == False and payload["max_tokens"] == 50
 | 
			
		||||
                    else False
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
 | 
			
		||||
            # This is a workaround until OpenAI fixes the issue with this model
 | 
			
		||||
            if payload.get("model") == "gpt-4-vision-preview":
 | 
			
		||||
                if "max_tokens" not in payload:
 | 
			
		||||
                    payload["max_tokens"] = 4000
 | 
			
		||||
                log.debug("Modified payload:", payload)
 | 
			
		||||
 | 
			
		||||
            # Convert the modified body back to JSON
 | 
			
		||||
            payload = json.dumps(payload)
 | 
			
		||||
 | 
			
		||||
        # Convert the modified body back to JSON
 | 
			
		||||
        body = json.dumps(body)
 | 
			
		||||
    except json.JSONDecodeError as e:
 | 
			
		||||
        log.error("Error loading request body into a dictionary:", e)
 | 
			
		||||
 | 
			
		||||
    print(payload)
 | 
			
		||||
 | 
			
		||||
    url = app.state.config.OPENAI_API_BASE_URLS[idx]
 | 
			
		||||
    key = app.state.config.OPENAI_API_KEYS[idx]
 | 
			
		||||
 | 
			
		||||
@ -368,7 +414,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
 | 
			
		||||
        r = requests.request(
 | 
			
		||||
            method=request.method,
 | 
			
		||||
            url=target_url,
 | 
			
		||||
            data=body,
 | 
			
		||||
            data=payload if payload else body,
 | 
			
		||||
            headers=headers,
 | 
			
		||||
            stream=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user