svs-services-server/completion/llama-server/llama-server.go

116 lines
3.2 KiB
Go

package llamaserver
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
)
type LlamaServerProvider struct {
Host string // http://localhost:8080/
Model string
}
type Message struct {
Role string `json:"role"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Content string `json:"content"`
}
type Request struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
ChatTemplateKwargs map[string]interface{} `json:"chat_template_kwargs,omitempty"`
}
type Response struct {
Choices []struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Created int64 `json:"created"` // unix timestamp; TODO: decode into time.Time
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Object string `json:"object"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
ID string `json:"id"`
Timings struct {
CacheN int `json:"cache_n"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
PromptPerTokenMS float64 `json:"prompt_per_token_ms"`
PromptPerSecond float64 `json:"prompt_per_second"`
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
PredictedPerSecond float64 `json:"predicted_per_second"`
} `json:"timings"`
}
func (p LlamaServerProvider) Health() (err error) {
client := http.Client{
Timeout: 100 * time.Millisecond,
}
res, err := client.Get(p.Host + "health")
if err != nil {
return err
}
if res.StatusCode != 200 {
return fmt.Errorf("llama-server health check returned status %v (%v)", res.StatusCode, res.Status)
}
return nil
}
func (p LlamaServerProvider) Complete(ctx context.Context, prompt string) (response string, err error) {
req := Request{
Messages: []Message{
{
Role: "user",
Content: prompt,
},
},
Model: p.Model,
ChatTemplateKwargs: map[string]interface{}{
"reasoning_effort": "low",
},
}
encReq, err := json.Marshal(req)
if err != nil {
return "", fmt.Errorf("marshaling json: %w", err)
}
res, err := http.Post(p.Host+"/v1/chat/completions", "application/json", bytes.NewReader(encReq))
if err != nil {
return "", err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("reading response body: %w", err)
}
log.Println(string(body))
resData := Response{}
dec := json.NewDecoder(bytes.NewReader(body))
if err := dec.Decode(&resData); err != nil {
return "", fmt.Errorf("decoding response: %w", err)
}
if len(resData.Choices) == 0 {
log.Println(resData)
return "", fmt.Errorf("no choices in response")
}
log.Printf("Generated %v (%v) tokens in %v ms (%v T/s)", resData.Usage.CompletionTokens, resData.Timings.PredictedN, resData.Timings.PredictedMS, resData.Timings.PredictedPerSecond)
return resData.Choices[0].Message.Content, nil
}