mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 02:46:53 +00:00
123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
|
import logging
|
||
|
from argparse import (
|
||
|
ArgumentDefaultsHelpFormatter,
|
||
|
ArgumentParser,
|
||
|
ArgumentTypeError,
|
||
|
)
|
||
|
|
||
|
from pymongo import MongoClient
|
||
|
from pymongo.collection import Collection
|
||
|
from pymongo.database import Database
|
||
|
|
||
|
|
||
|
logging.getLogger().setLevel(logging.INFO)
|
||
|
|
||
|
|
||
|
def fix_mongo_urls(mongo_host: str, host_source: str, host_target: str):
|
||
|
logging.info(f"Connecting to Mongo on {mongo_host}")
|
||
|
client = MongoClient(host=mongo_host)
|
||
|
backend_db: Database = client.backend
|
||
|
|
||
|
def get_updated_uri(uri: str):
|
||
|
if not uri or not uri.startswith(host_source):
|
||
|
return
|
||
|
relative_url = uri[len(host_source) :]
|
||
|
return f"{host_target.rstrip('/')}/{relative_url.lstrip('/')}"
|
||
|
|
||
|
host_source = host_source
|
||
|
host_target = host_target
|
||
|
model_collection: Collection = backend_db.get_collection("model")
|
||
|
if model_collection is not None:
|
||
|
logging.info("Updating model uris")
|
||
|
models_count = model_collection.count_documents({})
|
||
|
updated_models = 0
|
||
|
for model in model_collection.find(
|
||
|
{"uri": {"$regex": "^{}".format(host_source)}}, projection=["uri"]
|
||
|
):
|
||
|
updated_uri = get_updated_uri(model.get("uri"))
|
||
|
if updated_uri:
|
||
|
result = model_collection.update_one(
|
||
|
{"_id": model["_id"]}, {"$set": {"uri": updated_uri}}
|
||
|
)
|
||
|
updated_models += result.modified_count
|
||
|
|
||
|
logging.info(f"Updated {updated_models} models from {models_count}")
|
||
|
|
||
|
task_collection: Collection = backend_db.get_collection("task")
|
||
|
if task_collection is not None:
|
||
|
logging.info("Updating task uris")
|
||
|
tasks_count = task_collection.count_documents({})
|
||
|
updated_tasks = 0
|
||
|
for task in task_collection.find(
|
||
|
{"execution.artifacts": {"$exists": 1, "$ne": {}}},
|
||
|
projection=["execution.artifacts"],
|
||
|
):
|
||
|
artifacts = task.get("execution", {}).get("artifacts")
|
||
|
if not artifacts:
|
||
|
continue
|
||
|
|
||
|
uri_updated = False
|
||
|
for artifact in artifacts.values():
|
||
|
updated_uri = get_updated_uri(artifact.get("uri"))
|
||
|
if updated_uri:
|
||
|
artifact["uri"] = updated_uri
|
||
|
uri_updated = True
|
||
|
|
||
|
if uri_updated:
|
||
|
result = task_collection.update_one(
|
||
|
{"_id": task["_id"]}, {"$set": {"execution.artifacts": artifacts}}
|
||
|
)
|
||
|
updated_tasks += result.modified_count
|
||
|
|
||
|
logging.info(f"Updated {updated_tasks} tasks from {tasks_count}")
|
||
|
|
||
|
|
||
|
def normalise_host(host):
|
||
|
if not host.endswith("/"):
|
||
|
return host
|
||
|
return host[:-1]
|
||
|
|
||
|
|
||
|
def main():
|
||
|
def valid_url_prefix(url: str):
|
||
|
if "://" not in url:
|
||
|
raise ArgumentTypeError("url schema is missing")
|
||
|
return url
|
||
|
|
||
|
parser = ArgumentParser(
|
||
|
description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--mongo-host",
|
||
|
"-mh",
|
||
|
type=str,
|
||
|
default="mongodb://mongo:27017",
|
||
|
help="Mongo server host. The default is mongodb://mongo:27017",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--host-source",
|
||
|
"-hs",
|
||
|
type=valid_url_prefix,
|
||
|
required=True,
|
||
|
help="Source host for the files uploaded to the fileserver (in the form http://<host>:<port>)",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--host-target",
|
||
|
"-ht",
|
||
|
type=valid_url_prefix,
|
||
|
required=True,
|
||
|
help="Target host for the files uploaded to the fileserver (in the form http://<host>:<port>)",
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
fix_mongo_urls(
|
||
|
mongo_host=args.mongo_host,
|
||
|
host_source=args.host_source,
|
||
|
host_target=args.host_target,
|
||
|
)
|
||
|
logging.info("Completed successfully")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|