feat(agents): restore agent-worker container architecture + fix chat scroll and parallel chats
- Restore agent-worker from commit 153399f: autonomous HTTP server per agent
(main.go 597 lines, main_test.go 438 lines, Dockerfile.agent-worker)
- Add container fields to agents table (serviceName, servicePort, containerImage, containerStatus)
- Update executor.go: real delegateToAgent() with HTTP POST to agent containers
- Update db.go: GetAgentByID, UpdateContainerStatus, GetAgentHistory, SaveHistory
- Update orchestrator.go: inject DB into executor for container address resolution
- Add tRPC endpoints: agents.deployContainer, agents.stopContainer, agents.containerStatus
- Add Docker Swarm deploy/stop logic in server/agents.ts
- Add Start/Stop container buttons to Agents.tsx with status badges
- Fix chat auto-scroll: replace ScrollArea with overflow-y-auto for direct scrollTop control
- Fix parallel chats: make isThinking per-conversation (thinkingConvId) instead of global
so switching between chats works while one is processing
This commit is contained in:
727
gateway/cmd/agent-worker/main.go
Normal file
727
gateway/cmd/agent-worker/main.go
Normal file
@@ -0,0 +1,727 @@
|
||||
// GoClaw Agent Worker — автономный HTTP-сервер агента.
|
||||
//
|
||||
// Каждый агент запускается как отдельный Docker Swarm service.
|
||||
// Загружает свой конфиг из общей DB по AGENT_ID, выполняет LLM loop
|
||||
// и принимает параллельные задачи от Orchestrator и других агентов.
|
||||
//
|
||||
// Endpoints:
|
||||
//
|
||||
// GET /health — liveness probe
|
||||
// GET /info — конфиг агента (имя, модель, роль)
|
||||
// POST /chat — синхронный чат (LLM loop, ждёт ответ)
|
||||
// POST /task — поставить задачу в очередь (async, возвращает task_id)
|
||||
// GET /tasks — список задач агента (active + recent)
|
||||
// GET /tasks/{id} — статус конкретной задачи
|
||||
// GET /memory — последние N сообщений из истории агента
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/google/uuid"
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
"git.softuniq.eu/UniqAI/GoClaw/gateway/internal/db"
|
||||
"git.softuniq.eu/UniqAI/GoClaw/gateway/internal/llm"
|
||||
"git.softuniq.eu/UniqAI/GoClaw/gateway/internal/tools"
|
||||
)
|
||||
|
||||
// ─── Task types ──────────────────────────────────────────────────────────────
|
||||
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskPending TaskStatus = "pending"
|
||||
TaskRunning TaskStatus = "running"
|
||||
TaskDone TaskStatus = "done"
|
||||
TaskFailed TaskStatus = "failed"
|
||||
TaskCancelled TaskStatus = "cancelled"
|
||||
)
|
||||
|
||||
// Task — единица работы агента, принятая через /task.
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
FromAgentID int `json:"from_agent_id,omitempty"` // кто делегировал (0 = человек)
|
||||
Input string `json:"input"` // текст задачи
|
||||
CallbackURL string `json:"callback_url,omitempty"` // куда POST результат
|
||||
Priority int `json:"priority"` // 0=normal, 1=high
|
||||
TimeoutSecs int `json:"timeout_secs"`
|
||||
Status TaskStatus `json:"status"`
|
||||
Result string `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ToolCalls []ToolCallStep `json:"tool_calls,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
DoneAt *time.Time `json:"done_at,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCallStep — шаг вызова инструмента для отображения в UI.
|
||||
type ToolCallStep struct {
|
||||
Tool string `json:"tool"`
|
||||
Args any `json:"args"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Success bool `json:"success"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
}
|
||||
|
||||
// ChatMessage — сообщение в формате для /chat endpoint.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatRequest — запрос на /chat (синхронный).
|
||||
type ChatRequest struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Model string `json:"model,omitempty"` // override модели агента
|
||||
MaxIter int `json:"max_iter,omitempty"` // override max iterations
|
||||
}
|
||||
|
||||
// ChatResponse — ответ /chat.
|
||||
type ChatResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Response string `json:"response"`
|
||||
ToolCalls []ToolCallStep `json:"tool_calls"`
|
||||
Model string `json:"model"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// TaskRequest — запрос на /task (async).
|
||||
type TaskRequest struct {
|
||||
Input string `json:"input"`
|
||||
FromAgentID int `json:"from_agent_id,omitempty"`
|
||||
CallbackURL string `json:"callback_url,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
TimeoutSecs int `json:"timeout_secs,omitempty"`
|
||||
}
|
||||
|
||||
// ─── Agent Worker ─────────────────────────────────────────────────────────────
|
||||
|
||||
type AgentWorker struct {
|
||||
agentID int
|
||||
cfg *db.AgentConfig
|
||||
llm *llm.Client
|
||||
database *db.DB
|
||||
executor *tools.Executor
|
||||
|
||||
// Task queue — buffered channel
|
||||
taskQueue chan *Task
|
||||
// Task store — in-memory (id → Task)
|
||||
tasksMu sync.RWMutex
|
||||
tasks map[string]*Task
|
||||
// Recent tasks ring buffer (для GET /tasks)
|
||||
recentMu sync.Mutex
|
||||
recentKeys []string
|
||||
}
|
||||
|
||||
const (
|
||||
taskQueueDepth = 100
|
||||
maxRecentTasks = 50
|
||||
defaultMaxIter = 8
|
||||
defaultTimeout = 120
|
||||
workerGoroutines = 4 // параллельных воркеров на агента
|
||||
)
|
||||
|
||||
func newAgentWorker(agentID int, database *db.DB, llmClient *llm.Client) (*AgentWorker, error) {
|
||||
cfg, err := database.GetAgentByID(agentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("agent %d not found in DB: %w", agentID, err)
|
||||
}
|
||||
log.Printf("[AgentWorker] Loaded config: id=%d name=%q model=%s", cfg.ID, cfg.Name, cfg.Model)
|
||||
|
||||
w := &AgentWorker{
|
||||
agentID: agentID,
|
||||
cfg: cfg,
|
||||
llm: llmClient,
|
||||
database: database,
|
||||
taskQueue: make(chan *Task, taskQueueDepth),
|
||||
tasks: make(map[string]*Task),
|
||||
}
|
||||
// Tool executor: агент использует подмножество инструментов из allowedTools
|
||||
w.executor = tools.NewExecutor("/app", func() ([]map[string]any, error) {
|
||||
rows, err := database.ListAgents()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]map[string]any, len(rows))
|
||||
for i, r := range rows {
|
||||
result[i] = map[string]any{
|
||||
"id": r.ID, "name": r.Name, "role": r.Role,
|
||||
"model": r.Model, "isActive": r.IsActive,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// StartWorkers запускает N горутин-воркеров, читающих из taskQueue.
|
||||
func (w *AgentWorker) StartWorkers(ctx context.Context) {
|
||||
for i := 0; i < workerGoroutines; i++ {
|
||||
go w.runWorker(ctx, i)
|
||||
}
|
||||
log.Printf("[AgentWorker] %d worker goroutines started", workerGoroutines)
|
||||
}
|
||||
|
||||
func (w *AgentWorker) runWorker(ctx context.Context, workerID int) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[Worker-%d] shutting down", workerID)
|
||||
return
|
||||
case task := <-w.taskQueue:
|
||||
log.Printf("[Worker-%d] processing task %s", workerID, task.ID)
|
||||
w.processTask(ctx, task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnqueueTask добавляет задачу в очередь и в хранилище.
|
||||
func (w *AgentWorker) EnqueueTask(req TaskRequest) *Task {
|
||||
timeout := req.TimeoutSecs
|
||||
if timeout <= 0 {
|
||||
timeout = defaultTimeout
|
||||
}
|
||||
task := &Task{
|
||||
ID: uuid.New().String(),
|
||||
FromAgentID: req.FromAgentID,
|
||||
Input: req.Input,
|
||||
CallbackURL: req.CallbackURL,
|
||||
Priority: req.Priority,
|
||||
TimeoutSecs: timeout,
|
||||
Status: TaskPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
// Сохранить в store
|
||||
w.tasksMu.Lock()
|
||||
w.tasks[task.ID] = task
|
||||
w.tasksMu.Unlock()
|
||||
|
||||
// Добавить в recent ring
|
||||
w.recentMu.Lock()
|
||||
w.recentKeys = append(w.recentKeys, task.ID)
|
||||
if len(w.recentKeys) > maxRecentTasks {
|
||||
w.recentKeys = w.recentKeys[len(w.recentKeys)-maxRecentTasks:]
|
||||
}
|
||||
w.recentMu.Unlock()
|
||||
|
||||
// Отправить в очередь (non-blocking — если очередь полна, вернуть ошибку через Status)
|
||||
select {
|
||||
case w.taskQueue <- task:
|
||||
default:
|
||||
w.tasksMu.Lock()
|
||||
task.Status = TaskFailed
|
||||
task.Error = "task queue is full — agent is overloaded"
|
||||
w.tasksMu.Unlock()
|
||||
log.Printf("[AgentWorker] WARN: task queue full, task %s rejected", task.ID)
|
||||
}
|
||||
return task
|
||||
}
|
||||
|
||||
// processTask выполняет задачу через LLM loop и обновляет её статус.
|
||||
func (w *AgentWorker) processTask(ctx context.Context, task *Task) {
|
||||
now := time.Now()
|
||||
w.tasksMu.Lock()
|
||||
task.Status = TaskRunning
|
||||
task.StartedAt = &now
|
||||
w.tasksMu.Unlock()
|
||||
|
||||
// Выполняем чат
|
||||
chatCtx, cancel := context.WithTimeout(ctx, time.Duration(task.TimeoutSecs)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
messages := []ChatMessage{{Role: "user", Content: task.Input}}
|
||||
resp := w.runChat(chatCtx, messages, "", defaultMaxIter)
|
||||
|
||||
doneAt := time.Now()
|
||||
w.tasksMu.Lock()
|
||||
task.DoneAt = &doneAt
|
||||
task.ToolCalls = resp.ToolCalls
|
||||
if resp.Success {
|
||||
task.Status = TaskDone
|
||||
task.Result = resp.Response
|
||||
} else {
|
||||
task.Status = TaskFailed
|
||||
task.Error = resp.Error
|
||||
}
|
||||
w.tasksMu.Unlock()
|
||||
|
||||
log.Printf("[AgentWorker] task %s done: status=%s", task.ID, task.Status)
|
||||
|
||||
// Отправить результат на callback URL если задан
|
||||
if task.CallbackURL != "" {
|
||||
go w.postCallback(task)
|
||||
}
|
||||
|
||||
// Сохранить в DB history
|
||||
if w.database != nil {
|
||||
go func() {
|
||||
userMsg := task.Input
|
||||
agentResp := task.Result
|
||||
if task.Status == TaskFailed {
|
||||
agentResp = "[ERROR] " + task.Error
|
||||
}
|
||||
w.database.SaveHistory(db.HistoryInput{
|
||||
AgentID: w.agentID,
|
||||
UserMessage: userMsg,
|
||||
AgentResponse: agentResp,
|
||||
})
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// runChat — основной LLM loop агента.
|
||||
func (w *AgentWorker) runChat(ctx context.Context, messages []ChatMessage, overrideModel string, maxIter int) ChatResponse {
|
||||
model := w.cfg.Model
|
||||
if overrideModel != "" {
|
||||
model = overrideModel
|
||||
}
|
||||
if maxIter <= 0 {
|
||||
maxIter = defaultMaxIter
|
||||
}
|
||||
|
||||
// Собрать контекст: системный промпт + история + новые сообщения
|
||||
conv := []llm.Message{}
|
||||
if w.cfg.SystemPrompt != "" {
|
||||
conv = append(conv, llm.Message{Role: "system", Content: w.cfg.SystemPrompt})
|
||||
}
|
||||
|
||||
// Загрузить sliding window памяти из DB
|
||||
if w.database != nil {
|
||||
history, err := w.database.GetAgentHistory(w.agentID, 20)
|
||||
if err == nil {
|
||||
for _, h := range history {
|
||||
conv = append(conv, llm.Message{Role: "user", Content: h.UserMessage})
|
||||
if h.AgentResponse != "" {
|
||||
conv = append(conv, llm.Message{Role: "assistant", Content: h.AgentResponse})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Добавить текущие сообщения
|
||||
for _, m := range messages {
|
||||
conv = append(conv, llm.Message{Role: m.Role, Content: m.Content})
|
||||
}
|
||||
|
||||
// Получить доступные инструменты агента
|
||||
agentTools := w.getAgentTools()
|
||||
|
||||
temp := w.cfg.Temperature
|
||||
maxTok := w.cfg.MaxTokens
|
||||
if maxTok == 0 {
|
||||
maxTok = 4096
|
||||
}
|
||||
|
||||
var toolCallSteps []ToolCallStep
|
||||
var finalResponse string
|
||||
var lastModel string
|
||||
|
||||
for iter := 0; iter < maxIter; iter++ {
|
||||
req := llm.ChatRequest{
|
||||
Model: model,
|
||||
Messages: conv,
|
||||
Temperature: &temp,
|
||||
MaxTokens: &maxTok,
|
||||
}
|
||||
if len(agentTools) > 0 {
|
||||
req.Tools = agentTools
|
||||
req.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
resp, err := w.llm.Chat(ctx, req)
|
||||
if err != nil {
|
||||
// Fallback без инструментов
|
||||
req.Tools = nil
|
||||
req.ToolChoice = ""
|
||||
resp2, err2 := w.llm.Chat(ctx, req)
|
||||
if err2 != nil {
|
||||
return ChatResponse{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("LLM error (model: %s): %v", model, err2),
|
||||
}
|
||||
}
|
||||
if len(resp2.Choices) > 0 {
|
||||
finalResponse = resp2.Choices[0].Message.Content
|
||||
lastModel = resp2.Model
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
break
|
||||
}
|
||||
choice := resp.Choices[0]
|
||||
lastModel = resp.Model
|
||||
if lastModel == "" {
|
||||
lastModel = model
|
||||
}
|
||||
|
||||
// Инструменты?
|
||||
if choice.FinishReason == "tool_calls" && len(choice.Message.ToolCalls) > 0 {
|
||||
conv = append(conv, choice.Message)
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
start := time.Now()
|
||||
result := w.executor.Execute(ctx, tc.Function.Name, tc.Function.Arguments)
|
||||
step := ToolCallStep{
|
||||
Tool: tc.Function.Name,
|
||||
Success: result.Success,
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}
|
||||
var argsMap any
|
||||
_ = json.Unmarshal([]byte(tc.Function.Arguments), &argsMap)
|
||||
step.Args = argsMap
|
||||
|
||||
var toolContent string
|
||||
if result.Success {
|
||||
step.Result = result.Result
|
||||
b, _ := json.Marshal(result.Result)
|
||||
toolContent = string(b)
|
||||
} else {
|
||||
step.Error = result.Error
|
||||
toolContent = fmt.Sprintf(`{"error": %q}`, result.Error)
|
||||
}
|
||||
toolCallSteps = append(toolCallSteps, step)
|
||||
conv = append(conv, llm.Message{
|
||||
Role: "tool",
|
||||
Content: toolContent,
|
||||
ToolCallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
finalResponse = choice.Message.Content
|
||||
break
|
||||
}
|
||||
|
||||
return ChatResponse{
|
||||
Success: true,
|
||||
Response: finalResponse,
|
||||
ToolCalls: toolCallSteps,
|
||||
Model: lastModel,
|
||||
}
|
||||
}
|
||||
|
||||
// getAgentTools возвращает только те инструменты, которые разрешены агенту.
|
||||
func (w *AgentWorker) getAgentTools() []llm.Tool {
|
||||
allTools := tools.OrchestratorTools()
|
||||
allowed := make(map[string]bool, len(w.cfg.AllowedTools))
|
||||
for _, t := range w.cfg.AllowedTools {
|
||||
allowed[t] = true
|
||||
}
|
||||
// Если allowedTools пуст — агент получает базовый набор (http_request, file_read)
|
||||
if len(allowed) == 0 {
|
||||
allowed = map[string]bool{
|
||||
"http_request": true,
|
||||
"file_read": true,
|
||||
"file_list": true,
|
||||
}
|
||||
}
|
||||
var result []llm.Tool
|
||||
for _, td := range allTools {
|
||||
if allowed[td.Function.Name] {
|
||||
result = append(result, llm.Tool{
|
||||
Type: td.Type,
|
||||
Function: llm.ToolFunction{
|
||||
Name: td.Function.Name,
|
||||
Description: td.Function.Description,
|
||||
Parameters: td.Function.Parameters,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// postCallback отправляет результат задачи на callback URL.
|
||||
func (w *AgentWorker) postCallback(task *Task) {
|
||||
w.tasksMu.RLock()
|
||||
payload, _ := json.Marshal(task)
|
||||
w.tasksMu.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, task.CallbackURL,
|
||||
bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
log.Printf("[AgentWorker] callback URL invalid for task %s: %v", task.ID, err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[AgentWorker] callback failed for task %s: %v", task.ID, err)
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
log.Printf("[AgentWorker] callback sent for task %s → %s (status %d)",
|
||||
task.ID, task.CallbackURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
// ─── HTTP Handlers ────────────────────────────────────────────────────────────
|
||||
|
||||
func (w *AgentWorker) handleHealth(rw http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(rw).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"agentId": w.agentID,
|
||||
"name": w.cfg.Name,
|
||||
"model": w.cfg.Model,
|
||||
"queueLen": len(w.taskQueue),
|
||||
})
|
||||
}
|
||||
|
||||
func (w *AgentWorker) handleInfo(rw http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(rw).Encode(map[string]any{
|
||||
"id": w.cfg.ID,
|
||||
"name": w.cfg.Name,
|
||||
"role": w.cfg.Model,
|
||||
"model": w.cfg.Model,
|
||||
"allowedTools": w.cfg.AllowedTools,
|
||||
"isSystem": w.cfg.IsSystem,
|
||||
})
|
||||
}
|
||||
|
||||
func (w *AgentWorker) handleChat(rw http.ResponseWriter, r *http.Request) {
|
||||
var req ChatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(rw, `{"error":"invalid request body"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(req.Messages) == 0 {
|
||||
http.Error(rw, `{"error":"messages required"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
timeout := w.cfg.MaxTokens / 10 // грубая оценка
|
||||
if timeout < 30 {
|
||||
timeout = 30
|
||||
}
|
||||
if timeout > 300 {
|
||||
timeout = 300
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp := w.runChat(ctx, req.Messages, req.Model, req.MaxIter)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(rw).Encode(resp)
|
||||
}
|
||||
|
||||
func (w *AgentWorker) handleTask(rw http.ResponseWriter, r *http.Request) {
|
||||
var req TaskRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(rw, `{"error":"invalid request body"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.Input == "" {
|
||||
http.Error(rw, `{"error":"input required"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
task := w.EnqueueTask(req)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
json.NewEncoder(rw).Encode(map[string]any{
|
||||
"task_id": task.ID,
|
||||
"status": task.Status,
|
||||
"agent_id": w.agentID,
|
||||
"queue_len": len(w.taskQueue),
|
||||
})
|
||||
}
|
||||
|
||||
func (w *AgentWorker) handleListTasks(rw http.ResponseWriter, r *http.Request) {
|
||||
w.recentMu.Lock()
|
||||
keys := make([]string, len(w.recentKeys))
|
||||
copy(keys, w.recentKeys)
|
||||
w.recentMu.Unlock()
|
||||
|
||||
w.tasksMu.RLock()
|
||||
result := make([]*Task, 0, len(keys))
|
||||
for i := len(keys) - 1; i >= 0; i-- {
|
||||
if t, ok := w.tasks[keys[i]]; ok {
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
w.tasksMu.RUnlock()
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(rw).Encode(map[string]any{
|
||||
"tasks": result,
|
||||
"total": len(result),
|
||||
"queueLen": len(w.taskQueue),
|
||||
})
|
||||
}
|
||||
|
||||
func (w *AgentWorker) handleGetTask(rw http.ResponseWriter, r *http.Request) {
|
||||
taskID := chi.URLParam(r, "id")
|
||||
w.tasksMu.RLock()
|
||||
task, ok := w.tasks[taskID]
|
||||
w.tasksMu.RUnlock()
|
||||
if !ok {
|
||||
http.Error(rw, `{"error":"task not found"}`, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(rw).Encode(task)
|
||||
}
|
||||
|
||||
func (w *AgentWorker) handleMemory(rw http.ResponseWriter, r *http.Request) {
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
limit := 20
|
||||
if n, err := strconv.Atoi(limitStr); err == nil && n > 0 && n <= 100 {
|
||||
limit = n
|
||||
}
|
||||
|
||||
if w.database == nil {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(rw).Encode(map[string]any{"messages": []any{}, "total": 0})
|
||||
return
|
||||
}
|
||||
|
||||
history, err := w.database.GetAgentHistory(w.agentID, limit)
|
||||
if err != nil {
|
||||
http.Error(rw, `{"error":"failed to load history"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(rw).Encode(map[string]any{
|
||||
"agent_id": w.agentID,
|
||||
"messages": history,
|
||||
"total": len(history),
|
||||
})
|
||||
}
|
||||
|
||||
// ─── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
func main() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
_ = godotenv.Load("../.env")
|
||||
_ = godotenv.Load(".env")
|
||||
|
||||
// ── Конфиг из env ────────────────────────────────────────────────────────
|
||||
agentIDStr := os.Getenv("AGENT_ID")
|
||||
if agentIDStr == "" {
|
||||
log.Fatal("[AgentWorker] AGENT_ID env var is required")
|
||||
}
|
||||
agentID, err := strconv.Atoi(agentIDStr)
|
||||
if err != nil || agentID <= 0 {
|
||||
log.Fatalf("[AgentWorker] AGENT_ID must be a positive integer, got: %q", agentIDStr)
|
||||
}
|
||||
|
||||
port := os.Getenv("AGENT_PORT")
|
||||
if port == "" {
|
||||
port = "8001"
|
||||
}
|
||||
|
||||
llmBaseURL := getEnvFirst("LLM_BASE_URL", "OLLAMA_BASE_URL")
|
||||
if llmBaseURL == "" {
|
||||
llmBaseURL = "https://ollama.com/v1"
|
||||
}
|
||||
llmAPIKey := getEnvFirst("LLM_API_KEY", "OLLAMA_API_KEY")
|
||||
|
||||
dbURL := os.Getenv("DATABASE_URL")
|
||||
if dbURL == "" {
|
||||
log.Fatal("[AgentWorker] DATABASE_URL env var is required")
|
||||
}
|
||||
|
||||
log.Printf("[AgentWorker] Starting: AGENT_ID=%d PORT=%s LLM=%s", agentID, port, llmBaseURL)
|
||||
|
||||
// ── DB ───────────────────────────────────────────────────────────────────
|
||||
database, err := db.Connect(dbURL)
|
||||
if err != nil {
|
||||
log.Fatalf("[AgentWorker] DB connection failed: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
// ── LLM Client ───────────────────────────────────────────────────────────
|
||||
llmClient := llm.NewClient(llmBaseURL, llmAPIKey)
|
||||
|
||||
// ── Agent Worker ─────────────────────────────────────────────────────────
|
||||
worker, err := newAgentWorker(agentID, database, llmClient)
|
||||
if err != nil {
|
||||
log.Fatalf("[AgentWorker] init failed: %v", err)
|
||||
}
|
||||
|
||||
// ── Background workers ───────────────────────────────────────────────────
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
worker.StartWorkers(ctx)
|
||||
|
||||
// ── Router ───────────────────────────────────────────────────────────────
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(cors.Handler(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Content-Type", "Authorization", "X-Agent-ID"},
|
||||
}))
|
||||
|
||||
r.Get("/health", worker.handleHealth)
|
||||
r.Get("/info", worker.handleInfo)
|
||||
r.Post("/chat", worker.handleChat)
|
||||
r.Post("/task", worker.handleTask)
|
||||
r.Get("/tasks", worker.handleListTasks)
|
||||
r.Get("/tasks/{id}", worker.handleGetTask)
|
||||
r.Get("/memory", worker.handleMemory)
|
||||
|
||||
// ── HTTP Server ───────────────────────────────────────────────────────────
|
||||
srv := &http.Server{
|
||||
Addr: ":" + port,
|
||||
Handler: r,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 310 * time.Second, // > max task timeout
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
log.Printf("[AgentWorker] agent-id=%d listening on :%s", agentID, port)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("[AgentWorker] server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-quit
|
||||
log.Println("[AgentWorker] shutting down gracefully...")
|
||||
cancel() // stop task workers
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer shutdownCancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
log.Printf("[AgentWorker] shutdown error: %v", err)
|
||||
}
|
||||
log.Println("[AgentWorker] stopped.")
|
||||
}
|
||||
|
||||
func getEnvFirst(keys ...string) string {
|
||||
for _, k := range keys {
|
||||
if v := os.Getenv(k); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
438
gateway/cmd/agent-worker/main_test.go
Normal file
438
gateway/cmd/agent-worker/main_test.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.softuniq.eu/UniqAI/GoClaw/gateway/internal/db"
|
||||
)
|
||||
|
||||
// ─── Mock DB agent config ─────────────────────────────────────────────────────
|
||||
|
||||
func mockAgentConfig() *db.AgentConfig {
|
||||
return &db.AgentConfig{
|
||||
ID: 42,
|
||||
Name: "Test Agent",
|
||||
Model: "qwen2.5:7b",
|
||||
SystemPrompt: "You are a test agent.",
|
||||
AllowedTools: []string{"http_request", "file_list"},
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 2048,
|
||||
IsSystem: false,
|
||||
IsOrchestrator: false,
|
||||
IsActive: true,
|
||||
ContainerImage: "goclaw-agent-worker:latest",
|
||||
ContainerStatus: "running",
|
||||
ServiceName: "goclaw-agent-42",
|
||||
ServicePort: 8001,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Unit: AgentWorker struct ─────────────────────────────────────────────────
|
||||
|
||||
func TestAgentWorkerInit(t *testing.T) {
|
||||
w := &AgentWorker{
|
||||
agentID: 42,
|
||||
cfg: mockAgentConfig(),
|
||||
taskQueue: make(chan *Task, taskQueueDepth),
|
||||
tasks: make(map[string]*Task),
|
||||
}
|
||||
if w.agentID != 42 {
|
||||
t.Errorf("expected agentID=42, got %d", w.agentID)
|
||||
}
|
||||
if w.cfg.Name != "Test Agent" {
|
||||
t.Errorf("expected name 'Test Agent', got %q", w.cfg.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Unit: Task enqueue ───────────────────────────────────────────────────────
|
||||
|
||||
func TestEnqueueTask(t *testing.T) {
|
||||
w := &AgentWorker{
|
||||
agentID: 42,
|
||||
cfg: mockAgentConfig(),
|
||||
taskQueue: make(chan *Task, taskQueueDepth),
|
||||
tasks: make(map[string]*Task),
|
||||
}
|
||||
|
||||
task := w.EnqueueTask(TaskRequest{
|
||||
Input: "hello world",
|
||||
TimeoutSecs: 30,
|
||||
})
|
||||
|
||||
if task.ID == "" {
|
||||
t.Error("task ID should not be empty")
|
||||
}
|
||||
if task.Status != TaskPending {
|
||||
t.Errorf("expected status=pending, got %q", task.Status)
|
||||
}
|
||||
if task.Input != "hello world" {
|
||||
t.Errorf("expected input='hello world', got %q", task.Input)
|
||||
}
|
||||
if len(w.taskQueue) != 1 {
|
||||
t.Errorf("expected 1 task in queue, got %d", len(w.taskQueue))
|
||||
}
|
||||
|
||||
// Task should be in store
|
||||
w.tasksMu.RLock()
|
||||
stored, ok := w.tasks[task.ID]
|
||||
w.tasksMu.RUnlock()
|
||||
if !ok {
|
||||
t.Error("task not found in store")
|
||||
}
|
||||
if stored.ID != task.ID {
|
||||
t.Errorf("stored task ID mismatch: %q != %q", stored.ID, task.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnqueueTask_QueueFull(t *testing.T) {
|
||||
// Queue depth = 1 for this test
|
||||
w := &AgentWorker{
|
||||
agentID: 42,
|
||||
cfg: mockAgentConfig(),
|
||||
taskQueue: make(chan *Task, 1),
|
||||
tasks: make(map[string]*Task),
|
||||
}
|
||||
|
||||
// Fill the queue
|
||||
w.EnqueueTask(TaskRequest{Input: "task 1"})
|
||||
// Overflow
|
||||
task2 := w.EnqueueTask(TaskRequest{Input: "task 2"})
|
||||
|
||||
w.tasksMu.RLock()
|
||||
status := task2.Status
|
||||
w.tasksMu.RUnlock()
|
||||
|
||||
if status != TaskFailed {
|
||||
t.Errorf("expected task2 status=failed when queue full, got %q", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnqueueTask_DefaultTimeout(t *testing.T) {
|
||||
w := &AgentWorker{
|
||||
agentID: 42,
|
||||
cfg: mockAgentConfig(),
|
||||
taskQueue: make(chan *Task, taskQueueDepth),
|
||||
tasks: make(map[string]*Task),
|
||||
}
|
||||
|
||||
task := w.EnqueueTask(TaskRequest{Input: "no timeout set"})
|
||||
if task.TimeoutSecs != defaultTimeout {
|
||||
t.Errorf("expected default timeout=%d, got %d", defaultTimeout, task.TimeoutSecs)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── HTTP Handlers ────────────────────────────────────────────────────────────
|
||||
|
||||
func makeTestWorker() *AgentWorker {
|
||||
return &AgentWorker{
|
||||
agentID: 42,
|
||||
cfg: mockAgentConfig(),
|
||||
taskQueue: make(chan *Task, taskQueueDepth),
|
||||
tasks: make(map[string]*Task),
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleHealth(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
w.handleHealth(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("invalid JSON response: %v", err)
|
||||
}
|
||||
if body["status"] != "ok" {
|
||||
t.Errorf("expected status=ok, got %v", body["status"])
|
||||
}
|
||||
if int(body["agentId"].(float64)) != 42 {
|
||||
t.Errorf("expected agentId=42, got %v", body["agentId"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleInfo(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/info", nil)
|
||||
w.handleInfo(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(rr.Body).Decode(&body)
|
||||
if body["name"] != "Test Agent" {
|
||||
t.Errorf("expected name='Test Agent', got %v", body["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleTask_Valid(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
body := `{"input":"do something useful","timeout_secs":60}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/task", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
w.handleTask(rr, req)
|
||||
|
||||
if rr.Code != http.StatusAccepted {
|
||||
t.Errorf("expected 202, got %d", rr.Code)
|
||||
}
|
||||
var resp map[string]any
|
||||
json.NewDecoder(rr.Body).Decode(&resp)
|
||||
if resp["task_id"] == "" || resp["task_id"] == nil {
|
||||
t.Error("task_id should be in response")
|
||||
}
|
||||
if resp["status"] != string(TaskPending) {
|
||||
t.Errorf("expected status=pending, got %v", resp["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleTask_EmptyInput(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
req := httptest.NewRequest(http.MethodPost, "/task", bytes.NewBufferString(`{"input":""}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
w.handleTask(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleTask_InvalidJSON(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
req := httptest.NewRequest(http.MethodPost, "/task", bytes.NewBufferString(`not-json`))
|
||||
rr := httptest.NewRecorder()
|
||||
w.handleTask(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetTask_NotFound(t *testing.T) {
|
||||
// We can't easily use chi.URLParam in unit tests without a full router.
|
||||
// Test the store logic directly instead.
|
||||
w := makeTestWorker()
|
||||
w.tasksMu.RLock()
|
||||
_, ok := w.tasks["nonexistent-id"]
|
||||
w.tasksMu.RUnlock()
|
||||
if ok {
|
||||
t.Error("nonexistent task should not be found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListTasks_Empty(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/tasks", nil)
|
||||
w.handleListTasks(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
var resp map[string]any
|
||||
json.NewDecoder(rr.Body).Decode(&resp)
|
||||
if resp["total"].(float64) != 0 {
|
||||
t.Errorf("expected total=0, got %v", resp["total"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListTasks_WithTasks(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
w.EnqueueTask(TaskRequest{Input: "task A"})
|
||||
w.EnqueueTask(TaskRequest{Input: "task B"})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/tasks", nil)
|
||||
w.handleListTasks(rr, req)
|
||||
|
||||
var resp map[string]any
|
||||
json.NewDecoder(rr.Body).Decode(&resp)
|
||||
if int(resp["total"].(float64)) != 2 {
|
||||
t.Errorf("expected total=2, got %v", resp["total"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMemory_NoDB(t *testing.T) {
|
||||
w := makeTestWorker() // no database set
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/memory", nil)
|
||||
w.handleMemory(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
var resp map[string]any
|
||||
json.NewDecoder(rr.Body).Decode(&resp)
|
||||
if int(resp["total"].(float64)) != 0 {
|
||||
t.Errorf("expected total=0 without DB, got %v", resp["total"])
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Unit: getAgentTools ──────────────────────────────────────────────────────
|
||||
|
||||
func TestGetAgentTools_WithAllowedTools(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
agentTools := w.getAgentTools()
|
||||
|
||||
// Worker has allowedTools = ["http_request", "file_list"]
|
||||
if len(agentTools) == 0 {
|
||||
t.Error("expected some tools, got none")
|
||||
}
|
||||
names := make(map[string]bool)
|
||||
for _, t := range agentTools {
|
||||
names[t.Function.Name] = true
|
||||
}
|
||||
if !names["http_request"] {
|
||||
t.Error("expected http_request in allowed tools")
|
||||
}
|
||||
if !names["file_list"] {
|
||||
t.Error("expected file_list in allowed tools")
|
||||
}
|
||||
// shell_exec should NOT be allowed
|
||||
if names["shell_exec"] {
|
||||
t.Error("shell_exec should NOT be in allowed tools for this agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAgentTools_EmptyAllowedTools_UsesDefaults(t *testing.T) {
|
||||
cfg := mockAgentConfig()
|
||||
cfg.AllowedTools = []string{} // empty
|
||||
w := &AgentWorker{agentID: 1, cfg: cfg, taskQueue: make(chan *Task, 1), tasks: map[string]*Task{}}
|
||||
tools := w.getAgentTools()
|
||||
if len(tools) == 0 {
|
||||
t.Error("expected default tools when allowedTools is empty")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Unit: recent task ring ───────────────────────────────────────────────────
|
||||
|
||||
func TestRecentRing_MaxCapacity(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
// Enqueue more than maxRecentTasks
|
||||
for i := 0; i < maxRecentTasks+10; i++ {
|
||||
// Don't block — drain queue
|
||||
w.EnqueueTask(TaskRequest{Input: "task"})
|
||||
select {
|
||||
case <-w.taskQueue:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
w.recentMu.Lock()
|
||||
count := len(w.recentKeys)
|
||||
w.recentMu.Unlock()
|
||||
|
||||
if count > maxRecentTasks {
|
||||
t.Errorf("recent ring should not exceed %d, got %d", maxRecentTasks, count)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Unit: Task lifecycle ─────────────────────────────────────────────────────
|
||||
|
||||
func TestTaskLifecycle_Timestamps(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
before := time.Now()
|
||||
task := w.EnqueueTask(TaskRequest{Input: "lifecycle test"})
|
||||
after := time.Now()
|
||||
|
||||
if task.CreatedAt.Before(before) || task.CreatedAt.After(after) {
|
||||
t.Errorf("CreatedAt=%v should be between %v and %v", task.CreatedAt, before, after)
|
||||
}
|
||||
if task.StartedAt != nil {
|
||||
t.Error("StartedAt should be nil for pending task")
|
||||
}
|
||||
if task.DoneAt != nil {
|
||||
t.Error("DoneAt should be nil for pending task")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Unit: HTTP Chat handler (no LLM) ────────────────────────────────────────
|
||||
|
||||
func TestHandleChat_InvalidJSON(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", bytes.NewBufferString(`not-json`))
|
||||
rr := httptest.NewRecorder()
|
||||
w.handleChat(rr, req)
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChat_EmptyMessages(t *testing.T) {
|
||||
w := makeTestWorker()
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat",
|
||||
bytes.NewBufferString(`{"messages":[]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
w.handleChat(rr, req)
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for empty messages, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Integration: worker goroutine processes task ─────────────────────────────
|
||||
|
||||
func TestWorkerProcessesTask_WithMockLLM(t *testing.T) {
|
||||
// Create a mock LLM server that returns a simple response
|
||||
mockLLM := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]string{"role": "assistant", "content": "Mock answer"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
"model": "mock-model",
|
||||
})
|
||||
}))
|
||||
defer mockLLM.Close()
|
||||
|
||||
// We can't easily create a full AgentWorker with llm client without more refactoring,
|
||||
// so we test the task state machine directly
|
||||
w := makeTestWorker()
|
||||
|
||||
task := w.EnqueueTask(TaskRequest{Input: "test task", TimeoutSecs: 5})
|
||||
if task.Status != TaskPending {
|
||||
t.Errorf("expected pending, got %s", task.Status)
|
||||
}
|
||||
|
||||
// Simulate task processing (without LLM)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now()
|
||||
w.tasksMu.Lock()
|
||||
task.Status = TaskRunning
|
||||
task.StartedAt = &now
|
||||
w.tasksMu.Unlock()
|
||||
|
||||
// Simulate done
|
||||
doneAt := time.Now()
|
||||
w.tasksMu.Lock()
|
||||
task.Status = TaskDone
|
||||
task.Result = "completed"
|
||||
task.DoneAt = &doneAt
|
||||
w.tasksMu.Unlock()
|
||||
|
||||
_ = ctx
|
||||
|
||||
w.tasksMu.RLock()
|
||||
finalStatus := task.Status
|
||||
w.tasksMu.RUnlock()
|
||||
|
||||
if finalStatus != TaskDone {
|
||||
t.Errorf("expected task done, got %s", finalStatus)
|
||||
}
|
||||
}
|
||||
@@ -4,9 +4,10 @@ go 1.23.4
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/go-chi/chi/v5 v5.2.1 // indirect
|
||||
github.com/go-chi/cors v1.2.1 // indirect
|
||||
github.com/go-chi/chi/v5 v5.2.1
|
||||
github.com/go-chi/cors v1.2.1
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jmoiron/sqlx v1.4.0 // indirect
|
||||
github.com/joho/godotenv v1.5.1 // indirect
|
||||
github.com/joho/godotenv v1.5.1
|
||||
)
|
||||
|
||||
@@ -23,6 +23,11 @@ type AgentConfig struct {
|
||||
IsOrchestrator bool
|
||||
IsSystem bool
|
||||
IsActive bool
|
||||
// Container / Swarm fields (Phase A)
|
||||
ServiceName string
|
||||
ServicePort int
|
||||
ContainerImage string
|
||||
ContainerStatus string // "stopped" | "deploying" | "running" | "error"
|
||||
}
|
||||
|
||||
// AgentRow is a minimal agent representation for listing.
|
||||
@@ -68,7 +73,8 @@ func (d *DB) Close() {
|
||||
// GetOrchestratorConfig loads the agent with isOrchestrator=1 from DB.
|
||||
func (d *DB) GetOrchestratorConfig() (*AgentConfig, error) {
|
||||
row := d.conn.QueryRow(`
|
||||
SELECT id, name, model, systemPrompt, allowedTools, temperature, maxTokens, isOrchestrator, isSystem, isActive
|
||||
SELECT id, name, model, systemPrompt, allowedTools, temperature, maxTokens, isOrchestrator, isSystem, isActive,
|
||||
COALESCE(serviceName,''), COALESCE(servicePort,0), COALESCE(containerImage,'goclaw-agent-worker:latest'), COALESCE(containerStatus,'stopped')
|
||||
FROM agents
|
||||
WHERE isOrchestrator = 1
|
||||
LIMIT 1
|
||||
@@ -79,7 +85,8 @@ func (d *DB) GetOrchestratorConfig() (*AgentConfig, error) {
|
||||
// GetAgentByID loads a specific agent by ID.
|
||||
func (d *DB) GetAgentByID(id int) (*AgentConfig, error) {
|
||||
row := d.conn.QueryRow(`
|
||||
SELECT id, name, model, systemPrompt, allowedTools, temperature, maxTokens, isOrchestrator, isSystem, isActive
|
||||
SELECT id, name, model, systemPrompt, allowedTools, temperature, maxTokens, isOrchestrator, isSystem, isActive,
|
||||
COALESCE(serviceName,''), COALESCE(servicePort,0), COALESCE(containerImage,'goclaw-agent-worker:latest'), COALESCE(containerStatus,'stopped')
|
||||
FROM agents
|
||||
WHERE id = ?
|
||||
LIMIT 1
|
||||
@@ -129,6 +136,7 @@ func scanAgentConfig(row *sql.Row) (*AgentConfig, error) {
|
||||
&systemPrompt, &allowedToolsJSON,
|
||||
&temperature, &maxTokens,
|
||||
&isOrch, &isSystem, &isActive,
|
||||
&cfg.ServiceName, &cfg.ServicePort, &cfg.ContainerImage, &cfg.ContainerStatus,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -154,6 +162,107 @@ func scanAgentConfig(row *sql.Row) (*AgentConfig, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// ─── Agent Container Fields ───────────────────────────────────────────────────
|
||||
// These methods support the agent-worker container architecture where each
|
||||
// agent runs as an autonomous Docker Swarm service.
|
||||
|
||||
// UpdateContainerStatus updates the container lifecycle state of an agent.
|
||||
func (d *DB) UpdateContainerStatus(agentID int, status, serviceName string, servicePort int) error {
|
||||
if d.conn == nil {
|
||||
return nil
|
||||
}
|
||||
_, err := d.conn.Exec(`
|
||||
UPDATE agents
|
||||
SET containerStatus = ?, serviceName = ?, servicePort = ?, updatedAt = NOW()
|
||||
WHERE id = ?
|
||||
`, status, serviceName, servicePort, agentID)
|
||||
return err
|
||||
}
|
||||
|
||||
// HistoryInput holds data for one conversation entry.
|
||||
type HistoryInput struct {
|
||||
AgentID int
|
||||
UserMessage string
|
||||
AgentResponse string
|
||||
ConversationID string
|
||||
Status string // "success" | "error" | "pending"
|
||||
}
|
||||
|
||||
// HistoryRow is a single entry from agentHistory for sliding window memory.
|
||||
type HistoryRow struct {
|
||||
ID int `json:"id"`
|
||||
UserMessage string `json:"userMessage"`
|
||||
AgentResponse string `json:"agentResponse"`
|
||||
ConvID string `json:"conversationId"`
|
||||
}
|
||||
|
||||
// SaveHistory inserts a row into the agentHistory table.
|
||||
// Non-fatal — logs on error but does not return one.
|
||||
func (d *DB) SaveHistory(h HistoryInput) {
|
||||
if d.conn == nil {
|
||||
return
|
||||
}
|
||||
status := h.Status
|
||||
if status == "" {
|
||||
status = "success"
|
||||
}
|
||||
convID := sql.NullString{String: h.ConversationID, Valid: h.ConversationID != ""}
|
||||
resp := sql.NullString{String: h.AgentResponse, Valid: h.AgentResponse != ""}
|
||||
_, err := d.conn.Exec(`
|
||||
INSERT INTO agentHistory (agentId, userMessage, agentResponse, conversationId, status)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`,
|
||||
h.AgentID,
|
||||
truncate(h.UserMessage, 65535),
|
||||
resp,
|
||||
convID,
|
||||
status,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[DB] SaveHistory error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAgentHistory returns the last N conversation turns for an agent, oldest first.
|
||||
func (d *DB) GetAgentHistory(agentID, limit int) ([]HistoryRow, error) {
|
||||
if d.conn == nil {
|
||||
return nil, nil
|
||||
}
|
||||
rows, err := d.conn.Query(`
|
||||
SELECT id, userMessage, COALESCE(agentResponse,''), COALESCE(conversationId,'')
|
||||
FROM agentHistory
|
||||
WHERE agentId = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ?
|
||||
`, agentID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []HistoryRow
|
||||
for rows.Next() {
|
||||
var h HistoryRow
|
||||
if err := rows.Scan(&h.ID, &h.UserMessage, &h.AgentResponse, &h.ConvID); err != nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, h)
|
||||
}
|
||||
// Reverse so oldest is first (for LLM context ordering)
|
||||
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
|
||||
result[i], result[j] = result[j], result[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// truncate caps a string to maxLen bytes.
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen]
|
||||
}
|
||||
|
||||
// normalizeDSN converts mysql://user:pass@host:port/db to user:pass@tcp(host:port)/db
|
||||
func normalizeDSN(dsn string) string {
|
||||
if !strings.HasPrefix(dsn, "mysql://") {
|
||||
|
||||
@@ -96,6 +96,8 @@ func New(llmClient *llm.Client, database *db.DB, projectRoot string) *Orchestrat
|
||||
}
|
||||
// Inject agent list function to avoid circular dependency
|
||||
o.executor = tools.NewExecutor(projectRoot, o.listAgentsFn)
|
||||
// Inject DB so delegate_to_agent can resolve live agent container addresses
|
||||
o.executor.SetDatabase(database)
|
||||
return o
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -13,6 +14,8 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.softuniq.eu/UniqAI/GoClaw/gateway/internal/db"
|
||||
)
|
||||
|
||||
// ─── Types ────────────────────────────────────────────────────────────────────
|
||||
@@ -175,6 +178,8 @@ type Executor struct {
|
||||
httpClient *http.Client
|
||||
// agentListFn is injected to avoid circular dependency with orchestrator
|
||||
agentListFn func() ([]map[string]any, error)
|
||||
// database is used for delegate_to_agent to look up service address
|
||||
database *db.DB
|
||||
}
|
||||
|
||||
func NewExecutor(projectRoot string, agentListFn func() ([]map[string]any, error)) *Executor {
|
||||
@@ -187,6 +192,11 @@ func NewExecutor(projectRoot string, agentListFn func() ([]map[string]any, error
|
||||
}
|
||||
}
|
||||
|
||||
// SetDatabase injects the DB reference so delegate_to_agent can resolve agent addresses.
|
||||
func (e *Executor) SetDatabase(database *db.DB) {
|
||||
e.database = database
|
||||
}
|
||||
|
||||
// Execute dispatches a tool call by name.
|
||||
func (e *Executor) Execute(ctx context.Context, toolName string, argsJSON string) ToolResult {
|
||||
start := time.Now()
|
||||
@@ -215,7 +225,7 @@ func (e *Executor) Execute(ctx context.Context, toolName string, argsJSON string
|
||||
case "list_agents":
|
||||
result, execErr = e.listAgents()
|
||||
case "delegate_to_agent":
|
||||
result, execErr = e.delegateToAgent(args)
|
||||
result, execErr = e.delegateToAgent(ctx, args)
|
||||
default:
|
||||
return ToolResult{Success: false, Error: fmt.Sprintf("unknown tool: %s", toolName), DurationMs: ms(start)}
|
||||
}
|
||||
@@ -446,21 +456,86 @@ func (e *Executor) listAgents() (any, error) {
|
||||
return map[string]any{"agents": agents, "count": len(agents)}, nil
|
||||
}
|
||||
|
||||
func (e *Executor) delegateToAgent(args map[string]any) (any, error) {
|
||||
agentID, _ := args["agentId"].(float64)
|
||||
message, _ := args["message"].(string)
|
||||
if message == "" {
|
||||
return nil, fmt.Errorf("message is required")
|
||||
func (e *Executor) delegateToAgent(ctx context.Context, args map[string]any) (any, error) {
|
||||
agentIDf, _ := args["agentId"].(float64)
|
||||
agentID := int(agentIDf)
|
||||
task, _ := args["task"].(string)
|
||||
if task == "" {
|
||||
task, _ = args["message"].(string) // backward compat
|
||||
}
|
||||
// Delegation is handled at orchestrator level; here we return a placeholder
|
||||
if task == "" {
|
||||
return nil, fmt.Errorf("task (or message) is required")
|
||||
}
|
||||
callbackURL, _ := args["callbackUrl"].(string)
|
||||
async, _ := args["async"].(bool)
|
||||
|
||||
// Resolve agent container address from DB
|
||||
if e.database != nil {
|
||||
cfg, err := e.database.GetAgentByID(agentID)
|
||||
if err == nil && cfg != nil && cfg.ServicePort > 0 && cfg.ContainerStatus == "running" {
|
||||
// Agent is deployed — call its container via overlay DNS
|
||||
// Docker Swarm DNS: service name resolves inside overlay network
|
||||
agentURL := fmt.Sprintf("http://%s:%d", cfg.ServiceName, cfg.ServicePort)
|
||||
if async {
|
||||
return e.postAgentTask(ctx, agentURL, agentID, task, callbackURL)
|
||||
}
|
||||
return e.postAgentChat(ctx, agentURL, agentID, task)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: agent not deployed yet — return informational response
|
||||
return map[string]any{
|
||||
"delegated": true,
|
||||
"agentId": int(agentID),
|
||||
"message": message,
|
||||
"note": "Agent delegation queued — response will be processed in next iteration",
|
||||
"delegated": false,
|
||||
"agentId": agentID,
|
||||
"task": task,
|
||||
"note": fmt.Sprintf("Agent %d is not running (containerStatus != running). Deploy it first via Web Panel.", agentID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// postAgentTask POSTs to agent's /task endpoint (async, returns task_id).
|
||||
func (e *Executor) postAgentTask(ctx context.Context, agentURL string, fromAgentID int, task, callbackURL string) (any, error) {
|
||||
payload, _ := json.Marshal(map[string]any{
|
||||
"input": task,
|
||||
"from_agent_id": fromAgentID,
|
||||
"callback_url": callbackURL,
|
||||
})
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, agentURL+"/task", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delegate build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := e.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delegate HTTP error: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var result map[string]any
|
||||
_ = json.Unmarshal(body, &result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// postAgentChat POSTs to agent's /chat endpoint (sync, waits for response).
|
||||
func (e *Executor) postAgentChat(ctx context.Context, agentURL string, _ int, task string) (any, error) {
|
||||
payload, _ := json.Marshal(map[string]any{
|
||||
"messages": []map[string]string{{"role": "user", "content": task}},
|
||||
})
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, agentURL+"/chat", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delegate build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := e.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delegate HTTP error: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var result map[string]any
|
||||
_ = json.Unmarshal(body, &result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ─── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
func (e *Executor) resolvePath(path string) string {
|
||||
|
||||
Reference in New Issue
Block a user