116 lines
3.2 KiB
Go
116 lines
3.2 KiB
Go
package llamaserver
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
type LlamaServerProvider struct {
|
|
Host string // http://localhost:8080/
|
|
Model string
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type Request struct {
|
|
Messages []Message `json:"messages"`
|
|
Model string `json:"model"`
|
|
ChatTemplateKwargs map[string]interface{} `json:"chat_template_kwargs,omitempty"`
|
|
}
|
|
|
|
type Response struct {
|
|
Choices []struct {
|
|
Index int `json:"index"`
|
|
Message Message `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Created int64 `json:"created"` // unix timestamp; TODO: decode into time.Time
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Object string `json:"object"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
ID string `json:"id"`
|
|
Timings struct {
|
|
CacheN int `json:"cache_n"`
|
|
PromptN int `json:"prompt_n"`
|
|
PromptMS float64 `json:"prompt_ms"`
|
|
PromptPerTokenMS float64 `json:"prompt_per_token_ms"`
|
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
|
PredictedN int `json:"predicted_n"`
|
|
PredictedMS float64 `json:"predicted_ms"`
|
|
PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
|
|
PredictedPerSecond float64 `json:"predicted_per_second"`
|
|
} `json:"timings"`
|
|
}
|
|
|
|
func (p LlamaServerProvider) Health() (err error) {
|
|
client := http.Client{
|
|
Timeout: 100 * time.Millisecond,
|
|
}
|
|
res, err := client.Get(p.Host + "health")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if res.StatusCode != 200 {
|
|
return fmt.Errorf("llama-server health check returned status %v (%v)", res.StatusCode, res.Status)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p LlamaServerProvider) Complete(ctx context.Context, prompt string) (response string, err error) {
|
|
req := Request{
|
|
Messages: []Message{
|
|
{
|
|
Role: "user",
|
|
Content: prompt,
|
|
},
|
|
},
|
|
Model: p.Model,
|
|
ChatTemplateKwargs: map[string]interface{}{
|
|
"reasoning_effort": "low",
|
|
},
|
|
}
|
|
encReq, err := json.Marshal(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshaling json: %w", err)
|
|
}
|
|
res, err := http.Post(p.Host+"/v1/chat/completions", "application/json", bytes.NewReader(encReq))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("reading response body: %w", err)
|
|
}
|
|
log.Println(string(body))
|
|
|
|
resData := Response{}
|
|
dec := json.NewDecoder(bytes.NewReader(body))
|
|
if err := dec.Decode(&resData); err != nil {
|
|
return "", fmt.Errorf("decoding response: %w", err)
|
|
}
|
|
if len(resData.Choices) == 0 {
|
|
log.Println(resData)
|
|
return "", fmt.Errorf("no choices in response")
|
|
}
|
|
|
|
log.Printf("Generated %v (%v) tokens in %v ms (%v T/s)", resData.Usage.CompletionTokens, resData.Timings.PredictedN, resData.Timings.PredictedMS, resData.Timings.PredictedPerSecond)
|
|
|
|
return resData.Choices[0].Message.Content, nil
|
|
}
|