Add input parameters check to multiple APIs

This commit is contained in:
allegroai 2024-02-13 16:15:55 +02:00
parent 702b6dc9c8
commit a47e65d974
9 changed files with 225 additions and 138 deletions

View File

@ -13,6 +13,14 @@ from apiserver.config_repo import config
from apiserver.utilities.stringenum import StringEnum
class TaskRequest(Base):
task: str = StringField(required=True)
class ModelRequest(Base):
model: str = StringField(required=True)
class HistogramRequestBase(Base):
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
@ -29,6 +37,11 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
model_events: bool = BoolField(default=False)
class GetMetricsAndVariantsRequest(Base):
task: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(
items_types=str,
@ -51,6 +64,12 @@ class TaskMetric(Base):
variants: Sequence[str] = ListField(items_types=str)
class LegacyMetricEventsRequest(TaskRequest):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class MetricEventsRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
@ -59,7 +78,14 @@ class MetricEventsRequest(Base):
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField()
model_events: bool = BoolField(default=False)
class VectorMetricsIterHistogramRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
variant: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class GetVariantSampleRequest(Base):
@ -110,6 +136,11 @@ class TaskEventsRequest(TaskEventsRequestBase):
model_events: bool = BoolField(default=False)
class LegacyLogEventsRequest(TaskEventsRequestBase):
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc)
scroll_id: str = StringField()
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
@ -160,6 +191,11 @@ class MultiTaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
class LegacyMultiTaskEventsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
class MultiTaskPlotsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1)
scroll_id: str = StringField()
@ -177,6 +213,14 @@ class TaskPlotsRequest(Base):
model_events: bool = BoolField(default=False)
class GetScalarMetricDataRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)
class ClearScrollRequest(Base):
scroll_id: str = StringField()

View File

@ -42,6 +42,21 @@ class ModelRequest(models.Base):
model = fields.StringField(required=True)
class TaskRequest(models.Base):
task = fields.StringField(required=True)
class UpdateForTaskRequest(TaskRequest):
uri = fields.StringField()
iteration = fields.IntField()
override_model_id = fields.StringField()
class UpdateModelRequest(ModelRequest):
task = fields.StringField()
iteration = fields.IntField()
class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)

View File

@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base):
enabled = BoolField(default=None, nullable=True)
class GetConfigRequest(Base):
path = StringField()
class ReportStatsOptionResponse(Base):
supported = BoolField(default=True)
enabled = BoolField()

View File

@ -4,6 +4,10 @@ from jsonmodels.models import Base
from apiserver.apimodels import DictField
class UserRequest(Base):
user = StringField(required=True)
class CreateRequest(Base):
id = StringField(required=True)
name = StringField(required=True)

View File

