package main import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "git.softuniq.eu/UniqAI/GoClaw/gateway/internal/db" ) // ─── Mock DB agent config ───────────────────────────────────────────────────── func mockAgentConfig() *db.AgentConfig { return &db.AgentConfig{ ID: 42, Name: "Test Agent", Model: "qwen2.5:7b", SystemPrompt: "You are a test agent.", AllowedTools: []string{"http_request", "file_list"}, Temperature: 0.7, MaxTokens: 2048, IsSystem: false, IsOrchestrator: false, IsActive: true, ContainerImage: "goclaw-agent-worker:latest", ContainerStatus: "running", ServiceName: "goclaw-agent-42", ServicePort: 8001, } } // ─── Unit: AgentWorker struct ───────────────────────────────────────────────── func TestAgentWorkerInit(t *testing.T) { w := &AgentWorker{ agentID: 42, cfg: mockAgentConfig(), taskQueue: make(chan *Task, taskQueueDepth), tasks: make(map[string]*Task), } if w.agentID != 42 { t.Errorf("expected agentID=42, got %d", w.agentID) } if w.cfg.Name != "Test Agent" { t.Errorf("expected name 'Test Agent', got %q", w.cfg.Name) } } // ─── Unit: Task enqueue ─────────────────────────────────────────────────────── func TestEnqueueTask(t *testing.T) { w := &AgentWorker{ agentID: 42, cfg: mockAgentConfig(), taskQueue: make(chan *Task, taskQueueDepth), tasks: make(map[string]*Task), } task := w.EnqueueTask(TaskRequest{ Input: "hello world", TimeoutSecs: 30, }) if task.ID == "" { t.Error("task ID should not be empty") } if task.Status != TaskPending { t.Errorf("expected status=pending, got %q", task.Status) } if task.Input != "hello world" { t.Errorf("expected input='hello world', got %q", task.Input) } if len(w.taskQueue) != 1 { t.Errorf("expected 1 task in queue, got %d", len(w.taskQueue)) } // Task should be in store w.tasksMu.RLock() stored, ok := w.tasks[task.ID] w.tasksMu.RUnlock() if !ok { t.Error("task not found in store") } if stored.ID != task.ID { t.Errorf("stored task ID mismatch: %q != %q", stored.ID, task.ID) } } func TestEnqueueTask_QueueFull(t *testing.T) { // Queue depth = 1 for this test w := &AgentWorker{ agentID: 42, cfg: mockAgentConfig(), taskQueue: make(chan *Task, 1), tasks: make(map[string]*Task), } // Fill the queue w.EnqueueTask(TaskRequest{Input: "task 1"}) // Overflow task2 := w.EnqueueTask(TaskRequest{Input: "task 2"}) w.tasksMu.RLock() status := task2.Status w.tasksMu.RUnlock() if status != TaskFailed { t.Errorf("expected task2 status=failed when queue full, got %q", status) } } func TestEnqueueTask_DefaultTimeout(t *testing.T) { w := &AgentWorker{ agentID: 42, cfg: mockAgentConfig(), taskQueue: make(chan *Task, taskQueueDepth), tasks: make(map[string]*Task), } task := w.EnqueueTask(TaskRequest{Input: "no timeout set"}) if task.TimeoutSecs != defaultTimeout { t.Errorf("expected default timeout=%d, got %d", defaultTimeout, task.TimeoutSecs) } } // ─── HTTP Handlers ──────────────────────────────────────────────────────────── func makeTestWorker() *AgentWorker { return &AgentWorker{ agentID: 42, cfg: mockAgentConfig(), taskQueue: make(chan *Task, taskQueueDepth), tasks: make(map[string]*Task), } } func TestHandleHealth(t *testing.T) { w := makeTestWorker() rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/health", nil) w.handleHealth(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } var body map[string]any if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { t.Fatalf("invalid JSON response: %v", err) } if body["status"] != "ok" { t.Errorf("expected status=ok, got %v", body["status"]) } if int(body["agentId"].(float64)) != 42 { t.Errorf("expected agentId=42, got %v", body["agentId"]) } } func TestHandleInfo(t *testing.T) { w := makeTestWorker() rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/info", nil) w.handleInfo(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } var body map[string]any json.NewDecoder(rr.Body).Decode(&body) if body["name"] != "Test Agent" { t.Errorf("expected name='Test Agent', got %v", body["name"]) } } func TestHandleTask_Valid(t *testing.T) { w := makeTestWorker() body := `{"input":"do something useful","timeout_secs":60}` req := httptest.NewRequest(http.MethodPost, "/task", bytes.NewBufferString(body)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() w.handleTask(rr, req) if rr.Code != http.StatusAccepted { t.Errorf("expected 202, got %d", rr.Code) } var resp map[string]any json.NewDecoder(rr.Body).Decode(&resp) if resp["task_id"] == "" || resp["task_id"] == nil { t.Error("task_id should be in response") } if resp["status"] != string(TaskPending) { t.Errorf("expected status=pending, got %v", resp["status"]) } } func TestHandleTask_EmptyInput(t *testing.T) { w := makeTestWorker() req := httptest.NewRequest(http.MethodPost, "/task", bytes.NewBufferString(`{"input":""}`)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() w.handleTask(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", rr.Code) } } func TestHandleTask_InvalidJSON(t *testing.T) { w := makeTestWorker() req := httptest.NewRequest(http.MethodPost, "/task", bytes.NewBufferString(`not-json`)) rr := httptest.NewRecorder() w.handleTask(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", rr.Code) } } func TestHandleGetTask_NotFound(t *testing.T) { // We can't easily use chi.URLParam in unit tests without a full router. // Test the store logic directly instead. w := makeTestWorker() w.tasksMu.RLock() _, ok := w.tasks["nonexistent-id"] w.tasksMu.RUnlock() if ok { t.Error("nonexistent task should not be found") } } func TestHandleListTasks_Empty(t *testing.T) { w := makeTestWorker() rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/tasks", nil) w.handleListTasks(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } var resp map[string]any json.NewDecoder(rr.Body).Decode(&resp) if resp["total"].(float64) != 0 { t.Errorf("expected total=0, got %v", resp["total"]) } } func TestHandleListTasks_WithTasks(t *testing.T) { w := makeTestWorker() w.EnqueueTask(TaskRequest{Input: "task A"}) w.EnqueueTask(TaskRequest{Input: "task B"}) rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/tasks", nil) w.handleListTasks(rr, req) var resp map[string]any json.NewDecoder(rr.Body).Decode(&resp) if int(resp["total"].(float64)) != 2 { t.Errorf("expected total=2, got %v", resp["total"]) } } func TestHandleMemory_NoDB(t *testing.T) { w := makeTestWorker() // no database set rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/memory", nil) w.handleMemory(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } var resp map[string]any json.NewDecoder(rr.Body).Decode(&resp) if int(resp["total"].(float64)) != 0 { t.Errorf("expected total=0 without DB, got %v", resp["total"]) } } // ─── Unit: getAgentTools ────────────────────────────────────────────────────── func TestGetAgentTools_WithAllowedTools(t *testing.T) { w := makeTestWorker() agentTools := w.getAgentTools() // Worker has allowedTools = ["http_request", "file_list"] if len(agentTools) == 0 { t.Error("expected some tools, got none") } names := make(map[string]bool) for _, t := range agentTools { names[t.Function.Name] = true } if !names["http_request"] { t.Error("expected http_request in allowed tools") } if !names["file_list"] { t.Error("expected file_list in allowed tools") } // shell_exec should NOT be allowed if names["shell_exec"] { t.Error("shell_exec should NOT be in allowed tools for this agent") } } func TestGetAgentTools_EmptyAllowedTools_UsesDefaults(t *testing.T) { cfg := mockAgentConfig() cfg.AllowedTools = []string{} // empty w := &AgentWorker{agentID: 1, cfg: cfg, taskQueue: make(chan *Task, 1), tasks: map[string]*Task{}} tools := w.getAgentTools() if len(tools) == 0 { t.Error("expected default tools when allowedTools is empty") } } // ─── Unit: recent task ring ─────────────────────────────────────────────────── func TestRecentRing_MaxCapacity(t *testing.T) { w := makeTestWorker() // Enqueue more than maxRecentTasks for i := 0; i < maxRecentTasks+10; i++ { // Don't block — drain queue w.EnqueueTask(TaskRequest{Input: "task"}) select { case <-w.taskQueue: default: } } w.recentMu.Lock() count := len(w.recentKeys) w.recentMu.Unlock() if count > maxRecentTasks { t.Errorf("recent ring should not exceed %d, got %d", maxRecentTasks, count) } } // ─── Unit: Task lifecycle ───────────────────────────────────────────────────── func TestTaskLifecycle_Timestamps(t *testing.T) { w := makeTestWorker() before := time.Now() task := w.EnqueueTask(TaskRequest{Input: "lifecycle test"}) after := time.Now() if task.CreatedAt.Before(before) || task.CreatedAt.After(after) { t.Errorf("CreatedAt=%v should be between %v and %v", task.CreatedAt, before, after) } if task.StartedAt != nil { t.Error("StartedAt should be nil for pending task") } if task.DoneAt != nil { t.Error("DoneAt should be nil for pending task") } } // ─── Unit: HTTP Chat handler (no LLM) ──────────────────────────────────────── func TestHandleChat_InvalidJSON(t *testing.T) { w := makeTestWorker() req := httptest.NewRequest(http.MethodPost, "/chat", bytes.NewBufferString(`not-json`)) rr := httptest.NewRecorder() w.handleChat(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", rr.Code) } } func TestHandleChat_EmptyMessages(t *testing.T) { w := makeTestWorker() req := httptest.NewRequest(http.MethodPost, "/chat", bytes.NewBufferString(`{"messages":[]}`)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() w.handleChat(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("expected 400 for empty messages, got %d", rr.Code) } } // ─── Integration: worker goroutine processes task ───────────────────────────── func TestWorkerProcessesTask_WithMockLLM(t *testing.T) { // Create a mock LLM server that returns a simple response mockLLM := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]any{ "choices": []map[string]any{ { "message": map[string]string{"role": "assistant", "content": "Mock answer"}, "finish_reason": "stop", }, }, "model": "mock-model", }) })) defer mockLLM.Close() // We can't easily create a full AgentWorker with llm client without more refactoring, // so we test the task state machine directly w := makeTestWorker() task := w.EnqueueTask(TaskRequest{Input: "test task", TimeoutSecs: 5}) if task.Status != TaskPending { t.Errorf("expected pending, got %s", task.Status) } // Simulate task processing (without LLM) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() now := time.Now() w.tasksMu.Lock() task.Status = TaskRunning task.StartedAt = &now w.tasksMu.Unlock() // Simulate done doneAt := time.Now() w.tasksMu.Lock() task.Status = TaskDone task.Result = "completed" task.DoneAt = &doneAt w.tasksMu.Unlock() _ = ctx w.tasksMu.RLock() finalStatus := task.Status w.tasksMu.RUnlock() if finalStatus != TaskDone { t.Errorf("expected task done, got %s", finalStatus) } }