Add internal script to fix fileserver URLs in mongodb

This commit is contained in:
clearml 2024-12-05 22:30:03 +02:00
parent 83dbf0fcb8
commit 17fcaba2cb

122
apiserver/fix_mongo_urls.py Normal file
View File

@ -0,0 +1,122 @@
import logging
from argparse import (
ArgumentDefaultsHelpFormatter,
ArgumentParser,
ArgumentTypeError,
)
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
logging.getLogger().setLevel(logging.INFO)
def fix_mongo_urls(mongo_host: str, host_source: str, host_target: str):
logging.info(f"Connecting to Mongo on {mongo_host}")
client = MongoClient(host=mongo_host)
backend_db: Database = client.backend
def get_updated_uri(uri: str):
if not uri or not uri.startswith(host_source):
return
relative_url = uri[len(host_source) :]
return f"{host_target.rstrip('/')}/{relative_url.lstrip('/')}"
host_source = host_source
host_target = host_target
model_collection: Collection = backend_db.get_collection("model")
if model_collection is not None:
logging.info("Updating model uris")
models_count = model_collection.count_documents({})
updated_models = 0
for model in model_collection.find(
{"uri": {"$regex": "^{}".format(host_source)}}, projection=["uri"]
):
updated_uri = get_updated_uri(model.get("uri"))
if updated_uri:
result = model_collection.update_one(
{"_id": model["_id"]}, {"$set": {"uri": updated_uri}}
)
updated_models += result.modified_count
logging.info(f"Updated {updated_models} models from {models_count}")
task_collection: Collection = backend_db.get_collection("task")
if task_collection is not None:
logging.info("Updating task uris")
tasks_count = task_collection.count_documents({})
updated_tasks = 0
for task in task_collection.find(
{"execution.artifacts": {"$exists": 1, "$ne": {}}},
projection=["execution.artifacts"],
):
artifacts = task.get("execution", {}).get("artifacts")
if not artifacts:
continue
uri_updated = False
for artifact in artifacts.values():
updated_uri = get_updated_uri(artifact.get("uri"))
if updated_uri:
artifact["uri"] = updated_uri
uri_updated = True
if uri_updated:
result = task_collection.update_one(
{"_id": task["_id"]}, {"$set": {"execution.artifacts": artifacts}}
)
updated_tasks += result.modified_count
logging.info(f"Updated {updated_tasks} tasks from {tasks_count}")
def normalise_host(host):
if not host.endswith("/"):
return host
return host[:-1]
def main():
def valid_url_prefix(url: str):
if "://" not in url:
raise ArgumentTypeError("url schema is missing")
return url
parser = ArgumentParser(
description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--mongo-host",
"-mh",
type=str,
default="mongodb://mongo:27017",
help="Mongo server host. The default is mongodb://mongo:27017",
)
parser.add_argument(
"--host-source",
"-hs",
type=valid_url_prefix,
required=True,
help="Source host for the files uploaded to the fileserver (in the form http://<host>:<port>)",
)
parser.add_argument(
"--host-target",
"-ht",
type=valid_url_prefix,
required=True,
help="Target host for the files uploaded to the fileserver (in the form http://<host>:<port>)",
)
args = parser.parse_args()
fix_mongo_urls(
mongo_host=args.mongo_host,
host_source=args.host_source,
host_target=args.host_target,
)
logging.info("Completed successfully")
if __name__ == "__main__":
main()