mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
612 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b6f24b24d | ||
|
|
d03a931d84 | ||
|
|
5cc7199661 | ||
|
|
6537e9ef69 | ||
|
|
930aaff791 | ||
|
|
1999fb2479 | ||
|
|
9db14cc31d | ||
|
|
e3cc689528 | ||
|
|
9e0adc77dd | ||
|
|
58d9a64537 | ||
|
|
d397d2ae20 | ||
|
|
2d711e1500 | ||
|
|
97992b0d9e | ||
|
|
bc23f1b0cf | ||
|
|
6b3eff1426 | ||
|
|
caaf801cd0 | ||
|
|
c23e8a90d0 | ||
|
|
fa5b28ca0e | ||
|
|
bfb55a9463 | ||
|
|
37e485e1f2 | ||
|
|
3451ff441f | ||
|
|
53c9b5525e | ||
|
|
e5230edac3 | ||
|
|
a54dd8030c | ||
|
|
482a5c34bc | ||
|
|
ee2a72c70f | ||
|
|
a0d8aaf3b9 | ||
|
|
de1f823213 | ||
|
|
0c9e2f92ee | ||
|
|
6c49e96ff0 | ||
|
|
81e3fc6577 | ||
|
|
e6dc4b7557 | ||
|
|
238a47a197 | ||
|
|
04e7076628 | ||
|
|
0531612bf4 | ||
|
|
3ae410a1e9 | ||
|
|
98ed3075dd | ||
|
|
b871bf4224 | ||
|
|
8d4c02fc3c | ||
|
|
b986980c75 | ||
|
|
a4fa567be2 | ||
|
|
ddb91f226a | ||
|
|
7772f47773 | ||
|
|
9c118d14e0 | ||
|
|
efd56e085e | ||
|
|
4dff163af4 | ||
|
|
242a78a0fe | ||
|
|
78989fea91 | ||
|
|
5de7c12062 | ||
|
|
3f79c19079 | ||
|
|
fe29743c54 | ||
|
|
d760cf5835 | ||
|
|
3695f25a5f | ||
|
|
c6f1beafdd | ||
|
|
68a54c34f3 | ||
|
|
ab495ae586 | ||
|
|
b058770af1 | ||
|
|
f7e833bf6f | ||
|
|
36b9ab0453 | ||
|
|
ec0436d0da | ||
|
|
0f6c4e75b7 | ||
|
|
a41ae112a1 | ||
|
|
c28f478ea8 | ||
|
|
c18eb99d06 | ||
|
|
3a60f00d93 | ||
|
|
ee87778548 | ||
|
|
52c0c4d438 | ||
|
|
d117a4f022 | ||
|
|
6683d2d7a9 | ||
|
|
05357fe25e | ||
|
|
adc1825843 | ||
|
|
0c15169668 | ||
|
|
123dc1dcfb | ||
|
|
b2feafac09 | ||
|
|
b41ab8c550 | ||
|
|
62d5779bd5 | ||
|
|
f8b9d9802e | ||
|
|
dd8a1503b0 | ||
|
|
cff98ae900 | ||
|
|
9b108740da | ||
|
|
08a7bc7c9f | ||
|
|
fb256d7e5b | ||
|
|
710443b078 | ||
|
|
e0cde2f7c9 | ||
|
|
60b9c8de14 | ||
|
|
ecffe26be4 | ||
|
|
2570bd9e26 | ||
|
|
174f84514a | ||
|
|
65cb8d7b43 | ||
|
|
5f8ef808a3 | ||
|
|
4941ac70e0 | ||
|
|
67cd461145 | ||
|
|
92b5fc6f9a | ||
|
|
b90165b4e4 | ||
|
|
6c2dcb5c8a | ||
|
|
3efed32934 | ||
|
|
69737308fe | ||
|
|
a6dbea808a | ||
|
|
5131b17901 | ||
|
|
5f21c3a56d | ||
|
|
2350ac64ed | ||
|
|
d146127c18 | ||
|
|
abd65e103e | ||
|
|
bf65ea7bd0 | ||
|
|
73e278a8ed | ||
|
|
d92dfbbdb7 | ||
|
|
5c1e419eb5 | ||
|
|
124684f53f | ||
|
|
455b5d6758 | ||
|
|
c04e2e498b | ||
|
|
da8a45072f | ||
|
|
e1992e2054 | ||
|
|
c17cedd93a | ||
|
|
b6ad8f8790 | ||
|
|
5acc7eebc3 | ||
|
|
941927dfcd | ||
|
|
02933a9c93 | ||
|
|
e537651f29 | ||
|
|
af09fba755 | ||
|
|
04ea9018a3 | ||
|
|
ff7e1be24f | ||
|
|
fc4fd9e61c | ||
|
|
8908c7dcf9 | ||
|
|
b9996e2c1a | ||
|
|
afdc56f37c | ||
|
|
a25cd5dae8 | ||
|
|
447adb9090 | ||
|
|
92fd98d5ad | ||
|
|
c4001b4037 | ||
|
|
970a32287a | ||
|
|
17cd48dada | ||
|
|
ea3b6e955f | ||
|
|
843450bb9b | ||
|
|
e149af58b1 | ||
|
|
604a38035b | ||
|
|
cae38a365b | ||
|
|
e334246b46 | ||
|
|
36e013b40c | ||
|
|
f20cd6536e | ||
|
|
446bd35006 | ||
|
|
a377a7e315 | ||
|
|
3d046ac282 | ||
|
|
a08fa9a0e1 | ||
|
|
5856ed2836 | ||
|
|
d295355d99 | ||
|
|
77350f6119 | ||
|
|
bc2c2ebbfd | ||
|
|
1502e02a1a | ||
|
|
d0e2313a24 | ||
|
|
d8ba1a8ea7 | ||
|
|
ca7937fc4e | ||
|
|
df89bcceef | ||
|
|
cfccbe05c1 | ||
|
|
e352a6a1e7 | ||
|
|
8a3d992aaf | ||
|
|
c37f3d8d5b | ||
|
|
a96870e092 | ||
|
|
6bf1032237 | ||
|
|
3d816c747d | ||
|
|
3f2b96266b | ||
|
|
22b16d12eb | ||
|
|
c55b6f30df | ||
|
|
b7045d3d28 | ||
|
|
e31a404885 | ||
|
|
643588b71a | ||
|
|
a64c4d264d | ||
|
|
567780e188 | ||
|
|
1bc8529d83 | ||
|
|
6b480d7e87 | ||
|
|
083fd315e9 | ||
|
|
ef20e76174 | ||
|
|
8c8910808e | ||
|
|
f6ad379310 | ||
|
|
c5d6ce3e65 | ||
|
|
694dbc31c4 | ||
|
|
6488dc54e6 | ||
|
|
158da9b480 | ||
|
|
ec2e071ab7 | ||
|
|
465e270342 | ||
|
|
6705aff56f | ||
|
|
9069cfe1da | ||
|
|
677bb3ba6d | ||
|
|
cb253cff9e | ||
|
|
39ceb5ac5c | ||
|
|
d4edeaaf1b | ||
|
|
56aea1ffb8 | ||
|
|
09ab2af34c | ||
|
|
8bb26a6b0b | ||
|
|
3f2304549d | ||
|
|
ad72a435f1 | ||
|
|
f34332344e | ||
|
|
d324b57dd7 | ||
|
|
2216bfe875 | ||
|
|
9beefa7473 | ||
|
|
8ebc334889 | ||
|
|
e662c850af | ||
|
|
1e5163e530 | ||
|
|
1567774765 | ||
|
|
babfcbb707 | ||
|
|
027edd86bb | ||
|
|
cc83aadae6 | ||
|
|
8c18660a82 | ||
|
|
4fe61ee25c | ||
|
|
e18b21639c | ||
|
|
1cef03b8c2 | ||
|
|
d60d6dfe99 | ||
|
|
27d086bca2 | ||
|
|
add3f011a0 | ||
|
|
ee90b0b024 | ||
|
|
9bf107866f | ||
|
|
4d2f282950 | ||
|
|
b55fad1b59 | ||
|
|
ba77ff11e9 | ||
|
|
b67aa05d6f | ||
|
|
6b0c45a861 | ||
|
|
dc9623e964 | ||
|
|
3d73d60826 | ||
|
|
9f0c9c3690 | ||
|
|
1a3d3494ce | ||
|
|
b99f620073 | ||
|
|
e2f265b4bc | ||
|
|
251ee57ffd | ||
|
|
7e03104f1c | ||
|
|
f1a258208e | ||
|
|
66cc49313b | ||
|
|
9ae2943f7d | ||
|
|
54326f707b | ||
|
|
3a3b57c15f | ||
|
|
8ea8ad34e6 | ||
|
|
179661a0d4 | ||
|
|
3d22ca1888 | ||
|
|
fdf6798d0c | ||
|
|
9d9a44b927 | ||
|
|
dad935e81d | ||
|
|
a75534ec34 | ||
|
|
eab33de97e | ||
|
|
29de110abb | ||
|
|
2e7f418ee2 | ||
|
|
dadb996d22 | ||
|
|
174f692edf | ||
|
|
f4d5168a20 | ||
|
|
5a438e8435 | ||
|
|
ce4814dc47 | ||
|
|
ef42d0265d | ||
|
|
3c5195028e | ||
|
|
0d5174c453 | ||
|
|
c034c1a986 | ||
|
|
1b49da8748 | ||
|
|
26bda01a28 | ||
|
|
f5008d80ad | ||
|
|
8b464e7ae6 | ||
|
|
78e4a58c91 | ||
|
|
7a4a5eb03e | ||
|
|
d029d56508 | ||
|
|
6411954002 | ||
|
|
7f4ad0d1ca | ||
|
|
4cd4b2914d | ||
|
|
1d55710a0b | ||
|
|
8f646043bb | ||
|
|
4b11a6efcd | ||
|
|
cb3a7c90a8 | ||
|
|
074842a122 | ||
|
|
749ff4a44f | ||
|
|
7d6918ecb0 | ||
|
|
47184c2833 | ||
|
|
6434f1028e | ||
|
|
daade08940 | ||
|
|
a1d289822f | ||
|
|
1ce34f2c74 | ||
|
|
c2dc73a71f | ||
|
|
07bb3b5df8 | ||
|
|
067ef82576 | ||
|
|
59fc98e0c4 | ||
|
|
a936a210e8 | ||
|
|
be0cf0caa8 | ||
|
|
a8d90887e2 | ||
|
|
6f3257fed3 | ||
|
|
4bb8834551 | ||
|
|
286b8c3df5 | ||
|
|
16430a6636 | ||
|
|
d7ddfde26e | ||
|
|
e6c0f1b6d8 | ||
|
|
641ed1b510 | ||
|
|
e29ad4c9b2 | ||
|
|
3473d2bb02 | ||
|
|
ba03924cb4 | ||
|
|
6870d8aba9 | ||
|
|
64c63d2560 | ||
|
|
88836fae66 | ||
|
|
436883148b | ||
|
|
f9f2f0ccf0 | ||
|
|
f879f6924f | ||
|
|
b9cb587580 | ||
|
|
370e92c3dd | ||
|
|
03094076c8 | ||
|
|
bdf6c353bd | ||
|
|
23736efbc3 | ||
|
|
3c8e27dc94 | ||
|
|
ca890c7ae8 | ||
|
|
30909df73f | ||
|
|
b97a6084ce | ||
|
|
50438bd931 | ||
|
|
28daf49c91 | ||
|
|
4707647c92 | ||
|
|
6974aa3a99 | ||
|
|
e2deff4eef | ||
|
|
59994ccf9c | ||
|
|
29c792d459 | ||
|
|
df334d083e | ||
|
|
b548958c80 | ||
|
|
7bdf8fe30d | ||
|
|
c71c65be87 | ||
|
|
1cc6a8f787 | ||
|
|
e5b92f4a80 | ||
|
|
3272d0f31f | ||
|
|
618a0b9473 | ||
|
|
bca3a6e556 | ||
|
|
8b0afd47a6 | ||
|
|
0303c3525f | ||
|
|
563c451ac9 | ||
|
|
91b1b34a6b | ||
|
|
0ad0495733 | ||
|
|
03ae90c4a6 | ||
|
|
be788965e0 | ||
|
|
d198138c5b | ||
|
|
cf441987af | ||
|
|
b89de43373 | ||
|
|
0ef018c931 | ||
|
|
323b5db07c | ||
|
|
f084f6b9e7 | ||
|
|
eb4c9f0b13 | ||
|
|
018582ff8a | ||
|
|
7dcc0f6df2 | ||
|
|
5e0893dd80 | ||
|
|
ca81922651 | ||
|
|
07cc2fb08b | ||
|
|
842654d3fe | ||
|
|
00e5e2a0b1 | ||
|
|
37e5d8a7e0 | ||
|
|
5b1f468957 | ||
|
|
9103bf7984 | ||
|
|
e848d05677 | ||
|
|
1c7de3a86e | ||
|
|
e12fd8f3df | ||
|
|
29ef134b79 | ||
|
|
e24389fda9 | ||
|
|
f4ead86449 | ||
|
|
171969c5ea | ||
|
|
89f81bfe5a | ||
|
|
b8e62f27e2 | ||
|
|
c7bbac73d0 | ||
|
|
f832ea565a | ||
|
|
22e9c2b7eb | ||
|
|
c67a56eb8d | ||
|
|
df65e1c7ad | ||
|
|
01115c1223 | ||
|
|
6de88c3b93 | ||
|
|
9d77827252 | ||
|
|
76fb97624d | ||
|
|
20d6582f51 | ||
|
|
7ebda33793 | ||
|
|
953124aa37 | ||
|
|
ba3451ce5a | ||
|
|
b93591ec32 | ||
|
|
0abfd8da0d | ||
|
|
a9cc4e36c6 | ||
|
|
fe1c963eec | ||
|
|
111d80e88d | ||
|
|
6718862dbe | ||
|
|
0fe1bf8a61 | ||
|
|
10f326eda9 | ||
|
|
cd0d6c1a3d | ||
|
|
3205f2df97 | ||
|
|
5bdbcfcd8d | ||
|
|
a2e2052b30 | ||
|
|
0146ded4f4 | ||
|
|
dccf9dd8f8 | ||
|
|
7816b402bb | ||
|
|
cd4ce30f7c | ||
|
|
8c7e230898 | ||
|
|
42ba696518 | ||
|
|
3f84e60a1f | ||
|
|
baba8b5b73 | ||
|
|
77397c4f21 | ||
|
|
8678091d8f | ||
|
|
aa22170ab4 | ||
|
|
901ec37290 | ||
|
|
21f2ea8b17 | ||
|
|
8219e3d4e2 | ||
|
|
3ed71a61d5 | ||
|
|
18a88a8e8f | ||
|
|
318a72987c | ||
|
|
5ce202cc99 | ||
|
|
d09528bc26 | ||
|
|
42d2a41dbe | ||
|
|
82be1840b0 | ||
|
|
27352c5cb6 | ||
|
|
1ea6408d41 | ||
|
|
5e095af3aa | ||
|
|
ab3dceed92 | ||
|
|
3bf5126d84 | ||
|
|
ab2ab7b23a | ||
|
|
c9184d125b | ||
|
|
0c0fdb72b9 | ||
|
|
86378053d4 | ||
|
|
b1cbba0cf1 | ||
|
|
f31526042d | ||
|
|
3f8d5bc346 | ||
|
|
11d76e7d8c | ||
|
|
e76c0fbc63 | ||
|
|
fdc9956da3 | ||
|
|
f4addaa653 | ||
|
|
667964cc82 | ||
|
|
e1309e30b7 | ||
|
|
9403942ef7 | ||
|
|
84a75d9e70 | ||
|
|
c85ab66ae6 | ||
|
|
bf7f0f646b | ||
|
|
dcdf2a3d58 | ||
|
|
f8d8fc40a6 | ||
|
|
45d434a123 | ||
|
|
1834abe5bc | ||
|
|
d6321588f3 | ||
|
|
c17b10ff1d | ||
|
|
b125a56f86 | ||
|
|
c43ce3a17b | ||
|
|
b0b09616a8 | ||
|
|
ede5586ccc | ||
|
|
a1dcdffa53 | ||
|
|
35a11db58e | ||
|
|
d9bdebefc7 | ||
|
|
f29884f05a | ||
|
|
0f72d662f8 | ||
|
|
6202219034 | ||
|
|
bb3218f65d | ||
|
|
cbcaa7c789 | ||
|
|
427322a424 | ||
|
|
0e7d7d36a9 | ||
|
|
06032a6d66 | ||
|
|
b48f4eb2eb | ||
|
|
383b2666c4 | ||
|
|
50c373cf0d | ||
|
|
394a9de5fa | ||
|
|
fb5c06e9c3 | ||
|
|
1a9bbc9420 | ||
|
|
294da32401 | ||
|
|
7f00672010 | ||
|
|
99bf89a360 | ||
|
|
6c8508eb7f | ||
|
|
69714d5b5c | ||
|
|
f9516ec7d3 | ||
|
|
6fdde93dee | ||
|
|
7afc71ec91 | ||
|
|
4595117d91 | ||
|
|
8630cc1021 | ||
|
|
135885b609 | ||
|
|
eb0865662c | ||
|
|
b7b94e7ae5 | ||
|
|
72be8bee19 | ||
|
|
0722b20c1c | ||
|
|
a392a0e6ff | ||
|
|
e22fa2f478 | ||
|
|
8b49c1ac06 | ||
|
|
da1182a405 | ||
|
|
53e995ee8c | ||
|
|
4732dc1a88 | ||
|
|
e325bcaf67 | ||
|
|
a7c30453db | ||
|
|
dedac3b2fe | ||
|
|
7d10bbdf8e | ||
|
|
72213dffa4 | ||
|
|
f778837d4b | ||
|
|
153ed6a7b7 | ||
|
|
5d279c8c5a | ||
|
|
ed910d5f6a | ||
|
|
87d2b6fa15 | ||
|
|
94cfb17291 | ||
|
|
3f641d37b7 | ||
|
|
551be12f01 | ||
|
|
b536020058 | ||
|
|
fb6fbc0a06 | ||
|
|
5ae64fd791 | ||
|
|
f9776e4319 | ||
|
|
75e736e7d5 | ||
|
|
1e4756aa1d | ||
|
|
52529d3c55 | ||
|
|
53296e8891 | ||
|
|
1c87ebc900 | ||
|
|
14d9924ea0 | ||
|
|
69f9b424c7 | ||
|
|
1a6da301a8 | ||
|
|
2728b3ed14 | ||
|
|
38284eef1f | ||
|
|
9debe1adcd | ||
|
|
cc93c15f8a | ||
|
|
2c3f0e4ba3 | ||
|
|
c48eb34d8d | ||
|
|
49515e06e1 | ||
|
|
4a1d97c02f | ||
|
|
6c6c1c3f41 | ||
|
|
0ad687008c | ||
|
|
fe3dbc92dc | ||
|
|
dc53970ff0 | ||
|
|
73592b991b | ||
|
|
47b981a993 | ||
|
|
b500bcab0b | ||
|
|
59e910db1a | ||
|
|
2ecb430f02 | ||
|
|
a08722e394 | ||
|
|
67c210d9d7 | ||
|
|
101ba540f4 | ||
|
|
82fc28d477 | ||
|
|
7b73f699d2 | ||
|
|
a7e5380f67 | ||
|
|
bcade31786 | ||
|
|
6b902f85f4 | ||
|
|
6d4c974045 | ||
|
|
2346c6f3f5 | ||
|
|
82e51b4d36 | ||
|
|
e63599254e | ||
|
|
8e7e234161 | ||
|
|
17d94b26c3 | ||
|
|
1e701becd3 | ||
|
|
18c8dd449d | ||
|
|
50031c4d6d | ||
|
|
6101dc4f11 | ||
|
|
5d17059cbe | ||
|
|
b93e843143 | ||
|
|
1a732ccd8e | ||
|
|
2ea25e498f | ||
|
|
1b1cdb34ad | ||
|
|
e171a8b523 | ||
|
|
539b76d362 | ||
|
|
64b5e1f1f0 | ||
|
|
6a1eb9cea0 | ||
|
|
24907b4eaa | ||
|
|
efc540b837 | ||
|
|
96ffc89c64 | ||
|
|
4f2564d33a | ||
|
|
70ae090cc0 | ||
|
|
4f01778961 | ||
|
|
596bdd06ec | ||
|
|
6c56d0fc33 | ||
|
|
5f0213d2de | ||
|
|
15eb00a931 | ||
|
|
becc4fb6a2 | ||
|
|
32476a216a | ||
|
|
a9ba1580dc | ||
|
|
cfcd0b22a0 | ||
|
|
780355250c | ||
|
|
fd65ad38bc | ||
|
|
e29973a0b2 | ||
|
|
c259d0883e | ||
|
|
9eab017a31 | ||
|
|
68c7f307a2 | ||
|
|
0aa5694b58 | ||
|
|
639d72c5d6 | ||
|
|
70708ecdcc | ||
|
|
dacdd5e965 | ||
|
|
c199976f70 | ||
|
|
c3e2bc5ad7 | ||
|
|
f0c900c174 | ||
|
|
1bdbc44720 | ||
|
|
c6e765bd07 | ||
|
|
c037ddd044 | ||
|
|
ffe4764f20 | ||
|
|
1681fd6bf4 | ||
|
|
e55ce5536a | ||
|
|
b714952ab1 | ||
|
|
07fd8b9f2f | ||
|
|
d24f633a8e | ||
|
|
bed714890d | ||
|
|
02671910b2 | ||
|
|
1a00f29415 | ||
|
|
b7614622fc | ||
|
|
bc2cbe9a91 | ||
|
|
4daf607ff7 | ||
|
|
fd789ef20c | ||
|
|
76962667a3 | ||
|
|
a33c94e24f | ||
|
|
566b28dc4c | ||
|
|
54e3a156c1 | ||
|
|
8605186a97 | ||
|
|
61fb6553e6 | ||
|
|
76418eec1b | ||
|
|
b5cc858494 | ||
|
|
5c8519be1e | ||
|
|
18392ad2fd | ||
|
|
30c8be79b5 | ||
|
|
7c47946645 | ||
|
|
5684a7877c | ||
|
|
1568549fcc | ||
|
|
62533792b5 | ||
|
|
e9d4141460 | ||
|
|
8986f75356 | ||
|
|
1e40fc7922 | ||
|
|
eeab15e78c | ||
|
|
57714203b4 | ||
|
|
c14d201300 | ||
|
|
c70cbe04c1 | ||
|
|
c8f2b2b319 | ||
|
|
4af3c65e5d | ||
|
|
022fa7ba19 | ||
|
|
4f8cb35f9a | ||
|
|
d84b3688ca | ||
|
|
b4a20ef414 | ||
|
|
c461471942 | ||
|
|
351ddb73e7 | ||
|
|
02257fa18f | ||
|
|
8ef26e49f7 | ||
|
|
0c930f75a1 | ||
|
|
1c419ebf50 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
syntax: glob
|
||||
.idea
|
||||
apierrors/errors
|
||||
static/build.json
|
||||
@@ -11,10 +12,10 @@ test-reports
|
||||
.pytest_cache
|
||||
venv
|
||||
*.noseids
|
||||
build
|
||||
*.egg-info
|
||||
.cache
|
||||
.mypy_cache
|
||||
dist
|
||||
code.tar.gz
|
||||
server/schema/services/_cache.json
|
||||
server/apierrors/errors/*
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,7 +1,7 @@
|
||||
Server Side Public License
|
||||
VERSION 1, OCTOBER 16, 2018
|
||||
|
||||
Copyright © 2018 MongoDB, Inc.
|
||||
Copyright © 2019 allegro.ai, Inc.
|
||||
|
||||
Everyone is permitted to copy and distribute verbatim copies of this
|
||||
license document, but changing it is not allowed.
|
||||
|
||||
368
README.md
368
README.md
@@ -1,246 +1,244 @@
|
||||
# TRAINS Server
|
||||
<div align="center">
|
||||
|
||||
## Magic Version Control & Experiment Manager for AI
|
||||
<img src="docs/clearml_server_logo.png" width="250px">
|
||||
|
||||
## Introduction
|
||||
**ClearML - Auto-Magical Suite of tools to streamline your ML workflow
|
||||
</br>Experiment Manager, ML-Ops and Data-Management**
|
||||
|
||||
The **trains-server** is the infrastructure for [trains](https://github.com/allegroai/trains).
|
||||
It allows multiple users to collaborate and manage their experiments.
|
||||
|
||||
The **trains-server** contains the following components:
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
|
||||
[](https://artifacthub.io/packages/search?repo=allegroai)
|
||||
|
||||
* the Web-App which is a single-page UI for experiment management and browsing
|
||||
* a REST interface for:
|
||||
* documenting and logging experiment information, statistics and results
|
||||
* querying experiments history, logs and results
|
||||
* a locally-hosted file server for storing images and models making them easily accessible using the Web-App
|
||||
</div>
|
||||
|
||||
You can quickly setup your **trains-server** using a pre-built Docker image (see [Installation](#installation)).
|
||||
---
|
||||
<div align="center">
|
||||
|
||||
When new releases are available, you can upgrade your pre-built Docker image (see [Upgrade](#upgrade)).
|
||||
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
|
||||
|
||||
The **trains-server's** code is freely available [here](https://github.com/allegroai/trains-server).
|
||||
</div>
|
||||
|
||||
## System diagram
|
||||
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
|
||||
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
|
||||
due to Elasticsearch’s usage of the Java Security Manager.
|
||||
|
||||
<pre>
|
||||
TRAINS-server
|
||||
+--------------------------------------------------------------------+
|
||||
| |
|
||||
| Server Docker Elastic Docker Mongo Docker |
|
||||
| +-------------------------+ +---------------+ +------------+ |
|
||||
| | Pythonic Server | | | | | |
|
||||
| | +-----------------+ | | ElasticSearch | | MongoDB | |
|
||||
| | | WEB server | | | | | | |
|
||||
| | | Port 8080 | | | | | | |
|
||||
| | +--------+--------+ | | | | | |
|
||||
| | | | | | | | |
|
||||
| | +--------+--------+ | | | | | |
|
||||
| | | API server +----------------------------+ | |
|
||||
| | | Port 8008 +---------+ | | | |
|
||||
| | +-----------------+ | +-------+-------+ +-----+------+ |
|
||||
| | | | | |
|
||||
| | +-----------------+ | +---+----------------+------+ |
|
||||
| | | File Server +-------+ | Host Storage | |
|
||||
| | | Port 8081 | | +-----+ | |
|
||||
| | +-----------------+ | +---------------------------+ |
|
||||
| +------------+------------+ |
|
||||
+---------------|----------------------------------------------------+
|
||||
|HTTP
|
||||
+--------+
|
||||
GPU Machine |
|
||||
+------------------------|-------------------------------------------+
|
||||
| +------------------|--------------+ |
|
||||
| | Training | | +---------------------+ |
|
||||
| | Code +---+------------+ | | trains configuration| |
|
||||
| | | TRAINS | | | ~/trains.conf | |
|
||||
| | | +------+ | |
|
||||
| | +----------------+ | +---------------------+ |
|
||||
| +---------------------------------+ |
|
||||
+--------------------------------------------------------------------+
|
||||
</pre>
|
||||
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
|
||||
|
||||
## Installation
|
||||
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
This section contains the instructions to setup and launch a pre-built Docker image for the **trains-server**.
|
||||
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
|
||||
(which in any case **does not permit access to data within the Elasticsearch cluster**),
|
||||
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
|
||||
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
**Note**: This Docker image was tested with Linux, only. For Windows users, we recommend running the server
|
||||
on a Linux virtual machine.
|
||||
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
|
||||
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
|
||||
|
||||
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
|
||||
|
||||
---
|
||||
|
||||
## ClearML Server
|
||||
#### *Formerly known as Trains Server*
|
||||
|
||||
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
|
||||
It allows multiple users to collaborate and manage their experiments.
|
||||
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
|
||||
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
|
||||
|
||||
The **ClearML Server** contains the following components:
|
||||
|
||||
* The **ClearML** Web-App, a single-page UI for experiment management and browsing
|
||||
* RESTful API for:
|
||||
* Documenting and logging experiment information, statistics and results
|
||||
* Querying experiments history, logs and results
|
||||
* Locally-hosted file server for storing images and models making them easily accessible using the Web-App
|
||||
|
||||
You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server** using Docker, AWS EC2 AMI, or Kubernetes.
|
||||
|
||||
## System design
|
||||
|
||||
|
||||

|
||||
|
||||
The **ClearML Server** has two supported configurations:
|
||||
- Single IP (domain) with the following open ports
|
||||
- Web application on port 8080
|
||||
- API service on port 8008
|
||||
- File storage service on port 8081
|
||||
|
||||
- Sub-Domain configuration with default http/s ports (80 or 443)
|
||||
- Web application on sub-domain: app.\*.\*
|
||||
- API service on sub-domain: api.\*.\*
|
||||
- File storage service on sub-domain: files.\*.\*
|
||||
|
||||
## Launching The ClearML Server
|
||||
|
||||
### Prerequisites
|
||||
|
||||
You must be logged in as a user with sudo privileges.
|
||||
|
||||
### Setup
|
||||
The ports 8080/8081/8008 must be available for the **ClearML Server** services.
|
||||
|
||||
For example, to see if port `8080` is in use:
|
||||
|
||||
#### Step 1. Install Docker CE
|
||||
* Linux or macOS:
|
||||
|
||||
sudo lsof -Pn -i4 | grep :8080 | grep LISTEN
|
||||
|
||||
You must install Docker to run the pre-packaged **trains-server**.
|
||||
* Windows:
|
||||
|
||||
* For [Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/) / Mint (x86_64/amd64):
|
||||
netstat -an |find /i "8080"
|
||||
|
||||
### Launching
|
||||
|
||||
Launch The **ClearML Server** in any of the following formats:
|
||||
|
||||
```bash
|
||||
sudo apt-get install -y apt-transport-https ca-certificates curl software-properties-common
|
||||
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
|
||||
. /etc/os-release
|
||||
sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $UBUNTU_CODENAME stable"
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y docker-ce
|
||||
```
|
||||
- Pre-built [AWS EC2 AMI](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_aws_ec2_ami)
|
||||
- Pre-built [GCP Custom Image](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_gcp)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
|
||||
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
|
||||
|
||||
* For other operating systems, see [Supported platforms](https://docs.docker.com/install//#support) in the Docker documentation for instructions.
|
||||
## Connecting ClearML to your ClearML Server
|
||||
|
||||
#### Step 2. Setup the Docker daemon
|
||||
In order to set up the **ClearML** client to work with your **ClearML Server**:
|
||||
- Run the `clearml-init` command for an interactive setup.
|
||||
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
|
||||
|
||||
To run the ElasticSearch Docker container, you must setup the Docker daemon by modifing the default
|
||||
values required by Elastic in your Docker configuration file
|
||||
that are used by the **trains-server**. We provide instructions for the most common Docker configuration files.
|
||||
api {
|
||||
# API server on port 8008
|
||||
api_server: "http://localhost:8008"
|
||||
|
||||
You must edit or create a Docker configuration file:
|
||||
# web_server on port 8080
|
||||
web_server: "http://localhost:8080"
|
||||
|
||||
* If your Docker configuration file is `/etc/sysconfig/docker`, edit it.
|
||||
|
||||
Add the options in quotes to the available arguments in the `OPTIONS` section:
|
||||
|
||||
```bash
|
||||
OPTIONS="--default-ulimit nofile=1024:65536 --default-ulimit memlock=-1:-1"
|
||||
```
|
||||
|
||||
* Otherwise, edit `/etc/docker/daemon.json` (if it exists) or create it (if it does not exist).
|
||||
|
||||
Add or modify the `defaults-ulimits` section as shown below. Be sure your configuration file contains the `nofile` and `memlock` sub-sections and values shown.
|
||||
|
||||
**Note**: Your configuration file may contain other sections. If so, confirm that the sections are separated by commas. For more information about Docker configuration files, see an [Daemon configuration file](https://docs.docker.com/engine/reference/commandline/dockerd/#daemon-configuration-file) in the Docker documentation.
|
||||
|
||||
The **trains-server** required defaults values are:
|
||||
|
||||
```json
|
||||
{
|
||||
"default-ulimits": {
|
||||
"nofile": {
|
||||
"name": "nofile",
|
||||
"hard": 65536,
|
||||
"soft": 1024
|
||||
},
|
||||
"memlock":
|
||||
{
|
||||
"name": "memlock",
|
||||
"soft": -1,
|
||||
"hard": -1
|
||||
# file server on port 8081
|
||||
files_server: "http://localhost:8081"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Step 3. Restart the Docker daemon
|
||||
**Note**: If you have set up your **ClearML Server** in a sub-domain configuration, then there is no need to specify a port number,
|
||||
it will be inferred from the http/s scheme.
|
||||
|
||||
You must restart the Docker daemon after modifying the configuration file:
|
||||
After launching the **ClearML Server** and configuring the **ClearML** client to use the **ClearML Server**,
|
||||
you can [use](https://github.com/allegroai/clearml) **ClearML** in your experiments and view them in your **ClearML Server** web server,
|
||||
for example http://localhost:8080.
|
||||
For more information about the ClearML client, see [**ClearML**](https://github.com/allegroai/clearml).
|
||||
|
||||
```bash
|
||||
sudo service docker stop
|
||||
sudo service docker start
|
||||
```
|
||||
## ClearML-Agent Services <a name="services"></a>
|
||||
|
||||
#### Step 4. Set the Maximum Number of Memory Map Areas
|
||||
As of version 0.15 of **ClearML Server**, dockerized deployment includes a **ClearML-Agent Services** container running as
|
||||
part of the docker container collection.
|
||||
|
||||
The maximum number of memory map areas a process can use is defined
|
||||
using the `vm.max_map_count` kernel setting.
|
||||
ClearML-Agent Services is an extension of ClearML-Agent that provides the ability to launch long-lasting jobs
|
||||
that previously had to be executed on local / dedicated machines. It allows a single agent to
|
||||
launch multiple dockers (Tasks) for different use cases. To name a few use cases, auto-scaler service (spinning instances
|
||||
when the need arises and the budget allows), Controllers (Implementing pipelines and more sophisticated DevOps logic),
|
||||
Optimizer (such as Hyper-parameter Optimization or sweeping), and Application (such as interactive Bokeh apps for
|
||||
increased data transparency)
|
||||
|
||||
Elastic requires that `vm.max_map_count` to be at least 262144.
|
||||
ClearML-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
|
||||
Every task launched by ClearML-Agent Services will be registered as a new node in the system,
|
||||
providing tracking and transparency capabilities.
|
||||
You can also run the ClearML-Agent Services manually, see details in [ClearML-agent services mode](https://github.com/allegroai/clearml-agent#clearml-agent-services-mode-)
|
||||
|
||||
* For CentOS 7, Ubuntu 16.04, Mint 18.3, Ubuntu 18.04 and Mint 19 users, we tested the following commands to set
|
||||
`vm.max_map_count`:
|
||||
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
|
||||
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
|
||||
|
||||
```bash
|
||||
sudo echo "vm.max_map_count=262144" > /tmp/99-trains.conf
|
||||
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
|
||||
sudo sysctl -w vm.max_map_count=262144
|
||||
```
|
||||
## Advanced Functionality
|
||||
|
||||
* For information about setting this parameter on other systems, see the [elastic](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode) documentation.
|
||||
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#non-responsive-task-watchdog)
|
||||
|
||||
#### Step 5. Choose a Data Directory
|
||||
## Restarting ClearML Server
|
||||
|
||||
You must choose a directory on your system in which all data maintained by the **trains-server** is stored,
|
||||
create that directory, and set its permissions. The data stored in that directory includes the database, uploaded files and logs.
|
||||
To restart the **ClearML Server**, you must first stop the containers, and then restart them.
|
||||
|
||||
For example, if your data directory is `/opt/trains`, then use the following command:
|
||||
```bash
|
||||
docker-compose down
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo mkdir -p /opt/trains/data/elastic && sudo chown -R 1000:1000 /opt/trains
|
||||
```
|
||||
## Upgrading <a name="upgrade"></a>
|
||||
|
||||
### Launching Docker Containers
|
||||
**ClearML Server** releases are also reflected in the [docker compose configuration file](https://github.com/allegroai/trains-server/blob/master/docker/docker-compose.yml).
|
||||
We strongly encourage you to keep your **ClearML Server** up to date, by keeping up with the current release.
|
||||
|
||||
Launch the Docker containers. For example, if your data directory is `\opt\trains`,
|
||||
then use the following commands:
|
||||
**Note**: The following upgrade instructions use the Linux OS as an example.
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-elastic" -e "ES_JAVA_OPTS=-Xms2g -Xmx2g" -e "bootstrap.memory_lock=true" -e "cluster.name=trains" -e "discovery.zen.minimum_master_nodes=1" -e "node.name=trains" -e "script.inline=true" -e "script.update=true" -e "thread_pool.bulk.queue_size=2000" -e "thread_pool.search.queue_size=10000" -e "xpack.security.enabled=false" -e "xpack.monitoring.enabled=false" -e "cluster.routing.allocation.node_initial_primaries_recoveries=500" -e "node.ingest=true" -e "http.compression_level=7" -e "reindex.remote.whitelist=*.*" -e "script.painless.regex.enabled=true" --network="host" -v /opt/trains/data/elastic:/usr/share/elasticsearch/data docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
```
|
||||
To upgrade your existing **ClearML Server** deployment:
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5
|
||||
```
|
||||
1. Shut down the docker containers
|
||||
```bash
|
||||
docker-compose down
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-fileserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/data/fileserver:/mnt/fileserver allegroai/trains:latest fileserver
|
||||
```
|
||||
1. We highly recommend backing up your data directory before upgrading.
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-apiserver" --network="host" -v /opt/trains/logs:/var/log/trains allegroai/trains:latest apiserver
|
||||
```
|
||||
Assuming your data directory is `/opt/clearml`, to archive all data into `~/clearml_backup.tgz` execute:
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-webserver" --network="host" -v /opt/trains/logs:/var/log/trains allegroai/trains:latest webserver
|
||||
```
|
||||
```bash
|
||||
sudo tar czvf ~/clearml_backup.tgz /opt/clearml/data
|
||||
```
|
||||
|
||||
After the **trains-server** Dockers are up, the following are available:
|
||||
<details>
|
||||
<summary>Restore instructions:</summary>
|
||||
|
||||
* API server on port `8008`
|
||||
* Web server on port `8080`
|
||||
* File server on port `8081`
|
||||
To restore this example backup, execute:
|
||||
```bash
|
||||
sudo rm -R /opt/clearml/data
|
||||
sudo tar -xzf ~/clearml_backup.tgz -C /opt/clearml/data
|
||||
```
|
||||
</details>
|
||||
|
||||
## Upgrade
|
||||
1. Download the latest `docker-compose.yml` file.
|
||||
|
||||
We are constantly updating, improving and adding to the **trains-server**.
|
||||
New releases will include new pre-built Docker images.
|
||||
When we release a new version and include a new pre-built Docker image for it, upgrade as follows:
|
||||
```bash
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker/docker-compose.yml -o docker-compose.yml
|
||||
```
|
||||
|
||||
1. Shut down and remove each of your Docker instances using the following commands:
|
||||
1. Configure the ClearML-Agent Services (not supported on Windows installation).
|
||||
If `CLEARML_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `CLEARML_AGENT_GIT_USER` / `CLEARML_AGENT_GIT_PASS` are not provided,
|
||||
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export CLEARML_HOST_IP=server_host_ip_here
|
||||
export CLEARML_AGENT_GIT_USER=git_username_here
|
||||
export CLEARML_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
sudo docker stop <docker-name>
|
||||
sudo docker rm -v <docker-name>
|
||||
|
||||
The Docker names are (see [Launching Docker images](##launching-docker-images)):
|
||||
|
||||
* `trains-elastic`
|
||||
* `trains-mongo`
|
||||
* `trains-fileserver`
|
||||
* `trains-apiserver`
|
||||
* `trains-webserver`
|
||||
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
|
||||
```bash
|
||||
docker-compose -f docker-compose.yml pull
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
2. We highly recommend backing up your data directory!. A simple way to do that is using `tar`:
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://clear.ml/docs/latest/docs/faq/).**
|
||||
|
||||
For example, if your data directory is `/opt/trains`, use the following command:
|
||||
|
||||
sudo tar czvf ~/trains_backup.tgz /opt/trains/data
|
||||
|
||||
This back ups all data to an archive in your home directory.
|
||||
|
||||
To restore this example backup, use the following command:
|
||||
|
||||
sudo rm -R /opt/trains/data
|
||||
sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
|
||||
|
||||
3. Launch the newly released Docker image (see [Launching Docker images](#Launching-docker-images)).
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
|
||||
|
||||
Additionally, you can always find us at *clearml@allegro.ai*
|
||||
|
||||
## License
|
||||
|
||||
[Server Side Public License v1.0](https://github.com/mongodb/mongo/blob/master/LICENSE-Community.txt)
|
||||
|
||||
**trains-server** relies *heavily* on both [MongoDB](https://github.com/mongodb/mongo) and [ElasticSearch](https://github.com/elastic/elasticsearch).
|
||||
With the recent changes in both MongoDB's and ElasticSearch's OSS license, we feel it is our job as a community to support the projects we love and cherish.
|
||||
We feel the cause for the license change in both cases is more than just, and chose [SSPL](https://www.mongodb.com/licensing/server-side-public-license) because it is the more general and flexible of the two.
|
||||
The **ClearML Server** relies on both [MongoDB](https://github.com/mongodb/mongo) and [ElasticSearch](https://github.com/elastic/elasticsearch).
|
||||
With the recent changes in both MongoDB's and ElasticSearch's OSS license, we feel it is our responsibility as a
|
||||
member of the community to support the projects we love and cherish.
|
||||
We believe the cause for the license change in both cases is more than just,
|
||||
and chose [SSPL](https://www.mongodb.com/licensing/server-side-public-license) because it is the more general and flexible of the two licenses.
|
||||
|
||||
This is our way to say - we support you guys!
|
||||
|
||||
557
apiserver/LICENSE
Normal file
557
apiserver/LICENSE
Normal file
@@ -0,0 +1,557 @@
|
||||
Server Side Public License
|
||||
VERSION 1, OCTOBER 16, 2018
|
||||
|
||||
Copyright © 2019 allegro.ai, Inc.
|
||||
|
||||
Everyone is permitted to copy and distribute verbatim copies of this
|
||||
license document, but changing it is not allowed.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
“This License” refers to Server Side Public License.
|
||||
|
||||
“Copyright” also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
“The Program” refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as “you”. “Licensees” and
|
||||
“recipients” may be individuals or organizations.
|
||||
|
||||
To “modify” a work means to copy from or adapt all or part of the work in
|
||||
a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a “modified version” of the
|
||||
earlier work or a work “based on” the earlier work.
|
||||
|
||||
A “covered work” means either the unmodified Program or a work based on
|
||||
the Program.
|
||||
|
||||
To “propagate” a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To “convey” a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through a
|
||||
computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays “Appropriate Legal Notices” to the
|
||||
extent that it includes a convenient and prominently visible feature that
|
||||
(1) displays an appropriate copyright notice, and (2) tells the user that
|
||||
there is no warranty for the work (except to the extent that warranties
|
||||
are provided), that licensees may convey the work under this License, and
|
||||
how to view a copy of this License. If the interface presents a list of
|
||||
user commands or options, such as a menu, a prominent item in the list
|
||||
meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The “source code” for a work means the preferred form of the work for
|
||||
making modifications to it. “Object code” means any non-source form of a
|
||||
work.
|
||||
|
||||
A “Standard Interface” means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that is
|
||||
widely used among developers working in that language. The “System
|
||||
Libraries” of an executable work include anything, other than the work as
|
||||
a whole, that (a) is included in the normal form of packaging a Major
|
||||
Component, but which is not part of that Major Component, and (b) serves
|
||||
only to enable use of the work with that Major Component, or to implement
|
||||
a Standard Interface for which an implementation is available to the
|
||||
public in source code form. A “Major Component”, in this context, means a
|
||||
major essential component (kernel, window system, and so on) of the
|
||||
specific operating system (if any) on which the executable work runs, or
|
||||
a compiler used to produce the work, or an object code interpreter used
|
||||
to run it.
|
||||
|
||||
The “Corresponding Source” for a work in object code form means all the
|
||||
source code needed to generate, install, and (for an executable work) run
|
||||
the object code and to modify the work, including scripts to control
|
||||
those activities. However, it does not include the work's System
|
||||
Libraries, or general-purpose tools or generally available free programs
|
||||
which are used unmodified in performing those activities but which are
|
||||
not part of the work. For example, Corresponding Source includes
|
||||
interface definition files associated with source files for the work, and
|
||||
the source code for shared libraries and dynamically linked subprograms
|
||||
that the work is specifically designed to require, such as by intimate
|
||||
data communication or control flow between those subprograms and other
|
||||
parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users can
|
||||
regenerate automatically from other parts of the Corresponding Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program, subject to section 13. The
|
||||
output from running a covered work is covered by this License only if the
|
||||
output, given its content, constitutes a covered work. This License
|
||||
acknowledges your rights of fair use or other equivalent, as provided by
|
||||
copyright law. Subject to section 13, you may make, run and propagate
|
||||
covered works that you do not convey, without conditions so long as your
|
||||
license otherwise remains in force. You may convey covered works to
|
||||
others for the sole purpose of having them make modifications exclusively
|
||||
for you, or provide you with facilities for running those works, provided
|
||||
that you comply with the terms of this License in conveying all
|
||||
material for which you do not control copyright. Those thus making or
|
||||
running the covered works for you must do so exclusively on your
|
||||
behalf, under your direction and control, on terms that prohibit them
|
||||
from making any copies of your copyrighted material outside their
|
||||
relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under the
|
||||
conditions stated below. Sublicensing is not allowed; section 10 makes it
|
||||
unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article 11
|
||||
of the WIPO copyright treaty adopted on 20 December 1996, or similar laws
|
||||
prohibiting or restricting circumvention of such measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention is
|
||||
effected by exercising rights under this License with respect to the
|
||||
covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's users,
|
||||
your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice; keep
|
||||
intact all notices stating that this License and any non-permissive terms
|
||||
added in accord with section 7 apply to the code; keep intact all notices
|
||||
of the absence of any warranty; and give all recipients a copy of this
|
||||
License along with the Program. You may charge any price or no price for
|
||||
each copy that you convey, and you may offer support or warranty
|
||||
protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the terms
|
||||
of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified it,
|
||||
and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is released
|
||||
under this License and any conditions added under section 7. This
|
||||
requirement modifies the requirement in section 4 to “keep intact all
|
||||
notices”.
|
||||
|
||||
c) You must license the entire work, as a whole, under this License to
|
||||
anyone who comes into possession of a copy. This License will therefore
|
||||
apply, along with any applicable section 7 additional terms, to the
|
||||
whole of the work, and all its parts, regardless of how they are
|
||||
packaged. This License gives no permission to license the work in any
|
||||
other way, but it does not invalidate such permission if you have
|
||||
separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your work
|
||||
need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work, and
|
||||
which are not combined with it such as to form a larger program, in or on
|
||||
a volume of a storage or distribution medium, is called an “aggregate” if
|
||||
the compilation and its resulting copyright are not used to limit the
|
||||
access or legal rights of the compilation's users beyond what the
|
||||
individual works permit. Inclusion of a covered work in an aggregate does
|
||||
not cause this License to apply to the other parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms of
|
||||
sections 4 and 5, provided that you also convey the machine-readable
|
||||
Corresponding Source under the terms of this License, in one of these
|
||||
ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium customarily
|
||||
used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a written
|
||||
offer, valid for at least three years and valid for as long as you
|
||||
offer spare parts or customer support for that product model, to give
|
||||
anyone who possesses the object code either (1) a copy of the
|
||||
Corresponding Source for all the software in the product that is
|
||||
covered by this License, on a durable physical medium customarily used
|
||||
for software interchange, for a price no more than your reasonable cost
|
||||
of physically performing this conveying of source, or (2) access to
|
||||
copy the Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This alternative is
|
||||
allowed only occasionally and noncommercially, and only if you received
|
||||
the object code with such an offer, in accord with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated place
|
||||
(gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to copy
|
||||
the object code is a network server, the Corresponding Source may be on
|
||||
a different server (operated by you or a third party) that supports
|
||||
equivalent copying facilities, provided you maintain clear directions
|
||||
next to the object code saying where to find the Corresponding Source.
|
||||
Regardless of what server hosts the Corresponding Source, you remain
|
||||
obligated to ensure that it is available for as long as needed to
|
||||
satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided you
|
||||
inform other peers where the object code and Corresponding Source of
|
||||
the work are being offered to the general public at no charge under
|
||||
subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be included
|
||||
in conveying the object code work.
|
||||
|
||||
A “User Product” is either (1) a “consumer product”, which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, “normally used” refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
“Installation Information” for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as part
|
||||
of a transaction in which the right of possession and use of the User
|
||||
Product is transferred to the recipient in perpetuity or for a fixed term
|
||||
(regardless of how the transaction is characterized), the Corresponding
|
||||
Source conveyed under this section must be accompanied by the
|
||||
Installation Information. But this requirement does not apply if neither
|
||||
you nor any third party retains the ability to install modified object
|
||||
code on the User Product (for example, the work has been installed in
|
||||
ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access
|
||||
to a network may be denied when the modification itself materially
|
||||
and adversely affects the operation of the network or violates the
|
||||
rules and protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided, in
|
||||
accord with this section must be in a format that is publicly documented
|
||||
(and with an implementation available to the public in source code form),
|
||||
and must require no special password or key for unpacking, reading or
|
||||
copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
“Additional permissions” are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall be
|
||||
treated as though they were included in this License, to the extent that
|
||||
they are valid under applicable law. If additional permissions apply only
|
||||
to part of the Program, that part may be used separately under those
|
||||
permissions, but the entire Program remains governed by this License
|
||||
without regard to the additional permissions. When you convey a copy of
|
||||
a covered work, you may at your option remove any additional permissions
|
||||
from that copy, or from any part of it. (Additional permissions may be
|
||||
written to require their own removal in certain cases when you modify the
|
||||
work.) You may place additional permissions on material, added by you to
|
||||
a covered work, for which you have or can give appropriate copyright
|
||||
permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you add
|
||||
to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some trade
|
||||
names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that material
|
||||
by anyone who conveys the material (or modified versions of it) with
|
||||
contractual assumptions of liability to the recipient, for any
|
||||
liability that these contractual assumptions directly impose on those
|
||||
licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered “further
|
||||
restrictions” within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further restriction,
|
||||
you may remove that term. If a license document contains a further
|
||||
restriction but permits relicensing or conveying under this License, you
|
||||
may add to a covered work material governed by the terms of that license
|
||||
document, provided that the further restriction does not survive such
|
||||
relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you must
|
||||
place, in the relevant source files, a statement of the additional terms
|
||||
that apply to those files, or a notice indicating where to find the
|
||||
applicable terms. Additional terms, permissive or non-permissive, may be
|
||||
stated in the form of a separately written license, or stated as
|
||||
exceptions; the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or modify
|
||||
it is void, and will automatically terminate your rights under this
|
||||
License (including any patent licenses granted under the third paragraph
|
||||
of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your license
|
||||
from a particular copyright holder is reinstated (a) provisionally,
|
||||
unless and until the copyright holder explicitly and finally terminates
|
||||
your license, and (b) permanently, if the copyright holder fails to
|
||||
notify you of the violation by some reasonable means prior to 60 days
|
||||
after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is reinstated
|
||||
permanently if the copyright holder notifies you of the violation by some
|
||||
reasonable means, this is the first time you have received notice of
|
||||
violation of this License (for any work) from that copyright holder, and
|
||||
you cure the violation prior to 30 days after your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or run a
|
||||
copy of the Program. Ancillary propagation of a covered work occurring
|
||||
solely as a consequence of using peer-to-peer transmission to receive a
|
||||
copy likewise does not require acceptance. However, nothing other than
|
||||
this License grants you permission to propagate or modify any covered
|
||||
work. These actions infringe copyright if you do not accept this License.
|
||||
Therefore, by modifying or propagating a covered work, you indicate your
|
||||
acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically receives
|
||||
a license from the original licensors, to run, modify and propagate that
|
||||
work, subject to this License. You are not responsible for enforcing
|
||||
compliance by third parties with this License.
|
||||
|
||||
An “entity transaction” is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered work
|
||||
results from an entity transaction, each party to that transaction who
|
||||
receives a copy of the work also receives whatever licenses to the work
|
||||
the party's predecessor in interest had or could give under the previous
|
||||
paragraph, plus a right to possession of the Corresponding Source of the
|
||||
work from the predecessor in interest, if the predecessor has it or can
|
||||
get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the rights
|
||||
granted or affirmed under this License. For example, you may not impose a
|
||||
license fee, royalty, or other charge for exercise of rights granted
|
||||
under this License, and you may not initiate litigation (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that any patent claim
|
||||
is infringed by making, using, selling, offering for sale, or importing
|
||||
the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A “contributor” is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The work
|
||||
thus licensed is called the contributor's “contributor version”.
|
||||
|
||||
A contributor's “essential patent claims” are all patent claims owned or
|
||||
controlled by the contributor, whether already acquired or hereafter
|
||||
acquired, that would be infringed by some manner, permitted by this
|
||||
License, of making, using, or selling its contributor version, but do not
|
||||
include claims that would be infringed only as a consequence of further
|
||||
modification of the contributor version. For purposes of this definition,
|
||||
“control” includes the right to grant patent sublicenses in a manner
|
||||
consistent with the requirements of this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to make,
|
||||
use, sell, offer for sale, import and otherwise run, modify and propagate
|
||||
the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a “patent license” is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To “grant” such a patent license to a party
|
||||
means to make such an agreement or commitment not to enforce a patent
|
||||
against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license, and
|
||||
the Corresponding Source of the work is not available for anyone to copy,
|
||||
free of charge and under the terms of this License, through a publicly
|
||||
available network server or other readily accessible means, then you must
|
||||
either (1) cause the Corresponding Source to be so available, or (2)
|
||||
arrange to deprive yourself of the benefit of the patent license for this
|
||||
particular work, or (3) arrange, in a manner consistent with the
|
||||
requirements of this License, to extend the patent license to downstream
|
||||
recipients. “Knowingly relying” means you have actual knowledge that, but
|
||||
for the patent license, your conveying the covered work in a country, or
|
||||
your recipient's use of the covered work in a country, would infringe
|
||||
one or more identifiable patents in that country that you have reason
|
||||
to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties receiving
|
||||
the covered work authorizing them to use, propagate, modify or convey a
|
||||
specific copy of the covered work, then the patent license you grant is
|
||||
automatically extended to all recipients of the covered work and works
|
||||
based on it.
|
||||
|
||||
A patent license is “discriminatory” if it does not include within the
|
||||
scope of its coverage, prohibits the exercise of, or is conditioned on
|
||||
the non-exercise of one or more of the rights that are specifically
|
||||
granted under this License. You may not convey a covered work if you are
|
||||
a party to an arrangement with a third party that is in the business of
|
||||
distributing software, under which you make payment to the third party
|
||||
based on the extent of your activity of conveying the work, and under
|
||||
which the third party grants, to any of the parties who would receive the
|
||||
covered work from you, a discriminatory patent license (a) in connection
|
||||
with copies of the covered work conveyed by you (or copies made from
|
||||
those copies), or (b) primarily for and in connection with specific
|
||||
products or compilations that contain the covered work, unless you
|
||||
entered into that arrangement, or that patent license was granted, prior
|
||||
to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting any
|
||||
implied license or other defenses to infringement that may otherwise be
|
||||
available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot use,
|
||||
propagate or convey a covered work so as to satisfy simultaneously your
|
||||
obligations under this License and any other pertinent obligations, then
|
||||
as a consequence you may not use, propagate or convey it at all. For
|
||||
example, if you agree to terms that obligate you to collect a royalty for
|
||||
further conveying from those to whom you convey the Program, the only way
|
||||
you could satisfy both those terms and this License would be to refrain
|
||||
entirely from conveying the Program.
|
||||
|
||||
13. Offering the Program as a Service.
|
||||
|
||||
If you make the functionality of the Program or a modified version
|
||||
available to third parties as a service, you must make the Service Source
|
||||
Code available via network download to everyone at no charge, under the
|
||||
terms of this License. Making the functionality of the Program or
|
||||
modified version available to third parties as a service includes,
|
||||
without limitation, enabling third parties to interact with the
|
||||
functionality of the Program or modified version remotely through a
|
||||
computer network, offering a service the value of which entirely or
|
||||
primarily derives from the value of the Program or modified version, or
|
||||
offering a service that accomplishes for users the primary purpose of the
|
||||
Program or modified version.
|
||||
|
||||
“Service Source Code” means the Corresponding Source for the Program or
|
||||
the modified version, and the Corresponding Source for all programs that
|
||||
you use to make the Program or modified version available as a service,
|
||||
including, without limitation, management software, user interfaces,
|
||||
application program interfaces, automation software, monitoring software,
|
||||
backup software, storage software and hosting software, all such that a
|
||||
user could run an instance of the service using the Service Source Code
|
||||
you make available.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
MongoDB, Inc. may publish revised and/or new versions of the Server Side
|
||||
Public License from time to time. Such new versions will be similar in
|
||||
spirit to the present version, but may differ in detail to address new
|
||||
problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the Program
|
||||
specifies that a certain numbered version of the Server Side Public
|
||||
License “or any later version” applies to it, you have the option of
|
||||
following the terms and conditions either of that numbered version or of
|
||||
any later version published by MongoDB, Inc. If the Program does not
|
||||
specify a version number of the Server Side Public License, you may
|
||||
choose any version ever published by MongoDB, Inc.
|
||||
|
||||
If the Program specifies that a proxy can decide which future versions of
|
||||
the Server Side Public License can be used, that proxy's public statement
|
||||
of acceptance of a version permanently authorizes you to choose that
|
||||
version for the Program.
|
||||
|
||||
Later license versions may give you additional or different permissions.
|
||||
However, no additional obligations are imposed on any author or copyright
|
||||
holder as a result of your choosing to follow a later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM “AS IS” WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING
|
||||
ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
||||
THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO
|
||||
LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU
|
||||
OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
|
||||
PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided above
|
||||
cannot be given local legal effect according to their terms, reviewing
|
||||
courts shall apply local law that most closely approximates an absolute
|
||||
waiver of all civil liability in connection with the Program, unless a
|
||||
warranty or assumption of liability accompanies a copy of the Program in
|
||||
return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
6
apiserver/apierrors/__init__.py
Normal file
6
apiserver/apierrors/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .apierror import APIError
|
||||
from .base import BaseError
|
||||
|
||||
from apiserver.apierrors_generator import ErrorsGenerator
|
||||
|
||||
ErrorsGenerator.generate_python_files()
|
||||
@@ -1,9 +1,10 @@
|
||||
class APIError(Exception):
|
||||
def __init__(self, msg, code=500, subcode=0, **_):
|
||||
def __init__(self, msg, code=500, subcode=0, error_data=None, **_):
|
||||
super(APIError, self).__init__()
|
||||
self._msg = msg
|
||||
self._code = code
|
||||
self._subcode = subcode
|
||||
self._error_data = error_data or {}
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
@@ -17,5 +18,9 @@ class APIError(Exception):
|
||||
def subcode(self):
|
||||
return self._subcode
|
||||
|
||||
@property
|
||||
def error_data(self):
|
||||
return self._error_data
|
||||
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
@@ -1,9 +1,13 @@
|
||||
import six
|
||||
from boltons.typeutils import classproperty
|
||||
from typing import Tuple
|
||||
|
||||
import six
|
||||
from boltons.iterutils import is_collection, remap
|
||||
from boltons.typeutils import classproperty
|
||||
|
||||
from .apierror import APIError
|
||||
|
||||
jsonable_types = (dict, list, tuple, str, int, float, bool, type(None))
|
||||
|
||||
|
||||
class BaseError(APIError):
|
||||
_default_code = 500
|
||||
@@ -19,15 +23,26 @@ class BaseError(APIError):
|
||||
f"{k}={self._format_kwarg(v)}" for k, v in kwargs.items()
|
||||
)
|
||||
message += f": {kwargs_msg}"
|
||||
params = kwargs.copy()
|
||||
params.update(
|
||||
code=self._default_code, subcode=self._default_subcode, msg=message
|
||||
|
||||
super(BaseError, self).__init__(
|
||||
code=self._default_code,
|
||||
subcode=self._default_subcode,
|
||||
msg=message,
|
||||
error_data=self._to_safe_json_types(kwargs),
|
||||
)
|
||||
super(BaseError, self).__init__(**params)
|
||||
|
||||
@staticmethod
|
||||
def _to_safe_json_types(data):
|
||||
def visit(_, k, v):
|
||||
if not isinstance(v, jsonable_types):
|
||||
v = str(v)
|
||||
return k, v
|
||||
|
||||
return remap(data, visit=visit)
|
||||
|
||||
@staticmethod
|
||||
def _format_kwarg(value):
|
||||
if isinstance(value, (tuple, list)):
|
||||
if is_collection(value):
|
||||
return f'({", ".join(str(v) for v in value)})'
|
||||
elif isinstance(value, six.string_types):
|
||||
return value
|
||||
148
apiserver/apierrors/errors.conf
Normal file
148
apiserver/apierrors/errors.conf
Normal file
@@ -0,0 +1,148 @@
|
||||
301 {
|
||||
_: "moved_permanently"
|
||||
1: ["not_supported", "this endpoint is no longer supported for the requested API version"]
|
||||
}
|
||||
|
||||
400 {
|
||||
_: "bad_request"
|
||||
1: ["not_supported", "endpoint is not supported"]
|
||||
2: ["request_path_has_invalid_version", "request path has invalid version"]
|
||||
5: ["invalid_headers", "invalid headers"]
|
||||
6: ["impersonation_error", "impersonation error"]
|
||||
|
||||
10: ["invalid_id", "invalid object id"]
|
||||
11: ["missing_required_fields", "missing required fields"]
|
||||
12: ["validation_error", "validation error"]
|
||||
13: ["fields_not_allowed_for_role", "fields not allowed for role"]
|
||||
14: ["invalid fields", "fields not defined for object"]
|
||||
15: ["fields_conflict", "conflicting fields"]
|
||||
16: ["fields_value_error", "invalid value for fields"]
|
||||
17: ["batch_contains_no_items", "batch request contains no items"]
|
||||
18: ["batch_validation_error", "batch request validation error"]
|
||||
19: ["invalid_lucene_syntax", "malformed lucene query"]
|
||||
20: ["fields_type_error", "invalid type for fields"]
|
||||
21: ["invalid_regex_error", "malformed regular expression"]
|
||||
22: ["invalid_email_address", "malformed email address"]
|
||||
23: ["invalid_domain_name", "malformed domain name"]
|
||||
24: ["not_public_object", "object is not public"]
|
||||
|
||||
# Auth / Login
|
||||
75: ["invalid_access_key", "access key not found for user"]
|
||||
|
||||
# Tasks
|
||||
100: ["task_error", "general task error"]
|
||||
101: ["invalid_task_id", "invalid task id"]
|
||||
102: ["task_validation_error", "task validation error"]
|
||||
110: ["invalid_task_status", "invalid task status"]
|
||||
111: ["task_not_started", "task not started (invalid task status)"]
|
||||
112: ["task_in_progress", "task in progress (invalid task status)"]
|
||||
113: ["task_published", "task published (invalid task status)"]
|
||||
114: ["task_status_unknown", "task unknown (invalid task status)"]
|
||||
120: ["invalid_task_execution_progress", "invalid task execution progress"]
|
||||
121: ["failed_changing_task_status", "failed changing task status. probably someone changed it before you"]
|
||||
122: ["missing_task_fields", "task is missing expected fields"]
|
||||
123: ["task_cannot_be_deleted", "task cannot be deleted"]
|
||||
125: ["task_has_jobs_running", "task has jobs that haven't completed yet"]
|
||||
126: ["invalid_task_type", "invalid task type for this operations"]
|
||||
127: ["invalid_task_input", "invalid task output"]
|
||||
128: ["invalid_task_output", "invalid task output"]
|
||||
129: ["task_publish_in_progress", "Task publish in progress"]
|
||||
130: ["task_not_found", "task not found"]
|
||||
131: ["events_not_added", "events not added"]
|
||||
|
||||
# Models
|
||||
200: ["model_error", "general task error"]
|
||||
201: ["invalid_model_id", "invalid model id"]
|
||||
202: ["model_not_ready", "model is not ready"]
|
||||
203: ["model_is_ready", "model is ready"]
|
||||
204: ["invalid_model_uri", "invalid model URI"]
|
||||
205: ["model_in_use", "model is used by tasks"]
|
||||
206: ["model_creating_task_exists", "task that created this model exists"]
|
||||
|
||||
# Users
|
||||
300: ["invalid_user", "invalid user"]
|
||||
301: ["invalid_user_id", "invalid user id"]
|
||||
302: ["user_id_exists", "user id already exists"]
|
||||
305: ["invalid_preferences_update", "Malformed key and/or value"]
|
||||
|
||||
# Projects
|
||||
401: ["invalid_project_id", "invalid project id"]
|
||||
402: ["project_has_tasks", "project has associated tasks"]
|
||||
403: ["project_not_found", "project not found"]
|
||||
405: ["project_has_models", "project has associated models"]
|
||||
407: ["invalid_project_name", "invalid project name"]
|
||||
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
|
||||
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
|
||||
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
|
||||
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
|
||||
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
|
||||
|
||||
# Queues
|
||||
701: ["invalid_queue_id", "invalid queue id"]
|
||||
702: ["queue_not_empty", "queue is not empty"]
|
||||
703: ["invalid_queue_or_task_not_queued", "invalid queue id or task not in queue"]
|
||||
704: ["removed_during_reposition", "task was removed by another party during reposition"]
|
||||
705: ["failed_adding_during_reposition", "failed adding task back to queue during reposition"]
|
||||
706: ["task_already_queued", "failed adding task to queue since task is already queued"]
|
||||
707: ["no_default_queue", "no queue is tagged as the default queue for this company"]
|
||||
708: ["multiple_default_queues", "more than one queue is tagged as the default queue for this company"]
|
||||
|
||||
# Database
|
||||
800: ["data_validation_error", "data validation error"]
|
||||
801: ["expected_unique_data", "value combination already exists (unique field already contains this value)"]
|
||||
|
||||
# Workers
|
||||
1001: ["invalid_worker_id", "invalid worker id"]
|
||||
1002: ["worker_registration_failed", "worker registration failed"]
|
||||
1003: ["worker_registered", "worker is already registered"]
|
||||
1004: ["worker_not_registered", "worker is not registered"]
|
||||
1005: ["worker_stats_not_found", "worker stats not found"]
|
||||
|
||||
1104: ["invalid_scroll_id", "Invalid scroll id"]
|
||||
}
|
||||
|
||||
401 {
|
||||
_: "unauthorized"
|
||||
1: ["not_authorized", "unauthorized (not authorized for endpoint)"]
|
||||
2: ["entity_not_allowed", "unauthorized (entity not allowed)"]
|
||||
10: ["bad_auth_type", "unauthorized (bad authentication header type)"]
|
||||
20: ["no_credentials", "unauthorized (missing credentials)"]
|
||||
21: ["bad_credentials", "unauthorized (malformed credentials)"]
|
||||
22: ["invalid_credentials", "unauthorized (invalid credentials)"]
|
||||
30: ["invalid_token", "invalid token"]
|
||||
31: ["blocked_token", "token is blocked"]
|
||||
40: ["invalid_fixed_user", "fixed user ID was not found"]
|
||||
}
|
||||
|
||||
403: {
|
||||
_: "forbidden"
|
||||
10: ["routing_error", "forbidden (routing error)"]
|
||||
12: ["blocked_internal_endpoint", "forbidden (blocked internal endpoint)"]
|
||||
20: ["role_not_allowed", "forbidden (not allowed for role)"]
|
||||
21: ["no_write_permission", "forbidden (modification not allowed)"]
|
||||
}
|
||||
|
||||
410: {
|
||||
_: "gone"
|
||||
1: ["not_supported", "thus endpoint is not supported any more"]
|
||||
}
|
||||
|
||||
500 {
|
||||
_: "server_error"
|
||||
0: ["general_error", "general server error"]
|
||||
1: ["internal_error", "internal server error"]
|
||||
2: ["config_error", "configuration error"]
|
||||
3: ["build_info_error", "build info unavailable or corrupted"]
|
||||
4: ["low_disk_space", "Critical server error! Server reports low or insufficient disk space. Please resolve immediately by allocating additional disk space or freeing up storage space."]
|
||||
10: ["transaction_error", "a transaction call has returned with an error"]
|
||||
# Database-related issues
|
||||
100: ["data_error", "general data error"]
|
||||
101: ["inconsistent_data", "inconsistent data encountered in document"]
|
||||
102: ["database_unavailable", "database is temporarily unavailable"]
|
||||
110: ["update_failed", "update failed"]
|
||||
|
||||
# Index-related issues
|
||||
201: ["missing_index", "missing internal index"]
|
||||
|
||||
9999: ["not_implemented", "action is not yet implemented"]
|
||||
}
|
||||
1
apiserver/apierrors_generator/__init__.py
Normal file
1
apiserver/apierrors_generator/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .errors_generator import ErrorsGenerator
|
||||
4
apiserver/apierrors_generator/__main__.py
Normal file
4
apiserver/apierrors_generator/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .errors_generator import ErrorsGenerator
|
||||
|
||||
if __name__ == '__main__':
|
||||
ErrorsGenerator.generate_python_files()
|
||||
31
apiserver/apierrors_generator/errors_generator.py
Normal file
31
apiserver/apierrors_generator/errors_generator.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from pyhocon import ConfigFactory, ConfigTree
|
||||
|
||||
from .generator import Generator
|
||||
|
||||
|
||||
class ErrorsGenerator:
|
||||
_apierrors_path = Path(__file__).parents[1] / "apierrors"
|
||||
_files = [_apierrors_path / "errors.conf"]
|
||||
|
||||
@classmethod
|
||||
def _get_codes(cls):
|
||||
return {
|
||||
(k, v.pop("_")): v
|
||||
for k, v in reduce(
|
||||
ConfigTree.merge_configs, map(ConfigFactory.parse_file, cls._files),
|
||||
).items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def add_errors_file(cls, path: Union[Path, str]):
|
||||
cls._files.append(path)
|
||||
|
||||
@classmethod
|
||||
def generate_python_files(cls):
|
||||
Generator(cls._apierrors_path / "errors", format_pep8=False).make_errors(
|
||||
cls._get_codes()
|
||||
)
|
||||
@@ -8,9 +8,12 @@ from pathlib import Path
|
||||
|
||||
env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(str(Path(__file__).parent)),
|
||||
autoescape=jinja2.select_autoescape(disabled_extensions=('py',), default_for_string=False),
|
||||
autoescape=jinja2.select_autoescape(
|
||||
disabled_extensions=("py",), default_for_string=False
|
||||
),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True)
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
|
||||
|
||||
def env_filter(name=None):
|
||||
@@ -19,14 +22,14 @@ def env_filter(name=None):
|
||||
|
||||
@env_filter()
|
||||
def cls_name(name):
|
||||
delims = list(map(re.escape, (' ', '_')))
|
||||
parts = re.split('|'.join(delims), name)
|
||||
return ''.join(x.capitalize() for x in parts)
|
||||
delims = list(map(re.escape, (" ", "_")))
|
||||
parts = re.split("|".join(delims), name)
|
||||
return "".join(x.capitalize() for x in parts)
|
||||
|
||||
|
||||
class Generator(object):
|
||||
_base_class_name = 'BaseError'
|
||||
_base_class_module = 'apierrors.base'
|
||||
_base_class_name = "BaseError"
|
||||
_base_class_module = "apiserver.apierrors.base"
|
||||
|
||||
def __init__(self, path, format_pep8=True, use_md5=True):
|
||||
self._use_md5 = use_md5
|
||||
@@ -35,29 +38,37 @@ class Generator(object):
|
||||
self._path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _make_init_file(self, path):
|
||||
(self._path / path / '__init__.py').write_bytes('')
|
||||
(self._path / path / "__init__.py").write_bytes(b"")
|
||||
|
||||
def _do_render(self, file, template, context):
|
||||
with file.open('w') as f:
|
||||
with file.open("w") as f:
|
||||
result = template.render(
|
||||
base_class_name=self._base_class_name,
|
||||
base_class_module=self._base_class_module,
|
||||
**context)
|
||||
**context
|
||||
)
|
||||
if self._format_pep8:
|
||||
result = autopep8.fix_code(result, options={'aggressive': 1, 'verbose': 0, 'max_line_length': 120})
|
||||
import autopep8
|
||||
|
||||
result = autopep8.fix_code(
|
||||
result,
|
||||
options={"aggressive": 1, "verbose": 0, "max_line_length": 120},
|
||||
)
|
||||
f.write(result)
|
||||
|
||||
def _make_section(self, name, code, subcodes):
|
||||
self._do_render(
|
||||
file=(self._path / name).with_suffix('.py'),
|
||||
template=env.get_template('templates/section.jinja2'),
|
||||
context=dict(code=code, subcodes=list(subcodes.items()),))
|
||||
file=(self._path / name).with_suffix(".py"),
|
||||
template=env.get_template("templates/section.jinja2"),
|
||||
context=dict(code=code, subcodes=list(subcodes.items()),),
|
||||
)
|
||||
|
||||
def _make_init(self, sections):
|
||||
self._do_render(
|
||||
file=(self._path / '__init__.py'),
|
||||
template=env.get_template('templates/init.jinja2'),
|
||||
context=dict(sections=sections,))
|
||||
file=(self._path / "__init__.py"),
|
||||
template=env.get_template("templates/init.jinja2"),
|
||||
context=dict(sections=sections,),
|
||||
)
|
||||
|
||||
def _key_to_str(self, data):
|
||||
if isinstance(data, dict):
|
||||
@@ -66,11 +77,11 @@ class Generator(object):
|
||||
|
||||
def _calc_digest(self, data):
|
||||
data = json.dumps(self._key_to_str(data), sort_keys=True)
|
||||
return hashlib.md5(data.encode('utf8')).hexdigest()
|
||||
return hashlib.md5(data.encode("utf8")).hexdigest()
|
||||
|
||||
def make_errors(self, errors):
|
||||
digest = None
|
||||
digest_file = self._path / 'digest.md5'
|
||||
digest_file = self._path / "digest.md5"
|
||||
if self._use_md5:
|
||||
digest = self._calc_digest(errors)
|
||||
if digest_file.is_file():
|
||||
@@ -79,7 +90,7 @@ class Generator(object):
|
||||
|
||||
self._make_init(errors)
|
||||
for (code, section_name), subcodes in errors.items():
|
||||
self._make_section(section_name, code, subcodes)
|
||||
self._make_section(section_name, int(code), subcodes)
|
||||
|
||||
if self._use_md5:
|
||||
digest_file.write_text(digest)
|
||||
@@ -5,5 +5,5 @@ from {{ base_class_module }} import {{ base_class_name }}
|
||||
{% for subcode, (name, msg) in subcodes %}
|
||||
|
||||
|
||||
{{ error_class(name|cls_name, msg, code, subcode) -}}
|
||||
{{ error_class(name|cls_name, msg, code, subcode|int) -}}
|
||||
{% endfor %}
|
||||
297
apiserver/apimodels/__init__.py
Normal file
297
apiserver/apimodels/__init__.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from enum import Enum
|
||||
from typing import Union, Type, Iterable
|
||||
|
||||
import jsonmodels.errors
|
||||
import six
|
||||
from jsonmodels import fields
|
||||
from jsonmodels.fields import _LazyType, NotSet
|
||||
from jsonmodels.models import Base as ModelBase
|
||||
from jsonmodels.validators import Enum as EnumValidator
|
||||
from mongoengine.base import BaseDocument
|
||||
from validators import email as email_validator, domain as domain_validator
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.utilities.json import loads, dumps
|
||||
|
||||
|
||||
class EmailField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if email_validator(value) is not True:
|
||||
raise errors.bad_request.InvalidEmailAddress()
|
||||
|
||||
|
||||
class DomainField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if domain_validator(value) is not True:
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
|
||||
|
||||
def make_default(field_cls, default_value):
|
||||
class _FieldWithDefault(field_cls):
|
||||
def get_default_value(self):
|
||||
return default_value
|
||||
|
||||
return _FieldWithDefault
|
||||
|
||||
|
||||
class ListField(fields.ListField):
|
||||
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
|
||||
if default is not NotSet and callable(default):
|
||||
default = default()
|
||||
|
||||
super(ListField, self).__init__(items_types, *args, default=default, **kwargs)
|
||||
|
||||
def _cast_value(self, value):
|
||||
try:
|
||||
return super(ListField, self)._cast_value(value)
|
||||
except TypeError:
|
||||
if len(self.items_types) == 1 and issubclass(self.items_types[0], Enum):
|
||||
return self.items_types[0](value)
|
||||
return value
|
||||
|
||||
def validate_single_value(self, item):
|
||||
super(ListField, self).validate_single_value(item)
|
||||
if isinstance(item, ModelBase):
|
||||
item.validate()
|
||||
|
||||
|
||||
class DictField(fields.BaseField):
|
||||
types = (dict,)
|
||||
|
||||
def __init__(self, value_types=None, *args, **kwargs):
|
||||
self.value_types = self._assign_types(value_types)
|
||||
super(DictField, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_default_value(self):
|
||||
default = super(DictField, self).get_default_value()
|
||||
if default is None and not self.required:
|
||||
return {}
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _assign_types(value_types):
|
||||
if value_types:
|
||||
try:
|
||||
value_types = tuple(value_types)
|
||||
except TypeError:
|
||||
value_types = (value_types,)
|
||||
else:
|
||||
value_types = tuple()
|
||||
|
||||
return tuple(
|
||||
_LazyType(type_) if isinstance(type_, six.string_types) else type_
|
||||
for type_ in value_types
|
||||
)
|
||||
|
||||
def parse_value(self, values):
|
||||
"""Cast value to proper collection."""
|
||||
result = self.get_default_value()
|
||||
|
||||
if values is None:
|
||||
return result
|
||||
|
||||
if not self.value_types or not isinstance(values, dict):
|
||||
return values
|
||||
|
||||
return {key: self._cast_value(value) for key, value in values.items()}
|
||||
|
||||
def _cast_value(self, value):
|
||||
if isinstance(value, self.value_types):
|
||||
return value
|
||||
else:
|
||||
if len(self.value_types) != 1:
|
||||
tpl = 'Cannot decide which type to choose from "{types}".'
|
||||
raise jsonmodels.errors.ValidationError(
|
||||
tpl.format(
|
||||
types=', '.join([t.__name__ for t in self.value_types])
|
||||
)
|
||||
)
|
||||
return self.value_types[0](**value)
|
||||
|
||||
def validate(self, value):
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
if not self.value_types:
|
||||
return
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
for item in value.values():
|
||||
self.validate_single_value(item)
|
||||
|
||||
def validate_single_value(self, item):
|
||||
if not self.value_types:
|
||||
return
|
||||
|
||||
if not isinstance(item, self.value_types):
|
||||
raise jsonmodels.errors.ValidationError(
|
||||
"All items must be instances "
|
||||
'of "{types}", and not "{type}".'.format(
|
||||
types=", ".join([t.__name__ for t in self.value_types]),
|
||||
type=type(item).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _elem_to_struct(self, value):
|
||||
try:
|
||||
return value.to_struct()
|
||||
except AttributeError:
|
||||
return value
|
||||
|
||||
def to_struct(self, values):
|
||||
return {k: self._elem_to_struct(v) for k, v in values.items()}
|
||||
|
||||
|
||||
class IntField(fields.IntField):
|
||||
def parse_value(self, value):
|
||||
try:
|
||||
return super(IntField, self).parse_value(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
|
||||
|
||||
class NullableEnumValidator(EnumValidator):
|
||||
"""Validator for enums that allows a None value."""
|
||||
|
||||
def validate(self, value):
|
||||
if value is not None:
|
||||
super(NullableEnumValidator, self).validate(value)
|
||||
|
||||
|
||||
class EnumField(fields.StringField):
|
||||
def __init__(
|
||||
self,
|
||||
values_or_type: Union[Iterable, Type[Enum]],
|
||||
*args,
|
||||
required=False,
|
||||
default=None,
|
||||
**kwargs
|
||||
):
|
||||
choices = list(map(self.parse_value, values_or_type))
|
||||
validator_cls = EnumValidator if required else NullableEnumValidator
|
||||
kwargs.setdefault("validators", []).append(validator_cls(*choices))
|
||||
super().__init__(
|
||||
default=self.parse_value(default), required=required, *args, **kwargs
|
||||
)
|
||||
|
||||
def parse_value(self, value):
|
||||
if isinstance(value, Enum):
|
||||
return str(value.value)
|
||||
return super().parse_value(value)
|
||||
|
||||
|
||||
class ActualEnumField(fields.StringField):
|
||||
def __init__(
|
||||
self,
|
||||
enum_class: Type[Enum],
|
||||
*args,
|
||||
validators=None,
|
||||
required=False,
|
||||
default=None,
|
||||
**kwargs
|
||||
):
|
||||
self.__enum = enum_class
|
||||
self.types = (enum_class,)
|
||||
# noinspection PyTypeChecker
|
||||
choices = list(enum_class)
|
||||
validator_cls = EnumValidator if required else NullableEnumValidator
|
||||
validators = [*(validators or []), validator_cls(*choices)]
|
||||
super().__init__(
|
||||
default=self.parse_value(default) if default else NotSet,
|
||||
*args,
|
||||
required=required,
|
||||
validators=validators,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def parse_value(self, value):
|
||||
if value is NotSet and not self.required:
|
||||
return self.get_default_value()
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
return self.__enum(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
def to_struct(self, value):
|
||||
return super().to_struct(value.value)
|
||||
|
||||
|
||||
class JsonSerializableMixin:
|
||||
def to_json(self: ModelBase):
|
||||
return dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type[ModelBase], s):
|
||||
return cls(**loads(s))
|
||||
|
||||
|
||||
def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]:
|
||||
class _Wrapped(cls):
|
||||
_callable_default = None
|
||||
|
||||
def get_default_value(self):
|
||||
if self._callable_default:
|
||||
return self._callable_default()
|
||||
return super(_Wrapped, self).get_default_value()
|
||||
|
||||
def __init__(self, *args, default=None, **kwargs):
|
||||
if default and callable(default):
|
||||
self._callable_default = default
|
||||
default = default()
|
||||
super(_Wrapped, self).__init__(*args, default=default, **kwargs)
|
||||
|
||||
return _Wrapped
|
||||
|
||||
|
||||
class MongoengineFieldsDict(DictField):
|
||||
"""
|
||||
DictField representing mongoengine field names/value mapping.
|
||||
Used to convert mongoengine-style field/subfield notation to user-presentable syntax, including handling update
|
||||
operators.
|
||||
"""
|
||||
|
||||
mongoengine_update_operators = (
|
||||
"inc",
|
||||
"dec",
|
||||
"push",
|
||||
"push_all",
|
||||
"pop",
|
||||
"pull",
|
||||
"pull_all",
|
||||
"add_to_set",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_mongo_value(value):
|
||||
if isinstance(value, BaseDocument):
|
||||
return value.to_mongo()
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _normalize_mongo_field_path(cls, path, value):
|
||||
parts = path.split("__")
|
||||
if len(parts) > 1:
|
||||
if parts[0] == "set":
|
||||
parts = parts[1:]
|
||||
elif parts[0] == "unset":
|
||||
parts = parts[1:]
|
||||
value = None
|
||||
elif parts[0] in cls.mongoengine_update_operators:
|
||||
return None, None
|
||||
return ".".join(parts), cls._normalize_mongo_value(value)
|
||||
|
||||
def parse_value(self, value):
|
||||
value = super(MongoengineFieldsDict, self).parse_value(value)
|
||||
return {
|
||||
k: v
|
||||
for k, v in (self._normalize_mongo_field_path(*p) for p in value.items())
|
||||
if k is not None
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField, DateTimeField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Max, Enum
|
||||
|
||||
from apimodels import ListField, EnumField
|
||||
from config import config
|
||||
from database.model.auth import Role
|
||||
from database.utils import get_options
|
||||
from apiserver.apimodels import ListField, EnumField
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.auth import Role
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
|
||||
class GetTokenRequest(Base):
|
||||
@@ -75,10 +75,17 @@ class CreateUserResponse(Base):
|
||||
class Credentials(Base):
|
||||
access_key = StringField(required=True)
|
||||
secret_key = StringField(required=True)
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CredentialsResponse(Credentials):
|
||||
secret_key = StringField()
|
||||
last_used = DateTimeField(default=None)
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsRequest(Base):
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Base):
|
||||
@@ -89,6 +96,11 @@ class GetCredentialsResponse(Base):
|
||||
credentials = ListField(CredentialsResponse)
|
||||
|
||||
|
||||
class EditCredentialsRequest(Base):
|
||||
access_key = StringField(required=True)
|
||||
label = StringField()
|
||||
|
||||
|
||||
class RevokeCredentialsRequest(Base):
|
||||
access_key = StringField(required=True)
|
||||
|
||||
28
apiserver/apimodels/base.py
Normal file
28
apiserver/apimodels/base.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from jsonmodels import models, fields
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels import MongoengineFieldsDict, ListField
|
||||
|
||||
|
||||
class UpdateResponse(models.Base):
|
||||
updated = fields.IntField(required=True)
|
||||
fields = MongoengineFieldsDict()
|
||||
|
||||
|
||||
class PagedRequest(models.Base):
|
||||
page = fields.IntField()
|
||||
page_size = fields.IntField()
|
||||
|
||||
|
||||
class IdResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MakePublicRequest(models.Base):
|
||||
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
|
||||
|
||||
|
||||
class MoveRequest(models.Base):
|
||||
ids = ListField([str], validators=Length(minimum_value=1))
|
||||
project = fields.StringField()
|
||||
project_name = fields.StringField()
|
||||
25
apiserver/apimodels/batch.py
Normal file
25
apiserver/apimodels/batch.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
|
||||
|
||||
class BatchRequest(Base):
|
||||
ids: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class BatchResponse(Base):
|
||||
succeeded: Sequence[dict] = ListField([dict])
|
||||
failed: Sequence[dict] = ListField([dict])
|
||||
|
||||
|
||||
class UpdateBatchItem(UpdateResponse):
|
||||
id: str = StringField()
|
||||
|
||||
|
||||
class UpdateBatchResponse(BatchResponse):
|
||||
succeeded: Sequence[UpdateBatchItem] = ListField(UpdateBatchItem)
|
||||
34
apiserver/apimodels/custom_validators/__init__.py
Normal file
34
apiserver/apimodels/custom_validators/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import validators
|
||||
from jsonmodels.errors import ValidationError
|
||||
|
||||
|
||||
class ForEach(object):
|
||||
def __init__(self, validator):
|
||||
self.validator = validator
|
||||
|
||||
def validate(self, values):
|
||||
for value in values:
|
||||
self.validator.validate(value)
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
return self.validator.modify_schema(field_schema)
|
||||
|
||||
|
||||
class Hostname(object):
|
||||
|
||||
def validate(self, value):
|
||||
if validators.domain(value) is not True:
|
||||
raise ValidationError(f"Value '{value}' is not a valid hostname")
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
field_schema["format"] = "hostname"
|
||||
|
||||
|
||||
class Email(object):
|
||||
|
||||
def validate(self, value):
|
||||
if validators.email(value) is not True:
|
||||
raise ValidationError(f"Value '{value}' is not a valid email address")
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
field_schema["format"] = "email"
|
||||
174
apiserver/apimodels/events.py
Normal file
174
apiserver/apimodels/events.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
from apiserver.apimodels import ListField, IntField, ActualEnumField
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
class MetricVariants(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str,
|
||||
validators=[
|
||||
Length(
|
||||
minimum_value=1,
|
||||
maximum_value=config.get(
|
||||
"services.tasks.multi_task_histogram_limit", 10
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class MetricEventsRequest(Base):
|
||||
metrics: Sequence[TaskMetric] = ListField(
|
||||
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField()
|
||||
|
||||
|
||||
class GetVariantSampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
variant: str = StringField(required=True)
|
||||
iteration: Optional[int] = IntField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetMetricSamplesRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
iteration: Optional[int] = IntField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class NextHistorySampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
next_iteration: bool = BoolField(default=False)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
asc = auto()
|
||||
desc = auto()
|
||||
|
||||
|
||||
class TaskEventsRequestBase(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
|
||||
|
||||
class TaskEventsRequest(TaskEventsRequestBase):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
|
||||
scroll_id: str = StringField()
|
||||
count_total: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogEventsRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField(default=5000)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField()
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
|
||||
count_total: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
|
||||
|
||||
class MetricEvents(Base):
|
||||
task: str = StringField()
|
||||
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
|
||||
|
||||
|
||||
class MetricEventsResponse(Base):
|
||||
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class MultiTasksRequestBase(Base):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class SingleValueMetricsRequest(MultiTasksRequestBase):
|
||||
pass
|
||||
|
||||
|
||||
class TaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
|
||||
class TaskPlotsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class ClearScrollRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class ClearTaskLogRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
threshold_sec = IntField()
|
||||
allow_locked = BoolField(default=False)
|
||||
34
apiserver/apimodels/login.py
Normal file
34
apiserver/apimodels/login.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from jsonmodels.fields import StringField, BoolField, EmbeddedField, ListField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import DictField, callable_default
|
||||
|
||||
|
||||
class GetSupportedModesRequest(Base):
|
||||
state = StringField(help_text="ASCII base64 encoded application state")
|
||||
callback_url_prefix = StringField()
|
||||
|
||||
|
||||
class BasicGuestMode(Base):
|
||||
enabled = BoolField(default=False)
|
||||
name = StringField()
|
||||
username = StringField()
|
||||
password = StringField()
|
||||
|
||||
|
||||
class BasicMode(Base):
|
||||
enabled = BoolField(default=False)
|
||||
guest = callable_default(EmbeddedField)(BasicGuestMode, default=BasicGuestMode)
|
||||
|
||||
|
||||
class ServerErrors(Base):
|
||||
missed_es_upgrade = BoolField(default=False)
|
||||
es_connection_error = BoolField(default=False)
|
||||
|
||||
|
||||
class GetSupportedModesResponse(Base):
|
||||
basic = EmbeddedField(BasicMode)
|
||||
server_errors = EmbeddedField(ServerErrors)
|
||||
sso = DictField([str, type(None)])
|
||||
sso_providers = ListField([dict])
|
||||
authenticated = BoolField(default=False)
|
||||
24
apiserver/apimodels/metadata.py
Normal file
24
apiserver/apimodels/metadata.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class MetadataItem(Base):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
|
||||
|
||||
class DeleteMetadata(Base):
|
||||
keys: Sequence[str] = ListField(str, validators=validators.Length(minimum_value=1))
|
||||
|
||||
|
||||
class AddOrUpdateMetadata(Base):
|
||||
metadata: Sequence[MetadataItem] = ListField(
|
||||
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||
)
|
||||
replace_metadata = BoolField(default=False)
|
||||
81
apiserver/apimodels/models.py
Normal file
81
apiserver/apimodels/models.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from jsonmodels import models, fields
|
||||
from six import string_types
|
||||
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.batch import BatchRequest
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
AddOrUpdateMetadata,
|
||||
)
|
||||
|
||||
|
||||
class GetFrameworksRequest(models.Base):
|
||||
projects = fields.ListField(items_types=[str])
|
||||
|
||||
|
||||
class CreateModelRequest(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
uri = fields.StringField(required=True)
|
||||
labels = DictField(value_types=string_types + (int,))
|
||||
tags = ListField(items_types=string_types)
|
||||
system_tags = ListField(items_types=string_types)
|
||||
comment = fields.StringField()
|
||||
public = fields.BoolField(default=False)
|
||||
project = fields.StringField()
|
||||
parent = fields.StringField()
|
||||
framework = fields.StringField()
|
||||
design = DictField()
|
||||
ready = fields.BoolField(default=True)
|
||||
ui_cache = DictField()
|
||||
task = fields.StringField()
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class CreateModelResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
created = fields.BoolField(required=True)
|
||||
|
||||
|
||||
class ModelRequest(models.Base):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class DeleteModelRequest(ModelRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class ModelsDeleteManyRequest(BatchRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class PublishModelRequest(ModelRequest):
|
||||
force_publish_task = fields.BoolField(default=False)
|
||||
publish_task = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ModelTaskPublishResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
data = fields.EmbeddedField(UpdateResponse)
|
||||
|
||||
|
||||
class PublishModelResponse(UpdateResponse):
|
||||
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
|
||||
|
||||
|
||||
class ModelsPublishManyRequest(BatchRequest):
|
||||
force_publish_task = fields.BoolField(default=False)
|
||||
publish_task = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class DeleteMetadataRequest(DeleteMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class ModelsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
23
apiserver/apimodels/organization.py
Normal file
23
apiserver/apimodels/organization.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from jsonmodels import fields, models
|
||||
|
||||
from apiserver.apimodels import DictField
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
tags = fields.ListField([str])
|
||||
system_tags = fields.ListField([str])
|
||||
|
||||
|
||||
class TagsRequest(models.Base):
|
||||
include_system = fields.BoolField(default=False)
|
||||
filter = fields.EmbeddedField(Filter)
|
||||
|
||||
|
||||
class EntitiesCountRequest(models.Base):
|
||||
projects = DictField()
|
||||
tasks = DictField()
|
||||
models = DictField()
|
||||
pipelines = DictField()
|
||||
datasets = DictField()
|
||||
active_users = fields.ListField(str)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
19
apiserver/apimodels/pipelines.py
Normal file
19
apiserver/apimodels/pipelines.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class Arg(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
value = fields.StringField(required=True)
|
||||
|
||||
|
||||
class StartPipelineRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
queue = fields.StringField(required=True)
|
||||
args = ListField(Arg)
|
||||
|
||||
|
||||
class StartPipelineResponse(models.Base):
|
||||
pipeline = fields.StringField(required=True)
|
||||
enqueued = fields.BoolField(required=True)
|
||||
69
apiserver/apimodels/projects.py
Normal file
69
apiserver/apimodels/projects.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
||||
from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
|
||||
|
||||
class ProjectRequest(models.Base):
|
||||
project = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MergeRequest(ProjectRequest):
|
||||
destination_project = fields.StringField()
|
||||
|
||||
|
||||
class MoveRequest(ProjectRequest):
|
||||
new_location = fields.StringField()
|
||||
|
||||
|
||||
class DeleteRequest(ProjectRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
delete_contents = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class ProjectOrNoneRequest(models.Base):
|
||||
project = fields.StringField()
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class GetParamsRequest(ProjectOrNoneRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
||||
class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
|
||||
class MultiProjectRequest(models.Base):
|
||||
projects = fields.ListField(str)
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(MultiProjectRequest):
|
||||
tasks_state = ActualEnumField(EntityVisibility)
|
||||
|
||||
|
||||
class ProjectHyperparamValuesRequest(MultiProjectRequest):
|
||||
section = fields.StringField(required=True)
|
||||
name = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
|
||||
key = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_dataset_stats = fields.BoolField(default=False)
|
||||
include_stats = fields.BoolField(default=False)
|
||||
include_stats_filter = DictField()
|
||||
stats_with_children = fields.BoolField(default=True)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
active_users = fields.ListField(str)
|
||||
check_own_contents = fields.BoolField(default=False)
|
||||
shallow_search = fields.BoolField(default=False)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
91
apiserver/apimodels/queues.py
Normal file
91
apiserver/apimodels/queues.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
AddOrUpdateMetadata,
|
||||
)
|
||||
|
||||
|
||||
class GetDefaultResp(Base):
|
||||
id = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
|
||||
|
||||
class CreateRequest(Base):
|
||||
name = StringField(required=True)
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class QueueRequest(Base):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class GetByIdRequest(QueueRequest):
|
||||
max_task_entries = IntField()
|
||||
|
||||
|
||||
class GetAllRequest(Base):
|
||||
max_task_entries = IntField()
|
||||
search_hidden = BoolField(default=False)
|
||||
|
||||
|
||||
class GetNextTaskRequest(QueueRequest):
|
||||
queue = StringField(required=True)
|
||||
get_task_info = BoolField(default=False)
|
||||
task = StringField()
|
||||
|
||||
|
||||
class DeleteRequest(QueueRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class UpdateRequest(QueueRequest):
|
||||
name = StringField()
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class TaskRequest(QueueRequest):
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class MoveTaskRequest(TaskRequest):
|
||||
count = IntField(default=1)
|
||||
|
||||
|
||||
class MoveTaskResponse(Base):
|
||||
position = IntField()
|
||||
|
||||
|
||||
class GetMetricsRequest(Base):
|
||||
queue_ids = ListField([str])
|
||||
from_date = FloatField(required=True, validators=validators.Min(0))
|
||||
to_date = FloatField(required=True, validators=validators.Min(0))
|
||||
interval = IntField(required=True, validators=validators.Min(1))
|
||||
refresh = BoolField(default=False)
|
||||
|
||||
|
||||
class QueueMetrics(Base):
|
||||
queue = StringField()
|
||||
dates = ListField(int)
|
||||
avg_waiting_times = ListField([float, int])
|
||||
queue_lengths = ListField(int)
|
||||
|
||||
|
||||
class GetMetricsResponse(Base):
|
||||
queues = ListField(QueueMetrics)
|
||||
|
||||
|
||||
class DeleteMetadataRequest(DeleteMetadata):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
queue = StringField(required=True)
|
||||
15
apiserver/apimodels/server.py
Normal file
15
apiserver/apimodels/server.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from jsonmodels.fields import BoolField, DateTimeField, StringField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
|
||||
class ReportStatsOptionRequest(Base):
|
||||
enabled = BoolField(default=None, nullable=True)
|
||||
|
||||
|
||||
class ReportStatsOptionResponse(Base):
|
||||
supported = BoolField(default=True)
|
||||
enabled = BoolField()
|
||||
enabled_time = DateTimeField(nullable=True)
|
||||
enabled_version = StringField(nullable=True)
|
||||
enabled_user = StringField(nullable=True)
|
||||
current_version = StringField()
|
||||
320
apiserver/apimodels/tasks.py
Normal file
320
apiserver/apimodels/tasks.py
Normal file
@@ -0,0 +1,320 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
|
||||
from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apiserver.apimodels import DictField, ListField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.batch import BatchRequest, UpdateBatchItem, BatchResponse
|
||||
from apiserver.database.model.task.task import (
|
||||
TaskType,
|
||||
ArtifactModes,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
|
||||
class ArtifactTypeData(models.Base):
|
||||
preview = StringField()
|
||||
content_type = StringField()
|
||||
data_hash = StringField()
|
||||
|
||||
|
||||
class Artifact(models.Base):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(
|
||||
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
|
||||
)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = IntField()
|
||||
timestamp = IntField()
|
||||
type_data = EmbeddedField(ArtifactTypeData)
|
||||
display_data = ListField([list])
|
||||
|
||||
|
||||
class StartedResponse(UpdateResponse):
|
||||
started = IntField()
|
||||
|
||||
|
||||
class EnqueueResponse(UpdateResponse):
|
||||
queued = IntField()
|
||||
queue_watched = BoolField()
|
||||
|
||||
|
||||
class EnqueueBatchItem(UpdateBatchItem):
|
||||
queued: bool = BoolField()
|
||||
|
||||
|
||||
class EnqueueManyResponse(BatchResponse):
|
||||
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
|
||||
queue_watched = BoolField()
|
||||
|
||||
|
||||
class DequeueResponse(UpdateResponse):
|
||||
dequeued = IntField()
|
||||
|
||||
|
||||
class DequeueBatchItem(UpdateBatchItem):
|
||||
dequeued: bool = BoolField()
|
||||
|
||||
|
||||
class DequeueManyResponse(BatchResponse):
|
||||
succeeded: Sequence[DequeueBatchItem] = ListField(DequeueBatchItem)
|
||||
|
||||
|
||||
class ResetResponse(UpdateResponse):
|
||||
dequeued = DictField()
|
||||
events = DictField()
|
||||
deleted_models = IntField()
|
||||
urls = DictField()
|
||||
|
||||
|
||||
class ResetBatchItem(UpdateBatchItem):
|
||||
dequeued: bool = BoolField()
|
||||
deleted_models = IntField()
|
||||
urls = DictField()
|
||||
|
||||
|
||||
class ResetManyResponse(BatchResponse):
|
||||
succeeded: Sequence[ResetBatchItem] = ListField(ResetBatchItem)
|
||||
|
||||
|
||||
class TaskRequest(models.Base):
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class TaskUpdateRequest(TaskRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class UpdateRequest(TaskUpdateRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
|
||||
|
||||
class EnqueueRequest(UpdateRequest):
|
||||
queue = StringField()
|
||||
queue_name = StringField()
|
||||
verify_watched_queue = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteRequest(UpdateRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class SetRequirementsRequest(TaskRequest):
|
||||
requirements = DictField(required=True)
|
||||
|
||||
|
||||
class CompletedRequest(UpdateRequest):
|
||||
publish = BoolField(default=False)
|
||||
|
||||
|
||||
class CompletedResponse(UpdateResponse):
|
||||
published = IntField(default=0)
|
||||
|
||||
|
||||
class PublishRequest(UpdateRequest):
|
||||
publish_model = BoolField(default=True)
|
||||
|
||||
|
||||
class TaskData(models.Base):
|
||||
"""
|
||||
This is a partial description of task can be updated incrementally
|
||||
"""
|
||||
|
||||
|
||||
class CreateRequest(TaskData):
|
||||
name = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskType)))
|
||||
|
||||
|
||||
class PingRequest(TaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class GetTypesRequest(models.Base):
|
||||
projects = ListField(items_types=[str])
|
||||
|
||||
|
||||
class TaskInputModel(models.Base):
|
||||
name = StringField()
|
||||
model = StringField()
|
||||
|
||||
|
||||
class CloneRequest(TaskRequest):
|
||||
new_task_name = StringField()
|
||||
new_task_comment = StringField()
|
||||
new_task_tags = ListField([str])
|
||||
new_task_system_tags = ListField([str])
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
new_task_hyperparams = DictField()
|
||||
new_task_configuration = DictField()
|
||||
new_task_container = DictField()
|
||||
new_task_input_models = ListField([TaskInputModel])
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
new_project_name = StringField()
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskUpdateRequest):
|
||||
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class ArtifactId(models.Base):
|
||||
key = StringField(required=True)
|
||||
mode = StringField(
|
||||
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
|
||||
)
|
||||
|
||||
|
||||
class DeleteArtifactsRequest(TaskUpdateRequest):
|
||||
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
tasks = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class GetHyperParamsRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class HyperParamItem(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ReplaceHyperparams(object):
|
||||
none = "none"
|
||||
section = "section"
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskUpdateRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_hyperparams = StringField(
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskUpdateRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
names = ListField([str])
|
||||
|
||||
|
||||
class GetConfigurationNamesRequest(MultiTaskRequest):
|
||||
skip_empty = BoolField(default=True)
|
||||
|
||||
|
||||
class Configuration(models.Base):
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskUpdateRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskUpdateRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class ArchiveRequest(MultiTaskRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
|
||||
|
||||
class ArchiveResponse(models.Base):
|
||||
archived = IntField()
|
||||
|
||||
|
||||
class TaskBatchRequest(BatchRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
|
||||
|
||||
class StopManyRequest(TaskBatchRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class EnqueueManyRequest(TaskBatchRequest):
|
||||
queue = StringField()
|
||||
queue_name = StringField()
|
||||
validate_tasks = BoolField(default=False)
|
||||
verify_watched_queue = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteManyRequest(TaskBatchRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class ResetManyRequest(TaskBatchRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class PublishManyRequest(TaskBatchRequest):
|
||||
publish_model = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class AddUpdateModelRequest(TaskRequest):
|
||||
name = StringField(required=True)
|
||||
model = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||
iteration = IntField()
|
||||
|
||||
|
||||
class ModelItemKey(models.Base):
|
||||
name = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||
|
||||
|
||||
class DeleteModelsRequest(TaskRequest):
|
||||
models: Sequence[ModelItemKey] = ListField(
|
||||
[ModelItemKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import DictField
|
||||
from apiserver.apimodels import DictField
|
||||
|
||||
|
||||
class CreateRequest(Base):
|
||||
182
apiserver/apimodels/workers.py
Normal file
182
apiserver/apimodels/workers.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from enum import Enum
|
||||
|
||||
import six
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import (
|
||||
StringField,
|
||||
EmbeddedField,
|
||||
DateTimeField,
|
||||
IntField,
|
||||
FloatField,
|
||||
BoolField,
|
||||
)
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
||||
|
||||
DEFAULT_TIMEOUT = 10 * 60
|
||||
|
||||
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
timeout = make_default(
|
||||
IntField, DEFAULT_TIMEOUT
|
||||
)() # registration timeout in seconds (default is 10min)
|
||||
queues = ListField(six.string_types) # list of queues this worker listens to
|
||||
|
||||
|
||||
class MachineStats(Base):
|
||||
cpu_usage = ListField(six.integer_types + (float,))
|
||||
cpu_temperature = ListField(six.integer_types + (float,))
|
||||
gpu_usage = ListField(six.integer_types + (float,))
|
||||
gpu_temperature = ListField(six.integer_types + (float,))
|
||||
gpu_memory_free = ListField(six.integer_types + (float,))
|
||||
gpu_memory_used = ListField(six.integer_types + (float,))
|
||||
memory_used = FloatField()
|
||||
memory_free = FloatField()
|
||||
network_tx = FloatField()
|
||||
network_rx = FloatField()
|
||||
disk_free_home = FloatField()
|
||||
disk_free_temp = FloatField()
|
||||
disk_read = FloatField()
|
||||
disk_write = FloatField()
|
||||
|
||||
|
||||
class StatusReportRequest(WorkerRequest):
|
||||
task = StringField() # task the worker is running on
|
||||
queue = StringField() # queue from which task was taken
|
||||
queues = ListField(
|
||||
str
|
||||
) # list of queues this worker listens to. if None, this will not update the worker's queues list.
|
||||
timestamp = IntField(required=True)
|
||||
machine_stats = EmbeddedField(MachineStats)
|
||||
|
||||
|
||||
class IdNameEntry(Base):
|
||||
id = StringField(required=True)
|
||||
name = StringField()
|
||||
|
||||
|
||||
class WorkerEntry(Base, JsonSerializableMixin):
|
||||
key = StringField() # not required due to migration issues
|
||||
id = StringField(required=True)
|
||||
user = EmbeddedField(IdNameEntry)
|
||||
company = EmbeddedField(IdNameEntry)
|
||||
ip = StringField()
|
||||
task = EmbeddedField(IdNameEntry)
|
||||
project = EmbeddedField(IdNameEntry)
|
||||
queue = StringField() # queue from which current task was taken
|
||||
queues = ListField(str) # list of queues this worker listens to
|
||||
register_time = DateTimeField(required=True)
|
||||
register_timeout = IntField(required=True)
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
running_time = IntField()
|
||||
last_iteration = IntField()
|
||||
|
||||
|
||||
class QueueEntry(IdNameEntry):
|
||||
next_task = EmbeddedField(IdNameEntry)
|
||||
num_tasks = IntField()
|
||||
|
||||
|
||||
class WorkerResponseEntry(WorkerEntry):
|
||||
task = EmbeddedField(CurrentTaskEntry)
|
||||
queue = EmbeddedField(QueueEntry)
|
||||
queues = ListField(QueueEntry)
|
||||
|
||||
|
||||
class GetAllRequest(Base):
|
||||
last_seen = IntField(default=3600)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class GetAllResponse(Base):
|
||||
workers = ListField(WorkerResponseEntry)
|
||||
|
||||
|
||||
class StatsBase(Base):
|
||||
worker_ids = ListField(str)
|
||||
|
||||
|
||||
class StatsReportBase(StatsBase):
|
||||
from_date = FloatField(required=True, validators=validators.Min(0))
|
||||
to_date = FloatField(required=True, validators=validators.Min(0))
|
||||
interval = IntField(required=True, validators=validators.Min(1))
|
||||
|
||||
|
||||
class AggregationType(Enum):
|
||||
avg = "avg"
|
||||
min = "min"
|
||||
max = "max"
|
||||
|
||||
|
||||
class StatItem(Base):
|
||||
key = StringField(required=True)
|
||||
aggregation = EnumField(AggregationType, default=AggregationType.avg)
|
||||
|
||||
|
||||
class GetStatsRequest(StatsReportBase):
|
||||
items = ListField(
|
||||
StatItem, required=True, validators=validators.Length(minimum_value=1)
|
||||
)
|
||||
split_by_variant = BoolField(default=False)
|
||||
|
||||
|
||||
class AggregationStats(Base):
|
||||
aggregation = EnumField(AggregationType)
|
||||
values = ListField(float)
|
||||
|
||||
|
||||
class MetricStats(Base):
|
||||
metric = StringField()
|
||||
variant = StringField()
|
||||
dates = ListField(int)
|
||||
stats = ListField(AggregationStats)
|
||||
|
||||
|
||||
class WorkerStatistics(Base):
|
||||
worker = StringField()
|
||||
metrics = ListField(MetricStats)
|
||||
|
||||
|
||||
class GetStatsResponse(Base):
|
||||
workers = ListField(WorkerStatistics)
|
||||
|
||||
|
||||
class GetMetricKeysRequest(StatsBase):
|
||||
pass
|
||||
|
||||
|
||||
class MetricCategory(Base):
|
||||
name = StringField()
|
||||
metric_keys = ListField(str)
|
||||
|
||||
|
||||
class GetMetricKeysResponse(Base):
|
||||
categories = ListField(MetricCategory)
|
||||
|
||||
|
||||
class GetActivityReportRequest(StatsReportBase):
|
||||
pass
|
||||
|
||||
|
||||
class ActivityReportSeries(Base):
|
||||
dates = ListField(int)
|
||||
counts = ListField(int)
|
||||
|
||||
|
||||
class GetActivityReportResponse(Base):
|
||||
total = EmbeddedField(ActivityReportSeries)
|
||||
active = EmbeddedField(ActivityReportSeries)
|
||||
@@ -1,25 +1,21 @@
|
||||
from datetime import datetime
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.auth import (
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.auth import (
|
||||
GetTokenResponse,
|
||||
CreateUserRequest,
|
||||
Credentials as CredModel,
|
||||
)
|
||||
from apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from bll.user import UserBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User, Role, Credentials
|
||||
from database.model.company import Company
|
||||
from service_repo import APICall
|
||||
from service_repo.auth import (
|
||||
Identity,
|
||||
Token,
|
||||
get_client_id,
|
||||
get_secret_key,
|
||||
)
|
||||
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_version, get_build_number
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import User, Role, Credentials
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.service_repo import APICall, ServiceRepo
|
||||
from apiserver.service_repo.auth import Identity, Token, get_client_id, get_secret_key
|
||||
|
||||
log = config.logger("AuthBLL")
|
||||
|
||||
@@ -62,9 +58,13 @@ class AuthBLL:
|
||||
identity=identity,
|
||||
entities=entities,
|
||||
expiration_sec=expiration_sec,
|
||||
api_version=str(ServiceRepo.max_endpoint_version()),
|
||||
server_version=str(get_version()),
|
||||
server_build=str(get_build_number()),
|
||||
feature_set="basic",
|
||||
)
|
||||
|
||||
return GetTokenResponse(token=token.decode("ascii"))
|
||||
return GetTokenResponse(token=token)
|
||||
|
||||
@staticmethod
|
||||
def create_user(request: CreateUserRequest, call: APICall = None) -> str:
|
||||
@@ -149,7 +149,7 @@ class AuthBLL:
|
||||
|
||||
@classmethod
|
||||
def create_credentials(
|
||||
cls, user_id: str, company_id: str, role: str = None
|
||||
cls, user_id: str, company_id: str, role: str = None, label: str = None,
|
||||
) -> CredModel:
|
||||
|
||||
with translate_errors_context():
|
||||
@@ -158,9 +158,11 @@ class AuthBLL:
|
||||
if not user:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
|
||||
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
|
||||
cred = CredModel(
|
||||
access_key=get_client_id(), secret_key=get_secret_key(), label=label
|
||||
)
|
||||
user.credentials.append(
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key)
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key, label=label)
|
||||
)
|
||||
user.save()
|
||||
|
||||
1198
apiserver/bll/event/event_bll.py
Normal file
1198
apiserver/bll/event/event_bll.py
Normal file
File diff suppressed because it is too large
Load Diff
161
apiserver/bll/event/event_common.py
Normal file
161
apiserver/bll/event/event_common.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import base64
|
||||
import zlib
|
||||
from enum import Enum
|
||||
from typing import Union, Sequence, Mapping, Tuple
|
||||
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
metrics_scalar = "training_stats_scalar"
|
||||
metrics_vector = "training_stats_vector"
|
||||
metrics_image = "training_debug_image"
|
||||
metrics_plot = "plot"
|
||||
task_log = "log"
|
||||
all = "*"
|
||||
|
||||
|
||||
SINGLE_SCALAR_ITERATION = -(2 ** 31)
|
||||
MetricVariants = Mapping[str, Sequence[str]]
|
||||
|
||||
|
||||
class EventSettings:
|
||||
_max_es_allowed_aggregation_buckets = 10000
|
||||
|
||||
@classproperty
|
||||
def max_workers(self):
|
||||
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
|
||||
|
||||
@classproperty
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def max_es_buckets(self):
|
||||
percentage = (
|
||||
min(
|
||||
100,
|
||||
config.get(
|
||||
"services.events.events_retrieval.dynamic_metrics_count_threshold",
|
||||
80,
|
||||
),
|
||||
)
|
||||
/ 100
|
||||
)
|
||||
return int(self._max_es_allowed_aggregation_buckets * percentage)
|
||||
|
||||
|
||||
def get_index_name(company_id: str, event_type: str):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return f"events-{event_type}-{company_id.lower()}"
|
||||
|
||||
|
||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
if not es.indices.exists(es_index):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def search_company_events(
|
||||
es: Elasticsearch,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
body: dict,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.search(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def delete_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(index=es_index, body=body, conflicts="proceed", **kwargs)
|
||||
|
||||
|
||||
def count_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.count(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def get_max_metric_and_variant_counts(
|
||||
es: Elasticsearch,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
query: dict,
|
||||
**kwargs,
|
||||
) -> Tuple[int, int]:
|
||||
dynamic = config.get(
|
||||
"services.events.events_retrieval.dynamic_metrics_count", False
|
||||
)
|
||||
max_metrics_count = config.get(
|
||||
"services.events.events_retrieval.max_metrics_count", 100
|
||||
)
|
||||
max_variants_count = config.get(
|
||||
"services.events.events_retrieval.max_variants_count", 100
|
||||
)
|
||||
if not dynamic:
|
||||
return max_metrics_count, max_variants_count
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
|
||||
)
|
||||
|
||||
metrics_count = safe_get(
|
||||
es_res, "aggregations/metrics_count/value", max_metrics_count
|
||||
)
|
||||
if not metrics_count:
|
||||
return max_metrics_count, max_variants_count
|
||||
|
||||
return metrics_count, int(EventSettings.max_es_buckets / metrics_count)
|
||||
|
||||
|
||||
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
|
||||
conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"terms": {"variant": variants}},
|
||||
]
|
||||
}
|
||||
}
|
||||
if variants
|
||||
else {"term": {"metric": metric}}
|
||||
for metric, variants in metric_variants.items()
|
||||
]
|
||||
|
||||
return {"bool": {"should": conditions}}
|
||||
|
||||
|
||||
class PlotFields:
|
||||
valid_plot = "valid_plot"
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
source_urls = "source_urls"
|
||||
|
||||
|
||||
def uncompress_plot(event: dict):
|
||||
plot_data = event.pop(PlotFields.plot_data, None)
|
||||
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||
event[PlotFields.plot_str] = zlib.decompress(
|
||||
base64.b64decode(plot_data)
|
||||
).decode()
|
||||
475
apiserver/bll/event/event_metrics.py
Normal file
475
apiserver/bll/event/event_metrics.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Mapping
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
search_company_events,
|
||||
check_empty_data,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
SINGLE_SCALAR_ITERATION,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_AGGS_ELEMENTS_COUNT = 50
|
||||
MAX_SAMPLE_BUCKETS = 6000
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
samples: int,
|
||||
key: ScalarKeyEnum,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
The amount of points in each histogram should not exceed
|
||||
the requested samples
|
||||
"""
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
self,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
samples=samples,
|
||||
field=key.field,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
interval_groups = self._group_task_metric_intervals(intervals)
|
||||
|
||||
get_scalar_average = partial(
|
||||
self._get_scalar_average,
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
key=key,
|
||||
)
|
||||
if run_parallel:
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(get_scalar_average, interval_groups)
|
||||
)
|
||||
else:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
get_scalar_average(group) for group in interval_groups
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
tasks: Sequence[Task],
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
):
|
||||
"""
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in tasks}
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
)
|
||||
task_ids = [t.id for t in tasks]
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
)
|
||||
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
res[metric_key][variant_key][task_id] = variant_data
|
||||
|
||||
return res
|
||||
|
||||
def get_task_single_value_metrics(
|
||||
self, company_id: str, tasks: Sequence[Task]
|
||||
) -> Mapping[str, dict]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
"""
|
||||
if check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
):
|
||||
return {}
|
||||
|
||||
task_ids = [t.id for t in tasks]
|
||||
task_events = self._get_task_single_value_metrics(company_id, task_ids)
|
||||
|
||||
def _get_value(event: dict):
|
||||
return {
|
||||
field: event.get(field)
|
||||
for field in ("metric", "variant", "value", "timestamp")
|
||||
}
|
||||
|
||||
return {
|
||||
task: [_get_value(e) for e in events]
|
||||
for task, events in bucketize(task_events, itemgetter("task")).items()
|
||||
}
|
||||
|
||||
def _get_task_single_value_metrics(
|
||||
self, company_id: str, task_ids: Sequence[str]
|
||||
) -> Sequence[dict]:
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"terms": {"task": task_ids}},
|
||||
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.metrics_scalar,
|
||||
)
|
||||
if not es_res["hits"]["total"]["value"]:
|
||||
return []
|
||||
|
||||
return [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@classmethod
|
||||
def _group_task_metric_intervals(
|
||||
cls, intervals: Sequence[MetricInterval]
|
||||
) -> Sequence[MetricIntervalGroup]:
|
||||
"""
|
||||
Group task metric intervals so that the following conditions are meat:
|
||||
- All the metrics in the same group have the same interval (with 10% rounding)
|
||||
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
|
||||
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
|
||||
"""
|
||||
metric_interval_groups = []
|
||||
interval_group = []
|
||||
group_interval_upper_bound = 0
|
||||
group_max_interval = 0
|
||||
group_samples = 0
|
||||
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
|
||||
if (
|
||||
interval > group_interval_upper_bound
|
||||
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
|
||||
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
|
||||
):
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
interval_group = []
|
||||
group_max_interval = interval
|
||||
group_interval_upper_bound = interval + int(interval * 0.1)
|
||||
group_samples = 0
|
||||
interval_group.append((metric, variant))
|
||||
group_samples += size
|
||||
group_max_interval = max(group_max_interval, interval)
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
|
||||
return metric_interval_groups
|
||||
|
||||
def _get_task_metric_intervals(
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
task_id: str,
|
||||
samples: int,
|
||||
field: str = "iter",
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
amount of points does not exceed sample.
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
must = self._task_conditions(task_id)
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return []
|
||||
|
||||
return [
|
||||
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _build_metric_interval(
|
||||
metric: str, variant: str, data: dict, samples: int
|
||||
) -> Tuple[str, str, int, int]:
|
||||
"""
|
||||
Calculate index interval per metric_variant variant so that the
|
||||
total amount of intervals does not exceeds the samples
|
||||
Return the interval and resulting amount of intervals
|
||||
"""
|
||||
count = safe_get(data, "count/value", default=0)
|
||||
if count < samples:
|
||||
return metric, variant, 1, count
|
||||
|
||||
min_index = safe_get(data, "min_index/value", default=0)
|
||||
max_index = safe_get(data, "max_index/value", default=min_index)
|
||||
index_range = max_index - min_index + 1
|
||||
interval = max(1, math.ceil(float(index_range) / samples))
|
||||
max_samples = math.ceil(float(index_range) / interval)
|
||||
return (
|
||||
metric,
|
||||
variant,
|
||||
interval,
|
||||
max_samples,
|
||||
)
|
||||
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _get_scalar_average(
|
||||
self,
|
||||
metrics_interval: MetricIntervalGroup,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
"""
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
metrics = [
|
||||
(
|
||||
metric["key"],
|
||||
{
|
||||
variant["key"]: {
|
||||
"name": variant["key"],
|
||||
**key.get_iterations_data(variant),
|
||||
}
|
||||
for variant in metric["variants"]["buckets"]
|
||||
},
|
||||
)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
]
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _add_aggregation_average(aggregation):
|
||||
average_agg = {"avg_val": {"avg": {"field": "value"}}}
|
||||
return {
|
||||
key: {**value, "aggs": {**value.get("aggs", {}), **average_agg}}
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _task_conditions(task_id: str) -> list:
|
||||
return [
|
||||
{"term": {"task": task_id}},
|
||||
{"range": {"iter": {"gt": SINGLE_SCALAR_ITERATION}}},
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_task_metrics_query(
|
||||
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
|
||||
):
|
||||
must = cls._task_conditions(task_id)
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, variant in metrics
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
return {"bool": {"must": must}}
|
||||
|
||||
def get_task_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
"""
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return {}
|
||||
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
res = pool.map(
|
||||
partial(
|
||||
self._get_task_metrics,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
),
|
||||
task_ids,
|
||||
)
|
||||
return list(zip(task_ids, res))
|
||||
|
||||
def _get_task_metrics(
|
||||
self, task_id: str, company_id: str, event_type: EventType
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": self._task_conditions(task_id)}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_es_buckets,
|
||||
"order": {"_key": "asc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
|
||||
]
|
||||
197
apiserver/bll/event/events_iterator.py
Normal file
197
apiserver/bll/event/events_iterator.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from typing import Optional, Tuple, Sequence, Any
|
||||
|
||||
import attr
|
||||
import jsonmodels.models
|
||||
import jwt
|
||||
from elasticsearch import Elasticsearch
|
||||
from jwt.algorithms import get_default_algorithms
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
count_company_events,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class EventsIterator:
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_key_value: Optional[Any] = None,
|
||||
metric_variants: MetricVariants = None,
|
||||
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
|
||||
**kwargs,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
from_key_value = kwargs.pop("from_timestamp", from_key_value)
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
event_type=event_type,
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=ScalarKey.resolve(key),
|
||||
)
|
||||
return res
|
||||
|
||||
def count_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> int:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return 0
|
||||
|
||||
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
es_req = {
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_result = count_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
return es_result["count"]
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
key: ScalarKey,
|
||||
from_key_value: Optional[Any],
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If from_key_field is not set then start either from latest or earliest.
|
||||
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
|
||||
so that events with this value will not be lost between the calls.
|
||||
"""
|
||||
query, must = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": query,
|
||||
"sort": {key.field: "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_key_value:
|
||||
es_req["search_after"] = [from_key_value]
|
||||
|
||||
with translate_errors_context():
|
||||
es_result = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": must + [{"term": {key.field: events[-1][key.field]}}]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_initial_query_and_must(
|
||||
task_id: str, metric_variants: MetricVariants = None
|
||||
) -> Tuple[dict, list]:
|
||||
if not metric_variants:
|
||||
must = [{"term": {"task": task_id}}]
|
||||
query = {"term": {"task": task_id}}
|
||||
else:
|
||||
must = [
|
||||
{"term": {"task": task_id}},
|
||||
get_metric_variants_condition(metric_variants),
|
||||
]
|
||||
query = {"bool": {"must": must}}
|
||||
return query, must
|
||||
|
||||
|
||||
class Scroll(jsonmodels.models.Base):
|
||||
def get_scroll_id(self) -> str:
|
||||
return jwt.encode(
|
||||
self.to_struct(),
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_scroll_id(cls, scroll_id: str):
|
||||
try:
|
||||
return cls(
|
||||
**jwt.decode(
|
||||
scroll_id,
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
algorithms=get_default_algorithms(),
|
||||
)
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
raise ValueError("Invalid Scroll ID")
|
||||
455
apiserver/bll/event/history_debug_image_iterator.py
Normal file
455
apiserver/bll/event/history_debug_image_iterator.py
Normal file
@@ -0,0 +1,455 @@
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first, bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, ListField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class DebugImageSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class VariantSampleResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryDebugImageIterator:
|
||||
event_type = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for next/prev variant on the current iteration
|
||||
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = VariantSampleResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if next_iteration:
|
||||
event = self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
else:
|
||||
# noinspection PyArgumentList
|
||||
event = first(
|
||||
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
|
||||
for f in (
|
||||
self._get_next_for_current_iteration,
|
||||
self._get_next_for_another_iteration,
|
||||
)
|
||||
)
|
||||
if not event:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(event=event, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
event: dict, res: VariantSampleResult, state: DebugImageSampleState
|
||||
):
|
||||
state.variant = event["variant"]
|
||||
state.metric = event["metric"]
|
||||
state.iteration = event["iter"]
|
||||
res.event = event
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == state.variant and vs.metric == state.metric
|
||||
)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
@staticmethod
|
||||
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
|
||||
metrics = bucketize(variants, key=attrgetter("metric"))
|
||||
|
||||
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"gte": v.min_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in metric_variants
|
||||
]
|
||||
return {"bool": {"should": variants_conditions}}
|
||||
|
||||
metrics_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
_get_variants_conditions(metric_variants),
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, metric_variants in metrics.items()
|
||||
]
|
||||
return {"bool": {"should": metrics_conditions}}
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or sample is found
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
order = "desc" if navigate_earlier else "asc"
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the sample falls in invalid range are discarded
|
||||
If no suitable sample is found then None is returned
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = variants
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"range": {"iter": {range_operator: state.iteration}}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_sample_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = VariantSampleResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: DebugImageSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: DebugImageSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
# fix old variant states:
|
||||
for vs in state_.variant_states:
|
||||
if vs.metric is None:
|
||||
vs.metric = metric
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: DebugImageSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == variant and vs.metric == metric
|
||||
)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
event=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
|
||||
metrics = self._get_metric_variant_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(
|
||||
metric=metric,
|
||||
name=var_name,
|
||||
min_iteration=min_iter,
|
||||
max_iteration=max_iter,
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for var_name, min_iter, max_iter in variants
|
||||
]
|
||||
|
||||
def _get_metric_variant_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"urls": {
|
||||
# group by urls and choose the minimal iteration
|
||||
# from all the maximal iterations per url
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "asc"},
|
||||
"size": 1,
|
||||
},
|
||||
"aggs": {
|
||||
# find max iteration for each url
|
||||
"max_iter": {"max": {"field": "iter"}}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
||||
min_iter = int(urls[0]["max_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
||||
]
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
316
apiserver/bll/event/history_plots_iterator.py
Normal file
316
apiserver/bll/event/history_plots_iterator.py
Normal file
@@ -0,0 +1,316 @@
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, ListField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import (
|
||||
EventType,
|
||||
uncompress_plot,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class MetricState(Base):
|
||||
name: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class PlotsSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
metric_states: Sequence[MetricState] = ListField([MetricState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MetricSamplesResult(object):
|
||||
scroll_id: str = None
|
||||
events: list = []
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryPlotsIterator:
|
||||
event_type = EventType.metrics_plot
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=PlotsSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the samples for next/prev metric on the current iteration
|
||||
If does not exist then try getting sample for the first/last metric from next/prev iteration
|
||||
"""
|
||||
res = MetricSamplesResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
]
|
||||
if state.navigate_current_metric:
|
||||
must_conditions.append({"term": {"metric": state.metric}})
|
||||
|
||||
next_iteration_condition = {
|
||||
"range": {"iter": {range_operator: state.iteration}}
|
||||
}
|
||||
if next_iteration or state.navigate_current_metric:
|
||||
must_conditions.append(next_iteration_condition)
|
||||
else:
|
||||
next_metric_condition = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"iter": state.iteration}},
|
||||
{"range": {"metric": {range_operator: state.metric}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
must_conditions.append(
|
||||
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
|
||||
)
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order=order,
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def get_samples_for_metric(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = MetricSamplesResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: PlotsSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: PlotsSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
state: PlotsSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
|
||||
if not metric_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append({"range": {"iter": {"lte": iteration}}})
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order="desc",
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
return res
|
||||
|
||||
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
|
||||
metrics = self._get_metric_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.metric_states = [
|
||||
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
|
||||
for metric, (min_iter, max_iter) in metrics.items()
|
||||
]
|
||||
|
||||
def _get_metric_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Tuple[int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 5000,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"first_iter": {"min": {"field": "iter"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
)
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: (
|
||||
int(metric_bucket["first_iter"]["value"]),
|
||||
int(metric_bucket["last_iter"]["value"]),
|
||||
)
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
|
||||
):
|
||||
for event in events:
|
||||
uncompress_plot(event)
|
||||
state.metric = events[0]["metric"]
|
||||
state.iteration = events[0]["iter"]
|
||||
res.events = events
|
||||
metric_state = first(
|
||||
ms for ms in state.metric_states if ms.name == state.metric
|
||||
)
|
||||
if metric_state:
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
def _get_metric_events_for_condition(
|
||||
self, company_id: str, task: str, order: str, must_conditions: Sequence
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 1,
|
||||
"order": {"_key": order},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"variant": {"order": "asc"}},
|
||||
"size": 100,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return []
|
||||
|
||||
for level in ("iters", "metrics"):
|
||||
level_data = aggs_result[level]["buckets"]
|
||||
if not level_data:
|
||||
return []
|
||||
aggs_result = level_data[0]
|
||||
|
||||
return [
|
||||
hit["_source"]
|
||||
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
|
||||
]
|
||||
53
apiserver/bll/event/metric_debug_images_iterator.py
Normal file
53
apiserver/bll/event/metric_debug_images_iterator.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Sequence, Tuple, Callable
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .event_common import EventType
|
||||
from .metric_events_iterator import MetricEventsIterator, VariantState
|
||||
|
||||
|
||||
class MetricDebugImagesIterator(MetricEventsIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_image)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return [{"exists": {"field": "url"}}]
|
||||
|
||||
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
|
||||
aggs = {
|
||||
"urls": {
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "desc"},
|
||||
"size": 1, # we need only one url from the most recent iteration
|
||||
},
|
||||
"aggs": {
|
||||
"max_iter": {"max": {"field": "iter"}},
|
||||
"iters": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": 2, # need two last iterations so that we can take
|
||||
# the second one as invalid
|
||||
"_source": "iter",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def fill_variant_state_data(variant_bucket: dict, state: VariantState):
|
||||
"""If the image urls get recycled then fill the last_invalid_iteration field"""
|
||||
top_iter_url = nested_get(variant_bucket, ("urls", "buckets"))[0]
|
||||
iters = nested_get(top_iter_url, ("iters", "hits", "hits"))
|
||||
if len(iters) > 1:
|
||||
state.last_invalid_iteration = nested_get(iters[1], ("_source", "iter"))
|
||||
|
||||
return aggs, fill_variant_state_data
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
return event
|
||||
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
return {"url": {"order": "desc"}}
|
||||
442
apiserver/bll/event/metric_events_iterator.py
Normal file
442
apiserver/bll/event/metric_events_iterator.py
Normal file
@@ -0,0 +1,442 @@
|
||||
import abc
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Callable
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.metrics import MetricEventStats
|
||||
from apiserver.database.model.task.task import Task
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
variant: str = StringField(required=True)
|
||||
last_invalid_iteration: int = IntField()
|
||||
|
||||
|
||||
class MetricState(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[VariantState] = ListField([VariantState], required=True)
|
||||
timestamp: int = IntField(default=0)
|
||||
|
||||
|
||||
class TaskScrollState(Base):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
|
||||
last_min_iter: Optional[int] = IntField()
|
||||
last_max_iter: Optional[int] = IntField()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state for the metric"""
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class MetricEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MetricEventsResult(object):
|
||||
metric_events: Sequence[tuple] = []
|
||||
next_scroll_id: str = None
|
||||
|
||||
|
||||
class MetricEventsIterator:
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
|
||||
self.es = es
|
||||
self.event_type = event_type
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=MetricEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_metrics: Mapping[str, dict],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> MetricEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.event_type):
|
||||
return MetricEventsResult()
|
||||
|
||||
def init_state(state_: MetricEventsScrollState):
|
||||
state_.tasks = self._init_task_states(company_id, task_metrics)
|
||||
|
||||
def validate_state(state_: MetricEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
if refresh:
|
||||
self._reinit_outdated_task_states(company_id, state_, task_metrics)
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = MetricEventsResult(next_scroll_id=state.id)
|
||||
specific_variants_requested = any(
|
||||
variants
|
||||
for t, metrics in task_metrics.items()
|
||||
if metrics
|
||||
for m, variants in metrics.items()
|
||||
)
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_metric_events,
|
||||
company_id=company_id,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
specific_variants_requested=specific_variants_requested,
|
||||
),
|
||||
state.tasks,
|
||||
)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _reinit_outdated_task_states(
|
||||
self,
|
||||
company_id,
|
||||
state: MetricEventsScrollState,
|
||||
task_metrics: Mapping[str, dict],
|
||||
):
|
||||
"""
|
||||
Determine the metrics for which new event_type events were added
|
||||
since their states were initialized and re-init these states
|
||||
"""
|
||||
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
|
||||
def get_last_update_times_for_task_metrics(
|
||||
task: Task,
|
||||
) -> Mapping[str, datetime]:
|
||||
"""For metrics that reported event_type events get mapping of the metric name to the last update times"""
|
||||
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||
if not metric_stats:
|
||||
return {}
|
||||
|
||||
requested_metrics = task_metrics[task.id]
|
||||
return {
|
||||
stats.metric: stats.event_stats_by_type[
|
||||
self.event_type.value
|
||||
].last_update
|
||||
for stats in metric_stats.values()
|
||||
if self.event_type.value in stats.event_stats_by_type
|
||||
and (not requested_metrics or stats.metric in requested_metrics)
|
||||
}
|
||||
|
||||
update_times = {
|
||||
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
}
|
||||
task_metric_states = {
|
||||
task_state.task: {
|
||||
metric_state.metric: metric_state for metric_state in task_state.metrics
|
||||
}
|
||||
for task_state in state.tasks
|
||||
}
|
||||
task_metrics_to_recalc = {}
|
||||
for task, metrics_times in update_times.items():
|
||||
old_metric_states = task_metric_states[task]
|
||||
metrics_to_recalc = {
|
||||
m: task_metrics[task].get(m)
|
||||
for m, t in metrics_times.items()
|
||||
if m not in old_metric_states or old_metric_states[m].timestamp < t
|
||||
}
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
|
||||
|
||||
def merge_with_updated_task_states(
|
||||
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
|
||||
) -> TaskScrollState:
|
||||
task = old_state.task
|
||||
updated_state = first(uts for uts in updates if uts.task == task)
|
||||
if not updated_state:
|
||||
old_state.reset()
|
||||
return old_state
|
||||
|
||||
updated_metrics = [m.metric for m in updated_state.metrics]
|
||||
return TaskScrollState(
|
||||
task=task,
|
||||
metrics=[
|
||||
*updated_state.metrics,
|
||||
*(
|
||||
old_metric
|
||||
for old_metric in old_state.metrics
|
||||
if old_metric.metric not in updated_metrics
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
state.tasks = [
|
||||
merge_with_updated_task_states(task_state, updated_task_states)
|
||||
for task_state in state.tasks
|
||||
]
|
||||
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, dict]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
task_metric_states = pool.map(
|
||||
partial(self._init_metric_states_for_task, company_id=company_id),
|
||||
task_metrics.items(),
|
||||
)
|
||||
|
||||
return [
|
||||
TaskScrollState(task=task, metrics=metric_states,)
|
||||
for task, metric_states in zip(task_metrics, task_metric_states)
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_variant_state_aggs(
|
||||
self,
|
||||
) -> Tuple[dict, Callable[[dict, VariantState], None]]:
|
||||
pass
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, dict], company_id: str
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any event_type events
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
|
||||
if metrics:
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
variant_state_aggs, fill_variant_state_data = self._get_variant_state_aggs()
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_event_timestamp": {"max": {"field": "timestamp"}},
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
**(
|
||||
{"aggs": variant_state_aggs}
|
||||
if variant_state_aggs
|
||||
else {}
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
def init_variant_state(variant: dict):
|
||||
"""
|
||||
Return new variant state for the passed variant bucket
|
||||
"""
|
||||
state = VariantState(variant=variant["key"])
|
||||
if fill_variant_state_data:
|
||||
fill_variant_state_data(variant, state)
|
||||
|
||||
return state
|
||||
|
||||
return [
|
||||
MetricState(
|
||||
metric=metric["key"],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
variants=[
|
||||
init_variant_state(variant)
|
||||
for variant in dpath.get(metric, "variants/buckets")
|
||||
],
|
||||
)
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
pass
|
||||
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
task_state: TaskScrollState,
|
||||
company_id: str,
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
specific_variants_requested: bool,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Return task metric events grouped by iterations
|
||||
Update task scroll state
|
||||
"""
|
||||
if not task_state.metrics:
|
||||
return task_state.task, []
|
||||
|
||||
if task_state.last_max_iter is None:
|
||||
# the first fetch is always from the latest iteration to the earlier ones
|
||||
navigate_earlier = True
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task_state.task}},
|
||||
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
|
||||
range_condition = None
|
||||
if navigate_earlier and task_state.last_min_iter is not None:
|
||||
range_condition = {"lt": task_state.last_min_iter}
|
||||
elif not navigate_earlier and task_state.last_max_iter is not None:
|
||||
range_condition = {"gt": task_state.last_max_iter}
|
||||
if range_condition:
|
||||
must_conditions.append({"range": {"iter": range_condition}})
|
||||
|
||||
metrics_count = len(task_state.metrics)
|
||||
max_variants = int(EventSettings.max_es_buckets / (metrics_count * iter_count))
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": self._get_same_variant_events_order()
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return task_state.task, []
|
||||
|
||||
invalid_iterations = {
|
||||
(m.metric, v.variant): v.last_invalid_iteration
|
||||
for m in task_state.metrics
|
||||
for v in m.variants
|
||||
}
|
||||
allow_uninitialized = (
|
||||
False
|
||||
if specific_variants_requested
|
||||
else config.get(
|
||||
"services.events.events_retrieval.debug_images.allow_uninitialized_variants",
|
||||
False,
|
||||
)
|
||||
)
|
||||
|
||||
def is_valid_event(event: dict) -> bool:
|
||||
key = event.get("metric"), event.get("variant")
|
||||
if key not in invalid_iterations:
|
||||
return allow_uninitialized
|
||||
|
||||
max_invalid = invalid_iterations[key]
|
||||
return max_invalid is None or event.get("iter") > max_invalid
|
||||
|
||||
def get_iteration_events(it_: dict) -> Sequence:
|
||||
return [
|
||||
self._process_event(ev["_source"])
|
||||
for m in dpath.get(it_, "metrics/buckets")
|
||||
for v in dpath.get(m, "variants/buckets")
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
if is_valid_event(ev["_source"])
|
||||
]
|
||||
|
||||
iterations = []
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets"):
|
||||
events = get_iteration_events(it)
|
||||
if events:
|
||||
iterations.append({"iter": it["key"], "events": events})
|
||||
|
||||
if not navigate_earlier:
|
||||
iterations.sort(key=itemgetter("iter"), reverse=True)
|
||||
if iterations:
|
||||
task_state.last_max_iter = iterations[0]["iter"]
|
||||
task_state.last_min_iter = iterations[-1]["iter"]
|
||||
|
||||
return task_state.task, iterations
|
||||
25
apiserver/bll/event/metric_plots_iterator.py
Normal file
25
apiserver/bll/event/metric_plots_iterator.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Sequence
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import EventType, uncompress_plot
|
||||
from .metric_events_iterator import MetricEventsIterator
|
||||
|
||||
|
||||
class MetricPlotsIterator(MetricEventsIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_plot)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return []
|
||||
|
||||
def _get_variant_state_aggs(self):
|
||||
return None, None
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
uncompress_plot(event)
|
||||
return event
|
||||
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
return {"timestamp": {"order": "desc"}}
|
||||
173
apiserver/bll/event/scalar_key.py
Normal file
173
apiserver/bll/event/scalar_key.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Module for polymorphism over different types of X axes in scalar aggregations
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from typing import Any
|
||||
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ScalarKeyEnum(StringEnum):
|
||||
"""
|
||||
String enum representing X axes key
|
||||
"""
|
||||
|
||||
iter = auto()
|
||||
timestamp = auto()
|
||||
iso_time = auto()
|
||||
|
||||
|
||||
class ScalarKey(ABC):
|
||||
"""
|
||||
Abstract scalar key
|
||||
"""
|
||||
|
||||
_enum_to_key = {}
|
||||
bucket_key_key = "key"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def enum_value(self) -> ScalarKeyEnum:
|
||||
"""
|
||||
Enum value accepted in API requests
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Key name. Used as arbitrary internal key in elasticsearch queries
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def field(self) -> str:
|
||||
"""
|
||||
Event key to aggregate by
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
"""
|
||||
Get aggregation for this type of key
|
||||
:param interval: elasticsearch aggregation interval
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Save a mapping from enum values to key class
|
||||
"""
|
||||
if cls.enum_value not in ScalarKeyEnum:
|
||||
raise ValueError(f"{cls.enum_value!r} not in {ScalarKeyEnum.__name__}")
|
||||
if cls.enum_value in cls._enum_to_key:
|
||||
log.warning(
|
||||
f"'{cls.enum_value.value}' is already registered to {ScalarKey.__name__}"
|
||||
)
|
||||
cls._enum_to_key[cls.enum_value] = cls
|
||||
|
||||
@classmethod
|
||||
def resolve(cls, key: ScalarKeyEnum):
|
||||
"""
|
||||
Create a key instance from enum instance
|
||||
"""
|
||||
return cls._enum_to_key[key]()
|
||||
|
||||
def get_iterations_data(self, iter_buckets: dict) -> dict:
|
||||
"""
|
||||
Convert a list of bucket entries to `x`s array and `y`s array
|
||||
"""
|
||||
return extract_properties_to_lists(
|
||||
("x", "y"),
|
||||
iter_buckets[self.name]["buckets"],
|
||||
self._get_iterations_data_single,
|
||||
)
|
||||
|
||||
def _get_iterations_data_single(self, iter_data):
|
||||
"""
|
||||
Extract x value and y value from a single bucket item
|
||||
"""
|
||||
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
|
||||
|
||||
def cast_value(self, value: Any) -> Any:
|
||||
"""Cast value to appropriate type"""
|
||||
return value
|
||||
|
||||
|
||||
class TimestampKey(ScalarKey):
|
||||
"""
|
||||
Aggregate by timestamp in milliseconds since epoch
|
||||
"""
|
||||
|
||||
name = "timestamp"
|
||||
field = "timestamp"
|
||||
enum_value = ScalarKeyEnum.timestamp
|
||||
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class IterKey(ScalarKey):
|
||||
"""
|
||||
Aggregate by iteration number
|
||||
"""
|
||||
|
||||
name = "iters"
|
||||
field = "iter"
|
||||
enum_value = ScalarKeyEnum.iter
|
||||
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"histogram": {"field": "iter", "interval": interval, "min_doc_count": 1}
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class ISOTimeKey(ScalarKey):
|
||||
"""
|
||||
Aggregate by time formatted as ISO strings
|
||||
"""
|
||||
|
||||
name = "iso_time"
|
||||
field = "timestamp"
|
||||
enum_value = ScalarKeyEnum.iso_time
|
||||
bucket_key_key = "key_as_string"
|
||||
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
"format": "strict_date_time",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _get_iterations_data_single(self, iter_data):
|
||||
return iter_data[self.bucket_key_key], iter_data["avg_val"]["value"]
|
||||
189
apiserver/bll/model/__init__.py
Normal file
189
apiserver/bll/model/__init__.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from datetime import datetime
|
||||
from typing import Callable, Tuple, Sequence, Dict, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.models import ModelTaskPublishResponse
|
||||
from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from .metadata import Metadata
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
@classmethod
|
||||
def get_company_model_by_id(
|
||||
cls, company_id: str, model_id: str, only_fields=None
|
||||
) -> Model:
|
||||
query = dict(company=company_id, id=model_id)
|
||||
qs = Model.objects(**query)
|
||||
if only_fields:
|
||||
qs = qs.only(*only_fields)
|
||||
model = qs.first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(
|
||||
company_id,
|
||||
model_ids,
|
||||
only=None,
|
||||
allow_public=False,
|
||||
return_models=True,
|
||||
) -> Optional[Sequence[Model]]:
|
||||
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
|
||||
ids = set(model_ids)
|
||||
query = Q(id__in=ids)
|
||||
|
||||
q = Model.get_many(
|
||||
company=company_id,
|
||||
query=query,
|
||||
allow_public=allow_public,
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
q = q.only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidModelId(ids=model_ids)
|
||||
|
||||
if return_models:
|
||||
return list(q)
|
||||
|
||||
@classmethod
|
||||
def publish_model(
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
force_publish_task: bool = False,
|
||||
publish_task_func: Callable[[str, str, bool], dict] = None,
|
||||
) -> Tuple[int, ModelTaskPublishResponse]:
|
||||
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
if model.ready:
|
||||
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
|
||||
|
||||
published_task = None
|
||||
if model.task and publish_task_func:
|
||||
task = (
|
||||
Task.objects(id=model.task, company=company_id)
|
||||
.only("id", "status")
|
||||
.first()
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
task_publish_res = publish_task_func(
|
||||
model.task, company_id, force_publish_task
|
||||
)
|
||||
published_task = ModelTaskPublishResponse(
|
||||
id=model.task, data=task_publish_res
|
||||
)
|
||||
|
||||
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
|
||||
return updated, published_task
|
||||
|
||||
@classmethod
|
||||
def delete_model(
|
||||
cls, model_id: str, company_id: str, force: bool
|
||||
) -> Tuple[int, Model]:
|
||||
model = cls.get_company_model_by_id(
|
||||
company_id=company_id,
|
||||
model_id=model_id,
|
||||
only_fields=("id", "task", "project", "uri"),
|
||||
)
|
||||
deleted_model_id = f"{deleted_prefix}{model_id}"
|
||||
|
||||
using_tasks = Task.objects(models__input__model=model_id).only("id")
|
||||
if using_tasks:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelInUse(
|
||||
"as execution model, use force=True to delete",
|
||||
num_tasks=len(using_tasks),
|
||||
)
|
||||
# update deleted model id in using tasks
|
||||
Task._get_collection().update_many(
|
||||
filter={"_id": {"$in": [t.id for t in using_tasks]}},
|
||||
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
|
||||
array_filters=[{"elem.model": model_id}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
if model.task:
|
||||
task = Task.objects(id=model.task).first()
|
||||
if task and task.status == TaskStatus.published:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelCreatingTaskExists(
|
||||
"and published, use force=True to delete", task=model.task
|
||||
)
|
||||
if task.models.output and model_id in task.models.output:
|
||||
now = datetime.utcnow()
|
||||
Task._get_collection().update_one(
|
||||
filter={"_id": model.task, "models.output.model": model_id},
|
||||
update={
|
||||
"$set": {
|
||||
"models.output.$[elem].model": deleted_model_id,
|
||||
"output.error": f"model deleted on {now.isoformat()}",
|
||||
},
|
||||
"last_change": now,
|
||||
},
|
||||
array_filters=[{"elem.model": model_id}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
del_count = Model.objects(id=model_id, company=company_id).delete()
|
||||
return del_count, model
|
||||
|
||||
@classmethod
|
||||
def archive_model(cls, model_id: str, company_id: str):
|
||||
cls.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
archived = Model.objects(company=company_id, id=model_id).update(
|
||||
add_to_set__system_tags=EntityVisibility.archived.value,
|
||||
last_update=datetime.utcnow(),
|
||||
)
|
||||
|
||||
return archived
|
||||
|
||||
@classmethod
|
||||
def unarchive_model(cls, model_id: str, company_id: str):
|
||||
cls.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
unarchived = Model.objects(company=company_id, id=model_id).update(
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_update=datetime.utcnow(),
|
||||
)
|
||||
|
||||
return unarchived
|
||||
|
||||
@classmethod
|
||||
def get_model_stats(
|
||||
cls, company: str, model_ids: Sequence[str],
|
||||
) -> Dict[str, dict]:
|
||||
if not model_ids:
|
||||
return {}
|
||||
|
||||
result = Model.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"_id": {"$in": model_ids},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$addFields": {
|
||||
"labels_count": {"$size": {"$objectToArray": "$labels"}}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {"labels_count": 1},
|
||||
},
|
||||
]
|
||||
)
|
||||
return {
|
||||
r.pop("_id"): r
|
||||
for r in result
|
||||
}
|
||||
108
apiserver/bll/model/metadata.py
Normal file
108
apiserver/bll/model/metadata.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Sequence, Union, Mapping
|
||||
|
||||
from mongoengine import Document
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Metadata:
|
||||
@staticmethod
|
||||
def metadata_from_api(
|
||||
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
|
||||
) -> dict:
|
||||
if not api_data:
|
||||
return {}
|
||||
|
||||
if isinstance(api_data, dict):
|
||||
return {
|
||||
ParameterKeyEscaper.escape(k): v.to_struct()
|
||||
for k, v in api_data.items()
|
||||
}
|
||||
|
||||
return {
|
||||
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_metadata(
|
||||
cls,
|
||||
obj: Document,
|
||||
items: Sequence[MetadataItem],
|
||||
replace_metadata: bool,
|
||||
**more_updates,
|
||||
) -> int:
|
||||
update_cmds = dict()
|
||||
metadata = cls.metadata_from_api(items)
|
||||
if replace_metadata:
|
||||
update_cmds["set__metadata"] = metadata
|
||||
else:
|
||||
for key, value in metadata.items():
|
||||
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
|
||||
|
||||
return obj.update(**update_cmds, **more_updates)
|
||||
|
||||
@classmethod
|
||||
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
|
||||
return obj.update(
|
||||
**{
|
||||
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
|
||||
for key in set(keys)
|
||||
},
|
||||
**more_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
|
||||
for prefix in (
|
||||
"metadata.",
|
||||
"-metadata.",
|
||||
):
|
||||
paths = [
|
||||
cls._process_path(path) if path.startswith(prefix) else path
|
||||
for path in paths
|
||||
]
|
||||
return paths
|
||||
|
||||
@classmethod
|
||||
def escape_query_parameters(cls, call: APICall) -> dict:
|
||||
if not call.data:
|
||||
return call.data
|
||||
|
||||
keys = list(call.data)
|
||||
call_data = {
|
||||
safe_key: call.data[key]
|
||||
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
|
||||
}
|
||||
|
||||
projection = GetMixin.get_projection(call_data)
|
||||
if projection:
|
||||
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
|
||||
|
||||
ordering = GetMixin.get_ordering(call_data)
|
||||
if ordering:
|
||||
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
|
||||
|
||||
return call_data
|
||||
63
apiserver/bll/organization/__init__.py
Normal file
63
apiserver/bll/organization/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Sequence, Dict
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from .tags_cache import _TagsCache
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Tags(Enum):
|
||||
Task = "task"
|
||||
Model = "model"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company_id: str,
|
||||
entity: Tags,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
projects: Sequence[str] = None,
|
||||
) -> dict:
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
if not projects:
|
||||
return tags_cache.get_tags(
|
||||
company_id, include_system=include_system, filter_=filter_
|
||||
)
|
||||
|
||||
ret = defaultdict(set)
|
||||
for project in projects:
|
||||
project_tags = tags_cache.get_tags(
|
||||
company_id,
|
||||
include_system=include_system,
|
||||
filter_=filter_,
|
||||
project=project,
|
||||
)
|
||||
for field, tags in project_tags.items():
|
||||
ret[field] |= tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company_id, project, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.reset_tags(company_id, projects=projects)
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
148
apiserver/bll/organization/tags_cache.py
Normal file
148
apiserver/bll/organization/tags_cache.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from itertools import chain
|
||||
from typing import Sequence, Union, Type, Dict
|
||||
|
||||
from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.bll.project import project_ids_with_children
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
|
||||
log = config.logger(__file__)
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
|
||||
class _TagsCache:
|
||||
_tags_field = "tags"
|
||||
_system_tags_field = "system_tags"
|
||||
_dummy_tag = "__dummy__"
|
||||
# prepend our list in redis with this tag since empty lists are auto deleted
|
||||
|
||||
def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis):
|
||||
self.db_cls = db_cls
|
||||
self.redis = redis
|
||||
|
||||
@property
|
||||
def _tags_cache_expiration_seconds(self):
|
||||
return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600)
|
||||
|
||||
def _get_tags_from_db(
|
||||
self,
|
||||
company_id: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> set:
|
||||
query = Q(company=company_id)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project__in=project_ids_with_children([project]))
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
def _get_tags_cache_key(
|
||||
self,
|
||||
company_id: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
):
|
||||
"""
|
||||
Project None means 'from all company projects'
|
||||
The key is built in the way that scanning company keys for 'all company projects'
|
||||
will not return the keys related to the particular company projects and vice versa.
|
||||
So that we can have a fine grain control on what redis keys to invalidate
|
||||
"""
|
||||
filter_str = None
|
||||
if filter_:
|
||||
filter_str = "_".join(
|
||||
["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())]
|
||||
)
|
||||
key_parts = [field, company_id, project, self.db_cls.__name__, filter_str]
|
||||
return "_".join(filter(None, key_parts))
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company_id: str,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
project: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get tags and optionally system tags for the company
|
||||
Return the dictionary of tags per tags field name
|
||||
The function retrieves both cached values from Redis in one call
|
||||
and re calculates any of them if missing in Redis
|
||||
"""
|
||||
fields = [self._tags_field]
|
||||
if include_system:
|
||||
fields.append(self._system_tags_field)
|
||||
|
||||
ret = {}
|
||||
for field in fields:
|
||||
redis_key = self._get_tags_cache_key(
|
||||
company_id, field=field, project=project, filter_=filter_
|
||||
)
|
||||
cached_tags = self.redis.lrange(redis_key, 0, -1)
|
||||
if cached_tags:
|
||||
tags = [c.decode() for c in cached_tags[1:]]
|
||||
else:
|
||||
tags = list(
|
||||
self._get_tags_from_db(
|
||||
company_id, field=field, project=project, filter_=filter_
|
||||
)
|
||||
)
|
||||
self.redis.rpush(redis_key, self._dummy_tag, *tags)
|
||||
self.redis.expire(redis_key, self._tags_cache_expiration_seconds)
|
||||
|
||||
ret[field] = set(tags)
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
|
||||
"""
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
"""
|
||||
fields = [
|
||||
field
|
||||
for field, update in (
|
||||
(self._tags_field, tags),
|
||||
(self._system_tags_field, system_tags),
|
||||
)
|
||||
if update is not None
|
||||
]
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company_id, projects=[project], fields=fields)
|
||||
|
||||
def reset_tags(self, company_id: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
company_id,
|
||||
projects=projects,
|
||||
fields=(self._tags_field, self._system_tags_field),
|
||||
)
|
||||
|
||||
def _delete_redis_keys(
|
||||
self, company_id: str, projects: [Sequence[str]], fields: Sequence[str]
|
||||
):
|
||||
redis_keys = list(
|
||||
chain.from_iterable(
|
||||
self.redis.keys(
|
||||
self._get_tags_cache_key(company_id, field=f, project=p) + "*"
|
||||
)
|
||||
for f in fields
|
||||
for p in set(projects) | {None}
|
||||
)
|
||||
)
|
||||
if redis_keys:
|
||||
self.redis.delete(*redis_keys)
|
||||
3
apiserver/bll/project/__init__.py
Normal file
3
apiserver/bll/project/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .project_bll import ProjectBLL
|
||||
from .project_queries import ProjectQueries
|
||||
from .sub_projects import _ids_with_children as project_ids_with_children
|
||||
931
apiserver/bll/project/project_bll.py
Normal file
931
apiserver/bll/project/project_bll.py
Normal file
@@ -0,0 +1,931 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
Type,
|
||||
Tuple,
|
||||
Dict,
|
||||
Set,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Mapping,
|
||||
Any,
|
||||
)
|
||||
|
||||
from boltons.iterutils import partition
|
||||
from mongoengine import Q, Document
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility, AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
|
||||
from apiserver.database.utils import get_options, get_company_or_none_constraint
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .sub_projects import (
|
||||
_reposition_project_with_children,
|
||||
_ensure_project,
|
||||
_validate_project_name,
|
||||
_update_subproject_names,
|
||||
_save_under_parent,
|
||||
_get_sub_projects,
|
||||
_ids_with_children,
|
||||
_ids_with_parents,
|
||||
_get_project_depth,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
@classmethod
|
||||
def merge_project(
|
||||
cls, company, source_id: str, destination_id: str
|
||||
) -> Tuple[int, int, Set[str]]:
|
||||
"""
|
||||
Move all the tasks and sub projects from the source project to the destination
|
||||
Remove the source project
|
||||
Return the amounts of moved entities and subprojects + set of all the affected project ids
|
||||
"""
|
||||
if source_id == destination_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
source=source_id
|
||||
)
|
||||
source = Project.get(company, source_id)
|
||||
if destination_id:
|
||||
destination = Project.get(company, destination_id)
|
||||
if source_id in destination.path:
|
||||
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
|
||||
source=source_id, destination=destination_id
|
||||
)
|
||||
else:
|
||||
destination = None
|
||||
|
||||
children = _get_sub_projects(
|
||||
[source.id], _only=("id", "name", "parent", "path")
|
||||
)[source.id]
|
||||
if destination:
|
||||
cls.validate_projects_depth(
|
||||
projects=children,
|
||||
old_parent_depth=len(source.path) + 1,
|
||||
new_parent_depth=len(destination.path) + 1,
|
||||
)
|
||||
|
||||
moved_entities = 0
|
||||
for entity_type in (Task, Model):
|
||||
moved_entities += entity_type.objects(
|
||||
company=company,
|
||||
project=source_id,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).update(upsert=False, project=destination_id)
|
||||
|
||||
moved_sub_projects = 0
|
||||
for child in Project.objects(company=company, parent=source_id):
|
||||
_reposition_project_with_children(
|
||||
project=child,
|
||||
children=[c for c in children if c.parent == child.id],
|
||||
parent=destination,
|
||||
)
|
||||
moved_sub_projects += 1
|
||||
|
||||
affected = {source.id, *(source.path or [])}
|
||||
source.delete()
|
||||
|
||||
if destination:
|
||||
destination.update(last_update=datetime.utcnow())
|
||||
affected.update({destination.id, *(destination.path or [])})
|
||||
|
||||
return moved_entities, moved_sub_projects, affected
|
||||
|
||||
@staticmethod
|
||||
def validate_projects_depth(
|
||||
projects: Sequence[Project], old_parent_depth: int, new_parent_depth: int
|
||||
):
|
||||
for current in projects:
|
||||
current_depth = len(current.path) + 1
|
||||
if current_depth - old_parent_depth + new_parent_depth > max_depth:
|
||||
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
|
||||
|
||||
@classmethod
|
||||
def move_project(
|
||||
cls, company: str, user: str, project_id: str, new_location: str
|
||||
) -> Tuple[int, Set[str]]:
|
||||
"""
|
||||
Move project with its sub projects from its current location to the target one.
|
||||
If the target location does not exist then it will be created. If it exists then
|
||||
it should be writable. The source location should be writable too.
|
||||
Return the number of moved projects + set of all the affected project ids
|
||||
"""
|
||||
project = Project.get(company, project_id)
|
||||
old_parent_id = project.parent
|
||||
old_parent = (
|
||||
Project.get_for_writing(company=project.company, id=old_parent_id)
|
||||
if old_parent_id
|
||||
else None
|
||||
)
|
||||
|
||||
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
|
||||
project.id
|
||||
]
|
||||
cls.validate_projects_depth(
|
||||
projects=[project, *children],
|
||||
old_parent_depth=len(project.path),
|
||||
new_parent_depth=_get_project_depth(new_location),
|
||||
)
|
||||
|
||||
new_parent = _ensure_project(company=company, user=user, name=new_location)
|
||||
new_parent_id = new_parent.id if new_parent else None
|
||||
if old_parent_id == new_parent_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
location=new_parent.name if new_parent else ""
|
||||
)
|
||||
if new_parent and (
|
||||
project_id == new_parent.id or project_id in new_parent.path
|
||||
):
|
||||
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
|
||||
project=project_id, parent=new_parent.id
|
||||
)
|
||||
moved = _reposition_project_with_children(
|
||||
project, children=children, parent=new_parent
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
affected = set()
|
||||
for p in filter(None, (old_parent, new_parent)):
|
||||
p.update(last_update=now)
|
||||
affected.update({p.id, *(p.path or [])})
|
||||
|
||||
return moved, affected
|
||||
|
||||
@classmethod
|
||||
def update(cls, company: str, project_id: str, **fields):
|
||||
project = Project.get_for_writing(company=company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
new_name = fields.pop("name", None)
|
||||
if new_name:
|
||||
new_name, new_location = _validate_project_name(new_name)
|
||||
old_name, old_location = _validate_project_name(project.name)
|
||||
if new_location != old_location:
|
||||
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
|
||||
fields["name"] = new_name
|
||||
fields["basename"] = new_name.split("/")[-1]
|
||||
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
updated = project.update(upsert=False, **fields)
|
||||
|
||||
if new_name:
|
||||
old_name = project.name
|
||||
project.name = new_name
|
||||
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
|
||||
project.id
|
||||
]
|
||||
_update_subproject_names(
|
||||
project=project, children=children, old_name=old_name
|
||||
)
|
||||
|
||||
return updated
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
user: str,
|
||||
company: str,
|
||||
name: str,
|
||||
description: str = "",
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new project.
|
||||
Returns project ID
|
||||
"""
|
||||
if _get_project_depth(name) > max_depth:
|
||||
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
|
||||
|
||||
name, location = _validate_project_name(name)
|
||||
now = datetime.utcnow()
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
user=user,
|
||||
company=company,
|
||||
name=name,
|
||||
basename=name.split("/")[-1],
|
||||
description=description,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
default_output_destination=default_output_destination,
|
||||
created=now,
|
||||
last_update=now,
|
||||
)
|
||||
parent = _ensure_project(
|
||||
company=company,
|
||||
user=user,
|
||||
name=location,
|
||||
creation_params=parent_creation_params,
|
||||
)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
||||
return project.id
|
||||
|
||||
@classmethod
|
||||
def find_or_create(
|
||||
cls,
|
||||
user: str,
|
||||
company: str,
|
||||
project_name: str,
|
||||
description: str,
|
||||
project_id: str = None,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find a project named `project_name` or create a new one.
|
||||
Returns project ID
|
||||
"""
|
||||
if not project_id and not project_name:
|
||||
raise errors.bad_request.ValidationError("project id or name required")
|
||||
|
||||
if project_id:
|
||||
project = Project.objects(company=company, id=project_id).only("id").first()
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
return project_id
|
||||
|
||||
project_name, _ = _validate_project_name(project_name)
|
||||
project = Project.objects(company=company, name=project_name).only("id").first()
|
||||
if project:
|
||||
return project.id
|
||||
|
||||
return cls.create(
|
||||
user=user,
|
||||
company=company,
|
||||
name=project_name,
|
||||
description=description,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
default_output_destination=default_output_destination,
|
||||
parent_creation_params=parent_creation_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def move_under_project(
|
||||
cls,
|
||||
entity_cls: Type[Document],
|
||||
user: str,
|
||||
company: str,
|
||||
ids: Sequence[str],
|
||||
project: str = None,
|
||||
project_name: str = None,
|
||||
):
|
||||
"""
|
||||
Move a batch of entities to `project` or a project named `project_name` (create if does not exist)
|
||||
"""
|
||||
project = cls.find_or_create(
|
||||
user=user,
|
||||
company=company,
|
||||
project_id=project,
|
||||
project_name=project_name,
|
||||
description="",
|
||||
)
|
||||
extra = (
|
||||
{"set__last_change": datetime.utcnow()}
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
entity_cls.objects(company=company, id__in=ids).update(
|
||||
set__project=project, **extra
|
||||
)
|
||||
|
||||
return project
|
||||
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
|
||||
|
||||
@classmethod
|
||||
def make_projects_get_all_pipelines(
|
||||
cls,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
def ensure_valid_fields():
|
||||
"""
|
||||
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
||||
"""
|
||||
return {
|
||||
"$addFields": {
|
||||
"system_tags": {
|
||||
"$cond": {
|
||||
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
|
||||
"then": [],
|
||||
"else": "$system_tags",
|
||||
}
|
||||
},
|
||||
"status": {"$ifNull": ["$status", "unknown"]},
|
||||
}
|
||||
}
|
||||
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company_id,
|
||||
project_ids=project_ids,
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"project": "$project",
|
||||
"status": "$status",
|
||||
archived: cls.archived_tasks_cond,
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
# for each project, create a list of (status, count, archived)
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.project",
|
||||
"counts": {
|
||||
"$push": {
|
||||
"status": "$_id.status",
|
||||
"count": "$count",
|
||||
archived: "$_id.%s" % archived,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
def completed_after_subquery(additional_cond, time_thresh: datetime):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed after the time_thresh
|
||||
"if": {
|
||||
"$and": [
|
||||
"$completed",
|
||||
{"$gt": ["$completed", time_thresh]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
"then": 1,
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def max_started_subquery(condition):
|
||||
return {
|
||||
"$max": {
|
||||
"$cond": {
|
||||
"if": condition,
|
||||
"then": "$started",
|
||||
"else": datetime.min,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed and started and completed > started
|
||||
"if": {
|
||||
"$and": [
|
||||
"$started",
|
||||
"$completed",
|
||||
{"$gt": ["$completed", "$started"]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
# then: floor((completed - started) / 1000)
|
||||
"then": {
|
||||
"$floor": {
|
||||
"$divide": [
|
||||
{"$subtract": ["$completed", "$started"]},
|
||||
1000.0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
time_thresh = datetime.utcnow() - timedelta(hours=24)
|
||||
for state in cls.visibility_states:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
cond = (
|
||||
cls.archived_tasks_cond
|
||||
if state == EntityVisibility.archived
|
||||
else {"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
group_step[state.value] = runtime_subquery(cond)
|
||||
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
||||
cond, time_thresh=time_thresh
|
||||
)
|
||||
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
|
||||
|
||||
def add_state_to_filter(f: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
if not specific_state:
|
||||
return f
|
||||
|
||||
f = f or {}
|
||||
new_f = {k: v for k, v in f.items() if k != "system_tags"}
|
||||
system_tags = [
|
||||
tag
|
||||
for tag in f.get("system_tags", [])
|
||||
if tag
|
||||
not in (
|
||||
EntityVisibility.archived.value,
|
||||
f"-{EntityVisibility.archived.value}",
|
||||
)
|
||||
]
|
||||
|
||||
if specific_state == EntityVisibility.archived:
|
||||
system_tags.append(EntityVisibility.archived.value)
|
||||
else:
|
||||
system_tags.append(f"-{EntityVisibility.archived.value}")
|
||||
new_f["system_tags"] = system_tags
|
||||
|
||||
return new_f
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company_id,
|
||||
project_ids=project_ids,
|
||||
filter_=add_state_to_filter(filter_),
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
# for each project
|
||||
"$group": group_step
|
||||
},
|
||||
]
|
||||
|
||||
return status_count_pipeline, runtime_pipeline
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@staticmethod
|
||||
def aggregate_project_data(
|
||||
func: Callable[[T, T], T],
|
||||
project_ids: Sequence[str],
|
||||
child_projects: Mapping[str, Sequence[Project]],
|
||||
data: Mapping[str, T],
|
||||
) -> Dict[str, T]:
|
||||
"""
|
||||
Given a list of project ids and data collected over these projects and their subprojects
|
||||
For each project aggregates the data from all of its subprojects
|
||||
"""
|
||||
aggregated = {}
|
||||
if not data:
|
||||
return aggregated
|
||||
for pid in project_ids:
|
||||
relevant_projects = {p.id for p in child_projects.get(pid, [])} | {pid}
|
||||
relevant_data = [data for p, data in data.items() if p in relevant_projects]
|
||||
if not relevant_data:
|
||||
continue
|
||||
aggregated[pid] = reduce(func, relevant_data)
|
||||
return aggregated
|
||||
|
||||
@classmethod
|
||||
def get_dataset_stats(
|
||||
cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None,
|
||||
) -> Dict[str, dict]:
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
task_runtime_pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls.get_match_conditions(
|
||||
company=company,
|
||||
project_ids=project_ids,
|
||||
users=users,
|
||||
filter_={
|
||||
"system_tags": [f"-{EntityVisibility.archived.value}"]
|
||||
},
|
||||
),
|
||||
"runtime": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"project": 1, "runtime": 1, "last_update": 1}},
|
||||
{"$sort": {"project": 1, "last_update": 1}},
|
||||
{"$group": {"_id": "$project", "runtime": {"$last": "$runtime"}}},
|
||||
]
|
||||
|
||||
return {
|
||||
r["_id"]: {
|
||||
"file_count": r["runtime"].get("ds_file_count", 0),
|
||||
"total_size": r["runtime"].get("ds_total_size", 0),
|
||||
}
|
||||
for r in Task.aggregate(task_runtime_pipeline)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_project_stats(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
include_children: bool = True,
|
||||
search_hidden: bool = False,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
user_active_project_ids: Sequence[str] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
|
||||
child_projects = (
|
||||
_get_sub_projects(
|
||||
project_ids,
|
||||
_only=("id", "name"),
|
||||
search_hidden=search_hidden,
|
||||
allowed_ids=user_active_project_ids,
|
||||
)
|
||||
if include_children
|
||||
else {}
|
||||
)
|
||||
project_ids_with_children = set(project_ids) | {
|
||||
c.id for c in itertools.chain.from_iterable(child_projects.values())
|
||||
}
|
||||
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
|
||||
company,
|
||||
project_ids=list(project_ids_with_children),
|
||||
specific_state=specific_state,
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
|
||||
def set_default_count(entry):
|
||||
return dict(default_counts, **entry)
|
||||
|
||||
status_count = defaultdict(lambda: {})
|
||||
key = itemgetter(EntityVisibility.archived.value)
|
||||
for result in Task.aggregate(status_count_pipeline):
|
||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||
section = (
|
||||
EntityVisibility.archived if k else EntityVisibility.active
|
||||
).value
|
||||
status_count[result["_id"]][section] = set_default_count(
|
||||
{
|
||||
count_entry["status"]: count_entry["count"]
|
||||
for count_entry in group
|
||||
}
|
||||
)
|
||||
|
||||
def sum_status_count(
|
||||
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: {
|
||||
status: nested_get(a, (section, status), default=0)
|
||||
+ nested_get(b, (section, status), default=0)
|
||||
for status in set(a.get(section, {})) | set(b.get(section, {}))
|
||||
}
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
status_count = cls.aggregate_project_data(
|
||||
func=sum_status_count,
|
||||
project_ids=project_ids,
|
||||
child_projects=child_projects,
|
||||
data=status_count,
|
||||
)
|
||||
|
||||
runtime = {
|
||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||
for result in Task.aggregate(runtime_pipeline)
|
||||
}
|
||||
|
||||
def sum_runtime(
|
||||
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: a.get(section, 0) + b.get(section, 0)
|
||||
if not section.endswith("max_task_started")
|
||||
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
runtime = cls.aggregate_project_data(
|
||||
func=sum_runtime,
|
||||
project_ids=project_ids,
|
||||
child_projects=child_projects,
|
||||
data=runtime,
|
||||
)
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
project_runtime = runtime.get(project_id, {})
|
||||
project_section_statuses = nested_get(
|
||||
status_count, (project_id, section), default=default_counts
|
||||
)
|
||||
|
||||
def get_time_or_none(value):
|
||||
return value if value != datetime.min else None
|
||||
|
||||
return {
|
||||
"status_count": project_section_statuses,
|
||||
"total_tasks": sum(project_section_statuses.values()),
|
||||
"total_runtime": project_runtime.get(section, 0),
|
||||
"completed_tasks_24h": project_runtime.get(
|
||||
f"{section}_recently_completed", 0
|
||||
),
|
||||
"last_task_run": get_time_or_none(
|
||||
project_runtime.get(f"{section}_max_task_started", datetime.min)
|
||||
),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s
|
||||
for s in cls.visibility_states
|
||||
if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
stats = {
|
||||
project: {
|
||||
task_state.value: get_status_counts(project, task_state.value)
|
||||
for task_state in report_for_states
|
||||
}
|
||||
for project in project_ids
|
||||
}
|
||||
|
||||
children = {
|
||||
project: sorted(
|
||||
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
|
||||
key=itemgetter("name"),
|
||||
)
|
||||
for project in project_ids
|
||||
}
|
||||
return stats, children
|
||||
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls,
|
||||
company,
|
||||
project_ids: Sequence[str],
|
||||
user_ids: Optional[Sequence[str]] = None,
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
query = Q(company=company)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
|
||||
projects_query = query
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
projects_query &= Q(id__in=project_ids)
|
||||
|
||||
res = set(Project.objects(projects_query).distinct(field="user"))
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_project_tags(
|
||||
cls,
|
||||
company_id: str,
|
||||
include_system: bool,
|
||||
projects: Sequence[str] = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
query = Q(company=company_id)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
|
||||
if projects:
|
||||
query &= Q(id__in=_ids_with_children(projects))
|
||||
|
||||
tags = Project.objects(query).distinct("tags")
|
||||
system_tags = (
|
||||
Project.objects(query).distinct("system_tags") if include_system else []
|
||||
)
|
||||
return tags, system_tags
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
cls,
|
||||
company: str,
|
||||
users: Sequence[str],
|
||||
project_ids: Optional[Sequence[str]] = None,
|
||||
allow_public: bool = True,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Get the projects ids where user created any tasks including all the parents of these projects
|
||||
If project ids are specified then filter the results by these project ids
|
||||
"""
|
||||
query = Q(user__in=users)
|
||||
|
||||
if allow_public:
|
||||
query &= get_company_or_none_constraint(company)
|
||||
else:
|
||||
query &= Q(company=company)
|
||||
|
||||
user_projects_query = query
|
||||
if project_ids:
|
||||
ids_with_children = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=ids_with_children)
|
||||
user_projects_query &= Q(id__in=ids_with_children)
|
||||
|
||||
res = {p.id for p in Project.objects(user_projects_query).only("id")}
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="project"))
|
||||
|
||||
res = list(res)
|
||||
if not res:
|
||||
return res, res
|
||||
|
||||
user_active_project_ids = _ids_with_parents(res)
|
||||
filtered_ids = (
|
||||
list(set(user_active_project_ids) & set(project_ids))
|
||||
if project_ids
|
||||
else list(user_active_project_ids)
|
||||
)
|
||||
|
||||
return filtered_ids, user_active_project_ids
|
||||
|
||||
@classmethod
|
||||
def get_task_parents(
|
||||
cls,
|
||||
company_id: str,
|
||||
projects: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
state: Optional[EntityVisibility] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get list of unique parent tasks sorted by task name for the passed company projects
|
||||
If projects is None or empty then get parents for all the company tasks
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
|
||||
if projects:
|
||||
if include_subprojects:
|
||||
projects = _ids_with_children(projects)
|
||||
query &= Q(project__in=projects)
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
|
||||
|
||||
parent_ids = set(Task.objects(query).distinct("parent"))
|
||||
if not parent_ids:
|
||||
return []
|
||||
|
||||
parents = Task.get_many_with_join(
|
||||
company_id,
|
||||
query=Q(id__in=parent_ids),
|
||||
allow_public=True,
|
||||
override_projection=("id", "name", "project.name"),
|
||||
)
|
||||
return sorted(parents, key=itemgetter("name"))
|
||||
|
||||
@classmethod
|
||||
def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@classmethod
|
||||
def get_model_frameworks(cls, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
|
||||
@staticmethod
|
||||
def get_match_conditions(
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
filter_: Mapping[str, Any],
|
||||
users: Sequence[str],
|
||||
):
|
||||
conditions = {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
if users:
|
||||
conditions["user"] = {"$in": users}
|
||||
|
||||
if not filter_:
|
||||
return conditions
|
||||
|
||||
for field, field_filter in filter_.items():
|
||||
if not (
|
||||
field_filter
|
||||
and isinstance(field_filter, list)
|
||||
and all(isinstance(t, str) for t in field_filter)
|
||||
):
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"List of strings expected for the field: {field}"
|
||||
)
|
||||
exclude, include = partition(field_filter, lambda x: x.startswith("-"))
|
||||
conditions[field] = {
|
||||
**({"$in": include} if include else {}),
|
||||
**({"$nin": [e[1:] for e in exclude]} if exclude else {}),
|
||||
}
|
||||
|
||||
return conditions
|
||||
|
||||
@classmethod
|
||||
def calc_own_contents(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of task/models per requested project
|
||||
Use separate aggregation calls on Task/Model instead of lookup
|
||||
aggregation on projects in order not to hit memory limits on large tasks
|
||||
"""
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company,
|
||||
project_ids=project_ids,
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
{"$project": {"project": 1}},
|
||||
{"$group": {"_id": "$project", "count": {"$sum": 1}}},
|
||||
]
|
||||
|
||||
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
||||
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
|
||||
|
||||
tasks = get_agrregate_res(Task)
|
||||
models = get_agrregate_res(Model)
|
||||
return {
|
||||
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
|
||||
for pid in project_ids
|
||||
}
|
||||
220
apiserver/bll/project/project_cleanup.py
Normal file
220
apiserver/bll/project/project_cleanup.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from typing import Tuple, Set, Sequence
|
||||
|
||||
import attr
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.task.task_cleanup import (
|
||||
collect_debug_image_urls,
|
||||
collect_plot_image_urls,
|
||||
TaskUrls,
|
||||
_schedule_for_delete,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
async_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DeleteProjectResult:
|
||||
deleted: int = 0
|
||||
disassociated_tasks: int = 0
|
||||
deleted_models: int = 0
|
||||
deleted_tasks: int = 0
|
||||
urls: TaskUrls = None
|
||||
|
||||
|
||||
def validate_project_delete(company: str, project_id: str):
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path", "system_tags")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
is_pipeline = "pipeline" in (project.system_tags or [])
|
||||
project_ids = _ids_with_children([project_id])
|
||||
ret = {}
|
||||
for cls in (Task, Model):
|
||||
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
|
||||
for cls in (Task, Model):
|
||||
query = dict(
|
||||
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
|
||||
)
|
||||
name = f"non_archived_{cls.__name__.lower()}s"
|
||||
if not is_pipeline:
|
||||
ret[name] = cls.objects(**query).count()
|
||||
else:
|
||||
ret[name] = (
|
||||
cls.objects(**query, type=TaskType.controller).count()
|
||||
if cls == Task
|
||||
else 0
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def delete_project(
|
||||
company: str,
|
||||
user: str,
|
||||
project_id: str,
|
||||
force: bool,
|
||||
delete_contents: bool,
|
||||
delete_external_artifacts=True,
|
||||
) -> Tuple[DeleteProjectResult, Set[str]]:
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path", "system_tags")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
"services.async_urls_delete.enabled", False
|
||||
)
|
||||
is_pipeline = "pipeline" in (project.system_tags or [])
|
||||
project_ids = _ids_with_children([project_id])
|
||||
if not force:
|
||||
query = dict(
|
||||
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
|
||||
)
|
||||
if not is_pipeline:
|
||||
for cls, error in (
|
||||
(Task, errors.bad_request.ProjectHasTasks),
|
||||
(Model, errors.bad_request.ProjectHasModels),
|
||||
):
|
||||
non_archived = cls.objects(**query).only("id")
|
||||
if non_archived:
|
||||
raise error("use force=true to delete", id=project_id)
|
||||
else:
|
||||
non_archived = Task.objects(**query, type=TaskType.controller).only("id")
|
||||
if non_archived:
|
||||
raise errors.bad_request.ProjectHasTasks(
|
||||
"please archive all the runs inside the project", id=project_id
|
||||
)
|
||||
|
||||
if not delete_contents:
|
||||
for cls in (Model, Task):
|
||||
updated_count = cls.objects(project__in=project_ids).update(project=None)
|
||||
res = DeleteProjectResult(disassociated_tasks=updated_count)
|
||||
else:
|
||||
deleted_models, model_event_urls, model_urls = _delete_models(
|
||||
company=company, projects=project_ids
|
||||
)
|
||||
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
|
||||
company=company, projects=project_ids
|
||||
)
|
||||
event_urls = task_event_urls | model_event_urls
|
||||
if delete_external_artifacts:
|
||||
scheduled = _schedule_for_delete(
|
||||
task_id=project_id,
|
||||
company=company,
|
||||
user=user,
|
||||
urls=event_urls | model_urls | artifact_urls,
|
||||
can_delete_folders=True,
|
||||
)
|
||||
for urls in (event_urls, model_urls, artifact_urls):
|
||||
urls.difference_update(scheduled)
|
||||
res = DeleteProjectResult(
|
||||
deleted_tasks=deleted_tasks,
|
||||
deleted_models=deleted_models,
|
||||
urls=TaskUrls(
|
||||
model_urls=list(model_urls),
|
||||
event_urls=list(event_urls),
|
||||
artifact_urls=list(artifact_urls),
|
||||
),
|
||||
)
|
||||
|
||||
affected = {*project_ids, *(project.path or [])}
|
||||
res.deleted = Project.objects(id__in=project_ids).delete()
|
||||
|
||||
return res, affected
|
||||
|
||||
|
||||
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
|
||||
"""
|
||||
Delete only the task themselves and their non published version.
|
||||
Child models under the same project are deleted separately.
|
||||
Children tasks should be deleted in the same api call.
|
||||
If any child entities are left in another projects then updated their parent task to None
|
||||
"""
|
||||
tasks = Task.objects(project__in=projects).only("id", "execution__artifacts")
|
||||
if not tasks:
|
||||
return 0, set(), set()
|
||||
|
||||
task_ids = {t.id for t in tasks}
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
|
||||
|
||||
event_urls, artifact_urls = set(), set()
|
||||
for task in tasks:
|
||||
event_urls.update(collect_debug_image_urls(company, task.id))
|
||||
event_urls.update(collect_plot_image_urls(company, task.id))
|
||||
if task.execution and task.execution.artifacts:
|
||||
artifact_urls.update(
|
||||
{
|
||||
a.uri
|
||||
for a in task.execution.artifacts.values()
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
}
|
||||
)
|
||||
|
||||
event_bll.delete_multi_task_events(
|
||||
company, list(task_ids), async_delete=async_events_delete
|
||||
)
|
||||
deleted = tasks.delete()
|
||||
return deleted, event_urls, artifact_urls
|
||||
|
||||
|
||||
def _delete_models(
|
||||
company: str, projects: Sequence[str]
|
||||
) -> Tuple[int, Set[str], Set[str]]:
|
||||
"""
|
||||
Delete project models and update the tasks from other projects
|
||||
that reference them to reference None.
|
||||
"""
|
||||
models = Model.objects(project__in=projects).only("task", "id", "uri")
|
||||
if not models:
|
||||
return 0, set(), set()
|
||||
|
||||
model_ids = list({m.id for m in models})
|
||||
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"project": {"$nin": projects},
|
||||
"models.input.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.input.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
model_tasks = list({m.task for m in models if m.task})
|
||||
if model_tasks:
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"_id": {"$in": model_tasks},
|
||||
"project": {"$nin": projects},
|
||||
"models.output.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.output.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
event_urls, model_urls = set(), set()
|
||||
for m in models:
|
||||
event_urls.update(collect_debug_image_urls(company, m.id))
|
||||
event_urls.update(collect_plot_image_urls(company, m.id))
|
||||
if m.uri:
|
||||
model_urls.add(m.uri)
|
||||
|
||||
event_bll.delete_multi_task_events(
|
||||
company, model_ids, async_delete=async_events_delete
|
||||
)
|
||||
deleted = models.delete()
|
||||
return deleted, event_urls, model_urls
|
||||
370
apiserver/bll/project/project_queries.py
Normal file
370
apiserver/bll/project/project_queries.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectQueries:
|
||||
def __init__(self, redis=None):
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def _get_project_constraint(
|
||||
project_ids: Sequence[str], include_subprojects: bool
|
||||
) -> dict:
|
||||
"""
|
||||
If passed projects is None means top level projects
|
||||
If passed projects is empty means no project filtering
|
||||
"""
|
||||
if include_subprojects:
|
||||
if not project_ids:
|
||||
return {}
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
|
||||
if project_ids is None:
|
||||
project_ids = [None]
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
return {"project": {"$in": project_ids}}
|
||||
|
||||
@staticmethod
|
||||
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
|
||||
if allow_public:
|
||||
return {"company": {"$in": [None, "", company_id]}}
|
||||
|
||||
return {"company": company_id}
|
||||
|
||||
@classmethod
|
||||
def get_aggregated_project_parameters(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "section"))
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "name"))
|
||||
),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
ParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_param_values(
|
||||
self, key: str, last_update: datetime, allowed_delta_sec=0
|
||||
) -> Optional[ParamValues]:
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving params cached values: {str(ex)}")
|
||||
|
||||
def get_task_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key,
|
||||
last_update=last_update,
|
||||
allowed_delta_sec=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
),
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def get_unique_metric_variants(
|
||||
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
}
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@classmethod
|
||||
def get_model_metadata_keys(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"metadata": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
|
||||
{"$unwind": "$metadata"},
|
||||
{"$group": {"_id": "$metadata.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r.get("_id"))
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
def get_model_metadata_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
key: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
|
||||
last_updated_model = (
|
||||
Model.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_model:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
|
||||
last_update = last_updated_model.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.models.metadata_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
189
apiserver/bll/project/sub_projects.py
Normal file
189
apiserver/bll/project/sub_projects.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import itertools
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Optional, Sequence, Mapping
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.project import Project
|
||||
|
||||
name_separator = "/"
|
||||
|
||||
|
||||
def _get_project_depth(project_name: str) -> int:
|
||||
return len(list(filter(None, project_name.split(name_separator))))
|
||||
|
||||
|
||||
def _validate_project_name(project_name: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove redundant '/' characters. Ensure that the project name is not empty
|
||||
Return the cleaned up project name and location
|
||||
"""
|
||||
name_parts = list(filter(None, project_name.split(name_separator)))
|
||||
if not name_parts:
|
||||
raise errors.bad_request.InvalidProjectName(name=project_name)
|
||||
|
||||
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
|
||||
|
||||
|
||||
def _ensure_project(
|
||||
company: str, user: str, name: str, creation_params: dict = None
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Makes sure that the project with the given name exists
|
||||
If needed auto-create the project and all the missing projects in the path to it
|
||||
Return the project
|
||||
"""
|
||||
name = name.strip(name_separator)
|
||||
if not name:
|
||||
return None
|
||||
|
||||
project = _get_writable_project_from_name(company, name)
|
||||
if project:
|
||||
return project
|
||||
|
||||
now = datetime.utcnow()
|
||||
name, location = _validate_project_name(name)
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
user=user,
|
||||
company=company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name,
|
||||
basename=name.split("/")[-1],
|
||||
**(creation_params or dict(description="")),
|
||||
)
|
||||
parent = _ensure_project(company, user, location, creation_params=creation_params)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
||||
return project
|
||||
|
||||
|
||||
def _save_under_parent(project: Project, parent: Optional[Project]):
|
||||
"""
|
||||
Save the project under the given parent project or top level (parent=None)
|
||||
Check that the project location matches the parent name
|
||||
"""
|
||||
location, _, _ = project.name.rpartition(name_separator)
|
||||
if not parent:
|
||||
if location:
|
||||
raise ValueError(
|
||||
f"Project location {location} does not match empty parent name"
|
||||
)
|
||||
project.parent = None
|
||||
project.path = []
|
||||
project.save()
|
||||
return
|
||||
|
||||
if location != parent.name:
|
||||
raise ValueError(
|
||||
f"Project location {location} does not match parent name {parent.name}"
|
||||
)
|
||||
project.parent = parent.id
|
||||
project.path = [*(parent.path or []), parent.id]
|
||||
project.save()
|
||||
|
||||
|
||||
def _get_writable_project_from_name(
|
||||
company,
|
||||
name,
|
||||
_only: Optional[Sequence[str]] = ("id", "name", "path", "company", "parent"),
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Return a project from name. If the project not found then return None
|
||||
"""
|
||||
qs = Project.objects(company=company, name=name)
|
||||
if _only:
|
||||
qs = qs.only(*_only)
|
||||
return qs.first()
|
||||
|
||||
|
||||
def _get_sub_projects(
|
||||
project_ids: Sequence[str],
|
||||
_only: Sequence[str] = ("id", "path"),
|
||||
search_hidden=True,
|
||||
allowed_ids: Sequence[str] = None,
|
||||
) -> Mapping[str, Sequence[Project]]:
|
||||
"""
|
||||
Return the list of child projects of all the levels for the parent project ids
|
||||
"""
|
||||
query = dict(path__in=project_ids)
|
||||
if not search_hidden:
|
||||
query["system_tags__nin"] = [EntityVisibility.hidden.value]
|
||||
if allowed_ids:
|
||||
query["id__in"] = allowed_ids
|
||||
|
||||
qs = Project.objects(**query)
|
||||
if _only:
|
||||
_only = set(_only) | {"path"}
|
||||
qs = qs.only(*_only)
|
||||
subprojects = list(qs)
|
||||
|
||||
return {
|
||||
pid: [s for s in subprojects if pid in (s.path or [])] for pid in project_ids
|
||||
}
|
||||
|
||||
|
||||
def _ids_with_parents(project_ids: Sequence[str]) -> Sequence[str]:
|
||||
"""
|
||||
Return project ids with all the parent projects
|
||||
"""
|
||||
projects = Project.objects(id__in=project_ids).only("id", "path")
|
||||
parent_ids = set(itertools.chain.from_iterable(p.path for p in projects if p.path))
|
||||
return list({*(p.id for p in projects), *parent_ids})
|
||||
|
||||
|
||||
def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
|
||||
"""
|
||||
Return project ids with the ids of all the subprojects
|
||||
"""
|
||||
subprojects = Project.objects(path__in=project_ids).only("id")
|
||||
return list({*project_ids, *(child.id for child in subprojects)})
|
||||
|
||||
|
||||
def _update_subproject_names(
|
||||
project: Project,
|
||||
children: Sequence[Project],
|
||||
old_name: str,
|
||||
update_path: bool = False,
|
||||
old_path: Sequence[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Update sub project names when the base project name changes
|
||||
Optionally update the paths
|
||||
"""
|
||||
updated = 0
|
||||
for child in children:
|
||||
child_suffix = name_separator.join(
|
||||
child.name.split(name_separator)[len(old_name.split(name_separator)) :]
|
||||
)
|
||||
updates = {"name": name_separator.join((project.name, child_suffix))}
|
||||
if update_path:
|
||||
updates["path"] = project.path + child.path[len(old_path) :]
|
||||
updated += child.update(upsert=False, **updates)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def _reposition_project_with_children(
|
||||
project: Project, children: Sequence[Project], parent: Project
|
||||
) -> int:
|
||||
new_location = parent.name if parent else None
|
||||
old_name = project.name
|
||||
old_path = project.path
|
||||
project.name = name_separator.join(
|
||||
filter(None, (new_location, project.name.split(name_separator)[-1]))
|
||||
)
|
||||
_save_under_parent(project, parent=parent)
|
||||
|
||||
moved = 1 + _update_subproject_names(
|
||||
project=project,
|
||||
children=children,
|
||||
old_name=old_name,
|
||||
update_path=True,
|
||||
old_path=old_path,
|
||||
)
|
||||
return moved
|
||||
1
apiserver/bll/query/__init__.py
Normal file
1
apiserver/bll/query/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .builder import Builder
|
||||
36
apiserver/bll/query/builder.py
Normal file
36
apiserver/bll/query/builder.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Optional, Sequence, Iterable, Union
|
||||
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
RANGE_IGNORE_VALUE = -1
|
||||
|
||||
|
||||
class Builder:
|
||||
@staticmethod
|
||||
def dates_range(from_date: Union[int, float], to_date: Union[int, float]) -> dict:
|
||||
return {
|
||||
"range": {
|
||||
"timestamp": {
|
||||
"gte": int(from_date),
|
||||
"lte": int(to_date),
|
||||
"format": "epoch_second",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def terms(field: str, values: Iterable[str]) -> dict:
|
||||
return {"terms": {field: list(values)}}
|
||||
|
||||
@staticmethod
|
||||
def normalize_range(
|
||||
range_: Sequence[Union[int, float]],
|
||||
ignore_value: Union[int, float] = RANGE_IGNORE_VALUE,
|
||||
) -> Optional[Sequence[Union[int, float]]]:
|
||||
if not range_ or set(range_) == {ignore_value}:
|
||||
return None
|
||||
if len(range_) < 2:
|
||||
return [range_[0]] * 2
|
||||
return range_
|
||||
1
apiserver/bll/queue/__init__.py
Normal file
1
apiserver/bll/queue/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .queue_bll import QueueBLL
|
||||
379
apiserver/bll/queue/queue_bll.py
Normal file
379
apiserver/bll/queue/queue_bll.py
Normal file
@@ -0,0 +1,379 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Callable, Sequence, Optional, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.queue.queue_metrics import QueueMetrics
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.queue import Queue, Entry
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class QueueBLL(object):
|
||||
def __init__(self, worker_bll: WorkerBLL = None, es: Elasticsearch = None):
|
||||
self.worker_bll = worker_bll or WorkerBLL()
|
||||
self.es = es or es_factory.connect("workers")
|
||||
self._metrics = QueueMetrics(self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> QueueMetrics:
|
||||
return self._metrics
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
company_id: str,
|
||||
name: str,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Queue:
|
||||
"""Creates a queue"""
|
||||
with translate_errors_context():
|
||||
now = datetime.utcnow()
|
||||
queue = Queue(
|
||||
id=database.utils.id(),
|
||||
company=company_id,
|
||||
created=now,
|
||||
name=name,
|
||||
tags=tags or [],
|
||||
system_tags=system_tags or [],
|
||||
metadata=metadata,
|
||||
last_update=now,
|
||||
)
|
||||
queue.save()
|
||||
return queue
|
||||
|
||||
def get_by_name(
|
||||
self, company_id: str, queue_name: str, only: Optional[Sequence[str]] = None,
|
||||
) -> Queue:
|
||||
qs = Queue.objects(name=queue_name, company=company_id)
|
||||
if only:
|
||||
qs = qs.only(*only)
|
||||
|
||||
return qs.first()
|
||||
|
||||
@staticmethod
|
||||
def _get_task_entries_projection(max_task_entries: int) -> dict:
|
||||
return dict(slice__entries=max_task_entries)
|
||||
|
||||
def get_by_id(
|
||||
self,
|
||||
company_id: str,
|
||||
queue_id: str,
|
||||
only: Optional[Sequence[str]] = None,
|
||||
max_task_entries: int = None,
|
||||
) -> Queue:
|
||||
"""
|
||||
Get queue by id
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue is not found
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
qs = Queue.objects(**query)
|
||||
if only:
|
||||
qs = qs.only(*only)
|
||||
if max_task_entries:
|
||||
qs = qs.fields(**self._get_task_entries_projection(max_task_entries))
|
||||
queue = qs.first()
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
|
||||
return queue
|
||||
|
||||
@classmethod
|
||||
def get_queue_with_task(cls, company_id: str, queue_id: str, task_id: str) -> Queue:
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
queue = Queue.objects(entries__task=task_id, **query).first()
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
|
||||
task=task_id, **query
|
||||
)
|
||||
|
||||
return queue
|
||||
|
||||
def get_default(self, company_id: str) -> Queue:
|
||||
"""
|
||||
Get the default queue
|
||||
:raise errors.bad_request.NoDefaultQueue: if the default queue not found
|
||||
:raise errors.bad_request.MultipleDefaultQueues: if more than one default queue is found
|
||||
"""
|
||||
with translate_errors_context():
|
||||
res = Queue.objects(company=company_id, system_tags="default").only(
|
||||
"id", "name"
|
||||
)
|
||||
if not res:
|
||||
raise errors.bad_request.NoDefaultQueue()
|
||||
if len(res) > 1:
|
||||
raise errors.bad_request.MultipleDefaultQueues(
|
||||
queues=tuple(r.id for r in res)
|
||||
)
|
||||
|
||||
return res.first()
|
||||
|
||||
def update(
|
||||
self, company_id: str, queue_id: str, **update_fields
|
||||
) -> Tuple[int, dict]:
|
||||
"""
|
||||
Partial update of the queue from update_fields
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue is not found
|
||||
:return: number of updated objects and updated fields dictionary
|
||||
"""
|
||||
with translate_errors_context():
|
||||
# validate the queue exists
|
||||
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return Queue.safe_update(company_id, queue_id, update_fields)
|
||||
|
||||
def delete(self, company_id: str, queue_id: str, force: bool) -> None:
|
||||
"""
|
||||
Delete the queue
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue is not found
|
||||
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
|
||||
if queue.entries:
|
||||
if not force:
|
||||
raise errors.bad_request.QueueNotEmpty(
|
||||
"use force=true to delete", id=queue_id
|
||||
)
|
||||
from apiserver.bll.task import ChangeStatusRequest
|
||||
|
||||
for item in queue.entries:
|
||||
try:
|
||||
task = Task.get_for_writing(
|
||||
company=company_id,
|
||||
id=item.task,
|
||||
_only=["id", "status", "enqueue_status", "project"],
|
||||
)
|
||||
if not task:
|
||||
continue
|
||||
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=task.enqueue_status or TaskStatus.created,
|
||||
status_reason="Queue was deleted",
|
||||
status_message="",
|
||||
).execute(enqueue_status=None)
|
||||
except Exception as ex:
|
||||
log.exception(
|
||||
f"Failed dequeuing task {item.task} from queue: {queue_id}"
|
||||
)
|
||||
|
||||
queue.delete()
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
query: Q = None,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""Get all the queues according to the query"""
|
||||
with translate_errors_context():
|
||||
return Queue.get_many(
|
||||
company=company_id,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
else None,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def check_for_workers(self, company_id: str, queue_id: str) -> bool:
|
||||
for worker in self.worker_bll.get_all(company_id):
|
||||
if queue_id in worker.queues:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_queue_infos(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
query: Q = None,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get infos on all the company queues, including queue tasks and workers
|
||||
"""
|
||||
projection = Queue.get_extra_projection("entries.task.name")
|
||||
with translate_errors_context():
|
||||
res = Queue.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
override_projection=projection,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
else None,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
queue_workers = defaultdict(list)
|
||||
for worker in self.worker_bll.get_all(company_id):
|
||||
for queue in worker.queues:
|
||||
queue_workers[queue].append(worker)
|
||||
|
||||
for item in res:
|
||||
item["workers"] = [
|
||||
{
|
||||
"name": w.id,
|
||||
"ip": w.ip,
|
||||
"task": w.task.to_struct() if w.task else None,
|
||||
}
|
||||
for w in queue_workers.get(item["id"], [])
|
||||
]
|
||||
|
||||
return res
|
||||
|
||||
def add_task(self, company_id: str, queue_id: str, task_id: str) -> dict:
|
||||
"""
|
||||
Add the task to the queue and return the queue update results
|
||||
:raise errors.bad_request.TaskAlreadyQueued: if the task is already in the queue
|
||||
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the queue update operation failed
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
|
||||
if any(e.task == task_id for e in queue.entries):
|
||||
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
|
||||
|
||||
entry = Entry(added=datetime.utcnow(), task=task_id)
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
|
||||
push__entries=entry, last_update=datetime.utcnow(), upsert=False
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
|
||||
task=task_id, **query
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def get_next_task(
|
||||
self, company_id: str, queue_id: str, task_id: str = None
|
||||
) -> Optional[Entry]:
|
||||
"""
|
||||
Atomically pop and return the first task from the queue (or None)
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
queue = Queue.objects(
|
||||
**query, **({"entries__0__task": task_id} if task_id else {})
|
||||
).modify(pop__entries=-1, upsert=False)
|
||||
if not queue:
|
||||
if not task_id or not Queue.objects(**query).first():
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
return
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
if not queue.entries:
|
||||
return
|
||||
|
||||
try:
|
||||
Queue.objects(**query).update(last_update=datetime.utcnow())
|
||||
except Exception:
|
||||
log.exception("Error while updating Queue.last_update")
|
||||
|
||||
return queue.entries[0]
|
||||
|
||||
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
|
||||
"""
|
||||
Removes the task from the queue and returns the number of removed items
|
||||
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_queue_with_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task_id
|
||||
)
|
||||
|
||||
entries_to_remove = [e for e in queue.entries if e.task == task_id]
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
res = Queue.objects(entries__task=task_id, **query).update_one(
|
||||
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
return len(entries_to_remove) if res else 0
|
||||
|
||||
def reposition_task(
|
||||
self,
|
||||
company_id: str,
|
||||
queue_id: str,
|
||||
task_id: str,
|
||||
pos_func: Callable[[int], int],
|
||||
) -> int:
|
||||
"""
|
||||
Moves the task in the queue to the position calculated by pos_func
|
||||
Returns the updated task position in the queue
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_queue_with_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task_id
|
||||
)
|
||||
|
||||
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
|
||||
new_position = pos_func(position)
|
||||
|
||||
if new_position != position:
|
||||
entry = queue.entries[position]
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
updated = Queue.objects(entries__task=task_id, **query).update_one(
|
||||
pull__entries=entry, last_update=datetime.utcnow()
|
||||
)
|
||||
if not updated:
|
||||
raise errors.bad_request.RemovedDuringReposition(
|
||||
task=task_id, **query
|
||||
)
|
||||
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
|
||||
if new_position >= 0:
|
||||
inst["$push"]["entries"]["$position"] = new_position
|
||||
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
|
||||
__raw__=inst
|
||||
)
|
||||
if not res:
|
||||
raise errors.bad_request.FailedAddingDuringReposition(
|
||||
task=task_id, **query
|
||||
)
|
||||
|
||||
return new_position
|
||||
|
||||
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
|
||||
res = next(
|
||||
Queue.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"_id": queue_id,
|
||||
}
|
||||
},
|
||||
{"$project": {"count": {"$size": "$entries"}}},
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if res is None:
|
||||
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
|
||||
return int(res.get("count"))
|
||||
314
apiserver/bll/queue/queue_metrics.py
Normal file
314
apiserver/bll/queue/queue_metrics.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
from typing import Sequence
|
||||
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors.errors import bad_request
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.queue import Queue, Entry
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
_conf = config.get("services.queues")
|
||||
_queue_metrics_key_pattern = "queue_metrics_{queue}"
|
||||
redis = redman.connection("apiserver")
|
||||
|
||||
|
||||
class EsKeys:
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
QUEUE_FIELD = "queue"
|
||||
|
||||
|
||||
class QueueMetrics:
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@staticmethod
|
||||
def _queue_metrics_prefix_for_company(company_id: str) -> str:
|
||||
"""Returns the es index prefix for the company"""
|
||||
return f"queue_metrics_{company_id.lower()}_"
|
||||
|
||||
@staticmethod
|
||||
def _get_es_index_suffix():
|
||||
"""Get the index name suffix for storing current month data"""
|
||||
return datetime.utcnow().strftime("%Y-%m")
|
||||
|
||||
@staticmethod
|
||||
def _calc_avg_waiting_time(entries: Sequence[Entry]) -> float:
|
||||
"""
|
||||
Calculate avg waiting time for the given tasks.
|
||||
Return 0 if the list is empty
|
||||
"""
|
||||
if not entries:
|
||||
return 0
|
||||
|
||||
now = datetime.utcnow()
|
||||
total_waiting_in_secs = sum((now - e.added).total_seconds() for e in entries)
|
||||
return total_waiting_in_secs / len(entries)
|
||||
|
||||
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> int:
|
||||
"""
|
||||
Calculate and write queue statistics (avg waiting time and queue length) to Elastic
|
||||
:return: True if the write to es was successful, false otherwise
|
||||
"""
|
||||
es_index = (
|
||||
self._queue_metrics_prefix_for_company(company_id)
|
||||
+ self._get_es_index_suffix()
|
||||
)
|
||||
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
|
||||
def make_doc(queue: Queue) -> dict:
|
||||
entries = [e for e in queue.entries if e.added]
|
||||
return {
|
||||
EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
EsKeys.QUEUE_FIELD: queue.id,
|
||||
EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(entries),
|
||||
EsKeys.QUEUE_LENGTH_FIELD: len(entries),
|
||||
}
|
||||
|
||||
logged = 0
|
||||
for q in queues:
|
||||
queue_doc = make_doc(q)
|
||||
self.es.index(index=es_index, body=queue_doc)
|
||||
redis_key = _queue_metrics_key_pattern.format(queue=q.id)
|
||||
redis.set(redis_key, json.dumps(queue_doc))
|
||||
logged += 1
|
||||
|
||||
return logged
|
||||
|
||||
def _log_current_metrics(self, company_id: str, queue_ids=Sequence[str]):
|
||||
query = dict(company=company_id)
|
||||
if queue_ids:
|
||||
query["id__in"] = list(queue_ids)
|
||||
queues = Queue.objects(**query)
|
||||
self.log_queue_metrics_to_es(company_id, queues=list(queues))
|
||||
|
||||
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*", body=es_req,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_dates_agg(cls, interval) -> dict:
|
||||
"""
|
||||
Aggregation for building date histogram with internal grouping per queue.
|
||||
We are grouping by queue inside date histogram and not vice versa so that
|
||||
it will be easy to average between queue metrics inside each date bucket.
|
||||
Ignore empty buckets.
|
||||
"""
|
||||
return {
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": EsKeys.TIMESTAMP_FIELD,
|
||||
"fixed_interval": f"{interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
"queues": {
|
||||
"terms": {"field": EsKeys.QUEUE_FIELD},
|
||||
"aggs": cls._get_top_waiting_agg(),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_top_waiting_agg(cls) -> dict:
|
||||
"""
|
||||
Aggregation for getting max waiting time and the corresponding queue length
|
||||
inside each date->queue bucket
|
||||
"""
|
||||
return {
|
||||
"top_avg_waiting": {
|
||||
"top_hits": {
|
||||
"sort": [
|
||||
{EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
|
||||
{EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
|
||||
],
|
||||
"_source": {
|
||||
"includes": [
|
||||
EsKeys.WAITING_TIME_FIELD,
|
||||
EsKeys.QUEUE_LENGTH_FIELD,
|
||||
]
|
||||
},
|
||||
"size": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def get_queue_metrics(
|
||||
self,
|
||||
company_id: str,
|
||||
from_date: float,
|
||||
to_date: float,
|
||||
interval: int,
|
||||
queue_ids: Sequence[str],
|
||||
refresh: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the company queue metrics in the specified time range.
|
||||
Returned as date histograms of average values per queue and metric type.
|
||||
The from_date is extended by 'metrics_before_from_date' seconds from
|
||||
queues.conf due to possibly small amount of points. The default extension is 3600s
|
||||
In case no queue ids are specified the avg across all the
|
||||
company queues is calculated for each metric
|
||||
"""
|
||||
if refresh:
|
||||
self._log_current_metrics(company_id, queue_ids=queue_ids)
|
||||
|
||||
if from_date >= to_date:
|
||||
raise bad_request.FieldsValueError("from_date must be less than to_date")
|
||||
|
||||
seconds_before = config.get("services.queues.metrics_before_from_date", 3600)
|
||||
must_terms = [QueryBuilder.dates_range(from_date - seconds_before, to_date)]
|
||||
if queue_ids:
|
||||
must_terms.append(QueryBuilder.terms("queue", queue_ids))
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_terms}},
|
||||
"aggs": self._get_dates_agg(interval),
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
res = self._search_company_metrics(company_id, es_req)
|
||||
|
||||
if "aggregations" not in res:
|
||||
return {}
|
||||
|
||||
date_metrics = [
|
||||
dict(
|
||||
timestamp=d["key"],
|
||||
queue_metrics=self._extract_queue_metrics(d["queues"]["buckets"]),
|
||||
)
|
||||
for d in res["aggregations"]["dates"]["buckets"]
|
||||
if d["doc_count"] > 0
|
||||
]
|
||||
if queue_ids:
|
||||
return self._datetime_histogram_per_queue(date_metrics)
|
||||
|
||||
return self._average_datetime_histogram(date_metrics)
|
||||
|
||||
@classmethod
|
||||
def _datetime_histogram_per_queue(cls, date_metrics: Sequence[dict]) -> dict:
|
||||
"""
|
||||
Build datetime histogram per queue from datetime histogram where every
|
||||
bucket contains all the queues metrics
|
||||
"""
|
||||
queues_data = defaultdict(list)
|
||||
for date_data in date_metrics:
|
||||
timestamp = date_data["timestamp"]
|
||||
for queue, metrics in date_data["queue_metrics"].items():
|
||||
queues_data[queue].append({"date": timestamp, **metrics})
|
||||
|
||||
return queues_data
|
||||
|
||||
@classmethod
|
||||
def _average_datetime_histogram(cls, date_metrics: Sequence[dict]) -> dict:
|
||||
"""
|
||||
Calculate weighted averages and total count for each bucket of date_metrics histogram.
|
||||
If for any queue the data is missing then take it from the previous bucket
|
||||
The result is returned as a dictionary with one key 'total'
|
||||
"""
|
||||
queues_total = []
|
||||
last_values = {}
|
||||
for date_data in date_metrics:
|
||||
date_metrics = date_data["queue_metrics"]
|
||||
queue_metrics = {
|
||||
**date_metrics,
|
||||
**{k: v for k, v in last_values.items() if k not in date_metrics},
|
||||
}
|
||||
|
||||
total_length = sum(m["queue_length"] for m in queue_metrics.values())
|
||||
if total_length:
|
||||
total_average = sum(
|
||||
m["avg_waiting_time"] * m["queue_length"] / total_length
|
||||
for m in queue_metrics.values()
|
||||
)
|
||||
else:
|
||||
total_average = 0
|
||||
|
||||
queues_total.append(
|
||||
dict(
|
||||
date=date_data["timestamp"],
|
||||
avg_waiting_time=total_average,
|
||||
queue_length=total_length,
|
||||
)
|
||||
)
|
||||
|
||||
for k, v in date_metrics.items():
|
||||
last_values[k] = v
|
||||
|
||||
return dict(total=queues_total)
|
||||
|
||||
@classmethod
|
||||
def _extract_queue_metrics(cls, queue_buckets: Sequence[dict]) -> dict:
|
||||
"""
|
||||
Extract ES data for single date and queue bucket
|
||||
"""
|
||||
queue_metrics = dict()
|
||||
for queue_data in queue_buckets:
|
||||
if not queue_data["doc_count"]:
|
||||
continue
|
||||
res = queue_data["top_avg_waiting"]["hits"]["hits"][0]["_source"]
|
||||
queue_metrics[queue_data["key"]] = {
|
||||
"queue_length": res[EsKeys.QUEUE_LENGTH_FIELD],
|
||||
"avg_waiting_time": res[EsKeys.WAITING_TIME_FIELD],
|
||||
}
|
||||
return queue_metrics
|
||||
|
||||
|
||||
class MetricsRefresher:
|
||||
threads = ThreadsManager()
|
||||
|
||||
@classproperty
|
||||
def watch_interval_sec(self):
|
||||
return _conf.get("metrics_refresh_interval_sec", 300)
|
||||
|
||||
@classmethod
|
||||
@threads.register("queue_metrics_refresh_watchdog", daemon=True)
|
||||
def start(cls, queue_metrics: QueueMetrics = None):
|
||||
if not cls.watch_interval_sec:
|
||||
return
|
||||
|
||||
if not queue_metrics:
|
||||
from .queue_bll import QueueBLL
|
||||
|
||||
queue_metrics = QueueBLL().metrics
|
||||
|
||||
sleep(10)
|
||||
while True:
|
||||
try:
|
||||
for queue in Queue.objects():
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
doc_time = 0
|
||||
try:
|
||||
redis_key = _queue_metrics_key_pattern.format(queue=queue.id)
|
||||
data = redis.get(redis_key)
|
||||
if data:
|
||||
queue_doc = json.loads(data)
|
||||
doc_time = int(queue_doc.get(EsKeys.TIMESTAMP_FIELD))
|
||||
except Exception as ex:
|
||||
log.exception(
|
||||
f"Error reading queue metrics data for queue {queue.id}: {str(ex)}"
|
||||
)
|
||||
|
||||
if (
|
||||
not doc_time
|
||||
or (timestamp - doc_time) > cls.watch_interval_sec * 1000
|
||||
):
|
||||
queue_metrics.log_queue_metrics_to_es(queue.company, [queue])
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed collecting queue metrics: {str(ex)}")
|
||||
sleep(60)
|
||||
87
apiserver/bll/redis_cache_manager.py
Normal file
87
apiserver/bll/redis_cache_manager.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver import database
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _do_nothing(_: T):
|
||||
return
|
||||
|
||||
|
||||
class RedisCacheManager(Generic[T]):
|
||||
"""
|
||||
Class for store/retrieve of state objects from redis
|
||||
|
||||
self.state_class - class of the state
|
||||
self.redis - instance of redis
|
||||
self.expiration_interval - expiration interval in seconds
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, state_class: Type[T], redis: StrictRedis, expiration_interval: int
|
||||
):
|
||||
self.state_class = state_class
|
||||
self.redis = redis
|
||||
self.expiration_interval = expiration_interval
|
||||
|
||||
def set_state(self, state: T) -> None:
|
||||
redis_key = self._get_redis_key(state.id)
|
||||
self.redis.set(redis_key, state.to_json())
|
||||
self.redis.expire(redis_key, self.expiration_interval)
|
||||
|
||||
def get_state(self, state_id) -> Optional[T]:
|
||||
redis_key = self._get_redis_key(state_id)
|
||||
response = self.redis.get(redis_key)
|
||||
if response:
|
||||
return self.state_class.from_json(response)
|
||||
|
||||
def delete_state(self, state_id) -> None:
|
||||
self.redis.delete(self._get_redis_key(state_id))
|
||||
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
def get_or_create_state_core(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
) -> T:
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
return state
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
):
|
||||
"""
|
||||
Try to retrieve state with the given id from the Redis cache if yes then validates it
|
||||
If no then create a new one with randomly generated id
|
||||
Yield the state and write it back to redis once the user code block exits
|
||||
:param state_id: id of the state to retrieve
|
||||
:param init_state: user callback to init the newly created state
|
||||
If not passed then no init except for the id generation is done
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_or_create_state_core(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
)
|
||||
|
||||
try:
|
||||
yield state
|
||||
finally:
|
||||
self.set_state(state)
|
||||
97
apiserver/bll/statistics/resource_monitor.py
Normal file
97
apiserver/bll/statistics/resource_monitor.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from datetime import datetime
|
||||
import operator
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
|
||||
import attr
|
||||
import psutil
|
||||
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
|
||||
stat_threads = ThreadsManager("Statistics")
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class Sample:
|
||||
cpu_usage: float = 0.0
|
||||
mem_used_gb: float = 0
|
||||
mem_free_gb: float = 0
|
||||
|
||||
@classmethod
|
||||
def _apply(cls, op, *samples):
|
||||
return cls(
|
||||
**{
|
||||
field: op(*(getattr(sample, field) for sample in samples))
|
||||
for field in attr.fields_dict(cls)
|
||||
}
|
||||
)
|
||||
|
||||
def min(self, sample):
|
||||
return self._apply(min, self, sample)
|
||||
|
||||
def max(self, sample):
|
||||
return self._apply(max, self, sample)
|
||||
|
||||
def avg(self, sample, count):
|
||||
res = self._apply(lambda x: x * count, self)
|
||||
res = self._apply(operator.add, res, sample)
|
||||
res = self._apply(lambda x: x / (count + 1), res)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_current_sample(cls) -> "Sample":
|
||||
return cls(
|
||||
cpu_usage=psutil.cpu_percent(),
|
||||
mem_used_gb=psutil.virtual_memory().used / (1024 ** 3),
|
||||
mem_free_gb=psutil.virtual_memory().free / (1024 ** 3),
|
||||
)
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
sample = Sample.get_current_sample()
|
||||
self.avg = sample
|
||||
self.min = sample
|
||||
self.max = sample
|
||||
self.time = datetime.utcnow()
|
||||
self.count = 1
|
||||
|
||||
def add_sample(self, sample: Sample):
|
||||
self.min = self.min.min(sample)
|
||||
self.max = self.max.max(sample)
|
||||
self.avg = self.avg.avg(sample, self.count)
|
||||
self.count += 1
|
||||
|
||||
sample_interval_sec = 5
|
||||
_lock = Lock()
|
||||
accumulator = Accumulator()
|
||||
|
||||
@classmethod
|
||||
@stat_threads.register("resource_monitor", daemon=True)
|
||||
def start(cls):
|
||||
while True:
|
||||
sleep(cls.sample_interval_sec)
|
||||
sample = Sample.get_current_sample()
|
||||
with cls._lock:
|
||||
cls.accumulator.add_sample(sample)
|
||||
|
||||
@classmethod
|
||||
def get_stats(cls) -> dict:
|
||||
""" Returns current resource statistics and clears internal resource statistics """
|
||||
with cls._lock:
|
||||
min_ = attr.asdict(cls.accumulator.min)
|
||||
max_ = attr.asdict(cls.accumulator.max)
|
||||
avg = attr.asdict(cls.accumulator.avg)
|
||||
interval = datetime.utcnow() - cls.accumulator.time
|
||||
cls.accumulator = cls.Accumulator()
|
||||
|
||||
return {
|
||||
"interval_sec": interval.total_seconds(),
|
||||
"num_cores": psutil.cpu_count(),
|
||||
**{
|
||||
k: {"min": v, "max": max_[k], "avg": avg[k]}
|
||||
for k, v in min_.items()
|
||||
}
|
||||
}
|
||||
305
apiserver/bll/statistics/stats_reporter.py
Normal file
305
apiserver/bll/statistics/stats_reporter.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import logging
|
||||
import queue
|
||||
import random
|
||||
import time
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import dpath
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.bll.util import get_server_uuid
|
||||
from apiserver.bll.workers import WorkerStats, WorkerBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_deployment_type
|
||||
from apiserver.database.model import Company, User
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.json import dumps
|
||||
from apiserver.version import __version__ as current_version
|
||||
from .resource_monitor import ResourceMonitor, stat_threads
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
worker_bll = WorkerBLL()
|
||||
|
||||
|
||||
class StatisticsReporter:
|
||||
send_queue = queue.Queue()
|
||||
supported = config.get("apiserver.statistics.supported", True)
|
||||
|
||||
@classmethod
|
||||
def start(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
ResourceMonitor.start()
|
||||
cls.start_sender()
|
||||
cls.start_reporter()
|
||||
|
||||
@classmethod
|
||||
@stat_threads.register("reporter", daemon=True)
|
||||
def start_reporter(cls):
|
||||
"""
|
||||
Periodically send statistics reports for companies who have opted in.
|
||||
Note: in clearml we usually have only a single company
|
||||
"""
|
||||
if not cls.supported:
|
||||
return
|
||||
|
||||
report_interval = timedelta(
|
||||
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
||||
)
|
||||
sleep(report_interval.total_seconds())
|
||||
while True:
|
||||
try:
|
||||
for company in Company.objects(
|
||||
defaults__stats_option__enabled=True
|
||||
).only("id"):
|
||||
stats = cls.get_statistics(company.id)
|
||||
cls.send_queue.put(stats)
|
||||
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed collecting stats: {str(ex)}")
|
||||
|
||||
sleep(report_interval.total_seconds())
|
||||
|
||||
@classmethod
|
||||
@stat_threads.register("sender", daemon=True)
|
||||
def start_sender(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
|
||||
url = config.get("apiserver.statistics.url")
|
||||
|
||||
retries = config.get("apiserver.statistics.max_retries", 5)
|
||||
max_backoff = config.get("apiserver.statistics.max_backoff_sec", 5)
|
||||
session = requests.Session()
|
||||
adapter = HTTPAdapter(max_retries=Retry(retries))
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
session.headers["Content-type"] = "application/json"
|
||||
|
||||
WarningFilter.attach()
|
||||
|
||||
while True:
|
||||
try:
|
||||
report = cls.send_queue.get()
|
||||
|
||||
# Set a random backoff factor each time we send a report
|
||||
adapter.max_retries.backoff_factor = random.random() * max_backoff
|
||||
|
||||
session.post(url, data=dumps(report))
|
||||
|
||||
except Exception as ex:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_statistics(cls, company_id: str) -> dict:
|
||||
"""
|
||||
Returns a statistics report per company
|
||||
"""
|
||||
return {
|
||||
"time": datetime.utcnow(),
|
||||
"company_id": company_id,
|
||||
"server": {
|
||||
"version": current_version,
|
||||
"deployment": get_deployment_type(),
|
||||
"uuid": get_server_uuid(),
|
||||
"queues": {"count": Queue.objects(company=company_id).count()},
|
||||
"users": {"count": User.objects(company=company_id).count()},
|
||||
"resources": ResourceMonitor.get_stats(),
|
||||
"experiments": next(
|
||||
iter(cls._get_experiments_stats(company_id).values()), {}
|
||||
),
|
||||
},
|
||||
"agents": cls._get_agents_statistics(company_id),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_agents_statistics(cls, company_id: str) -> Sequence[dict]:
|
||||
result = cls._get_resource_stats_per_agent(company_id, key="resources")
|
||||
dpath.merge(
|
||||
result, cls._get_experiments_stats_per_agent(company_id, key="experiments")
|
||||
)
|
||||
return [{"uuid": agent_id, **data} for agent_id, data in result.items()]
|
||||
|
||||
@classmethod
|
||||
def _get_resource_stats_per_agent(cls, company_id: str, key: str) -> dict:
|
||||
agent_resource_threshold_sec = timedelta(
|
||||
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
||||
).total_seconds()
|
||||
to_timestamp = int(time.time())
|
||||
from_timestamp = to_timestamp - int(agent_resource_threshold_sec)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
|
||||
"aggs": {
|
||||
"workers": {
|
||||
"terms": {"field": "worker"},
|
||||
"aggs": {
|
||||
"categories": {
|
||||
"terms": {"field": "category"},
|
||||
"aggs": {"count": {"cardinality": {"field": "variant"}}},
|
||||
},
|
||||
"metrics": {
|
||||
"terms": {"field": "metric"},
|
||||
"aggs": {
|
||||
"min": {"min": {"field": "value"}},
|
||||
"max": {"max": {"field": "value"}},
|
||||
"avg": {"avg": {"field": "value"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
res = cls._run_worker_stats_query(company_id, es_req)
|
||||
|
||||
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
|
||||
names = {"cpu": "num_cores"}
|
||||
return {
|
||||
names[c["key"]]: safe_get(c, "count/value")
|
||||
for c in categories
|
||||
if c["key"] in names
|
||||
}
|
||||
|
||||
def _get_metric_fields(metrics: Sequence[dict]) -> dict:
|
||||
names = {
|
||||
"cpu_usage": "cpu_usage",
|
||||
"memory_used": "mem_used_gb",
|
||||
"memory_free": "mem_free_gb",
|
||||
}
|
||||
return {
|
||||
names[m["key"]]: {
|
||||
"min": safe_get(m, "min/value"),
|
||||
"max": safe_get(m, "max/value"),
|
||||
"avg": safe_get(m, "avg/value"),
|
||||
}
|
||||
for m in metrics
|
||||
if m["key"] in names
|
||||
}
|
||||
|
||||
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
||||
return {
|
||||
b["key"]: {
|
||||
key: {
|
||||
"interval_sec": agent_resource_threshold_sec,
|
||||
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
|
||||
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
|
||||
}
|
||||
}
|
||||
for b in buckets
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_experiments_stats_per_agent(cls, company_id: str, key: str) -> dict:
|
||||
agent_relevant_threshold = timedelta(
|
||||
days=config.get("apiserver.statistics.agent_relevant_threshold_days", 30)
|
||||
)
|
||||
to_timestamp = int(time.time())
|
||||
from_timestamp = to_timestamp - int(agent_relevant_threshold.total_seconds())
|
||||
workers = cls._get_active_workers(company_id, from_timestamp, to_timestamp)
|
||||
if not workers:
|
||||
return {}
|
||||
|
||||
stats = cls._get_experiments_stats(company_id, list(workers.keys()))
|
||||
return {
|
||||
worker_id: {key: {**workers[worker_id], **stat}}
|
||||
for worker_id, stat in stats.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_active_workers(
|
||||
cls, company_id, from_timestamp: int, to_timestamp: int
|
||||
) -> dict:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
|
||||
"aggs": {
|
||||
"workers": {
|
||||
"terms": {"field": "worker"},
|
||||
"aggs": {"last_activity_time": {"max": {"field": "timestamp"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
res = cls._run_worker_stats_query(company_id, es_req)
|
||||
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
||||
return {
|
||||
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
|
||||
for b in buckets
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
|
||||
return worker_bll.es_client.search(
|
||||
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_experiments_stats(
|
||||
cls, company_id, workers: Optional[Sequence] = None
|
||||
) -> dict:
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": company_id,
|
||||
"started": {"$exists": True, "$ne": None},
|
||||
"last_update": {"$exists": True, "$ne": None},
|
||||
"status": {"$nin": ["created", "queued"]},
|
||||
**({"last_worker": {"$in": workers}} if workers else {}),
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$last_worker" if workers else None,
|
||||
"count": {"$sum": 1},
|
||||
"avg_run_time_sec": {
|
||||
"$avg": {
|
||||
"$divide": [
|
||||
{"$subtract": ["$last_update", "$started"]},
|
||||
1000,
|
||||
]
|
||||
}
|
||||
},
|
||||
"avg_iterations": {"$avg": "$last_iteration"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"count": 1,
|
||||
"avg_run_time_sec": {"$trunc": "$avg_run_time_sec"},
|
||||
"avg_iterations": {"$trunc": "$avg_iterations"},
|
||||
}
|
||||
},
|
||||
]
|
||||
return {
|
||||
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
|
||||
for group in Task.aggregate(pipeline)
|
||||
}
|
||||
|
||||
|
||||
class WarningFilter(logging.Filter):
|
||||
@classmethod
|
||||
def attach(cls):
|
||||
from urllib3.connectionpool import (
|
||||
ConnectionPool,
|
||||
) # required to make sure the logger is created
|
||||
|
||||
assert ConnectionPool # make sure import is not optimized out
|
||||
|
||||
logging.getLogger("urllib3.connectionpool").addFilter(cls())
|
||||
|
||||
def filter(self, record):
|
||||
if (
|
||||
record.levelno == logging.WARNING
|
||||
and len(record.args) > 2
|
||||
and record.args[2] == "/stats"
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@@ -3,5 +3,4 @@ from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
split_by,
|
||||
)
|
||||
86
apiserver/bll/task/artifacts.py
Normal file
86
apiserver/bll/task/artifacts.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from operator import itemgetter
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.database.utils import hash_field_name
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
|
||||
|
||||
def get_artifact_id(artifact: dict):
|
||||
"""
|
||||
Calculate id from 'key' and 'mode' fields
|
||||
Return hash on on the id so that it will not contain mongo illegal characters
|
||||
"""
|
||||
key_hash: str = hash_field_name(artifact["key"])
|
||||
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
|
||||
return f"{key_hash}_{mode}"
|
||||
|
||||
|
||||
def artifacts_prepare_for_save(fields: dict):
|
||||
artifacts_field = ("execution", "artifacts")
|
||||
artifacts = nested_get(fields, artifacts_field)
|
||||
if artifacts is None:
|
||||
return
|
||||
|
||||
nested_set(
|
||||
fields, artifacts_field, value={get_artifact_id(a): a for a in artifacts}
|
||||
)
|
||||
|
||||
|
||||
def artifacts_unprepare_from_saved(fields):
|
||||
artifacts_field = ("execution", "artifacts")
|
||||
artifacts = nested_get(fields, artifacts_field)
|
||||
if artifacts is None:
|
||||
return
|
||||
|
||||
nested_set(
|
||||
fields,
|
||||
artifacts_field,
|
||||
value=sorted(artifacts.values(), key=itemgetter("key")),
|
||||
)
|
||||
|
||||
|
||||
class Artifacts:
|
||||
@classmethod
|
||||
def add_or_update_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
artifacts: Sequence[ApiArtifact],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||
}
|
||||
|
||||
update_cmds = {
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
artifact_ids: Sequence[ArtifactId],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||
]
|
||||
delete_cmds = {
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
235
apiserver/bll/task/hyperparams.py
Normal file
235
apiserver/bll/task/hyperparams.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Dict
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.tasks import (
|
||||
HyperParamKey,
|
||||
HyperParamItem,
|
||||
ReplaceHyperparams,
|
||||
Configuration,
|
||||
)
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
class HyperParams:
|
||||
_properties_section = "properties"
|
||||
|
||||
@classmethod
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_params_list(
|
||||
cls, items: Dict[str, Dict[str, ParamsItem]]
|
||||
) -> Sequence[dict]:
|
||||
ret = list(chain.from_iterable(v.values() for v in items.values()))
|
||||
return [
|
||||
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _normalize_params(cls, params: Sequence) -> bool:
|
||||
"""
|
||||
Lower case properties section and return True if it is the only section
|
||||
"""
|
||||
for p in params:
|
||||
if p.section.lower() == cls._properties_section:
|
||||
p.section = cls._properties_section
|
||||
|
||||
return all(p.section == cls._properties_section for p in params)
|
||||
|
||||
@classmethod
|
||||
def delete_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamKey],
|
||||
force: bool,
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=delete_cmds, set_last_update=not properties_only
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
force: bool,
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[
|
||||
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
||||
] = value
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=update_cmds, set_last_update=not properties_only
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
sections = iterutils.bucketize(items, key=attrgetter("section"))
|
||||
return {
|
||||
ParameterKeyEscaper.escape(section): {
|
||||
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
|
||||
for param in params
|
||||
}
|
||||
for section, params in sections.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configurations(
|
||||
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
|
||||
) -> Dict[str, dict]:
|
||||
only = ["id"]
|
||||
if names:
|
||||
only.extend(
|
||||
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
|
||||
)
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {
|
||||
"configuration": [
|
||||
c.to_proper_dict()
|
||||
for c in sorted(task.configuration.values(), key=attrgetter("name"))
|
||||
]
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configuration_names(
|
||||
cls, company_id: str, task_ids: Sequence[str], skip_empty: bool
|
||||
) -> Dict[str, list]:
|
||||
skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}}
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
*([skip_empty_condition] if skip_empty else []),
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
98
apiserver/bll/task/non_responsive_tasks_watchdog.py
Normal file
98
apiserver/bll/task/non_responsive_tasks_watchdog.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
|
||||
from apiserver.bll.task import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import TaskStatus, Task
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class NonResponsiveTasksWatchdog:
|
||||
threads = ThreadsManager()
|
||||
|
||||
class _Settings:
|
||||
"""
|
||||
Retrieves watchdog settings from the config file
|
||||
The properties are not cached so that the updates in
|
||||
the config file are reflected
|
||||
"""
|
||||
|
||||
_prefix = "services.tasks.non_responsive_tasks_watchdog"
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return config.get(f"{self._prefix}.enabled", True)
|
||||
|
||||
@property
|
||||
def watch_interval_sec(self):
|
||||
return config.get(f"{self._prefix}.watch_interval_sec", 900)
|
||||
|
||||
@property
|
||||
def threshold_sec(self):
|
||||
return config.get(f"{self._prefix}.threshold_sec", 7200)
|
||||
|
||||
settings = _Settings()
|
||||
|
||||
@classmethod
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start(cls):
|
||||
sleep(cls.settings.watch_interval_sec)
|
||||
while True:
|
||||
watch_interval = cls.settings.watch_interval_sec
|
||||
if cls.settings.enabled:
|
||||
try:
|
||||
stopped = cls.cleanup_tasks(
|
||||
threshold_sec=cls.settings.threshold_sec
|
||||
)
|
||||
log.info(f"{stopped} non-responsive tasks stopped")
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
sleep(watch_interval)
|
||||
|
||||
@classmethod
|
||||
def cleanup_tasks(cls, threshold_sec):
|
||||
relevant_status = (TaskStatus.in_progress,)
|
||||
threshold = timedelta(seconds=threshold_sec)
|
||||
ref_time = datetime.utcnow() - threshold
|
||||
log.info(
|
||||
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
|
||||
)
|
||||
|
||||
tasks = list(
|
||||
Task.objects(status__in=relevant_status, last_update__lt=ref_time).only(
|
||||
"id", "name", "status", "project", "last_update"
|
||||
)
|
||||
)
|
||||
log.info(f"{len(tasks)} non-responsive tasks found")
|
||||
if not tasks:
|
||||
return 0
|
||||
|
||||
err_count = 0
|
||||
project_ids = set()
|
||||
now = datetime.utcnow()
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
updated = Task.objects(id=task.id, status=task.status).update(
|
||||
status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
)
|
||||
if updated:
|
||||
project_ids.add(task.project)
|
||||
else:
|
||||
err_count += 1
|
||||
except Exception as ex:
|
||||
log.error("Failed setting status: %s", str(ex))
|
||||
|
||||
update_project_time(list(project_ids))
|
||||
|
||||
return len(tasks) - err_count
|
||||
216
apiserver/bll/task/param_utils.py
Normal file
216
apiserver/bll/task/param_utils.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple, Optional
|
||||
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
hyperparams_default_section = "Args"
|
||||
hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
if default_section is None:
|
||||
return None, full_name
|
||||
|
||||
section, _, name = full_name.partition("/")
|
||||
if section != tf_define_section:
|
||||
return default_section, full_name
|
||||
|
||||
if not name:
|
||||
raise errors.bad_request.ValidationError("Parameter name cannot be empty")
|
||||
return section, name
|
||||
|
||||
|
||||
def _get_full_param_name(param: dict) -> str:
|
||||
section = param.get("section")
|
||||
if section != tf_define_section:
|
||||
return param["name"]
|
||||
|
||||
return "/".join((section, param["name"]))
|
||||
|
||||
|
||||
def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
removed = 0
|
||||
if not data:
|
||||
return removed
|
||||
|
||||
if with_sections:
|
||||
for section, section_data in list(data.items()):
|
||||
removed += _remove_legacy_params(section_data)
|
||||
if not section_data:
|
||||
"""If section is empty after removing legacy params then delete it"""
|
||||
del data[section]
|
||||
else:
|
||||
for key, param in list(data.items()):
|
||||
if param.get("type") == hyperparams_legacy_type:
|
||||
removed += 1
|
||||
del data[key]
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
param for param in data.values() if param.get("type") == hyperparams_legacy_type
|
||||
]
|
||||
|
||||
|
||||
def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
"""
|
||||
If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
|
||||
(("execution", "model_desc"), "configuration", None),
|
||||
):
|
||||
legacy_params = nested_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not fields.get(new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
previous_data = previous_task.to_proper_dict().get(new_params_field)
|
||||
removed = _remove_legacy_params(
|
||||
previous_data, with_sections=default_section is not None
|
||||
)
|
||||
if not legacy_params and not removed:
|
||||
# if we only need to delete legacy fields from the db
|
||||
# but they are not there then there is no point to proceed
|
||||
continue
|
||||
|
||||
fields_update = {new_params_field: previous_data}
|
||||
params_unprepare_from_saved(fields_update)
|
||||
fields.update(fields_update)
|
||||
|
||||
for full_name, value in legacy_params.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_params_field, section, name)))
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
nested_set(fields, new_path, new_param)
|
||||
nested_delete(fields, old_params_field)
|
||||
|
||||
def ensure_non_empty(k: str, desc: str) -> str:
|
||||
if not k:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"Empty {desc} name is not allowed"
|
||||
)
|
||||
return k
|
||||
|
||||
params = fields.get("hyperparams")
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(ensure_non_empty(key, "section")): {
|
||||
ParameterKeyEscaper.escape(ensure_non_empty(k, "parameter")): v
|
||||
for k, v in value.items()
|
||||
}
|
||||
for key, value in params.items()
|
||||
}
|
||||
fields["hyperparams"] = escaped_params
|
||||
|
||||
params = fields.get("configuration")
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(ensure_non_empty(key, "configuration")): value
|
||||
for key, value in params.items()
|
||||
}
|
||||
fields["configuration"] = escaped_params
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
"""
|
||||
Unescape all section and param names for hyper params and configuration
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = fields.get(param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
ParameterKeyEscaper.unescape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
fields[param_field] = unescaped_params
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
("hyperparams", ("execution", "parameters"), True),
|
||||
("configuration", ("execution", "model_desc"), False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
fields.get(new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
nested_set(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
)
|
||||
|
||||
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 4:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", "configuration"),
|
||||
("execution.docker_cmd", "container"),
|
||||
):
|
||||
path: str
|
||||
paths = [path.replace(old_prefix, new_prefix) for path in paths]
|
||||
|
||||
for prefix in (
|
||||
"hyperparams.",
|
||||
"-hyperparams.",
|
||||
"configuration.",
|
||||
"-configuration.",
|
||||
):
|
||||
paths = [
|
||||
_process_path(path) if path.startswith(prefix) else path for path in paths
|
||||
]
|
||||
return paths
|
||||
511
apiserver/bll/task/task_bll.py
Normal file
511
apiserver/bll/task/task_bll.py
Normal file
@@ -0,0 +1,511 @@
|
||||
from datetime import datetime
|
||||
from typing import Collection, Sequence, Tuple, Optional, Dict
|
||||
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from redis import StrictRedis
|
||||
from six import string_types
|
||||
|
||||
import apiserver.database.utils as dbutils
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.tasks import TaskInputModel
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.metrics import EventStats, MetricEventStats
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
ModelItem,
|
||||
Models,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
TaskModelNames,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
deleted_prefix,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
queue_bll = QueueBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
class TaskBLL:
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.events_es = events_es or es_factory.connect("events")
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
|
||||
):
|
||||
if only_fields:
|
||||
if isinstance(only_fields, string_types):
|
||||
only_fields = [only_fields]
|
||||
else:
|
||||
only_fields = list(only_fields)
|
||||
only_fields = only_fields + ["status"]
|
||||
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
if required_status and not task.status == required_status:
|
||||
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(
|
||||
company_id, task_ids, only=None, allow_public=False, return_tasks=True
|
||||
) -> Optional[Sequence[Task]]:
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context():
|
||||
ids = set(task_ids)
|
||||
q = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=ids),
|
||||
allow_public=allow_public,
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
# Make sure to reset fields filters (some fields are excluded by default) since this
|
||||
# is an internal call and specific fields were requested.
|
||||
q = q.all_fields().only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidTaskId(ids=task_ids)
|
||||
|
||||
if return_tasks:
|
||||
return list(q)
|
||||
|
||||
@staticmethod
|
||||
def create(call: APICall, fields: dict):
|
||||
identity = call.identity
|
||||
now = datetime.utcnow()
|
||||
return Task(
|
||||
id=create_id(),
|
||||
user=identity.user,
|
||||
company=identity.company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
**fields,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_input_models(task, allow_only_public=False):
|
||||
if not task.models.input:
|
||||
return
|
||||
|
||||
company = None if allow_only_public else task.company
|
||||
model_ids = set(m.model for m in task.models.input)
|
||||
models = Model.objects(
|
||||
Q(id__in=model_ids) & get_company_or_none_constraint(company)
|
||||
).only("id")
|
||||
missing = model_ids - {m.id for m in models}
|
||||
if missing:
|
||||
raise errors.bad_request.InvalidModelId(models=missing)
|
||||
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def clone_task(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
name: Optional[str] = None,
|
||||
comment: Optional[str] = None,
|
||||
parent: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
container: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
input_models: Optional[Sequence[TaskInputModel]] = None,
|
||||
validate_references: bool = False,
|
||||
new_project_name: str = None,
|
||||
) -> Tuple[Task, dict]:
|
||||
validate_tags(tags, system_tags)
|
||||
params_dict = {
|
||||
field: value
|
||||
for field, value in (
|
||||
("hyperparams", hyperparams),
|
||||
("configuration", configuration),
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
|
||||
now = datetime.utcnow()
|
||||
if input_models:
|
||||
input_models = [
|
||||
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
|
||||
]
|
||||
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
if execution_overrides:
|
||||
execution_model = execution_overrides.pop("model", None)
|
||||
if not input_models and execution_model:
|
||||
input_models = [
|
||||
ModelItem(
|
||||
model=execution_model,
|
||||
name=TaskModelNames[TaskModelTypes.input],
|
||||
updated=now,
|
||||
)
|
||||
]
|
||||
|
||||
docker_cmd = execution_overrides.pop("docker_cmd", None)
|
||||
if not container and docker_cmd:
|
||||
image, _, arguments = docker_cmd.partition(" ")
|
||||
container = {"image": image, "arguments": arguments}
|
||||
|
||||
artifacts_prepare_for_save({"execution": execution_overrides})
|
||||
|
||||
params_dict["execution"] = {}
|
||||
for legacy_param in ("parameters", "configuration"):
|
||||
legacy_value = execution_overrides.pop(legacy_param, None)
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
|
||||
escape_dict_field(execution_overrides, "model_labels")
|
||||
|
||||
execution_dict.update(execution_overrides)
|
||||
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
execution_dict["artifacts"] = {
|
||||
k: a
|
||||
for k, a in artifacts.items()
|
||||
if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output
|
||||
}
|
||||
execution_dict.pop("queue", None)
|
||||
|
||||
new_project_data = None
|
||||
if not project and new_project_name:
|
||||
# Use a project with the provided name, or create a new project
|
||||
project = ProjectBLL.find_or_create(
|
||||
project_name=new_project_name,
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
description="",
|
||||
)
|
||||
new_project_data = {"id": project, "name": new_project_name}
|
||||
|
||||
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
|
||||
if not input_tags:
|
||||
return input_tags
|
||||
|
||||
return [
|
||||
tag
|
||||
for tag in input_tags
|
||||
if tag
|
||||
not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
]
|
||||
|
||||
parent_task = (
|
||||
task.parent
|
||||
if task.parent and not task.parent.startswith(deleted_prefix)
|
||||
else task.id
|
||||
)
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or parent_task,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination) if task.output else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
container=escape_dict(container) or task.container,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_models=validate_references or input_models,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
update_project_time(new_task.project)
|
||||
|
||||
return new_task, new_project_data
|
||||
|
||||
@classmethod
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_models=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
"""
|
||||
Validate task properties according to the flag
|
||||
Task project is always checked for being writable
|
||||
in order to disable the modification of public projects
|
||||
"""
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
and not task.parent.startswith(deleted_prefix)
|
||||
and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
)
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if task.project:
|
||||
project = Project.get_for_writing(company=task.company, id=task.project)
|
||||
if validate_project and not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
if validate_models:
|
||||
cls.validate_input_models(task)
|
||||
|
||||
@staticmethod
|
||||
def set_last_update(
|
||||
task_ids: Collection[str],
|
||||
company_id: str,
|
||||
last_update: datetime,
|
||||
**extra_updates,
|
||||
):
|
||||
tasks = Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"status", "started"
|
||||
)
|
||||
count = 0
|
||||
for task in tasks:
|
||||
updates = extra_updates
|
||||
if task.status == TaskStatus.in_progress and task.started:
|
||||
updates = {
|
||||
"active_duration": (
|
||||
datetime.utcnow() - task.started
|
||||
).total_seconds(),
|
||||
**extra_updates,
|
||||
}
|
||||
count += Task.objects(id=task.id, company=company_id).update(
|
||||
upsert=False,
|
||||
last_update=last_update,
|
||||
last_change=last_update,
|
||||
**updates,
|
||||
)
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
last_update: datetime = None,
|
||||
last_iteration: int = None,
|
||||
last_iteration_max: int = None,
|
||||
last_scalar_events: Dict[str, Dict[str, dict]] = None,
|
||||
last_events: Dict[str, Dict[str, dict]] = None,
|
||||
**extra_updates,
|
||||
):
|
||||
"""
|
||||
Update task statistics
|
||||
:param task_id: Task's ID.
|
||||
:param company_id: Task's company ID.
|
||||
:param last_update: Last update time. If not provided, defaults to datetime.utcnow().
|
||||
:param last_iteration: Last reported iteration. Use this to set a value regardless of current
|
||||
task's last iteration value.
|
||||
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
||||
if the current task's last iteration value is smaller than the provided value.
|
||||
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
|
||||
:param last_events: Last reported metrics summary (value, metric, event type).
|
||||
:param extra_updates: Extra task updates to include in this update call.
|
||||
:return:
|
||||
"""
|
||||
last_update = last_update or datetime.utcnow()
|
||||
|
||||
if last_iteration is not None:
|
||||
extra_updates.update(last_iteration=last_iteration)
|
||||
elif last_iteration_max is not None:
|
||||
extra_updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
if last_scalar_events is not None:
|
||||
max_values = config.get("services.tasks.max_last_metrics", 2000)
|
||||
total_metrics = set()
|
||||
if max_values:
|
||||
query = dict(id=task_id)
|
||||
to_add = sum(len(v) for m, v in last_scalar_events.items())
|
||||
if to_add <= max_values:
|
||||
query[f"unique_metrics__{max_values-to_add}__exists"] = True
|
||||
task = Task.objects(**query).only("unique_metrics").first()
|
||||
if task and task.unique_metrics:
|
||||
total_metrics = set(task.unique_metrics)
|
||||
|
||||
new_metrics = []
|
||||
for metric_key, metric_data in last_scalar_events.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
metric = (
|
||||
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
|
||||
)
|
||||
if max_values:
|
||||
if (
|
||||
len(total_metrics) >= max_values
|
||||
and metric not in total_metrics
|
||||
):
|
||||
continue
|
||||
total_metrics.add(metric)
|
||||
|
||||
new_metrics.append(metric)
|
||||
path = f"last_metrics__{metric_key}__{variant_key}"
|
||||
for key, value in variant_data.items():
|
||||
if key == "min_value":
|
||||
extra_updates[f"min__{path}__min_value"] = value
|
||||
elif key == "max_value":
|
||||
extra_updates[f"max__{path}__max_value"] = value
|
||||
elif key in ("metric", "variant", "value"):
|
||||
extra_updates[f"set__{path}__{key}"] = value
|
||||
if new_metrics:
|
||||
extra_updates["add_to_set__unique_metrics"] = new_metrics
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
|
||||
return {
|
||||
event_type: EventStats(last_update=event["timestamp"])
|
||||
for event_type, event in metric_data.items()
|
||||
}
|
||||
|
||||
metric_stats = {
|
||||
dbutils.hash_field_name(metric_key): MetricEventStats(
|
||||
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
|
||||
)
|
||||
for metric_key, metric_data in last_events.items()
|
||||
}
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
return TaskBLL.set_last_update(
|
||||
task_ids=[task_id],
|
||||
company_id=company_id,
|
||||
last_update=last_update,
|
||||
**extra_updates,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
):
|
||||
try:
|
||||
cls.dequeue(task, company_id)
|
||||
except errors.bad_request.InvalidQueueOrTaskNotQueued:
|
||||
# dequeue may fail if the queue was deleted
|
||||
pass
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=task.enqueue_status or TaskStatus.created,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(enqueue_status=None)
|
||||
|
||||
@classmethod
|
||||
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
|
||||
"""
|
||||
Dequeue the task from the queue
|
||||
:param task: task to dequeue
|
||||
:param company_id: task's company ID.
|
||||
:param silent_fail: do not throw exceptions. APIError is still thrown
|
||||
:raise errors.bad_request.InvalidTaskId: if the task's status is not queued
|
||||
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
|
||||
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
|
||||
:return: the result of queues.remove_task call. None in case of silent failure
|
||||
"""
|
||||
if task.status not in (TaskStatus.queued,):
|
||||
if silent_fail:
|
||||
return
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
status=task.status, expected=TaskStatus.queued
|
||||
)
|
||||
|
||||
if not task.execution or not task.execution.queue:
|
||||
if silent_fail:
|
||||
return
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"task has no queue value", field="execution.queue"
|
||||
)
|
||||
|
||||
return {
|
||||
"removed": queue_bll.remove_task(
|
||||
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
|
||||
)
|
||||
}
|
||||
324
apiserver/bll/task/task_cleanup.py
Normal file
324
apiserver/bll/task/task_cleanup.py
Normal file
@@ -0,0 +1,324 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import partition, bucketize, first
|
||||
from mongoengine import NotUniqueError
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_bll import PlotFields
|
||||
from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
|
||||
from apiserver.database.model.url_to_delete import (
|
||||
StorageType,
|
||||
UrlToDelete,
|
||||
FileType,
|
||||
DeletionStatus,
|
||||
)
|
||||
from apiserver.database.utils import id as db_id
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
async_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskUrls:
|
||||
model_urls: Sequence[str]
|
||||
event_urls: Sequence[str]
|
||||
artifact_urls: Sequence[str]
|
||||
|
||||
def __add__(self, other: "TaskUrls"):
|
||||
if not other:
|
||||
return self
|
||||
|
||||
return TaskUrls(
|
||||
model_urls=list(set(self.model_urls) | set(other.model_urls)),
|
||||
event_urls=list(set(self.event_urls) | set(other.event_urls)),
|
||||
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class CleanupResult:
|
||||
"""
|
||||
Counts of objects modified in task cleanup operation
|
||||
"""
|
||||
|
||||
updated_children: int
|
||||
updated_models: int
|
||||
deleted_models: int
|
||||
urls: TaskUrls = None
|
||||
|
||||
def __add__(self, other: "CleanupResult"):
|
||||
if not other:
|
||||
return self
|
||||
|
||||
return CleanupResult(
|
||||
updated_children=self.updated_children + other.updated_children,
|
||||
updated_models=self.updated_models + other.updated_models,
|
||||
deleted_models=self.deleted_models + other.deleted_models,
|
||||
urls=self.urls + other.urls if self.urls else other.urls,
|
||||
)
|
||||
|
||||
|
||||
def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
|
||||
urls = set()
|
||||
next_scroll_id = None
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
|
||||
"""
|
||||
Return the set of unique image urls
|
||||
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
|
||||
"""
|
||||
after_key = None
|
||||
urls = set()
|
||||
while True:
|
||||
res, after_key = event_bll.get_debug_image_urls(
|
||||
company_id=company, task_id=task_or_model, after_key=after_key,
|
||||
)
|
||||
urls.update(res)
|
||||
if not after_key:
|
||||
break
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
supported_storage_types = {
|
||||
"https://": StorageType.fileserver,
|
||||
"http://": StorageType.fileserver,
|
||||
}
|
||||
|
||||
|
||||
def _schedule_for_delete(
|
||||
company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
|
||||
) -> Set[str]:
|
||||
urls_per_storage = bucketize(
|
||||
urls,
|
||||
key=lambda u: first(
|
||||
type_
|
||||
for prefix, type_ in supported_storage_types.items()
|
||||
if u.startswith(prefix)
|
||||
),
|
||||
)
|
||||
urls_per_storage.pop(None, None)
|
||||
|
||||
processed_urls = set()
|
||||
for storage_type, storage_urls in urls_per_storage.items():
|
||||
delete_folders = (storage_type == StorageType.fileserver) and can_delete_folders
|
||||
scheduled_to_delete = set()
|
||||
for url in storage_urls:
|
||||
folder = None
|
||||
if delete_folders:
|
||||
folder, _, _ = url.rpartition("/")
|
||||
|
||||
to_delete = folder or url
|
||||
if to_delete in scheduled_to_delete:
|
||||
processed_urls.add(url)
|
||||
continue
|
||||
|
||||
try:
|
||||
UrlToDelete(
|
||||
id=db_id(),
|
||||
company=company,
|
||||
user=user,
|
||||
url=to_delete,
|
||||
task=task_id,
|
||||
created=datetime.utcnow(),
|
||||
storage_type=storage_type,
|
||||
type=FileType.folder if folder else FileType.file,
|
||||
).save()
|
||||
except (DuplicateKeyError, NotUniqueError):
|
||||
existing = UrlToDelete.objects(company=company, url=to_delete).first()
|
||||
if existing:
|
||||
existing.update(
|
||||
user=user,
|
||||
task=task_id,
|
||||
created=datetime.utcnow(),
|
||||
retry_count=0,
|
||||
unset__last_failure_time=1,
|
||||
unset__last_failure_reason=1,
|
||||
status=DeletionStatus.created,
|
||||
)
|
||||
processed_urls.add(url)
|
||||
scheduled_to_delete.add(to_delete)
|
||||
|
||||
return processed_urls
|
||||
|
||||
|
||||
def cleanup_task(
|
||||
company: str,
|
||||
user: str,
|
||||
task: Task,
|
||||
force: bool = False,
|
||||
update_children=True,
|
||||
return_file_urls=False,
|
||||
delete_output_models=True,
|
||||
delete_external_artifacts=True,
|
||||
) -> CleanupResult:
|
||||
"""
|
||||
Validate task deletion and delete/modify all its output.
|
||||
:param task: task object
|
||||
:param force: whether to delete task with published outputs
|
||||
:return: count of delete and modified items
|
||||
"""
|
||||
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
|
||||
task, force
|
||||
)
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
"services.async_urls_delete.enabled", False
|
||||
)
|
||||
event_urls, artifact_urls, model_urls = set(), set(), set()
|
||||
if return_file_urls or delete_external_artifacts:
|
||||
event_urls = collect_debug_image_urls(task.company, task.id)
|
||||
event_urls.update(collect_plot_image_urls(task.company, task.id))
|
||||
if task.execution and task.execution.artifacts:
|
||||
artifact_urls = {
|
||||
a.uri
|
||||
for a in task.execution.artifacts.values()
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
}
|
||||
model_urls = {
|
||||
m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids
|
||||
}
|
||||
|
||||
deleted_task_id = f"{deleted_prefix}{task.id}"
|
||||
updated_children = 0
|
||||
if update_children:
|
||||
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
|
||||
|
||||
deleted_models = 0
|
||||
updated_models = 0
|
||||
for models, allow_delete in ((draft_models, True), (published_models, False)):
|
||||
if not models:
|
||||
continue
|
||||
if delete_output_models and allow_delete:
|
||||
model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
|
||||
for m_id in model_ids:
|
||||
if return_file_urls or delete_external_artifacts:
|
||||
event_urls.update(collect_debug_image_urls(task.company, m_id))
|
||||
event_urls.update(collect_plot_image_urls(task.company, m_id))
|
||||
try:
|
||||
event_bll.delete_task_events(
|
||||
task.company,
|
||||
m_id,
|
||||
allow_locked=True,
|
||||
model=True,
|
||||
async_delete=async_events_delete,
|
||||
)
|
||||
except errors.bad_request.InvalidModelId as ex:
|
||||
log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
|
||||
|
||||
deleted_models += Model.objects(id__in=list(model_ids)).delete()
|
||||
if in_use_model_ids:
|
||||
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
|
||||
continue
|
||||
|
||||
if update_children:
|
||||
updated_models += Model.objects(id__in=[m.id for m in models]).update(
|
||||
task=deleted_task_id
|
||||
)
|
||||
else:
|
||||
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
|
||||
|
||||
event_bll.delete_task_events(
|
||||
task.company, task.id, allow_locked=force, async_delete=async_events_delete
|
||||
)
|
||||
|
||||
if delete_external_artifacts:
|
||||
scheduled = _schedule_for_delete(
|
||||
task_id=task.id,
|
||||
company=company,
|
||||
user=user,
|
||||
urls=event_urls | model_urls | artifact_urls,
|
||||
can_delete_folders=not in_use_model_ids and not published_models,
|
||||
)
|
||||
for urls in (event_urls, model_urls, artifact_urls):
|
||||
urls.difference_update(scheduled)
|
||||
|
||||
return CleanupResult(
|
||||
deleted_models=deleted_models,
|
||||
updated_children=updated_children,
|
||||
updated_models=updated_models,
|
||||
urls=TaskUrls(
|
||||
event_urls=list(event_urls),
|
||||
artifact_urls=list(artifact_urls),
|
||||
model_urls=list(model_urls),
|
||||
)
|
||||
if return_file_urls
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
def verify_task_children_and_ouptuts(
|
||||
task, force: bool
|
||||
) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]:
|
||||
if not force:
|
||||
published_children_count = Task.objects(
|
||||
parent=task.id, status=TaskStatus.published
|
||||
).count()
|
||||
if published_children_count:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has children, use force=True",
|
||||
task=task.id,
|
||||
children=published_children_count,
|
||||
)
|
||||
|
||||
model_fields = ["id", "ready", "uri"]
|
||||
published_models, draft_models = partition(
|
||||
Model.objects(task=task.id).only(*model_fields), key=attrgetter("ready"),
|
||||
)
|
||||
if not force and published_models:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output models, use force=True",
|
||||
task=task.id,
|
||||
models=len(published_models),
|
||||
)
|
||||
|
||||
if task.models and task.models.output:
|
||||
model_ids = [m.model for m in task.models.output]
|
||||
for output_model in Model.objects(id__in=model_ids).only(*model_fields):
|
||||
if output_model.ready:
|
||||
if not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True",
|
||||
task=task.id,
|
||||
model=output_model.id,
|
||||
)
|
||||
published_models.append(output_model)
|
||||
else:
|
||||
draft_models.append(output_model)
|
||||
|
||||
in_use_model_ids = {}
|
||||
if draft_models:
|
||||
model_ids = {m.id for m in draft_models}
|
||||
dependent_tasks = Task.objects(models__input__model__in=list(model_ids)).only(
|
||||
"id", "models"
|
||||
)
|
||||
in_use_model_ids = model_ids & {
|
||||
m.model
|
||||
for m in chain.from_iterable(
|
||||
t.models.input for t in dependent_tasks if t.models
|
||||
)
|
||||
}
|
||||
|
||||
return published_models, draft_models, in_use_model_ids
|
||||
449
apiserver/bll/task/task_operations.py
Normal file
449
apiserver/bll/task/task_operations.py
Normal file
@@ -0,0 +1,449 @@
|
||||
from datetime import datetime
|
||||
from typing import Callable, Any, Tuple, Union, Sequence
|
||||
|
||||
from apiserver.apierrors import errors, APIError
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
validate_status_change,
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
)
|
||||
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
TaskStatus,
|
||||
Task,
|
||||
TaskSystemTags,
|
||||
TaskStatusMessage,
|
||||
ArtifactModes,
|
||||
Execution,
|
||||
DEFAULT_LAST_ITERATION,
|
||||
)
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
|
||||
log = config.logger(__file__)
|
||||
queue_bll = QueueBLL()
|
||||
|
||||
|
||||
def archive_task(
|
||||
task: Union[str, Task], company_id: str, status_message: str, status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Deque and archive task
|
||||
Return 1 if successful
|
||||
"""
|
||||
if isinstance(task, str):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task,
|
||||
company_id=company_id,
|
||||
only=(
|
||||
"id",
|
||||
"execution",
|
||||
"status",
|
||||
"project",
|
||||
"system_tags",
|
||||
"enqueue_status",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task, company_id, status_message, status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
return task.update(
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
add_to_set__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def unarchive_task(
|
||||
task: str, company_id: str, status_message: str, status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Unarchive task. Return 1 if successful
|
||||
"""
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task, company_id=company_id, only=("id",), requires_write_access=True,
|
||||
)
|
||||
return task.update(
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def dequeue_task(
|
||||
task_id: str, company_id: str, status_message: str, status_reason: str,
|
||||
) -> Tuple[int, dict]:
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(**query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
res = TaskBLL.dequeue_and_change_status(
|
||||
task, company_id, status_message=status_message, status_reason=status_reason,
|
||||
)
|
||||
return 1, res
|
||||
|
||||
|
||||
def enqueue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
queue_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
queue_name: str = None,
|
||||
validate: bool = False,
|
||||
force: bool = False,
|
||||
) -> Tuple[int, dict]:
|
||||
if queue_id and queue_name:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Either queue id or queue name should be provided"
|
||||
)
|
||||
|
||||
if queue_name:
|
||||
queue = queue_bll.get_by_name(
|
||||
company_id=company_id, queue_name=queue_name, only=("id",)
|
||||
)
|
||||
if not queue:
|
||||
queue = queue_bll.create(company_id=company_id, name=queue_name)
|
||||
queue_id = queue.id
|
||||
|
||||
if not queue_id:
|
||||
# try to get default queue
|
||||
queue_id = queue_bll.get_default(company_id).id
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(**query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
if validate:
|
||||
TaskBLL.validate(task)
|
||||
|
||||
res = ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.queued,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
allow_same_state_transition=False,
|
||||
force=force,
|
||||
).execute(enqueue_status=task.status)
|
||||
|
||||
try:
|
||||
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
|
||||
except Exception:
|
||||
# failed enqueueing, revert to previous state
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
current_status_override=TaskStatus.queued,
|
||||
new_status=task.status,
|
||||
force=True,
|
||||
status_reason="failed enqueueing",
|
||||
).execute(enqueue_status=None)
|
||||
raise
|
||||
|
||||
# set the current queue ID in the task
|
||||
if task.execution:
|
||||
Task.objects(**query).update(execution__queue=queue_id, multi=False)
|
||||
else:
|
||||
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
|
||||
|
||||
nested_set(res, ("fields", "execution.queue"), queue_id)
|
||||
return 1, res
|
||||
|
||||
|
||||
def move_tasks_to_trash(tasks: Sequence[str]) -> int:
|
||||
try:
|
||||
collection_name = Task._get_collection_name()
|
||||
trash_collection_name = f"{collection_name}__trash"
|
||||
Task.aggregate(
|
||||
[
|
||||
{"$match": {"_id": {"$in": tasks}}},
|
||||
{
|
||||
"$merge": {
|
||||
"into": trash_collection_name,
|
||||
"on": "_id",
|
||||
"whenMatched": "replace",
|
||||
"whenNotMatched": "insert",
|
||||
}
|
||||
},
|
||||
],
|
||||
allow_disk_use=True,
|
||||
)
|
||||
except Exception as ex:
|
||||
log.error(f"Error copying tasks to trash {str(ex)}")
|
||||
|
||||
return Task.objects(id__in=tasks).delete()
|
||||
|
||||
|
||||
def delete_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
move_to_trash: bool,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
if (
|
||||
task.status != TaskStatus.created
|
||||
and EntityVisibility.archived.value not in task.system_tags
|
||||
and not force
|
||||
):
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"due to status, use force=True",
|
||||
task=task.id,
|
||||
expected=TaskStatus.created,
|
||||
current=task.status,
|
||||
)
|
||||
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id=company_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
cleanup_res = cleanup_task(
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
task=task,
|
||||
force=force,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
delete_external_artifacts=delete_external_artifacts,
|
||||
)
|
||||
|
||||
if move_to_trash:
|
||||
# make sure that whatever changes were done to the task are saved
|
||||
# the task itself will be deleted later in the move_tasks_to_trash operation
|
||||
task.save()
|
||||
else:
|
||||
task.delete()
|
||||
|
||||
update_project_time(task.project)
|
||||
return 1, task, cleanup_res
|
||||
|
||||
|
||||
def reset_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
clear_all: bool,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[dict, CleanupResult, dict]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
if not force and task.status == TaskStatus.published:
|
||||
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
|
||||
|
||||
dequeued = {}
|
||||
updates = {}
|
||||
|
||||
try:
|
||||
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
cleaned_up = cleanup_task(
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
task=task,
|
||||
force=force,
|
||||
update_children=False,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
delete_external_artifacts=delete_external_artifacts,
|
||||
)
|
||||
|
||||
updates.update(
|
||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||
set__last_metrics={},
|
||||
set__unique_metrics=[],
|
||||
set__metric_stats={},
|
||||
set__models__output=[],
|
||||
set__runtime={},
|
||||
unset__output__result=1,
|
||||
unset__output__error=1,
|
||||
unset__last_worker=1,
|
||||
unset__last_worker_report=1,
|
||||
)
|
||||
|
||||
if clear_all:
|
||||
updates.update(
|
||||
set__execution=Execution(), unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(unset__execution__queue=1)
|
||||
if task.execution and task.execution.artifacts:
|
||||
updates.update(
|
||||
set__execution__artifacts={
|
||||
key: artifact
|
||||
for key, artifact in task.execution.artifacts.items()
|
||||
if artifact.mode == ArtifactModes.input
|
||||
}
|
||||
)
|
||||
|
||||
res = ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.created,
|
||||
force=force,
|
||||
status_reason="reset",
|
||||
status_message="reset",
|
||||
).execute(
|
||||
started=None,
|
||||
completed=None,
|
||||
published=None,
|
||||
active_duration=None,
|
||||
enqueue_status=None,
|
||||
**updates,
|
||||
)
|
||||
|
||||
return dequeued, cleaned_up, res
|
||||
|
||||
|
||||
def publish_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
force: bool,
|
||||
publish_model_func: Callable[[str, str], Any] = None,
|
||||
status_message: str = "",
|
||||
status_reason: str = "",
|
||||
) -> dict:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
if not force:
|
||||
validate_status_change(task.status, TaskStatus.published)
|
||||
|
||||
previous_task_status = task.status
|
||||
output = task.output or Output()
|
||||
publish_failed = False
|
||||
|
||||
try:
|
||||
# set state to publishing
|
||||
task.status = TaskStatus.publishing
|
||||
task.save()
|
||||
|
||||
# publish task models
|
||||
if task.models and task.models.output and publish_model_func:
|
||||
model_id = task.models.output[-1].model
|
||||
model = (
|
||||
Model.objects(id=model_id, company=company_id)
|
||||
.only("id", "ready")
|
||||
.first()
|
||||
)
|
||||
if model and not model.ready:
|
||||
publish_model_func(model.id, company_id)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.published,
|
||||
force=force,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(published=datetime.utcnow(), output=output)
|
||||
|
||||
except Exception as ex:
|
||||
publish_failed = True
|
||||
raise ex
|
||||
finally:
|
||||
if publish_failed:
|
||||
task.status = previous_task_status
|
||||
task.save()
|
||||
|
||||
|
||||
def stop_task(
|
||||
task_id: str, company_id: str, user_name: str, status_reason: str, force: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Stop a running task. Requires task status 'in_progress' and
|
||||
execution_progress 'running', or force=True. Development task or
|
||||
task that has no associated worker is stopped immediately.
|
||||
For a non-development task with worker only the status message
|
||||
is set to 'stopping' to allow the worker to stop the task and report by itself
|
||||
:return: updated task fields
|
||||
"""
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
only=(
|
||||
"status",
|
||||
"project",
|
||||
"tags",
|
||||
"system_tags",
|
||||
"last_worker",
|
||||
"last_update",
|
||||
"execution.queue",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
def is_run_by_worker(t: Task) -> bool:
|
||||
"""Checks if there is an active worker running the task"""
|
||||
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
|
||||
return (
|
||||
t.last_worker
|
||||
and t.last_update
|
||||
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
|
||||
)
|
||||
|
||||
is_queued = task.status == TaskStatus.queued
|
||||
set_stopped = (
|
||||
is_queued
|
||||
or TaskSystemTags.development in task.system_tags
|
||||
or not is_run_by_worker(task)
|
||||
)
|
||||
|
||||
if set_stopped:
|
||||
if is_queued:
|
||||
try:
|
||||
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
new_status = task.status
|
||||
status_message = TaskStatusMessage.stopping
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=new_status,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
).execute()
|
||||
@@ -1,18 +1,18 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Callable, Tuple, Sequence
|
||||
from typing import Sequence, Union
|
||||
|
||||
import attr
|
||||
import six
|
||||
|
||||
from apierrors import errors
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from database.utils import get_options
|
||||
from timing_context import TimingContext
|
||||
from utilities.attrs import typed_attrs
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
deleted_prefix = "__DELETED__"
|
||||
|
||||
|
||||
@typed_attrs
|
||||
@@ -25,9 +25,10 @@ class ChangeStatusRequest(object):
|
||||
status_message = attr.ib(type=six.string_types, default="")
|
||||
force = attr.ib(type=bool, default=False)
|
||||
allow_same_state_transition = attr.ib(type=bool, default=True)
|
||||
current_status_override = attr.ib(default=None)
|
||||
|
||||
def execute(self, **kwargs):
|
||||
current_status = self.task.status
|
||||
current_status = self.current_status_override or self.task.status
|
||||
project_id = self.task.project
|
||||
|
||||
# Verify new status is allowed from current status (will throw exception if not valid)
|
||||
@@ -42,14 +43,18 @@ class ChangeStatusRequest(object):
|
||||
status_message=self.status_message,
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
)
|
||||
|
||||
if self.new_status == TaskStatus.queued:
|
||||
fields["pull__system_tags"] = TaskSystemTags.development
|
||||
|
||||
def safe_mongoengine_key(key):
|
||||
return f"__{key}" if key in control else key
|
||||
|
||||
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_status"):
|
||||
with translate_errors_context():
|
||||
# atomic change of task status by querying the task with the EXPECTED status before modifying it
|
||||
params = fields.copy()
|
||||
params.update(control)
|
||||
@@ -66,6 +71,10 @@ class ChangeStatusRequest(object):
|
||||
)
|
||||
|
||||
update_project_time(project_id)
|
||||
|
||||
# make sure that _raw_ queries are not returned back to the client
|
||||
fields.pop("__raw__", None)
|
||||
|
||||
return dict(updated=updated, fields=fields)
|
||||
|
||||
def validate_transition(self, current_status):
|
||||
@@ -95,15 +104,23 @@ def validate_status_change(current_status, new_status):
|
||||
|
||||
|
||||
state_machine = {
|
||||
TaskStatus.created: {TaskStatus.in_progress},
|
||||
TaskStatus.in_progress: {TaskStatus.stopped, TaskStatus.failed, TaskStatus.created},
|
||||
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
|
||||
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
|
||||
TaskStatus.in_progress: {
|
||||
TaskStatus.stopped,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.created,
|
||||
TaskStatus.completed,
|
||||
},
|
||||
TaskStatus.stopped: {
|
||||
TaskStatus.closed,
|
||||
TaskStatus.created,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.queued,
|
||||
TaskStatus.in_progress,
|
||||
TaskStatus.published,
|
||||
TaskStatus.publishing,
|
||||
TaskStatus.completed,
|
||||
},
|
||||
TaskStatus.closed: {
|
||||
TaskStatus.created,
|
||||
@@ -115,6 +132,11 @@ state_machine = {
|
||||
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
|
||||
TaskStatus.publishing: {TaskStatus.published},
|
||||
TaskStatus.published: set(),
|
||||
TaskStatus.completed: {
|
||||
TaskStatus.published,
|
||||
TaskStatus.in_progress,
|
||||
TaskStatus.created,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -124,28 +146,50 @@ def get_possible_status_changes(current_status):
|
||||
:return possible states from current state
|
||||
"""
|
||||
possible = state_machine.get(current_status)
|
||||
assert (
|
||||
possible is not None
|
||||
), f"Current status {current_status} not supported by state machine"
|
||||
if possible is None:
|
||||
raise errors.server_error.InternalError(
|
||||
f"Current status {current_status} not supported by state machine"
|
||||
)
|
||||
|
||||
return possible
|
||||
|
||||
|
||||
def update_project_time(project_id):
|
||||
if project_id:
|
||||
Project.objects(id=project_id).update(last_update=datetime.utcnow())
|
||||
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
if not project_ids:
|
||||
return
|
||||
|
||||
if isinstance(project_ids, str):
|
||||
project_ids = [project_ids]
|
||||
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def split_by(
|
||||
condition: Callable[[T], bool], items: Sequence[T]
|
||||
) -> Tuple[Sequence[T], Sequence[T]]:
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
) -> Task:
|
||||
"""
|
||||
split "items" to two lists by "condition"
|
||||
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||
"""
|
||||
applied = zip(map(condition, items), items)
|
||||
return (
|
||||
[item for cond, item in applied if cond],
|
||||
[item for cond, item in applied if not cond],
|
||||
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
allowed_statuses = (
|
||||
[TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created]
|
||||
)
|
||||
if task.status not in allowed_statuses:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
|
||||
now = datetime.utcnow()
|
||||
last_updates = dict(last_change=now)
|
||||
if set_last_update:
|
||||
last_updates.update(last_update=now)
|
||||
return task.update(**update_cmds, **last_updates)
|
||||
@@ -1,7 +1,9 @@
|
||||
from apierrors import errors
|
||||
from apimodels.users import CreateRequest
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.user import User
|
||||
from datetime import datetime
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.users import CreateRequest
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.user import User
|
||||
|
||||
|
||||
class UserBLL:
|
||||
@@ -12,7 +14,7 @@ class UserBLL:
|
||||
if user_id and User.objects(id=user_id).only("id"):
|
||||
raise errors.bad_request.UserIdExists(id=user_id)
|
||||
|
||||
user = User(**request.to_struct())
|
||||
user = User(**request.to_struct(), created=datetime.utcnow())
|
||||
user.save(force_insert=True)
|
||||
|
||||
@staticmethod
|
||||
134
apiserver/bll/util.py
Normal file
134
apiserver/bll/util.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import functools
|
||||
import itertools
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
Dict,
|
||||
Any,
|
||||
Set,
|
||||
Iterable,
|
||||
Tuple,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.settings import Settings
|
||||
|
||||
|
||||
class SetFieldsResolver:
|
||||
"""
|
||||
The class receives set fields dictionary
|
||||
and for the set fields that require 'min' or 'max'
|
||||
operation replace them with a simple set in case the
|
||||
DB document does not have these fields set
|
||||
"""
|
||||
|
||||
SET_MODIFIERS = ("min", "max")
|
||||
|
||||
def __init__(self, set_fields: Dict[str, Any]):
|
||||
self.orig_fields = {}
|
||||
self.fields = {}
|
||||
self.add_fields(**set_fields)
|
||||
|
||||
def add_fields(self, **set_fields: Any):
|
||||
self.orig_fields.update(set_fields)
|
||||
self.fields.update(
|
||||
{
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
)
|
||||
|
||||
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
|
||||
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
|
||||
return self.fields[name]
|
||||
return name
|
||||
|
||||
def get_fields(self, doc: AttributedDocument):
|
||||
"""
|
||||
For the given document return the set fields instructions
|
||||
with min/max operations replaced with a single set in case
|
||||
the document does not have the field set
|
||||
"""
|
||||
return {
|
||||
self._get_updated_name(doc, name): value
|
||||
for name, value in self.orig_fields.items()
|
||||
}
|
||||
|
||||
def get_names(self) -> Set[str]:
|
||||
"""
|
||||
Returns the names of the fields that had min/max modifiers
|
||||
in the format suitable for projection (dot separated)
|
||||
"""
|
||||
return set(name.replace("__", ".") for name in self.fields.values())
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_server_uuid() -> Optional[str]:
|
||||
return Settings.get_by_key("server.uuid")
|
||||
|
||||
|
||||
def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
|
||||
"""
|
||||
Decorates a method for parallel chunked execution. The method should have
|
||||
one positional parameter (that is used for breaking into chunks)
|
||||
and arbitrary number of keyword params. The return value should be iterable
|
||||
The results are concatenated in the same order as the passed params
|
||||
"""
|
||||
if func is None:
|
||||
return functools.partial(parallel_chunked_decorator, chunk_size=chunk_size)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, iterable: Iterable, **kwargs):
|
||||
assert iterutils.is_collection(
|
||||
iterable
|
||||
), "The positional parameter should be an iterable for breaking into chunks"
|
||||
|
||||
func_with_params = functools.partial(func, self, **kwargs)
|
||||
with ThreadPoolExecutor() as pool:
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
filter(
|
||||
None,
|
||||
pool.map(
|
||||
func_with_params,
|
||||
iterutils.chunked_iter(iterable, chunk_size),
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def run_batch_operation(
|
||||
func: Callable[[str], T], ids: Sequence[str]
|
||||
) -> Tuple[Sequence[Tuple[str, T]], Sequence[dict]]:
|
||||
results = list()
|
||||
failures = list()
|
||||
for _id in ids:
|
||||
try:
|
||||
results.append((_id, func(_id)))
|
||||
except APIError as err:
|
||||
failures.append(
|
||||
{
|
||||
"id": _id,
|
||||
"error": {
|
||||
"codes": [err.code, err.subcode],
|
||||
"msg": err.msg,
|
||||
"data": err.error_data,
|
||||
},
|
||||
}
|
||||
)
|
||||
return results, failures
|
||||
554
apiserver/bll/workers/__init__.py
Normal file
554
apiserver/bll/workers/__init__.py
Normal file
@@ -0,0 +1,554 @@
|
||||
import itertools
|
||||
from datetime import datetime, timedelta
|
||||
from time import time
|
||||
from typing import Sequence, Set, Optional
|
||||
|
||||
import attr
|
||||
import elasticsearch.helpers
|
||||
from boltons.iterutils import partition
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.apierrors.errors import bad_request, server_error
|
||||
from apiserver.apimodels.workers import (
|
||||
DEFAULT_TIMEOUT,
|
||||
IdNameEntry,
|
||||
WorkerEntry,
|
||||
StatusReportRequest,
|
||||
WorkerResponseEntry,
|
||||
QueueEntry,
|
||||
MachineStats,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import User
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.tools import safe_get
|
||||
from .stats import WorkerStats
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class WorkerBLL:
|
||||
def __init__(self, es=None, redis=None):
|
||||
self.es_client = es or es_factory.connect("workers")
|
||||
self.redis = redis or redman.connection("workers")
|
||||
self._stats = WorkerStats(self.es_client)
|
||||
|
||||
@property
|
||||
def stats(self) -> WorkerStats:
|
||||
return self._stats
|
||||
|
||||
def register_worker(
|
||||
self,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
worker: str,
|
||||
ip: str = "",
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user ID under which this worker is running
|
||||
:param worker: worker ID
|
||||
:param ip: the real ip of the worker
|
||||
:param queues: queues reported as being monitored by the worker
|
||||
:param timeout: registration expiration timeout in seconds
|
||||
:param tags: a list of tags for this worker
|
||||
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
|
||||
:return: worker entry instance
|
||||
"""
|
||||
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
|
||||
|
||||
timeout = timeout or DEFAULT_TIMEOUT
|
||||
queues = queues or []
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=user_id, company=company_id)
|
||||
user = User.objects(**query).only("id", "name").first()
|
||||
if not user:
|
||||
raise bad_request.InvalidUserId(**query)
|
||||
company = Company.objects(id=company_id).only("id", "name").first()
|
||||
if not company:
|
||||
raise bad_request.InvalidId("invalid company", company=company_id)
|
||||
|
||||
queue_objs = Queue.objects(company=company_id, id__in=queues).only("id")
|
||||
if len(queue_objs) < len(queues):
|
||||
invalid = set(queues).difference(q.id for q in queue_objs)
|
||||
raise bad_request.InvalidQueueId(ids=invalid)
|
||||
|
||||
now = datetime.utcnow()
|
||||
entry = WorkerEntry(
|
||||
key=key,
|
||||
id=worker,
|
||||
user=user.to_proper_dict(),
|
||||
company=company.to_proper_dict(),
|
||||
ip=ip,
|
||||
queues=queues,
|
||||
register_time=now,
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
)
|
||||
|
||||
self._save_worker_data(entry)
|
||||
|
||||
return entry
|
||||
|
||||
def unregister_worker(self, company_id: str, user_id: str, worker: str) -> None:
|
||||
"""
|
||||
Unregister a worker
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user ID under which this worker is running
|
||||
:param worker: worker ID
|
||||
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
|
||||
"""
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
if not res and not config.get("apiserver.workers.auto_unregister", False):
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
ip: str,
|
||||
report: StatusReportRequest,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user_id ID under which this worker is running
|
||||
:param ip: worker IP
|
||||
:param report: the report itself
|
||||
:param tags: tags for this worker
|
||||
:raise bad_request.InvalidTaskId: the reported task was not found
|
||||
:return: worker entry instance
|
||||
"""
|
||||
entry = self._get_worker(company_id, user_id, report.worker)
|
||||
|
||||
try:
|
||||
entry.ip = ip
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
if system_tags is not None:
|
||||
entry.system_tags = system_tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
company_id=company_id,
|
||||
company_name=entry.company.name,
|
||||
worker=entry.key,
|
||||
timestamp=report.timestamp,
|
||||
task=report.task,
|
||||
machine_stats=report.machine_stats,
|
||||
)
|
||||
|
||||
entry.queue = report.queue
|
||||
|
||||
if report.queues:
|
||||
entry.queues = report.queues
|
||||
|
||||
if not report.task:
|
||||
entry.task = None
|
||||
entry.project = None
|
||||
else:
|
||||
with translate_errors_context():
|
||||
query = dict(id=report.task, company=company_id)
|
||||
update = dict(
|
||||
last_worker=report.worker,
|
||||
last_worker_report=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
)
|
||||
# modify(new=True, ...) returns the modified object
|
||||
task = Task.objects(**query).modify(new=True, **update)
|
||||
if not task:
|
||||
raise bad_request.InvalidTaskId(**query)
|
||||
entry.task = IdNameEntry(id=task.id, name=task.name)
|
||||
|
||||
entry.project = None
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(
|
||||
id=project.id, name=project.name
|
||||
)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
raise
|
||||
except Exception as e:
|
||||
msg = "Failed processing worker status report"
|
||||
log.exception(msg)
|
||||
raise server_error.DataError(msg, err=e.args[0])
|
||||
finally:
|
||||
self._save_worker(entry)
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
company_id: str,
|
||||
last_seen: Optional[int] = None,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""
|
||||
Get all the company workers that were active during the last_seen period
|
||||
:param company_id: worker's company id
|
||||
:param last_seen: period in seconds to check. Min value is 1 second
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
|
||||
except Exception as e:
|
||||
raise server_error.DataError("failed loading worker entries", err=e.args[0])
|
||||
|
||||
if last_seen:
|
||||
ref_time = datetime.utcnow() - timedelta(seconds=max(1, last_seen))
|
||||
workers = [
|
||||
w
|
||||
for w in workers
|
||||
if w.last_activity_time.replace(tzinfo=None) >= ref_time
|
||||
]
|
||||
|
||||
return workers
|
||||
|
||||
def get_all_with_projection(
|
||||
self,
|
||||
company_id: str,
|
||||
last_seen: int,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerResponseEntry]:
|
||||
|
||||
helpers = list(
|
||||
map(
|
||||
WorkerConversionHelper.from_worker_entry,
|
||||
self.get_all(
|
||||
company_id=company_id,
|
||||
last_seen=last_seen,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
|
||||
all_queues = set(
|
||||
itertools.chain.from_iterable(helper.queue_ids for helper in helpers)
|
||||
)
|
||||
|
||||
queues_info = {}
|
||||
if all_queues:
|
||||
projection = [
|
||||
{"$match": {"_id": {"$in": list(all_queues)}}},
|
||||
{
|
||||
"$project": {
|
||||
"name": 1,
|
||||
"next_entry": {"$arrayElemAt": ["$entries", 0]},
|
||||
"num_entries": {"$size": "$entries"},
|
||||
}
|
||||
},
|
||||
]
|
||||
queues_info = {
|
||||
res["_id"]: res for res in Queue.objects.aggregate(projection)
|
||||
}
|
||||
task_ids = task_ids.union(
|
||||
filter(
|
||||
None,
|
||||
(
|
||||
safe_get(info, "next_entry/task")
|
||||
for info in queues_info.values()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
tasks_info = {}
|
||||
if task_ids:
|
||||
tasks_info = {
|
||||
task.id: task
|
||||
for task in Task.objects(id__in=task_ids).only(
|
||||
"name", "started", "last_iteration", "active_duration"
|
||||
)
|
||||
}
|
||||
|
||||
def update_queue_entries(*entries):
|
||||
for entry in entries:
|
||||
if not entry:
|
||||
continue
|
||||
info = queues_info.get(entry.id, None)
|
||||
if not info:
|
||||
continue
|
||||
entry.name = info.get("name", None)
|
||||
entry.num_tasks = info.get("num_entries", 0)
|
||||
task_id = safe_get(info, "next_entry/task")
|
||||
if task_id:
|
||||
task = tasks_info.get(task_id, None)
|
||||
entry.next_task = IdNameEntry(
|
||||
id=task_id, name=task.name if task else None
|
||||
)
|
||||
|
||||
for helper in helpers:
|
||||
worker = helper.worker
|
||||
if helper.task_id:
|
||||
task = tasks_info.get(helper.task_id, None)
|
||||
if task:
|
||||
worker.task.running_time = (task.active_duration or 0) * 1000
|
||||
worker.task.last_iteration = task.last_iteration
|
||||
|
||||
update_queue_entries(worker.queue)
|
||||
if worker.queues:
|
||||
update_queue_entries(*worker.queues)
|
||||
|
||||
return [helper.worker for helper in helpers]
|
||||
|
||||
@staticmethod
|
||||
def _get_worker_key(company: str, user: str, worker_id: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"worker_{company}_{user}_{worker_id}"
|
||||
|
||||
def _get_worker(self, company_id: str, user_id: str, worker: str) -> WorkerEntry:
|
||||
"""
|
||||
Get a worker entry for the provided worker ID. The entry is loaded from Redis
|
||||
if it exists (i.e. worker has already been registered), otherwise the worker
|
||||
is registered and its entry stored into Redis).
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user ID under which this worker is running
|
||||
:param worker: worker ID
|
||||
:raise bad_request.InvalidWorkerId: in case the worker id was not found
|
||||
:return: worker entry instance
|
||||
"""
|
||||
key = self._get_worker_key(company_id, user_id, worker)
|
||||
|
||||
data = self.redis.get(key)
|
||||
|
||||
if data:
|
||||
try:
|
||||
entry = WorkerEntry.from_json(data)
|
||||
if not entry.key:
|
||||
entry.key = key
|
||||
self._save_worker(entry)
|
||||
return entry
|
||||
except Exception as e:
|
||||
msg = "Failed parsing worker entry"
|
||||
log.exception(msg)
|
||||
raise server_error.DataError(msg, err=e.args[0])
|
||||
|
||||
# Failed loading worker from Redis
|
||||
if config.get("apiserver.workers.auto_register", False):
|
||||
try:
|
||||
return self.register_worker(company_id, user_id, worker)
|
||||
except Exception:
|
||||
log.error(
|
||||
"Failed auto registration of {} for company {}".format(
|
||||
worker, company_id
|
||||
)
|
||||
)
|
||||
|
||||
raise bad_request.InvalidWorkerId(worker=worker)
|
||||
|
||||
@staticmethod
|
||||
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers.{tags_field}_{company}_{tag}"
|
||||
|
||||
@staticmethod
|
||||
def _get_all_workers_key(company: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers_{company}"
|
||||
|
||||
def _save_worker_data(self, entry: WorkerEntry):
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
company_id = entry.company.id
|
||||
expiration = int(time()) + entry.register_timeout
|
||||
worker_item = {entry.key: expiration}
|
||||
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
|
||||
for tags, tags_field in (
|
||||
(entry.tags, "tags"),
|
||||
(entry.system_tags, "systemtags"),
|
||||
):
|
||||
for tag in tags:
|
||||
name = self._get_tagged_workers_key(company_id, tags_field, tag)
|
||||
self.redis.zadd(name, worker_item)
|
||||
|
||||
def _save_worker(self, entry: WorkerEntry) -> None:
|
||||
"""Save worker entry in Redis"""
|
||||
try:
|
||||
self._save_worker_data(entry)
|
||||
except Exception:
|
||||
msg = "Failed saving worker entry"
|
||||
log.exception(msg)
|
||||
|
||||
def _get(
|
||||
self,
|
||||
company: str,
|
||||
user: str = "*",
|
||||
worker_id: str = "*",
|
||||
user_tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""Get worker entries matching the company and user, worker patterns"""
|
||||
|
||||
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
|
||||
if user == "*":
|
||||
return in_keys
|
||||
user_bytes = user.encode()
|
||||
return {k for k in in_keys if user_bytes in k}
|
||||
|
||||
if user_tags or system_tags:
|
||||
worker_keys = set()
|
||||
for tags, tags_field in (
|
||||
(user_tags, "tags"),
|
||||
(system_tags, "systemtags"),
|
||||
):
|
||||
if not tags:
|
||||
continue
|
||||
timestamp = int(time())
|
||||
include, exclude = partition(tags, key=lambda x: x[0] != "-")
|
||||
if include:
|
||||
tagged_workers = set()
|
||||
for tag in include:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
tagged_workers.update(
|
||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
||||
)
|
||||
tagged_workers = filter_by_user(tagged_workers)
|
||||
worker_keys = (
|
||||
worker_keys.intersection(tagged_workers)
|
||||
if worker_keys
|
||||
else tagged_workers
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
if exclude:
|
||||
if not worker_keys:
|
||||
all_workers_key = self._get_all_workers_key(company)
|
||||
self.redis.zremrangebyscore(
|
||||
all_workers_key, min=0, max=timestamp
|
||||
)
|
||||
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
|
||||
worker_keys = filter_by_user(worker_keys)
|
||||
if not worker_keys:
|
||||
return []
|
||||
for tag in exclude:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag[1:]
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
worker_keys.difference_update(
|
||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
else:
|
||||
match = self._get_worker_key(company, user, "*")
|
||||
worker_keys = self.redis.scan_iter(match)
|
||||
|
||||
entries = []
|
||||
for key in worker_keys:
|
||||
data = self.redis.get(key)
|
||||
if data:
|
||||
entries.append(WorkerEntry.from_json(data))
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def _get_es_index_suffix():
|
||||
"""Get the index name suffix for storing current month data"""
|
||||
return datetime.utcnow().strftime("%Y-%m")
|
||||
|
||||
def _log_stats_to_es(
|
||||
self,
|
||||
company_id: str,
|
||||
company_name: str,
|
||||
worker: str,
|
||||
timestamp: int,
|
||||
task: str,
|
||||
machine_stats: MachineStats,
|
||||
) -> bool:
|
||||
"""
|
||||
Actually writing the worker statistics to Elastic
|
||||
:return: True if successful, False otherwise
|
||||
"""
|
||||
es_index = (
|
||||
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
|
||||
f"{self._get_es_index_suffix()}"
|
||||
)
|
||||
|
||||
def make_doc(category, metric, variant, value) -> dict:
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_source=dict(
|
||||
timestamp=timestamp,
|
||||
worker=worker,
|
||||
company=company_name,
|
||||
task=task,
|
||||
category=category,
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
value=float(value),
|
||||
),
|
||||
)
|
||||
|
||||
actions = []
|
||||
for field, value in machine_stats.to_struct().items():
|
||||
if not value:
|
||||
continue
|
||||
category = field.partition("_")[0]
|
||||
metric = field
|
||||
if not isinstance(value, (list, tuple)):
|
||||
actions.append(make_doc(category, metric, "total", value))
|
||||
else:
|
||||
actions.extend(
|
||||
make_doc(category, metric, str(i), val)
|
||||
for i, val in enumerate(value)
|
||||
)
|
||||
|
||||
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
|
||||
added, errors = es_res[:2]
|
||||
return (added == len(actions)) and not errors
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class WorkerConversionHelper:
|
||||
worker: WorkerResponseEntry
|
||||
task_id: str
|
||||
queue_ids: Set[str]
|
||||
|
||||
@classmethod
|
||||
def from_worker_entry(cls, worker: WorkerEntry):
|
||||
data = worker.to_struct()
|
||||
queue = data.pop("queue", None) or None
|
||||
queue_ids = set(data.pop("queues", []))
|
||||
queues = [QueueEntry(id=id) for id in queue_ids]
|
||||
if queue:
|
||||
queue = next((q for q in queues if q.id == queue), None)
|
||||
return cls(
|
||||
worker=WorkerResponseEntry(queues=queues, queue=queue, **data),
|
||||
task_id=worker.task.id if worker.task else None,
|
||||
queue_ids=queue_ids,
|
||||
)
|
||||
240
apiserver/bll/workers/stats.py
Normal file
240
apiserver/bll/workers/stats.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from operator import attrgetter
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
|
||||
from apiserver.apierrors.errors import bad_request
|
||||
from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatItem
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class WorkerStats:
|
||||
def __init__(self, es):
|
||||
self.es = es
|
||||
|
||||
@staticmethod
|
||||
def worker_stats_prefix_for_company(company_id: str) -> str:
|
||||
"""Returns the es index prefix for the company"""
|
||||
return f"worker_stats_{company_id.lower()}_"
|
||||
|
||||
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
def get_worker_stats_keys(
|
||||
self, company_id: str, worker_ids: Optional[Sequence[str]]
|
||||
) -> dict:
|
||||
"""
|
||||
Get dictionary of metric types grouped by categories
|
||||
:param company_id: company id
|
||||
:param worker_ids: optional list of workers to get metric types from.
|
||||
If not specified them metrics for all the company workers returned
|
||||
:return:
|
||||
"""
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"categories": {
|
||||
"terms": {"field": "category"},
|
||||
"aggs": {"metrics": {"terms": {"field": "metric"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
if worker_ids:
|
||||
es_req["query"] = QueryBuilder.terms("worker", worker_ids)
|
||||
|
||||
res = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if not res["hits"]["total"]["value"]:
|
||||
raise bad_request.WorkerStatsNotFound(
|
||||
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
|
||||
)
|
||||
|
||||
return {
|
||||
category["key"]: [
|
||||
metric["key"] for metric in category["metrics"]["buckets"]
|
||||
]
|
||||
for category in res["aggregations"]["categories"]["buckets"]
|
||||
}
|
||||
|
||||
def get_worker_stats(self, company_id: str, request: GetStatsRequest) -> dict:
|
||||
"""
|
||||
Get statistics for company workers metrics in the specified time range
|
||||
Returned as date histograms for different aggregation types
|
||||
grouped by worker, metric type (and optionally metric variant)
|
||||
Buckets with no metrics are not returned
|
||||
Note: all the statistics are retrieved as one ES query
|
||||
"""
|
||||
if request.from_date >= request.to_date:
|
||||
raise bad_request.FieldsValueError("from_date must be less than to_date")
|
||||
|
||||
def get_dates_agg() -> dict:
|
||||
es_to_agg_types = (
|
||||
("avg", AggregationType.avg.value),
|
||||
("min", AggregationType.min.value),
|
||||
("max", AggregationType.max.value),
|
||||
)
|
||||
|
||||
return {
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{request.interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
agg_type: {es_agg: {"field": "value"}}
|
||||
for es_agg, agg_type in es_to_agg_types
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def get_variants_agg() -> dict:
|
||||
return {
|
||||
"variants": {"terms": {"field": "variant"}, "aggs": get_dates_agg()}
|
||||
}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"workers": {
|
||||
"terms": {"field": "worker"},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric"},
|
||||
"aggs": get_variants_agg()
|
||||
if request.split_by_variant
|
||||
else get_dates_agg(),
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
query_terms = [
|
||||
QueryBuilder.dates_range(request.from_date, request.to_date),
|
||||
QueryBuilder.terms("metric", {item.key for item in request.items}),
|
||||
]
|
||||
if request.worker_ids:
|
||||
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
|
||||
es_req["query"] = {"bool": {"must": query_terms}}
|
||||
|
||||
with translate_errors_context():
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
return self._extract_results(data, request.items, request.split_by_variant)
|
||||
|
||||
@staticmethod
|
||||
def _extract_results(
|
||||
data: dict, request_items: Sequence[StatItem], split_by_variant: bool
|
||||
) -> dict:
|
||||
"""
|
||||
Clean results returned from elastic search (remove "aggregations", "buckets" etc.),
|
||||
leave only aggregation types requested by the user and return a clean dictionary
|
||||
and return a "clean" dictionary of
|
||||
:param data: aggregation data retrieved from ES
|
||||
:param request_items: aggs types requested by the user
|
||||
:param split_by_variant: if False then aggregate by metric type, otherwise metric type + variant
|
||||
"""
|
||||
if "aggregations" not in data:
|
||||
return {}
|
||||
|
||||
items_by_key = bucketize(request_items, key=attrgetter("key"))
|
||||
aggs_per_metric = {
|
||||
key: [item.aggregation for item in items]
|
||||
for key, items in items_by_key.items()
|
||||
}
|
||||
|
||||
def extract_date_stats(date: dict, metric_key) -> dict:
|
||||
return {
|
||||
"date": date["key"],
|
||||
"count": date["doc_count"],
|
||||
**{agg: date[agg]["value"] for agg in aggs_per_metric[metric_key]},
|
||||
}
|
||||
|
||||
def extract_metric_results(
|
||||
metric_or_variant: dict, metric_key: str
|
||||
) -> Sequence[dict]:
|
||||
return [
|
||||
extract_date_stats(date, metric_key)
|
||||
for date in metric_or_variant["dates"]["buckets"]
|
||||
if date["doc_count"]
|
||||
]
|
||||
|
||||
def extract_variant_results(metric: dict) -> dict:
|
||||
metric_key = metric["key"]
|
||||
return {
|
||||
variant["key"]: extract_metric_results(variant, metric_key)
|
||||
for variant in metric["variants"]["buckets"]
|
||||
}
|
||||
|
||||
def extract_worker_results(worker: dict) -> dict:
|
||||
return {
|
||||
metric["key"]: extract_variant_results(metric)
|
||||
if split_by_variant
|
||||
else extract_metric_results(metric, metric["key"])
|
||||
for metric in worker["metrics"]["buckets"]
|
||||
}
|
||||
|
||||
return {
|
||||
worker["key"]: extract_worker_results(worker)
|
||||
for worker in data["aggregations"]["workers"]["buckets"]
|
||||
}
|
||||
|
||||
def get_activity_report(
|
||||
self,
|
||||
company_id: str,
|
||||
from_date: float,
|
||||
to_date: float,
|
||||
interval: int,
|
||||
active_only: bool,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get statistics for company workers metrics in the specified time range
|
||||
Returned as date histograms for different aggregation types
|
||||
grouped by worker, metric type (and optionally metric variant)
|
||||
Note: all the statistics are retrieved using one ES query
|
||||
"""
|
||||
if from_date >= to_date:
|
||||
raise bad_request.FieldsValueError("from_date must be less than to_date")
|
||||
|
||||
must = [QueryBuilder.dates_range(from_date, to_date)]
|
||||
if active_only:
|
||||
must.append({"exists": {"field": "task"}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{interval}s",
|
||||
},
|
||||
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if "aggregations" not in data:
|
||||
return {}
|
||||
|
||||
ret = [
|
||||
dict(date=date["key"], count=date["workers_count"]["value"])
|
||||
for date in data["aggregations"]["dates"]["buckets"]
|
||||
]
|
||||
|
||||
if ret and ret[-1]["date"] > (to_date - 0.9 * interval):
|
||||
# remove last interval if it's incomplete. Allow 10% tolerance
|
||||
ret.pop()
|
||||
|
||||
return ret
|
||||
1
apiserver/config/__init__.py
Normal file
1
apiserver/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .basic import BasicConfig, ConfigurationError
|
||||
217
apiserver/config/basic.py
Normal file
217
apiserver/config/basic.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
import platform
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
from pathlib import Path
|
||||
from typing import List, Any, TypeVar, Sequence
|
||||
|
||||
from boltons.iterutils import first
|
||||
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
|
||||
from pyparsing import (
|
||||
ParseFatalException,
|
||||
ParseException,
|
||||
RecursiveGrammarException,
|
||||
ParseSyntaxException,
|
||||
)
|
||||
|
||||
from apiserver.utilities import json
|
||||
|
||||
EXTRA_CONFIG_PATHS = ("/opt/trains/config", "/opt/clearml/config")
|
||||
DEFAULT_PREFIXES = ("clearml", "trains")
|
||||
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ";"
|
||||
|
||||
|
||||
class BasicConfig:
|
||||
NotSet = object()
|
||||
|
||||
extra_config_values_env_key_sep = "__"
|
||||
default_config_dir = "default"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
folder: str = None,
|
||||
verbose: bool = True,
|
||||
prefix: Sequence[str] = DEFAULT_PREFIXES,
|
||||
):
|
||||
folder = (
|
||||
Path(folder)
|
||||
if folder
|
||||
else Path(__file__).with_name(self.default_config_dir)
|
||||
)
|
||||
if not folder.is_dir():
|
||||
raise ValueError("Invalid configuration folder")
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
self.extra_config_path_override_var = [
|
||||
f"{p.upper()}_CONFIG_DIR" for p in prefix
|
||||
]
|
||||
|
||||
self.prefix = prefix[0]
|
||||
self.extra_config_values_env_key_prefix = [
|
||||
f"{p.upper()}{self.extra_config_values_env_key_sep}"
|
||||
for p in reversed(prefix)
|
||||
]
|
||||
|
||||
self._paths = [folder, *self._get_paths()]
|
||||
self._config = self._reload()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._config[key]
|
||||
|
||||
def get(self, key: str, default: Any = NotSet) -> Any:
|
||||
value = self._config.get(key, default)
|
||||
if value is self.NotSet:
|
||||
raise KeyError(
|
||||
f"Unable to find value for key '{key}' and default value was not provided."
|
||||
)
|
||||
return value
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return self._config.as_plain_ordered_dict()
|
||||
|
||||
def as_json(self) -> str:
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
def logger(self, name: str) -> logging.Logger:
|
||||
if Path(name).is_file():
|
||||
name = Path(name).stem
|
||||
if name == "__init__" and Path(name).parent.stem:
|
||||
name = Path(name).parent.stem
|
||||
path = ".".join((self.prefix, name))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_extra_env_config_values(self) -> ConfigTree:
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
|
||||
for prefix in self.extra_config_values_env_key_prefix:
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = (
|
||||
key[len(prefix) :]
|
||||
.replace(self.extra_config_values_env_key_sep, ".")
|
||||
.lower()
|
||||
)
|
||||
result = self._merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _get_paths(self) -> List[Path]:
|
||||
default_paths = EXTRA_CONFIG_PATH_SEP.join(EXTRA_CONFIG_PATHS)
|
||||
value = first(map(getenv, self.extra_config_path_override_var), default_paths)
|
||||
|
||||
paths = [
|
||||
Path(expandvars(v)).expanduser() for v in value.split(EXTRA_CONFIG_PATH_SEP)
|
||||
]
|
||||
|
||||
if value is not default_paths:
|
||||
invalid = [path for path in paths if not path.is_dir()]
|
||||
if invalid:
|
||||
print(
|
||||
f"WARNING: Invalid paths in {self.extra_config_path_override_var} env var: {' '.join(map(str, invalid))}"
|
||||
)
|
||||
|
||||
return [path for path in paths if path.is_dir()]
|
||||
|
||||
def reload(self):
|
||||
self._config = self._reload()
|
||||
|
||||
def _reload(self) -> ConfigTree:
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
|
||||
configs = [self._read_recursive(path) for path in self._paths]
|
||||
|
||||
return reduce(
|
||||
lambda last, config: self._merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _merge_configs(cls, a, b, copy_trees=False, override_prefix="-"):
|
||||
"""Based on pyhocon.ConfigTree.merge_configs, with dict override support using a `-` key prefix"""
|
||||
for key, value in b.items():
|
||||
override = key.startswith(override_prefix)
|
||||
if override:
|
||||
key = key[len(override_prefix):]
|
||||
# if key is in both a and b and both values are dictionary then merge it otherwise override it
|
||||
if not override and key in a and isinstance(a[key], ConfigTree) and isinstance(b[key], ConfigTree):
|
||||
if copy_trees:
|
||||
a[key] = a[key].copy()
|
||||
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
|
||||
else:
|
||||
if isinstance(value, ConfigValues):
|
||||
value.parent = a
|
||||
value.key = key
|
||||
if key in a:
|
||||
value.overriden_value = a[key]
|
||||
a[key] = value
|
||||
if a.root:
|
||||
if b.root:
|
||||
a.history[key] = a.history.get(key, []) + b.history.get(key, [value])
|
||||
else:
|
||||
a.history[key] = a.history.get(key, []) + [value]
|
||||
|
||||
return a
|
||||
|
||||
def _read_recursive(self, conf_root) -> ConfigTree:
|
||||
conf = ConfigTree()
|
||||
|
||||
if not conf_root:
|
||||
return conf
|
||||
|
||||
if not conf_root.is_dir():
|
||||
if self.verbose:
|
||||
if not conf_root.exists():
|
||||
print(f"No config in {conf_root}")
|
||||
else:
|
||||
print(f"Not a directory: {conf_root}")
|
||||
return conf
|
||||
|
||||
if self.verbose:
|
||||
print(f"Loading config from {conf_root}")
|
||||
|
||||
for file in conf_root.rglob("*.conf"):
|
||||
key = ".".join(file.relative_to(conf_root).with_suffix("").parts)
|
||||
conf.put(key, self._read_single_file(file))
|
||||
|
||||
return conf
|
||||
|
||||
def _read_single_file(self, file_path):
|
||||
if self.verbose:
|
||||
print(f"Loading config from file {file_path}")
|
||||
|
||||
try:
|
||||
return ConfigFactory.parse_file(file_path)
|
||||
except ParseSyntaxException as ex:
|
||||
msg = f"Failed parsing {file_path} ({ex.__class__.__name__}): (at char {ex.loc}, line:{ex.lineno}, col:{ex.column})"
|
||||
raise ConfigurationError(msg, file_path=file_path) from ex
|
||||
except (ParseException, ParseFatalException, RecursiveGrammarException) as ex:
|
||||
msg = f"Failed parsing {file_path} ({ex.__class__.__name__}): {ex}"
|
||||
raise ConfigurationError(msg) from ex
|
||||
except Exception as ex:
|
||||
print(f"Failed loading {file_path}: {ex}")
|
||||
raise
|
||||
|
||||
def initialize_logging(self):
|
||||
logging_config = self.get("logging", None)
|
||||
if not logging_config:
|
||||
return
|
||||
logging.config.dictConfig(logging_config)
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
def __init__(self, msg, file_path=None, *args):
|
||||
super().__init__(msg, *args)
|
||||
self.file_path = file_path
|
||||
|
||||
|
||||
ConfigType = TypeVar("ConfigType", bound=BasicConfig)
|
||||
156
apiserver/config/default/apiserver.conf
Normal file
156
apiserver/config/default/apiserver.conf
Normal file
@@ -0,0 +1,156 @@
|
||||
{
|
||||
watch: false # Watch for changes (dev only)
|
||||
debug: false # Debug mode
|
||||
pretty_json: false # prettify json response
|
||||
return_stack: true # return stack trace on error
|
||||
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
|
||||
|
||||
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
|
||||
# valid values are:
|
||||
# - an integer number, specifying a status code
|
||||
# - a tuple of (code, subcode or list of subcodes)
|
||||
return_stack_on_code: [
|
||||
[500, 0] # raise on internal server error with no subcode
|
||||
]
|
||||
|
||||
listen {
|
||||
ip : "0.0.0.0"
|
||||
port: 8008
|
||||
}
|
||||
|
||||
version {
|
||||
required: false
|
||||
default: 1.0
|
||||
# if set then calls to endpoints with the version
|
||||
# greater that the current max version will be rejected
|
||||
check_max_version: false
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_files: ["/path/to/export.zip"]
|
||||
fail_on_error: false
|
||||
# artifacts_path: "/mnt/fileserver"
|
||||
}
|
||||
|
||||
# time in seconds to take an exclusive lock to init es and mongodb
|
||||
# not including the pre_populate
|
||||
db_init_timout: 120
|
||||
|
||||
mongo {
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
strict: false
|
||||
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
}
|
||||
|
||||
elastic {
|
||||
probing {
|
||||
# settings for inital probing of elastic connection
|
||||
max_retries: 4
|
||||
timeout: 30
|
||||
}
|
||||
upgrade_monitoring {
|
||||
v16_migration_verification: true
|
||||
}
|
||||
}
|
||||
|
||||
auth {
|
||||
# verify user tokens
|
||||
verify_user_tokens: false
|
||||
|
||||
# max token expiration timeout in seconds (1 year)
|
||||
max_expiration_sec: 31536000
|
||||
|
||||
# default token expiration timeout in seconds (30 days)
|
||||
default_expiration_sec: 2592000
|
||||
|
||||
# cookie containing auth token, for requests arriving from a web-browser
|
||||
session_auth_cookie_name: "clearml_token_basic"
|
||||
|
||||
# cookie configuration for authorization cookies generated by auth.login
|
||||
cookies {
|
||||
httponly: true # allow only http to access the cookies (no JS etc)
|
||||
secure: false # not using HTTPS
|
||||
domain: null # Limit to localhost is not supported
|
||||
max_age: 99999999999
|
||||
}
|
||||
|
||||
# provide a cookie domain override per company
|
||||
# cookies_domain_override {
|
||||
# <company-id>: <domain>
|
||||
# }
|
||||
|
||||
# # A list of fixed users
|
||||
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
|
||||
# fixed_users {
|
||||
# enabled: true
|
||||
# pass_hashed: false
|
||||
# users: [
|
||||
# {
|
||||
# username: "john"
|
||||
# password: "123456"
|
||||
# name: "john doe"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# }
|
||||
}
|
||||
|
||||
cors {
|
||||
origins: "*"
|
||||
|
||||
# Not supported when origins is "*"
|
||||
supports_credentials: true
|
||||
}
|
||||
|
||||
default_company: "d1bd92a3b039400cbafc60a7a5b1e52b"
|
||||
|
||||
workers {
|
||||
# Auto-register unknown workers on status reports and other calls
|
||||
auto_register: true
|
||||
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
|
||||
auto_unregister: true
|
||||
# Timeout in seconds on task status update. If exceeded
|
||||
# then task can be stopped without communicating to the worker
|
||||
task_update_timeout: 600
|
||||
}
|
||||
|
||||
check_for_updates {
|
||||
enabled: true
|
||||
|
||||
# Check for updates every 24 hours
|
||||
check_interval_sec: 86400
|
||||
|
||||
url: "https://updates.clear.ml/updates"
|
||||
|
||||
component_name: "clearml-server"
|
||||
|
||||
# GET request timeout
|
||||
request_timeout_sec: 3.0
|
||||
}
|
||||
|
||||
statistics {
|
||||
# Note: statistics are sent ONLY if the user has actively opted-in
|
||||
supported: true
|
||||
|
||||
url: "https://updates.clear.ml/stats"
|
||||
|
||||
report_interval_hours: 24
|
||||
agent_relevant_threshold_days: 30
|
||||
|
||||
max_retries: 5
|
||||
max_backoff_sec: 5
|
||||
}
|
||||
|
||||
getting_started_info {
|
||||
"agentName": "clearml",
|
||||
"configure": "clearml-init",
|
||||
"install": "pip install clearml",
|
||||
"packageName": "clearml"
|
||||
}
|
||||
|
||||
}
|
||||
47
apiserver/config/default/hosts.conf
Normal file
47
apiserver/config/default/hosts.conf
Normal file
@@ -0,0 +1,47 @@
|
||||
fileserver = "http://localhost:8081"
|
||||
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
}
|
||||
|
||||
workers {
|
||||
hosts: [{host:"127.0.0.1", port:9200}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
}
|
||||
}
|
||||
|
||||
mongo {
|
||||
backend {
|
||||
host: "mongodb://127.0.0.1:27017/backend"
|
||||
}
|
||||
auth {
|
||||
host: "mongodb://127.0.0.1:27017/auth"
|
||||
}
|
||||
}
|
||||
|
||||
redis {
|
||||
apiserver {
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
db: 0
|
||||
}
|
||||
workers {
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
db: 4
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,7 @@
|
||||
backupCount: 3
|
||||
maxBytes: 10240000,
|
||||
class: "logging.handlers.RotatingFileHandler",
|
||||
filename: "/var/log/trains/apiserver.log"
|
||||
filename: "/var/log/clearml/apiserver.log"
|
||||
}
|
||||
}
|
||||
root {
|
||||
@@ -13,17 +13,22 @@
|
||||
credentials {
|
||||
# system credentials as they appear in the auth DB, used for intra-service communications
|
||||
apiserver {
|
||||
role: "system"
|
||||
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
|
||||
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
|
||||
}
|
||||
webserver {
|
||||
role: "system"
|
||||
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
|
||||
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
|
||||
revoke_in_fixed_mode: true
|
||||
}
|
||||
tests {
|
||||
role: "user"
|
||||
display_name: "Default User"
|
||||
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
|
||||
revoke_in_fixed_mode: true
|
||||
}
|
||||
}
|
||||
}
|
||||
4
apiserver/config/default/services/_mongo.conf
Normal file
4
apiserver/config/default/services/_mongo.conf
Normal file
@@ -0,0 +1,4 @@
|
||||
max_page_size: 500
|
||||
|
||||
# expiration time in seconds for the redis scroll states in get_many family of apis
|
||||
scroll_state_expiration_seconds: 600
|
||||
12
apiserver/config/default/services/async_urls_delete.conf
Normal file
12
apiserver/config/default/services/async_urls_delete.conf
Normal file
@@ -0,0 +1,12 @@
|
||||
# if set to True then on task delete/reset external file urls for know storage types are scheduled for async delete
|
||||
# otherwise they are returned to a client for the client side delete
|
||||
enabled: false
|
||||
max_retries: 3
|
||||
retry_timeout_sec: 60
|
||||
|
||||
fileserver {
|
||||
# fileserver url prefixes. Evaluated in the order of priority
|
||||
# Can be in the form <schema>://host:port/path or /path
|
||||
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
|
||||
timeout_sec: 300
|
||||
}
|
||||
16
apiserver/config/default/services/auth.conf
Normal file
16
apiserver/config/default/services/auth.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
fixed_users {
|
||||
guest {
|
||||
enabled: false
|
||||
|
||||
default_company: "025315a9321f49f8be07f5ac48fbcf92"
|
||||
|
||||
name: "Guest"
|
||||
username: "guest"
|
||||
password: "guest"
|
||||
|
||||
# Allow access only to the following endpoints when using user/pass credentials
|
||||
allow_endpoints: [
|
||||
"auth.login"
|
||||
]
|
||||
}
|
||||
}
|
||||
45
apiserver/config/default/services/events.conf
Normal file
45
apiserver/config/default/services/events.conf
Normal file
@@ -0,0 +1,45 @@
|
||||
es_index_prefix: "events"
|
||||
|
||||
ignore_iteration {
|
||||
metrics: [":monitor:machine", ":monitor:gpu"]
|
||||
}
|
||||
|
||||
|
||||
events_retrieval {
|
||||
state_expiration_sec: 3600
|
||||
|
||||
# max number of concurrent queries to ES when calculating events metrics
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
||||
|
||||
# If set then max_metrics_count and max_variants_count are calculated dynamically on user data
|
||||
dynamic_metrics_count: true
|
||||
|
||||
# The percentage from the ES aggs limit (10000) to use for the max_metrics and max_variants calculation
|
||||
dynamic_metrics_count_threshold: 80
|
||||
|
||||
# the max amount of metrics to aggregate on
|
||||
max_metrics_count: 100
|
||||
|
||||
# the max amount of variants to aggregate on
|
||||
max_variants_count: 100
|
||||
|
||||
debug_images {
|
||||
# Allow to return the debug images for the variants with uninitialized valid iterations border
|
||||
allow_uninitialized_variants: true
|
||||
}
|
||||
|
||||
max_raw_scalars_size: 200000
|
||||
|
||||
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
|
||||
}
|
||||
|
||||
# if set then plot str will be checked for the valid json on plot add
|
||||
# and the result of the check is written to the db
|
||||
validate_plot_str: false
|
||||
|
||||
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
|
||||
plot_compression_threshold: 100000
|
||||
|
||||
# async events delete threshold
|
||||
max_async_deleted_events_per_sec: 1000
|
||||
7
apiserver/config/default/services/models.conf
Normal file
7
apiserver/config/default/services/models.conf
Normal file
@@ -0,0 +1,7 @@
|
||||
metadata_values {
|
||||
# maximal amount of distinct model values to retrieve
|
||||
max_count: 100
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
3
apiserver/config/default/services/organization.conf
Normal file
3
apiserver/config/default/services/organization.conf
Normal file
@@ -0,0 +1,3 @@
|
||||
tags_cache {
|
||||
expiration_seconds: 3600
|
||||
}
|
||||
18
apiserver/config/default/services/projects.conf
Normal file
18
apiserver/config/default/services/projects.conf
Normal file
@@ -0,0 +1,18 @@
|
||||
# Order of featured projects, by name or ID
|
||||
featured {
|
||||
order: [
|
||||
# {id: "<project-id>"}
|
||||
# OR
|
||||
# {name: "<project-name>"}
|
||||
# OR
|
||||
# {name_regex: "<python-regex>"}
|
||||
]
|
||||
|
||||
# default featured index for public projects not specified in the order
|
||||
public_default: 9999
|
||||
}
|
||||
|
||||
sub_projects {
|
||||
# the max sub project depth
|
||||
max_depth: 10
|
||||
}
|
||||
8
apiserver/config/default/services/queues.conf
Normal file
8
apiserver/config/default/services/queues.conf
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
metrics_before_from_date: 3600
|
||||
# interval in seconds to update queue metrics. Put 0 to disable
|
||||
metrics_refresh_interval_sec: 300
|
||||
# the queues with these tags will not be returned from get_all/get_all_ex unless id or name specified
|
||||
# or search_hidden is set
|
||||
hidden_tags: [k8s-glue]
|
||||
}
|
||||
29
apiserver/config/default/services/tasks.conf
Normal file
29
apiserver/config/default/services/tasks.conf
Normal file
@@ -0,0 +1,29 @@
|
||||
non_responsive_tasks_watchdog {
|
||||
enabled: true
|
||||
|
||||
# In-progress tasks older than this value in seconds will be stopped by the watchdog
|
||||
threshold_sec: 7200
|
||||
|
||||
# Watchdog will sleep for this number of seconds after each cycle
|
||||
watch_interval_sec: 900
|
||||
}
|
||||
|
||||
multi_task_histogram_limit: 100
|
||||
|
||||
hyperparam_values {
|
||||
# maximal amount of distinct hyperparam values to retrieve
|
||||
max_count: 100
|
||||
|
||||
# max allowed outdate time for the cashed result
|
||||
cache_allowed_outdate_sec: 60
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
|
||||
# the maximum amount of unique last metrics/variants combinations
|
||||
# for which the last values are stored in a task
|
||||
max_last_metrics: 2000
|
||||
|
||||
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
|
||||
async_events_delete: false
|
||||
51
apiserver/config/info.py
Normal file
51
apiserver/config/info.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from functools import lru_cache
|
||||
from os import getenv
|
||||
from pathlib import Path
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.version import __version__
|
||||
|
||||
root = Path(__file__).parent.parent
|
||||
|
||||
|
||||
def _get(prop_name, env_suffix=None, default=""):
|
||||
suffix = env_suffix or prop_name
|
||||
keys = [f"{p}_SERVER_{suffix}" for p in ("CLEARML", "TRAINS")]
|
||||
value = first(map(getenv, keys))
|
||||
if value:
|
||||
return value
|
||||
|
||||
try:
|
||||
return (root / prop_name).read_text().strip()
|
||||
except FileNotFoundError:
|
||||
return default
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_build_number():
|
||||
return _get("BUILD")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_version():
|
||||
return _get("VERSION", default=__version__)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_commit_number():
|
||||
return _get("COMMIT")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_deployment_type() -> str:
|
||||
return _get("DEPLOY", env_suffix="DEPLOYMENT_TYPE", default="manual")
|
||||
|
||||
|
||||
def get_default_company():
|
||||
return config.get("apiserver.default_company")
|
||||
|
||||
|
||||
missed_es_upgrade = False
|
||||
es_connection_error = False
|
||||
4
apiserver/config_repo.py
Normal file
4
apiserver/config_repo.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from apiserver.config import BasicConfig
|
||||
|
||||
config = BasicConfig()
|
||||
config.initialize_logging()
|
||||
137
apiserver/database/__init__.py
Normal file
137
apiserver/database/__init__.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from os import getenv
|
||||
|
||||
from boltons.iterutils import first
|
||||
from furl import furl
|
||||
from jsonmodels import models
|
||||
from jsonmodels.errors import ValidationError
|
||||
from jsonmodels.fields import StringField
|
||||
from mongoengine import register_connection
|
||||
from mongoengine.connection import get_connection, disconnect
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from .defs import Database
|
||||
from .utils import get_items
|
||||
|
||||
log = config.logger("database")
|
||||
|
||||
strict = config.get("apiserver.mongo.strict", True)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"CLEARML_MONGODB_SERVICE_HOST",
|
||||
"TRAINS_MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = (
|
||||
"CLEARML_MONGODB_SERVICE_PORT",
|
||||
"TRAINS_MONGODB_SERVICE_PORT",
|
||||
"MONGODB_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
|
||||
OVERRIDE_USERNAME_ENV_KEY = "CLEARML_MONGODB_SERVICE_USERNAME"
|
||||
OVERRIDE_PASSWORD_ENV_KEY = "CLEARML_MONGODB_SERVICE_PASSWORD"
|
||||
OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
host = StringField(required=True)
|
||||
alias = StringField()
|
||||
|
||||
|
||||
class DatabaseFactory:
|
||||
_entries = []
|
||||
|
||||
@classmethod
|
||||
def _create_db_entry(cls, alias: str, settings: dict) -> DatabaseEntry:
|
||||
return DatabaseEntry(alias=alias, **settings)
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
db_entries = config.get("hosts.mongo", {})
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
override_username = getenv(OVERRIDE_USERNAME_ENV_KEY)
|
||||
override_password = getenv(OVERRIDE_PASSWORD_ENV_KEY)
|
||||
override_query = getenv(OVERRIDE_QUERY_ENV_KEY)
|
||||
|
||||
if override_connection_string:
|
||||
log.info(f"Using override mongodb connection string template {override_connection_string}")
|
||||
else:
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
if override_username:
|
||||
log.info(f"Using override mongodb username {override_username}")
|
||||
if override_password:
|
||||
log.info(f"Using override mongodb password ******")
|
||||
if override_query:
|
||||
log.info(f"Using override mongodb query {override_query}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
|
||||
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
|
||||
|
||||
if override_connection_string:
|
||||
con_str = f"{override_connection_string.rstrip('/')}/{key}"
|
||||
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
|
||||
entry.host = con_str
|
||||
else:
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
if override_username:
|
||||
entry.host = furl(entry.host).set(username=override_username).url
|
||||
if override_password:
|
||||
entry.host = furl(entry.host).set(password=override_password).url
|
||||
if override_query:
|
||||
entry.host = furl(entry.host).set(query=override_query).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
log.info(
|
||||
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
|
||||
)
|
||||
register_connection(**entry.to_struct())
|
||||
|
||||
cls._entries.append(entry)
|
||||
except ValidationError as ex:
|
||||
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
|
||||
if missing:
|
||||
raise ValueError(
|
||||
"Missing database configuration for %s" % ", ".join(missing)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_entries(cls):
|
||||
return cls._entries
|
||||
|
||||
@classmethod
|
||||
def get_hosts(cls):
|
||||
return [entry.host for entry in cls.get_entries()]
|
||||
|
||||
@classmethod
|
||||
def get_aliases(cls):
|
||||
return [entry.alias for entry in cls.get_entries()]
|
||||
|
||||
@classmethod
|
||||
def reconnect(cls):
|
||||
for entry in cls.get_entries():
|
||||
# there is bug in the current implementation that prevents
|
||||
# reconnection from work so workaround this
|
||||
# get_connection(entry.alias, reconnect=True)
|
||||
disconnect(entry.alias)
|
||||
register_connection(**entry.to_struct())
|
||||
get_connection(entry.alias)
|
||||
|
||||
|
||||
db = DatabaseFactory()
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from textwrap import shorten
|
||||
|
||||
import dpath
|
||||
from dpath.exceptions import InvalidKeyName
|
||||
@@ -17,7 +18,7 @@ from mongoengine.errors import (
|
||||
)
|
||||
from pymongo.errors import PyMongoError, NotMasterError
|
||||
|
||||
from apierrors import errors
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class MakeGetAllQueryError(Exception):
|
||||
@@ -33,7 +34,7 @@ class ParseCallError(Exception):
|
||||
self.params = kwargs
|
||||
|
||||
|
||||
def throws_default_error(err_cls):
|
||||
def throws_default_error(err_cls, shorten_width: int = None):
|
||||
"""
|
||||
Used to make functions (Exception, str) -> Optional[str] searching for specialized error messages raise those
|
||||
messages in ``err_cls``. If the decorated function does not find a suitable error message,
|
||||
@@ -45,25 +46,49 @@ def throws_default_error(err_cls):
|
||||
@wraps(func)
|
||||
def wrapper(self, e, message, **kwargs):
|
||||
extra_info = func(self, e, message, **kwargs)
|
||||
raise err_cls(message, err=e, extra_info=extra_info)
|
||||
err = str(e)
|
||||
if shorten_width:
|
||||
err = shorten(err, shorten_width, placeholder="...")
|
||||
raise err_cls(message, err=err, extra_info=extra_info)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# noinspection RegExpRedundantEscape
|
||||
class ElasticErrorsHandler(object):
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError)
|
||||
def _bulk_meta_error(cls, error):
|
||||
try:
|
||||
_, err_type = next(dpath.search(error, "*/error/type", yielded=True))
|
||||
_, reason = next(dpath.search(error, "*/error/reason", yielded=True))
|
||||
if err_type == "cluster_block_exception":
|
||||
raise errors.server_error.LowDiskSpace(
|
||||
"metrics, logs and all indexed data is in read-only mode!",
|
||||
reason=re.sub(r"^index\s\[.*?\]\s", "", reason) if reason else ""
|
||||
)
|
||||
return
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError, shorten_width=200)
|
||||
def bulk_error(cls, e, _, **__):
|
||||
if not e.errors:
|
||||
return
|
||||
|
||||
# Currently we only handle the first error
|
||||
error = e.errors[0]
|
||||
|
||||
cls._bulk_meta_error(error)
|
||||
|
||||
# Else try returning a better error string
|
||||
for _, reason in dpath.search(e.errors[0], "*/error/reason", yielded=True):
|
||||
return reason
|
||||
|
||||
|
||||
# noinspection RegExpRedundantEscape
|
||||
class MongoEngineErrorsHandler(object):
|
||||
# NotUniqueError
|
||||
__not_unique_regex = re.compile(
|
||||
@@ -81,6 +106,7 @@ class MongoEngineErrorsHandler(object):
|
||||
def validation_error(cls, e: ValidationError, message, **_):
|
||||
# Thrown when a document is validated. Documents are validated by default on save and on update
|
||||
err_dict = e.errors or {e.field_name: e.message}
|
||||
err_dict = {key: str(value) for key, value in err_dict.items()}
|
||||
raise errors.bad_request.DataValidationError(message, **err_dict)
|
||||
|
||||
@classmethod
|
||||
@@ -140,7 +166,10 @@ class MongoEngineErrorsHandler(object):
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.InternalError)
|
||||
def invalid_query_error(cls, e, message, **_):
|
||||
pass
|
||||
if e.args:
|
||||
inner = e.args[0]
|
||||
if isinstance(inner, LookUpError):
|
||||
cls.lookup_error(inner, message)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from operator import itemgetter
|
||||
from sys import maxsize
|
||||
from typing import Type, Tuple
|
||||
|
||||
import six
|
||||
from mongoengine import (
|
||||
@@ -11,7 +12,11 @@ from mongoengine import (
|
||||
SortedListField,
|
||||
MapField,
|
||||
DictField,
|
||||
DynamicField,
|
||||
)
|
||||
from mongoengine.fields import key_not_string, key_starts_with_dollar, EmailField
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class LengthRangeListField(ListField):
|
||||
@@ -88,102 +93,22 @@ class CustomFloatField(FloatField):
|
||||
self.error("Float value must be greater than %s" % str(self.greater_than))
|
||||
|
||||
|
||||
# TODO: bucket name should be at most 63 characters....
|
||||
aws_s3_bucket_only_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
)
|
||||
class CanonicEmailField(EmailField):
|
||||
"""email field that is always lower cased"""
|
||||
def __set__(self, instance, value: str):
|
||||
if value is not None:
|
||||
try:
|
||||
value = value.lower()
|
||||
except AttributeError:
|
||||
pass
|
||||
super().__set__(instance, value)
|
||||
|
||||
aws_s3_url_with_bucket_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?))" # domain...
|
||||
)
|
||||
|
||||
non_aws_s3_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w))" # bucket name
|
||||
)
|
||||
|
||||
google_gs_bucket_only_regex = (
|
||||
r"^gs://"
|
||||
r"(?:(?:\w[A-Z0-9\-_]+\w)\.)*(?:\w[A-Z0-9\-_]+\w)" # bucket name
|
||||
)
|
||||
|
||||
file_regex = r"^file://"
|
||||
|
||||
generic_url_regex = (
|
||||
r"^%s://" # scheme placeholder
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
)
|
||||
|
||||
path_suffix = r"(?:/?|[/?]\S+)$"
|
||||
file_path_suffix = r"(?:/\S*[^/]+)$"
|
||||
|
||||
|
||||
class _RegexURLField(StringField):
|
||||
_regex = []
|
||||
|
||||
def __init__(self, regex, **kwargs):
|
||||
super(_RegexURLField, self).__init__(**kwargs)
|
||||
regex = regex if isinstance(regex, (tuple, list)) else [regex]
|
||||
self._regex = [
|
||||
re.compile(e, re.IGNORECASE) if isinstance(e, six.string_types) else e
|
||||
for e in regex
|
||||
]
|
||||
|
||||
def validate(self, value):
|
||||
# Check first if the scheme is valid
|
||||
if not any(regex for regex in self._regex if regex.match(value)):
|
||||
self.error("Invalid URL: {}".format(value))
|
||||
return
|
||||
|
||||
|
||||
class OutputDestinationField(_RegexURLField):
|
||||
""" A field representing task output URL """
|
||||
|
||||
schemes = ["s3", "gs", "file"]
|
||||
_expressions = (
|
||||
aws_s3_bucket_only_regex + path_suffix,
|
||||
aws_s3_url_with_bucket_regex + path_suffix,
|
||||
non_aws_s3_regex + path_suffix,
|
||||
google_gs_bucket_only_regex + path_suffix,
|
||||
file_regex + path_suffix,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(OutputDestinationField, self).__init__(self._expressions, **kwargs)
|
||||
|
||||
|
||||
class SupportedURLField(_RegexURLField):
|
||||
""" A field representing a model URL """
|
||||
|
||||
schemes = ["s3", "gs", "file", "http", "https"]
|
||||
|
||||
_expressions = tuple(
|
||||
pattern + file_path_suffix
|
||||
for pattern in (
|
||||
aws_s3_bucket_only_regex,
|
||||
aws_s3_url_with_bucket_regex,
|
||||
non_aws_s3_regex,
|
||||
google_gs_bucket_only_regex,
|
||||
file_regex,
|
||||
(generic_url_regex % "http"),
|
||||
(generic_url_regex % "https"),
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(SupportedURLField, self).__init__(self._expressions, **kwargs)
|
||||
def prepare_query_value(self, op, value):
|
||||
if not isinstance(op, six.string_types):
|
||||
return value
|
||||
if value is not None:
|
||||
value = value.lower()
|
||||
return super().prepare_query_value(op, value)
|
||||
|
||||
|
||||
class StrippedStringField(StringField):
|
||||
@@ -221,17 +146,89 @@ def contains_empty_key(d):
|
||||
return True
|
||||
|
||||
|
||||
class SafeMapField(MapField):
|
||||
class DictValidationMixin:
|
||||
"""
|
||||
DictField validation in MongoEngine requires default alias and permissions to access DB version:
|
||||
https://github.com/MongoEngine/mongoengine/issues/2239
|
||||
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
|
||||
"""
|
||||
|
||||
def _safe_validate(self: DictField, value):
|
||||
if not isinstance(value, dict):
|
||||
self.error("Only dictionaries may be used in a DictField")
|
||||
|
||||
if key_not_string(value):
|
||||
msg = "Invalid dictionary key - documents must have only string keys"
|
||||
self.error(msg)
|
||||
|
||||
if key_starts_with_dollar(value):
|
||||
self.error(
|
||||
'Invalid dictionary key name - keys may not startswith "$" characters'
|
||||
)
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
|
||||
class SafeMapField(MapField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
super(SafeMapField, self).validate(value)
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class SafeDictField(DictField):
|
||||
class NullableStringField(StringField):
|
||||
def validate(self, value):
|
||||
super(SafeDictField, self).validate(value)
|
||||
if value is None:
|
||||
return
|
||||
super(NullableStringField, self).validate(value)
|
||||
|
||||
|
||||
class SafeDictField(DictField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a DictField")
|
||||
|
||||
|
||||
class SafeSortedListField(SortedListField):
|
||||
"""
|
||||
SortedListField that does not raise an error in case items are not comparable
|
||||
(in which case they will be sorted by their string representation)
|
||||
"""
|
||||
|
||||
def to_mongo(self, *args, **kwargs):
|
||||
try:
|
||||
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
|
||||
except TypeError:
|
||||
return self._safe_to_mongo(*args, **kwargs)
|
||||
|
||||
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
|
||||
if self._ordering is not None:
|
||||
|
||||
def key(v):
|
||||
return str(itemgetter(self._ordering)(v))
|
||||
|
||||
else:
|
||||
key = str
|
||||
return sorted(value, key=key, reverse=self._order_reverse)
|
||||
|
||||
|
||||
class UnionField(DynamicField):
|
||||
def __init__(self, types, *args, **kwargs):
|
||||
super(UnionField, self).__init__(*args, **kwargs)
|
||||
self.types: Tuple[Type] = tuple(types)
|
||||
|
||||
def validate(self, value, clean=True):
|
||||
if not isinstance(value, self.types):
|
||||
type_names = [t.__name__ for t in self.types]
|
||||
expected = " or ".join(
|
||||
filter(
|
||||
None,
|
||||
(", ".join(type_names[:-1]), type_names[-1]))
|
||||
)
|
||||
self.error(
|
||||
f"Expected {expected}, got {type(value).__name__}: {value}"
|
||||
)
|
||||
super(UnionField, self).validate(value, clean)
|
||||
@@ -1,9 +1,11 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import Document, StringField
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.base import DbModelMixin, ABSTRACT_FLAG
|
||||
from database.model.company import Company
|
||||
from database.model.user import User
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model.base import DbModelMixin, ABSTRACT_FLAG
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.user import User
|
||||
|
||||
|
||||
class AttributedDocument(DbModelMixin, Document):
|
||||
@@ -54,3 +56,8 @@ def validate_id(cls, company, **kwargs):
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
)
|
||||
|
||||
|
||||
class EntityVisibility(Enum):
|
||||
active = "active"
|
||||
archived = "archived"
|
||||
hidden = "hidden"
|
||||
@@ -6,10 +6,10 @@ from mongoengine import (
|
||||
DateTimeField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import AuthDocument
|
||||
from database.utils import get_options
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model.base import AuthDocument
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
|
||||
class Entities(object):
|
||||
@@ -32,6 +32,8 @@ class Role(object):
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
guest = "guest"
|
||||
""" Guest user. Read Only."""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
@@ -43,15 +45,19 @@ class Role(object):
|
||||
|
||||
|
||||
class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
label = StringField()
|
||||
last_used = DateTimeField()
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
meta = {"db_alias": Database.auth, "strict": strict}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StringField(unique_with="company")
|
||||
name = StringField()
|
||||
|
||||
created = DateTimeField()
|
||||
""" User auth entry creation time """
|
||||
@@ -68,5 +74,5 @@ class User(DbModelMixin, AuthDocument):
|
||||
credentials = EmbeddedDocumentListField(Credentials, default=list)
|
||||
""" Credentials generated for this user """
|
||||
|
||||
email = EmailField(unique=True, required=True)
|
||||
email = EmailField(unique=True, sparse=True)
|
||||
""" Email uniquely identifying the user """
|
||||
1211
apiserver/database/model/base.py
Normal file
1211
apiserver/database/model/base.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user