@ -32,6 +32,14 @@ from apiserver.apimodels.events import (
TaskMetric,
MultiTaskPlotsRequest,
MultiTaskMetricsRequest,
LegacyLogEventsRequest,
TaskRequest,
GetMetricsAndVariantsRequest,
ModelRequest,
LegacyMetricEventsRequest,
GetScalarMetricDataRequest,
VectorMetricsIterHistogramRequest,
LegacyMultiTaskEventsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
@ -97,15 +105,15 @@ def add_batch(call: APICall, company_id, _):
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log_v1_5(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_log")
def get_task_log_v1_5(call, company_id, request: LegacyLogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
order = request.order
scroll_id = request.scroll_id
batch_size = request.batch_size
events, scroll_id, total_events = event_bll.scroll_task_events(
task.get_index_company(),
task_id,
@ -119,17 +127,17 @@ def get_task_log_v1_5(call, company_id, _):
)
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
def get_task_log_v1_7(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_log", min_version="1.7")
def get_task_log_v1_7(call, company_id, request: LegacyLogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
order = request.order
from_ = call.data.get("from") or "head"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
scroll_id = request.scroll_id
batch_size = request.batch_size
scroll_order = "asc" if (from_ == "head") else "desc"
@ -177,9 +185,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
)
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.download_task_log")
def download_task_log(call, company_id, request: TaskRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
@ -257,10 +265,12 @@ def download_task_log(call, company_id, _):
call.result.raw_data = generate()
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
@endpoint("events.get_vector_metrics_and_variants")
def get_vector_metrics_and_variants(
call, company_id, request: GetMetricsAndVariantsRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
task_id,
@ -273,10 +283,12 @@ def get_vector_metrics_and_variants(call, company_id, _):
)
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
@endpoint("events.get_scalar_metrics_and_variants")
def get_scalar_metrics_and_variants(
call, company_id, request: GetMetricsAndVariantsRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
task_id,
@ -292,18 +304,19 @@ def get_scalar_metrics_and_variants(call, company_id, _):
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
@endpoint(
"events.vector_metrics_iter_histogram",
required_fields=["task", "metric", "variant"],
)
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
def vector_metrics_iter_histogram(
call, company_id, request: VectorMetricsIterHistogramRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
task_id,
model_events=model_events,
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
metric = request.metric
variant = request.variant
iterations, vectors = event_bll.get_vector_metrics_per_iter(
task_or_model.get_index_company(), task_id, metric, variant
)
@ -404,13 +417,13 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
)
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
@endpoint("events.get_scalar_metric_data")
def get_scalar_metric_data(call, company_id, request: GetScalarMetricDataRequest):
task_id = request.task
metric = request.metric
scroll_id = request.scroll_id
no_scroll = request.no_scroll
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
@ -435,9 +448,9 @@ def get_scalar_metric_data(call, company_id, _):
)
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
def get_task_latest_scalar_values(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_latest_scalar_values")
def get_task_latest_scalar_values(call, company_id, request: TaskRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
@ -558,11 +571,11 @@ def get_task_single_value_metrics(
)
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
@endpoint("events.get_multi_task_plots")
def get_multi_task_plots_v1_7(call, company_id, request: LegacyMultiTaskEventsRequest):
task_ids = request.tasks
iters = request.iters
scroll_id = request.scroll_id
companies = _get_task_or_model_index_companies(company_id, task_ids)
@ -644,11 +657,11 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
)
@endpoint("events.get_task_plots", required_fields=["task"])
def get_task_plots_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
@endpoint("events.get_task_plots")
def get_task_plots_v1_7(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@ -766,11 +779,11 @@ def task_plots(call, company_id, request: MetricEventsRequest):
)
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
@endpoint("events.debug_images")
def get_debug_images_v1_7(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@ -803,12 +816,12 @@ def get_debug_images_v1_7(call, company_id, _):
)
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
model_events = call.data.get("model_events", False)
@endpoint("events.debug_images", min_version="1.8")
def get_debug_images_v1_8(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
model_events = request.model_events
tasks_or_model = _assert_task_or_model_exists(
company_id,
@ -975,8 +988,7 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
return {"metrics": []}
metrics = event_bll.metrics.get_multi_task_metrics(
companies=companies,
event_type=request.event_type
companies=companies, event_type=request.event_type
)
res = [
{
@ -985,14 +997,12 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
}
for m, vars_ in metrics.items()
]
call.result.data = {
"metrics": sorted(res, key=itemgetter("metric"))
}
call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.delete_for_task")
def delete_for_task(call, company_id, request: TaskRequest):
task_id = request.task
allow_locked = call.data.get("allow_locked", False)
get_task_with_write_access(
@ -1005,9 +1015,9 @@ def delete_for_task(call, company_id, _):
)
@endpoint("events.delete_for_model", required_fields=["model"])
def delete_for_model(call: APICall, company_id: str, _):
model_id = call.data["model"]
@endpoint("events.delete_for_model")
def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
model_id = request.model
allow_locked = call.data.get("allow_locked", False)
model_bll.assert_exists(company_id, model_id, return_models=False)

View File

@ -21,6 +21,10 @@ from apiserver.apimodels.models import (
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
ModelsGetRequest,
ModelRequest,
TaskRequest,
UpdateForTaskRequest,
UpdateModelRequest,
)
from apiserver.apimodels.tasks import UpdateTagsRequest
from apiserver.bll.model import ModelBLL, Metadata
@ -67,9 +71,9 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
unescape_metadata(call, model_data)
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
@endpoint("models.get_by_id")
def get_by_id(call: APICall, company_id, request: ModelRequest):
model_id = request.model
call_data = Metadata.escape_query_parameters(call.data)
models = Model.get_many(
company=company_id,
@ -87,12 +91,12 @@ def get_by_id(call: APICall, company_id, _):
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call: APICall, company_id, _):
@endpoint("models.get_by_task_id")
def get_by_task_id(call: APICall, company_id, request: TaskRequest):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
task_id = call.data["task"]
task_id = request.task
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
@ -157,7 +161,7 @@ def get_by_id_ex(call: APICall, company_id, _):
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
@endpoint("models.get_all")
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = Metadata.escape_query_parameters(call.data)
@ -236,15 +240,15 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
)
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call: APICall, company_id, _):
@endpoint("models.update_for_task")
def update_for_task(call: APICall, company_id, request: UpdateForTaskRequest):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
override_model_id = call.data.get("override_model_id")
task_id = request.task
uri = request.uri
iteration = request.iteration
override_model_id = request.override_model_id
if not (uri or override_model_id) or (uri and override_model_id):
raise errors.bad_request.MissingRequiredFields(
"exactly one field is required", fields=("uri", "override_model_id")
@ -411,9 +415,9 @@ def validate_task(company_id: str, identity: Identity, fields: dict):
)
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
@endpoint("models.edit", response_data_model=UpdateResponse)
def edit(call: APICall, company_id, request: UpdateModelRequest):
model_id = request.model
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
@ -428,7 +432,7 @@ def edit(call: APICall, company_id, _):
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
iteration = request.iteration
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
@ -460,13 +464,9 @@ def edit(call: APICall, company_id, _):
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
def _update_model(call: APICall, company_id, model_id):
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
@ -502,11 +502,9 @@ def _update_model(call: APICall, company_id, model_id=None):
return UpdateResponse(updated=updated_count, fields=updated_fields)
@endpoint(
"models.update", required_fields=["model"], response_data_model=UpdateResponse
)
def update(call, company_id, _):
call.result.data_model = _update_model(call, company_id)
@endpoint("models.update", response_data_model=UpdateResponse)
def update(call, company_id, request: UpdateModelRequest):
call.result.data_model = _update_model(call, company_id, model_id=request.model)
@endpoint(
@ -629,7 +627,9 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user),
func=partial(
ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user
),
ids=request.ids,
)
call.result.data_model = BatchResponse(

View File

@ -59,13 +59,12 @@ create_fields = {
}
@endpoint("projects.get_by_id", required_fields=["project"])
def get_by_id(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
@endpoint("projects.get_by_id")
def get_by_id(call: APICall, company: str, request: ProjectRequest):
project_id = request.project
with translate_errors_context():
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
query = Q(id=project_id) & get_company_or_none_constraint(company)
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
@ -147,8 +146,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
data,
shallow_search=request.shallow_search,
)
selected_project_ids = None
if request.active_users or request.children_type:
@ -246,7 +247,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
if request.include_dataset_stats:
dataset_stats = project_bll.get_dataset_stats(
company=company_id, project_ids=project_ids, users=request.active_users,
company=company_id,
project_ids=project_ids,
users=request.active_users,
)
for project in projects:
project["dataset_stats"] = dataset_stats.get(project["id"])
@ -255,15 +258,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
@endpoint("projects.get_all")
def get_all(call: APICall):
def get_all(call: APICall, company: str, _):
data = call.data
conform_tag_fields(call, data)
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
data,
shallow_search=data.get("shallow_search", False),
)
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
company=company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
@ -277,9 +281,11 @@ def get_all(call: APICall):
@endpoint(
"projects.create", required_fields=["name"], response_data_model=IdResponse,
"projects.create",
required_fields=["name"],
response_data_model=IdResponse,
)
def create(call: APICall):
def create(call: APICall, company: str, _):
identity = call.identity
with translate_errors_context():
@ -288,15 +294,15 @@ def create(call: APICall):
return IdResponse(
id=ProjectBLL.create(
user=identity.user, company=identity.company, **fields,
user=identity.user,
company=company,
**fields,
)
)
@endpoint(
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
)
def update(call: APICall):
@endpoint("projects.update", response_data_model=UpdateResponse)
def update(call: APICall, company: str, request: ProjectRequest):
"""
update
@ -309,9 +315,7 @@ def update(call: APICall):
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields, validate=True)
updated = ProjectBLL.update(
company=call.identity.company, project_id=call.data["project"], **fields
)
updated = ProjectBLL.update(company=company, project_id=request.project, **fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@ -375,7 +379,6 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
def get_unique_metric_variants(
call: APICall, company_id: str, request: GetUniqueMetricsRequest
):
metrics = project_queries.get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
@ -429,7 +432,6 @@ def get_model_metadata_values(
request_data_model=GetParamsRequest,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,

View File

@ -3,7 +3,11 @@ from datetime import datetime
from pyhocon.config_tree import NoneValue
from apiserver.apierrors import errors
from apiserver.apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
from apiserver.apimodels.server import (
ReportStatsOptionRequest,
ReportStatsOptionResponse,
GetConfigRequest,
)
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
from apiserver.config_repo import config
from apiserver.config.info import get_version, get_build_number, get_commit_number
@ -22,8 +26,8 @@ def get_stats(call: APICall):
@endpoint("server.config")
def get_config(call: APICall):
path = call.data.get("path")
def get_config(call: APICall, _, request: GetConfigRequest):
path = request.path
if path:
c = dict(config.get(path))
else:

View File

@ -7,7 +7,11 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.users import CreateRequest, SetPreferencesRequest
from apiserver.apimodels.users import (
CreateRequest,
SetPreferencesRequest,
UserRequest,
)
from apiserver.bll.project import ProjectBLL
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
@ -48,13 +52,13 @@ def get_user(call, company_id, user_id, only=None):
return res.to_proper_dict()
@endpoint("users.get_by_id", required_fields=["user"])
def get_by_id(call: APICall, company_id, _):
user_id = call.data["user"]
@endpoint("users.get_by_id")
def get_by_id(call: APICall, company_id, request: UserRequest):
user_id = request.user
call.result.data = {"user": get_user(call, company_id, user_id)}
@endpoint("users.get_all_ex", required_fields=[])
@endpoint("users.get_all_ex")
def get_all_ex(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many_with_join(company=company_id, query_dict=call.data)
@ -62,7 +66,7 @@ def get_all_ex(call: APICall, company_id, _):
call.result.data = {"users": res}
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
@endpoint("users.get_all_ex", min_version="2.8")
def get_all_ex2_8(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
data = call.data
@ -83,7 +87,7 @@ def get_all_ex2_8(call: APICall, company_id, _):
call.result.data = {"users": res}
@endpoint("users.get_all", required_fields=[])
@endpoint("users.get_all")
def get_all(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many(
@ -138,9 +142,9 @@ def create(call: APICall):
UserBLL.create(call.data_model)
@endpoint("users.delete", required_fields=["user"])
def delete(call: APICall):
UserBLL.delete(call.data["user"])
@endpoint("users.delete")
def delete(_: APICall, __, request: UserRequest):
UserBLL.delete(request.user)
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
@ -159,9 +163,9 @@ def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
return User.safe_update(company_id, user_id, partial_update_dict)
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
def update(call, company_id, _):
user_id = call.data["user"]
@endpoint("users.update", response_data_model=UpdateResponse)
def update(call, company_id, request: UserRequest):
user_id = request.user
update_count, updated_fields = update_user(user_id, company_id, call.data)
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)