- db.go: added SaveMetric(MetricInput) and SaveHistory(HistoryInput) methods that write directly to MySQL; non-fatal (log-only on error) - handlers.go (OrchestratorStream): after each SSE stream finishes, an async goroutine saves agentMetrics (agentId, requestId, tokens, processingTimeMs, model, toolsCalled, status) and agentHistory (userMessage, agentResponse); both error and success paths covered; orchAgentID resolved from DB - routers.ts (agents.chat): saveMetric() called for both success and error paths in the Node.js direct-chat fallback (was only saving agentHistory before) - Verified: agentMetrics row ID=2 shows processingTimeMs=2133, totalTokens=143, model=minimax-m2.7, Cyrillic text stored correctly as UTF-8
325 lines
9.0 KiB
Go
325 lines
9.0 KiB
Go
// Package db provides MySQL/TiDB connectivity and agent config queries.
|
|
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"strings"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
// AgentConfig holds the orchestrator/agent configuration loaded from DB.
|
|
type AgentConfig struct {
|
|
ID int
|
|
Name string
|
|
Model string
|
|
SystemPrompt string
|
|
AllowedTools []string
|
|
Temperature float64
|
|
MaxTokens int
|
|
IsOrchestrator bool
|
|
IsSystem bool
|
|
IsActive bool
|
|
}
|
|
|
|
// AgentRow is a minimal agent representation for listing.
|
|
type AgentRow struct {
|
|
ID int `json:"id"`
|
|
Name string `json:"name"`
|
|
Role string `json:"role"`
|
|
Model string `json:"model"`
|
|
Description string `json:"description"`
|
|
IsActive bool `json:"isActive"`
|
|
IsSystem bool `json:"isSystem"`
|
|
IsOrchestrator bool `json:"isOrchestrator"`
|
|
}
|
|
|
|
type DB struct {
|
|
conn *sql.DB
|
|
}
|
|
|
|
func Connect(dsn string) (*DB, error) {
|
|
if dsn == "" {
|
|
return nil, fmt.Errorf("DATABASE_URL is empty")
|
|
}
|
|
// Convert mysql:// URL to DSN format if needed
|
|
dsn = normalizeDSN(dsn)
|
|
|
|
conn, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open DB: %w", err)
|
|
}
|
|
if err := conn.Ping(); err != nil {
|
|
return nil, fmt.Errorf("failed to ping DB: %w", err)
|
|
}
|
|
log.Println("[DB] Connected to MySQL")
|
|
return &DB{conn: conn}, nil
|
|
}
|
|
|
|
func (d *DB) Close() {
|
|
if d.conn != nil {
|
|
_ = d.conn.Close()
|
|
}
|
|
}
|
|
|
|
// GetOrchestratorConfig loads the agent with isOrchestrator=1 from DB.
|
|
func (d *DB) GetOrchestratorConfig() (*AgentConfig, error) {
|
|
row := d.conn.QueryRow(`
|
|
SELECT id, name, model, systemPrompt, allowedTools, temperature, maxTokens, isOrchestrator, isSystem, isActive
|
|
FROM agents
|
|
WHERE isOrchestrator = 1
|
|
LIMIT 1
|
|
`)
|
|
return scanAgentConfig(row)
|
|
}
|
|
|
|
// GetAgentByID loads a specific agent by ID.
|
|
func (d *DB) GetAgentByID(id int) (*AgentConfig, error) {
|
|
row := d.conn.QueryRow(`
|
|
SELECT id, name, model, systemPrompt, allowedTools, temperature, maxTokens, isOrchestrator, isSystem, isActive
|
|
FROM agents
|
|
WHERE id = ?
|
|
LIMIT 1
|
|
`, id)
|
|
return scanAgentConfig(row)
|
|
}
|
|
|
|
// ListAgents returns all active agents.
|
|
func (d *DB) ListAgents() ([]AgentRow, error) {
|
|
rows, err := d.conn.Query(`
|
|
SELECT id, name, role, model, COALESCE(description,''), isActive, isSystem, isOrchestrator
|
|
FROM agents
|
|
ORDER BY isOrchestrator DESC, isSystem DESC, id ASC
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var agents []AgentRow
|
|
for rows.Next() {
|
|
var a AgentRow
|
|
var isActive, isSystem, isOrch int
|
|
if err := rows.Scan(&a.ID, &a.Name, &a.Role, &a.Model, &a.Description, &isActive, &isSystem, &isOrch); err != nil {
|
|
continue
|
|
}
|
|
a.IsActive = isActive == 1
|
|
a.IsSystem = isSystem == 1
|
|
a.IsOrchestrator = isOrch == 1
|
|
agents = append(agents, a)
|
|
}
|
|
return agents, nil
|
|
}
|
|
|
|
// ─── LLM Provider ─────────────────────────────────────────────────────────────
|
|
|
|
// ProviderRow holds the active LLM provider config from DB.
|
|
type ProviderRow struct {
|
|
ID int
|
|
Name string
|
|
BaseURL string
|
|
APIKey string // decrypted (Node.js encrypts, Go just reads raw for now)
|
|
}
|
|
|
|
// GetActiveProvider returns the active LLM provider from the llmProviders table.
|
|
// Note: The API key is stored AES-256-GCM encrypted by the Node.js server.
|
|
// The Go gateway reads the raw encrypted bytes but cannot decrypt them (no shared key in Go).
|
|
// The proper flow: Node.js decrypts the key and passes it via /api/providers/reload.
|
|
// For now, GetActiveProvider returns the stored encrypted bytes as-is (not useful for direct use).
|
|
// Use UpdateCredentials on the LLM client instead.
|
|
func (d *DB) GetActiveProvider() (*ProviderRow, error) {
|
|
var p ProviderRow
|
|
var apiKeyEncrypted sql.NullString
|
|
row := d.conn.QueryRow(`
|
|
SELECT id, name, baseUrl, COALESCE(apiKeyEncrypted, '')
|
|
FROM llmProviders
|
|
WHERE isActive = 1
|
|
LIMIT 1
|
|
`)
|
|
err := row.Scan(&p.ID, &p.Name, &p.BaseURL, &apiKeyEncrypted)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// We cannot decrypt the key in Go (different crypto impl from Node.js)
|
|
// Return empty key — the LLM client will use its env-configured key
|
|
p.APIKey = ""
|
|
return &p, nil
|
|
}
|
|
|
|
// ─── Metrics & History ────────────────────────────────────────────────────────
|
|
|
|
// MetricInput holds data for a single orchestrator request metric.
|
|
type MetricInput struct {
|
|
AgentID int
|
|
RequestID string
|
|
UserMessage string
|
|
AgentResponse string
|
|
InputTokens int
|
|
OutputTokens int
|
|
TotalTokens int
|
|
ProcessingTimeMs int64
|
|
Status string // "success" | "error" | "timeout"
|
|
ErrorMessage string
|
|
ToolsCalled []string
|
|
Model string
|
|
}
|
|
|
|
// SaveMetric inserts a row into the agentMetrics table.
|
|
// Non-fatal — logs on error but does not return one.
|
|
func (d *DB) SaveMetric(m MetricInput) {
|
|
if d.conn == nil {
|
|
return
|
|
}
|
|
toolsJSON, _ := json.Marshal(m.ToolsCalled)
|
|
_, err := d.conn.Exec(`
|
|
INSERT INTO agentMetrics
|
|
(agentId, requestId, userMessage, agentResponse,
|
|
inputTokens, outputTokens, totalTokens,
|
|
processingTimeMs, status, errorMessage, toolsCalled, model)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
`,
|
|
m.AgentID,
|
|
m.RequestID,
|
|
truncate(m.UserMessage, 65535),
|
|
truncate(m.AgentResponse, 65535),
|
|
m.InputTokens, m.OutputTokens, m.TotalTokens,
|
|
m.ProcessingTimeMs,
|
|
m.Status,
|
|
m.ErrorMessage,
|
|
string(toolsJSON),
|
|
m.Model,
|
|
)
|
|
if err != nil {
|
|
log.Printf("[DB] SaveMetric error: %v", err)
|
|
}
|
|
}
|
|
|
|
// HistoryInput holds data for one conversation entry.
|
|
type HistoryInput struct {
|
|
AgentID int
|
|
UserMessage string
|
|
AgentResponse string
|
|
ConversationID string
|
|
Status string // "success" | "error" | "pending"
|
|
}
|
|
|
|
// SaveHistory inserts a row into the agentHistory table.
|
|
// Non-fatal — logs on error but does not return one.
|
|
func (d *DB) SaveHistory(h HistoryInput) {
|
|
if d.conn == nil {
|
|
return
|
|
}
|
|
status := h.Status
|
|
if status == "" {
|
|
status = "success"
|
|
}
|
|
convID := sql.NullString{String: h.ConversationID, Valid: h.ConversationID != ""}
|
|
resp := sql.NullString{String: h.AgentResponse, Valid: h.AgentResponse != ""}
|
|
_, err := d.conn.Exec(`
|
|
INSERT INTO agentHistory (agentId, userMessage, agentResponse, conversationId, status)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
`,
|
|
h.AgentID,
|
|
truncate(h.UserMessage, 65535),
|
|
resp,
|
|
convID,
|
|
status,
|
|
)
|
|
if err != nil {
|
|
log.Printf("[DB] SaveHistory error: %v", err)
|
|
}
|
|
}
|
|
|
|
// truncate caps a string to maxLen bytes (not runes — fast path for DB limits).
|
|
func truncate(s string, maxLen int) string {
|
|
if len(s) <= maxLen {
|
|
return s
|
|
}
|
|
return s[:maxLen]
|
|
}
|
|
|
|
// ─── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
func scanAgentConfig(row *sql.Row) (*AgentConfig, error) {
|
|
var cfg AgentConfig
|
|
var systemPrompt sql.NullString
|
|
var allowedToolsJSON sql.NullString
|
|
var temperature sql.NullFloat64
|
|
var maxTokens sql.NullInt64
|
|
var isOrch, isSystem, isActive int
|
|
|
|
err := row.Scan(
|
|
&cfg.ID, &cfg.Name, &cfg.Model,
|
|
&systemPrompt, &allowedToolsJSON,
|
|
&temperature, &maxTokens,
|
|
&isOrch, &isSystem, &isActive,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cfg.SystemPrompt = systemPrompt.String
|
|
cfg.Temperature = temperature.Float64
|
|
if cfg.Temperature == 0 {
|
|
cfg.Temperature = 0.5
|
|
}
|
|
cfg.MaxTokens = int(maxTokens.Int64)
|
|
if cfg.MaxTokens == 0 {
|
|
cfg.MaxTokens = 8192
|
|
}
|
|
cfg.IsOrchestrator = isOrch == 1
|
|
cfg.IsSystem = isSystem == 1
|
|
cfg.IsActive = isActive == 1
|
|
|
|
if allowedToolsJSON.Valid && allowedToolsJSON.String != "" && allowedToolsJSON.String != "null" {
|
|
_ = json.Unmarshal([]byte(allowedToolsJSON.String), &cfg.AllowedTools)
|
|
}
|
|
|
|
return &cfg, nil
|
|
}
|
|
|
|
// normalizeDSN converts mysql://user:pass@host:port/db to user:pass@tcp(host:port)/db
|
|
func normalizeDSN(dsn string) string {
|
|
if !strings.HasPrefix(dsn, "mysql://") {
|
|
return dsn
|
|
}
|
|
// Strip scheme
|
|
dsn = strings.TrimPrefix(dsn, "mysql://")
|
|
// user:pass@host:port/db → user:pass@tcp(host:port)/db
|
|
atIdx := strings.LastIndex(dsn, "@")
|
|
if atIdx < 0 {
|
|
return dsn
|
|
}
|
|
userInfo := dsn[:atIdx]
|
|
hostDB := dsn[atIdx+1:]
|
|
|
|
// Split host:port/db
|
|
slashIdx := strings.Index(hostDB, "/")
|
|
var hostPort, dbName string
|
|
if slashIdx >= 0 {
|
|
hostPort = hostDB[:slashIdx]
|
|
dbName = hostDB[slashIdx:]
|
|
} else {
|
|
hostPort = hostDB
|
|
dbName = ""
|
|
}
|
|
|
|
// TiDB Cloud and other cloud MySQL require TLS — detect by host pattern
|
|
tlsParam := ""
|
|
if strings.Contains(hostPort, "tidbcloud") ||
|
|
strings.Contains(hostPort, "tidb.cloud") ||
|
|
strings.Contains(hostPort, "aws") ||
|
|
strings.Contains(hostPort, "gcp") ||
|
|
strings.Contains(hostPort, "azure") {
|
|
tlsParam = "&tls=true"
|
|
}
|
|
// Also detect if the original DSN had ?ssl or ?tls params
|
|
if strings.Contains(dbName, "ssl") || strings.Contains(dbName, "tls") {
|
|
tlsParam = "" // already handled in dbName
|
|
}
|
|
return fmt.Sprintf("%s@tcp(%s)%s?parseTime=true&charset=utf8mb4%s", userInfo, hostPort, dbName, tlsParam)
|
|
}
|