Files
GoClaw/gateway/internal/llm/client.go

197 lines
5.1 KiB
Go

// Package llm provides an OpenAI-compatible client for the Ollama Cloud API.
package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// ─── Types ────────────────────────────────────────────────────────────────────
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function ToolCallFunction `json:"function"`
}
type ToolCallFunction struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type Tool struct {
Type string `json:"type"`
Function ToolFunction `json:"function"`
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]any `json:"parameters"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
}
type ChatChoice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type ChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type ModelsResponse struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
// ─── Client ───────────────────────────────────────────────────────────────────
type Client struct {
baseURL string
apiKey string
httpClient *http.Client
}
func NewClient(baseURL, apiKey string) *Client {
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
apiKey: apiKey,
httpClient: &http.Client{
Timeout: 180 * time.Second,
},
}
}
func (c *Client) headers() map[string]string {
h := map[string]string{
"Content-Type": "application/json",
}
if c.apiKey != "" {
h["Authorization"] = "Bearer " + c.apiKey
}
return h
}
// Health checks if the Ollama API is reachable.
func (c *Client) Health(ctx context.Context) (bool, int64, error) {
start := time.Now()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/models", nil)
if err != nil {
return false, 0, err
}
for k, v := range c.headers() {
req.Header.Set(k, v)
}
resp, err := c.httpClient.Do(req)
latency := time.Since(start).Milliseconds()
if err != nil {
return false, latency, err
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK, latency, nil
}
// ListModels returns available models.
func (c *Client) ListModels(ctx context.Context) (*ModelsResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/models", nil)
if err != nil {
return nil, err
}
for k, v := range c.headers() {
req.Header.Set(k, v)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("ollama API error (%d): %s", resp.StatusCode, string(body))
}
var result ModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return &result, nil
}
// Chat sends a chat completion request (non-streaming).
func (c *Client) Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) {
req.Stream = false
body, err := json.Marshal(req)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
c.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, err
}
for k, v := range c.headers() {
httpReq.Header.Set(k, v)
}
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("ollama chat API error (%d): %s", resp.StatusCode, string(respBody))
}
var result ChatResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return &result, nil
}