Files
GoClaw/gateway/cmd/agent-worker/main_test.go
bboxwtf f8e0ca7d5d feat(gateway): restore Phase C full agent lifecycle API
- Restored Phase C gateway code (handlers, main.go, docker client, db)
- Added routes: GET /api/agents/running, POST /api/agents (CRUD),
  POST /api/agents/{id}/deploy, POST /api/agents/{id}/stop,
  POST /api/agents/{id}/restart, POST /api/agents/{id}/scale
- Fixed StopAgent: always try to stop by canonical name goclaw-agent-{id}
  even when serviceName is empty in DB
- Fixed DeployAgent: handle 409 conflict by removing existing container
  and retrying once (idempotent deploy)
- Added swarm_manager.go: background SwarmManager for dead-letter recovery
- Added AGENT_NETWORK and AGENT_DB_URL config options
- Updated .gitignore to exclude gateway binaries
- All agents use standalone docker run (not Swarm) on bridge network

Verified on prod: deploy/stop/restart cycle works correctly,
/api/agents/running returns live running agents with containerStatus
2026-04-19 11:40:39 +00:00

549 lines
16 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.softuniq.eu/UniqAI/GoClaw/gateway/internal/db"
)
// ─── Mock DB agent config ─────────────────────────────────────────────────────
func mockAgentConfig() *db.AgentConfig {
return &db.AgentConfig{
ID: 42,
Name: "Test Agent",
Model: "qwen2.5:7b",
SystemPrompt: "You are a test agent.",
AllowedTools: []string{"http_request", "file_list"},
Temperature: 0.7,
MaxTokens: 2048,
IsSystem: false,
IsOrchestrator: false,
IsActive: true,
ContainerImage: "goclaw-agent-worker:latest",
ContainerStatus: "running",
ServiceName: "goclaw-agent-42",
ServicePort: 8001,
}
}
// ─── Unit: AgentWorker struct ─────────────────────────────────────────────────
func TestAgentWorkerInit(t *testing.T) {
w := &AgentWorker{
agentID: 42,
cfg: mockAgentConfig(),
taskQueue: make(chan *Task, taskQueueDepth),
tasks: make(map[string]*Task),
}
if w.agentID != 42 {
t.Errorf("expected agentID=42, got %d", w.agentID)
}
if w.cfg.Name != "Test Agent" {
t.Errorf("expected name 'Test Agent', got %q", w.cfg.Name)
}
}
// ─── Unit: Task enqueue ───────────────────────────────────────────────────────
func TestEnqueueTask(t *testing.T) {
w := &AgentWorker{
agentID: 42,
cfg: mockAgentConfig(),
taskQueue: make(chan *Task, taskQueueDepth),
tasks: make(map[string]*Task),
}
task := w.EnqueueTask(TaskRequest{
Input: "hello world",
TimeoutSecs: 30,
})
if task.ID == "" {
t.Error("task ID should not be empty")
}
if task.Status != TaskPending {
t.Errorf("expected status=pending, got %q", task.Status)
}
if task.Input != "hello world" {
t.Errorf("expected input='hello world', got %q", task.Input)
}
if len(w.taskQueue) != 1 {
t.Errorf("expected 1 task in queue, got %d", len(w.taskQueue))
}
// Task should be in store
w.tasksMu.RLock()
stored, ok := w.tasks[task.ID]
w.tasksMu.RUnlock()
if !ok {
t.Error("task not found in store")
}
if stored.ID != task.ID {
t.Errorf("stored task ID mismatch: %q != %q", stored.ID, task.ID)
}
}
func TestEnqueueTask_QueueFull(t *testing.T) {
// Queue depth = 1 for this test
w := &AgentWorker{
agentID: 42,
cfg: mockAgentConfig(),
taskQueue: make(chan *Task, 1),
tasks: make(map[string]*Task),
}
// Fill the queue
w.EnqueueTask(TaskRequest{Input: "task 1"})
// Overflow
task2 := w.EnqueueTask(TaskRequest{Input: "task 2"})
w.tasksMu.RLock()
status := task2.Status
w.tasksMu.RUnlock()
if status != TaskFailed {
t.Errorf("expected task2 status=failed when queue full, got %q", status)
}
}
func TestEnqueueTask_DefaultTimeout(t *testing.T) {
w := &AgentWorker{
agentID: 42,
cfg: mockAgentConfig(),
taskQueue: make(chan *Task, taskQueueDepth),
tasks: make(map[string]*Task),
}
task := w.EnqueueTask(TaskRequest{Input: "no timeout set"})
if task.TimeoutSecs != defaultTimeout {
t.Errorf("expected default timeout=%d, got %d", defaultTimeout, task.TimeoutSecs)
}
}
// ─── HTTP Handlers ────────────────────────────────────────────────────────────
func makeTestWorker() *AgentWorker {
mc := defaultMaxConcurrent
sem := make(chan struct{}, mc)
for i := 0; i < mc; i++ {
sem <- struct{}{}
}
return &AgentWorker{
agentID: 42,
cfg: mockAgentConfig(),
taskQueue: make(chan *Task, taskQueueDepth),
tasks: make(map[string]*Task),
rateSem: sem,
maxConcurrent: mc,
}
}
func TestHandleHealth(t *testing.T) {
w := makeTestWorker()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w.handleHealth(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
var body map[string]any
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
t.Fatalf("invalid JSON response: %v", err)
}
if body["status"] != "ok" {
t.Errorf("expected status=ok, got %v", body["status"])
}
if int(body["agentId"].(float64)) != 42 {
t.Errorf("expected agentId=42, got %v", body["agentId"])
}
}
func TestHandleInfo(t *testing.T) {
w := makeTestWorker()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/info", nil)
w.handleInfo(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
var body map[string]any
json.NewDecoder(rr.Body).Decode(&body)
if body["name"] != "Test Agent" {
t.Errorf("expected name='Test Agent', got %v", body["name"])
}
}
func TestHandleTask_Valid(t *testing.T) {
w := makeTestWorker()
body := `{"input":"do something useful","timeout_secs":60}`
req := httptest.NewRequest(http.MethodPost, "/task", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
w.handleTask(rr, req)
if rr.Code != http.StatusAccepted {
t.Errorf("expected 202, got %d", rr.Code)
}
var resp map[string]any
json.NewDecoder(rr.Body).Decode(&resp)
if resp["task_id"] == "" || resp["task_id"] == nil {
t.Error("task_id should be in response")
}
if resp["status"] != string(TaskPending) {
t.Errorf("expected status=pending, got %v", resp["status"])
}
}
func TestHandleTask_EmptyInput(t *testing.T) {
w := makeTestWorker()
req := httptest.NewRequest(http.MethodPost, "/task", bytes.NewBufferString(`{"input":""}`))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
w.handleTask(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", rr.Code)
}
}
func TestHandleTask_InvalidJSON(t *testing.T) {
w := makeTestWorker()
req := httptest.NewRequest(http.MethodPost, "/task", bytes.NewBufferString(`not-json`))
rr := httptest.NewRecorder()
w.handleTask(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", rr.Code)
}
}
func TestHandleGetTask_NotFound(t *testing.T) {
// We can't easily use chi.URLParam in unit tests without a full router.
// Test the store logic directly instead.
w := makeTestWorker()
w.tasksMu.RLock()
_, ok := w.tasks["nonexistent-id"]
w.tasksMu.RUnlock()
if ok {
t.Error("nonexistent task should not be found")
}
}
func TestHandleListTasks_Empty(t *testing.T) {
w := makeTestWorker()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/tasks", nil)
w.handleListTasks(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
var resp map[string]any
json.NewDecoder(rr.Body).Decode(&resp)
if resp["total"].(float64) != 0 {
t.Errorf("expected total=0, got %v", resp["total"])
}
}
func TestHandleListTasks_WithTasks(t *testing.T) {
w := makeTestWorker()
w.EnqueueTask(TaskRequest{Input: "task A"})
w.EnqueueTask(TaskRequest{Input: "task B"})
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/tasks", nil)
w.handleListTasks(rr, req)
var resp map[string]any
json.NewDecoder(rr.Body).Decode(&resp)
if int(resp["total"].(float64)) != 2 {
t.Errorf("expected total=2, got %v", resp["total"])
}
}
func TestHandleMemory_NoDB(t *testing.T) {
w := makeTestWorker() // no database set
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/memory", nil)
w.handleMemory(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
var resp map[string]any
json.NewDecoder(rr.Body).Decode(&resp)
if int(resp["total"].(float64)) != 0 {
t.Errorf("expected total=0 without DB, got %v", resp["total"])
}
}
// ─── Unit: getAgentTools ──────────────────────────────────────────────────────
func TestGetAgentTools_WithAllowedTools(t *testing.T) {
w := makeTestWorker()
agentTools := w.getAgentTools()
// Worker has allowedTools = ["http_request", "file_list"]
if len(agentTools) == 0 {
t.Error("expected some tools, got none")
}
names := make(map[string]bool)
for _, t := range agentTools {
names[t.Function.Name] = true
}
if !names["http_request"] {
t.Error("expected http_request in allowed tools")
}
if !names["file_list"] {
t.Error("expected file_list in allowed tools")
}
// shell_exec should NOT be allowed
if names["shell_exec"] {
t.Error("shell_exec should NOT be in allowed tools for this agent")
}
}
func TestGetAgentTools_EmptyAllowedTools_UsesDefaults(t *testing.T) {
cfg := mockAgentConfig()
cfg.AllowedTools = []string{} // empty
w := &AgentWorker{agentID: 1, cfg: cfg, taskQueue: make(chan *Task, 1), tasks: map[string]*Task{}}
tools := w.getAgentTools()
if len(tools) == 0 {
t.Error("expected default tools when allowedTools is empty")
}
}
// ─── Unit: recent task ring ───────────────────────────────────────────────────
func TestRecentRing_MaxCapacity(t *testing.T) {
w := makeTestWorker()
// Enqueue more than maxRecentTasks
for i := 0; i < maxRecentTasks+10; i++ {
// Don't block — drain queue
w.EnqueueTask(TaskRequest{Input: "task"})
select {
case <-w.taskQueue:
default:
}
}
w.recentMu.Lock()
count := len(w.recentKeys)
w.recentMu.Unlock()
if count > maxRecentTasks {
t.Errorf("recent ring should not exceed %d, got %d", maxRecentTasks, count)
}
}
// ─── Unit: Task lifecycle ─────────────────────────────────────────────────────
func TestTaskLifecycle_Timestamps(t *testing.T) {
w := makeTestWorker()
before := time.Now()
task := w.EnqueueTask(TaskRequest{Input: "lifecycle test"})
after := time.Now()
if task.CreatedAt.Before(before) || task.CreatedAt.After(after) {
t.Errorf("CreatedAt=%v should be between %v and %v", task.CreatedAt, before, after)
}
if task.StartedAt != nil {
t.Error("StartedAt should be nil for pending task")
}
if task.DoneAt != nil {
t.Error("DoneAt should be nil for pending task")
}
}
// ─── Unit: HTTP Chat handler (no LLM) ────────────────────────────────────────
func TestHandleChat_InvalidJSON(t *testing.T) {
w := makeTestWorker()
req := httptest.NewRequest(http.MethodPost, "/chat", bytes.NewBufferString(`not-json`))
rr := httptest.NewRecorder()
w.handleChat(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", rr.Code)
}
}
func TestHandleChat_EmptyMessages(t *testing.T) {
w := makeTestWorker()
req := httptest.NewRequest(http.MethodPost, "/chat",
bytes.NewBufferString(`{"messages":[]}`))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
w.handleChat(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400 for empty messages, got %d", rr.Code)
}
}
// ─── Integration: worker goroutine processes task ─────────────────────────────
func TestWorkerProcessesTask_WithMockLLM(t *testing.T) {
// Create a mock LLM server that returns a simple response
mockLLM := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{
"message": map[string]string{"role": "assistant", "content": "Mock answer"},
"finish_reason": "stop",
},
},
"model": "mock-model",
})
}))
defer mockLLM.Close()
// We can't easily create a full AgentWorker with llm client without more refactoring,
// so we test the task state machine directly
w := makeTestWorker()
task := w.EnqueueTask(TaskRequest{Input: "test task", TimeoutSecs: 5})
if task.Status != TaskPending {
t.Errorf("expected pending, got %s", task.Status)
}
// Simulate task processing (without LLM)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
now := time.Now()
w.tasksMu.Lock()
task.Status = TaskRunning
task.StartedAt = &now
w.tasksMu.Unlock()
// Simulate done
doneAt := time.Now()
w.tasksMu.Lock()
task.Status = TaskDone
task.Result = "completed"
task.DoneAt = &doneAt
w.tasksMu.Unlock()
_ = ctx
w.tasksMu.RLock()
finalStatus := task.Status
w.tasksMu.RUnlock()
if finalStatus != TaskDone {
t.Errorf("expected task done, got %s", finalStatus)
}
}
// ─── Phase C: Rate-limiting tests ─────────────────────────────────────────────
// TestRateLimiting_TokensInitialized verifies that the semaphore is filled with
// maxConcurrent tokens on worker creation.
func TestRateLimiting_TokensInitialized(t *testing.T) {
mc := 3
sem := make(chan struct{}, mc)
for i := 0; i < mc; i++ {
sem <- struct{}{}
}
w := &AgentWorker{
agentID: 42,
cfg: mockAgentConfig(),
taskQueue: make(chan *Task, taskQueueDepth),
tasks: make(map[string]*Task),
rateSem: sem,
maxConcurrent: mc,
}
if len(w.rateSem) != mc {
t.Errorf("expected %d tokens in semaphore, got %d", mc, len(w.rateSem))
}
if cap(w.rateSem) != mc {
t.Errorf("expected semaphore capacity=%d, got %d", mc, cap(w.rateSem))
}
}
// TestRateLimiting_TokenAcquireRelease verifies that tokens can be acquired and
// released correctly (simulating what processTask does).
func TestRateLimiting_TokenAcquireRelease(t *testing.T) {
mc := 2
sem := make(chan struct{}, mc)
for i := 0; i < mc; i++ {
sem <- struct{}{}
}
w := &AgentWorker{
agentID: 42,
cfg: mockAgentConfig(),
taskQueue: make(chan *Task, taskQueueDepth),
tasks: make(map[string]*Task),
rateSem: sem,
maxConcurrent: mc,
}
// Acquire both tokens
<-w.rateSem
<-w.rateSem
if len(w.rateSem) != 0 {
t.Errorf("expected 0 free tokens after acquiring all, got %d", len(w.rateSem))
}
// Release one token
w.rateSem <- struct{}{}
if len(w.rateSem) != 1 {
t.Errorf("expected 1 free token after release, got %d", len(w.rateSem))
}
// Release second token
w.rateSem <- struct{}{}
if len(w.rateSem) != mc {
t.Errorf("expected %d free tokens after full release, got %d", mc, len(w.rateSem))
}
}
// TestRateLimiting_HealthShowsActiveTasks verifies the /health endpoint reports
// active task count and rate-limit info.
func TestRateLimiting_HealthShowsActiveTasks(t *testing.T) {
mc := 3
sem := make(chan struct{}, mc)
for i := 0; i < mc; i++ {
sem <- struct{}{}
}
w := &AgentWorker{
agentID: 42,
cfg: mockAgentConfig(),
taskQueue: make(chan *Task, taskQueueDepth),
tasks: make(map[string]*Task),
rateSem: sem,
maxConcurrent: mc,
}
// Simulate 1 active task (consume 1 token)
<-w.rateSem
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w.handleHealth(rr, req)
var body map[string]any
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
if int(body["maxConcurrent"].(float64)) != mc {
t.Errorf("expected maxConcurrent=%d, got %v", mc, body["maxConcurrent"])
}
if int(body["rateLimitFree"].(float64)) != mc-1 {
t.Errorf("expected rateLimitFree=%d, got %v", mc-1, body["rateLimitFree"])
}
if int(body["activeTasks"].(float64)) != 1 {
t.Errorf("expected activeTasks=1, got %v", body["activeTasks"])
}
}