Initial working copy

This commit is contained in:
Chandler Swift 2025-12-19 23:14:28 -06:00
parent 2a335176e6
commit 2c876cef42
19 changed files with 783 additions and 126 deletions

View file

@ -1,4 +1,4 @@
package llama
package gollamacpp
import (
"flag"
@ -9,7 +9,14 @@ import (
"github.com/go-skynet/go-llama.cpp"
)
func main() {
var (
threads = 4
tokens = 128
gpulayers = 0
seed = -1
)
func Run() {
var model string
flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError)

View file

@ -0,0 +1,116 @@
package llamaserver
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
)
type LlamaServerProvider struct {
Host string // http://localhost:8080/
Model string
}
type Message struct {
Role string `json:"role"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Content string `json:"content"`
}
type Request struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
ChatTemplateKwargs map[string]interface{} `json:"chat_template_kwargs,omitempty"`
}
type Response struct {
Choices []struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Created int64 `json:"created"` // unix timestamp; TODO: decode into time.Time
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Object string `json:"object"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
ID string `json:"id"`
Timings struct {
CacheN int `json:"cache_n"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
PromptPerTokenMS float64 `json:"prompt_per_token_ms"`
PromptPerSecond float64 `json:"prompt_per_second"`
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
PredictedPerSecond float64 `json:"predicted_per_second"`
} `json:"timings"`
}
func (p LlamaServerProvider) Health() (err error) {
client := http.Client{
Timeout: 100 * time.Millisecond,
}
res, err := client.Get(p.Host + "health")
if err != nil {
return err
}
if res.StatusCode != 200 {
return fmt.Errorf("llama-server health check returned status %v (%v)", res.StatusCode, res.Status)
}
return nil
}
func (p LlamaServerProvider) Complete(ctx context.Context, prompt string) (response string, err error) {
req := Request{
Messages: []Message{
{
Role: "user",
Content: prompt,
},
},
Model: p.Model,
ChatTemplateKwargs: map[string]interface{}{
"reasoning_effort": "low",
},
}
encReq, err := json.Marshal(req)
if err != nil {
return "", fmt.Errorf("marshaling json: %w", err)
}
res, err := http.Post(p.Host+"/v1/chat/completions", "application/json", bytes.NewReader(encReq))
if err != nil {
return "", err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("reading response body: %w", err)
}
log.Println(string(body))
resData := Response{}
dec := json.NewDecoder(bytes.NewReader(body))
if err := dec.Decode(&resData); err != nil {
return "", fmt.Errorf("decoding response: %w", err)
}
if len(resData.Choices) == 0 {
log.Println(resData)
return "", fmt.Errorf("no choices in response")
}
log.Printf("Generated %v (%v) tokens in %v ms (%v T/s)", resData.Usage.CompletionTokens, resData.Timings.PredictedN, resData.Timings.PredictedMS, resData.Timings.PredictedPerSecond)
return resData.Choices[0].Message.Content, nil
}

View file

View file

@ -0,0 +1,17 @@
curl 'http://localhost:8080/v1/chat/completions' \
-X POST \
-H 'User-Agent: Mozilla/5.0 (X11; Linux x86_64; rv:146.0) Gecko/20100101 Firefox/146.0' \
-H 'Accept: */*' \
-H 'Accept-Language: en-US,en;q=0.5' \
-H 'Accept-Encoding: gzip, deflate, br, zstd' \
-H 'Referer: http://localhost:8080/' \
-H 'Content-Type: application/json' \
-H 'Origin: http://localhost:8080' \
-H 'Connection: keep-alive' \
-H 'Sec-Fetch-Dest: empty' \
-H 'Sec-Fetch-Mode: cors' \
-H 'Sec-Fetch-Site: same-origin' \
-H 'Priority: u=4' \
-H 'Pragma: no-cache' \
-H 'Cache-Control: no-cache' \
--data @req.json

View file

@ -0,0 +1,117 @@
#include "llama.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
void null_log_callback(enum ggml_log_level level, const char *message, void *user_data) {}
int64_t time_us(void) {
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
}
void chandlerscustomllama(char* prompt) {
int n_predict = 100;
printf("Prompt: %s\n", prompt);
struct llama_model_params model_params = llama_model_default_params();
// printf("model_params.n_gpu_layers: %d\n", model_params.n_gpu_layers);
llama_log_set(null_log_callback, NULL); // Disable logging
struct llama_model *model = llama_model_load_from_file("/home/chandler/llms/gpt-oss-20b-Q4_K_M.gguf", model_params);
if (model == NULL) {
fprintf(stderr, "Failed to load model\n");
return;
}
const struct llama_vocab * vocab = llama_model_get_vocab(model);
const int n_prompt = -llama_tokenize(vocab, prompt, strlen(prompt), NULL, 0, true, true);
llama_token * prompt_tokens = malloc(sizeof(llama_token) * n_prompt);
if (llama_tokenize(vocab, prompt, strlen(prompt), prompt_tokens, n_prompt, true, true) < 0) {
fprintf(stderr, "%s: error: failed to tokenize prompt\n", __func__);
return;
}
struct llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_prompt + n_predict - 1;
ctx_params.n_batch = n_prompt;
ctx_params.no_perf = false; // TODO: true
struct llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create llama_context\n", __func__);
return;
}
// initialize the sampler
struct llama_sampler_chain_params sparams = llama_sampler_chain_default_params();
struct llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
// prepare a batch for the prompt
struct llama_batch batch = llama_batch_get_one(prompt_tokens, n_prompt);
if (llama_model_has_encoder(model)) {
if (llama_encode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
}
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = llama_vocab_bos(vocab);
}
batch = llama_batch_get_one(&decoder_start_token_id, 1);
}
int64_t start = time_us();
int n_decode = 0;
llama_token new_token_id;
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
}
n_pos += batch.n_tokens;
// sample the next token
{
new_token_id = llama_sampler_sample(smpl, ctx, -1);
// is it an end of generation?
if (llama_vocab_is_eog(vocab, new_token_id)) {
break;
}
char buf[128]; // TODO: how do we know that this is enough?
int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true); // TODO: do I want special tokens?
if (n < 0) {
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
return;
}
buf[n] = 0;
printf("%s", buf); // TODO: null terminator?
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
n_decode += 1;
}
}
int64_t end = time_us();
fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (end - start) / 1000000.0f, n_decode / ((end - start) / 1000000.0f));
llama_sampler_free(smpl);
llama_free(ctx);
llama_model_free(model);
}

View file

@ -0,0 +1,15 @@
package llamacpp
/*
#include "llamacpp.h"
#include <stdlib.h>
*/
import "C"
import "unsafe"
func Run() {
prompt := C.CString("Here is a very short story about a brave knight. One day, the knight")
C.chandlerscustomllama(prompt)
C.free(unsafe.Pointer(prompt))
}

View file

@ -0,0 +1 @@
void chandlerscustomllama(char* prompt);

View file

@ -5,10 +5,17 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
)
type OpenRouterProvider struct {
// Endpoint string
Token string
Model string // "openai/gpt-oss-20b:free"
}
type Message struct {
Role string `json:"role"` // "system" | "user" | "assistant"
Content string `json:"content"`
@ -21,6 +28,10 @@ type ChatCompletionRequest struct {
MaxTokens *int `json:"max_tokens,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string; keep flexible
Provider struct {
Sort string `json:"sort,omitempty"`
} `json:"provider,omitempty"`
}
type ChatCompletionResponse struct {
@ -35,60 +46,52 @@ type ChatCompletionResponse struct {
} `json:"choices"`
}
func ChatCompletion(ctx context.Context, req ChatCompletionRequest) (ChatCompletionResponse, error) {
httpClient := http.Client{Timeout: 10 * time.Second}
func (p OpenRouterProvider) Complete(ctx context.Context, prompt string) (string, error) {
req := ChatCompletionRequest{
Model: p.Model,
Messages: []Message{
{
Role: "user",
Content: prompt,
},
},
Provider: struct {
Sort string `json:"sort,omitempty"`
}{
Sort: "throughput",
},
}
httpClient := http.Client{Timeout: 10 * time.Second}
body, err := json.Marshal(req)
if err != nil {
return ChatCompletionResponse{}, err
return "", err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", "https://openrouter.ai/api/v1/chat/completions", bytes.NewReader(body))
if err != nil {
return ChatCompletionResponse{}, err
return "", err
}
httpReq.Header.Set("Authorization", "Bearer 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")
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %v", p.Token))
httpReq.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(httpReq)
if err != nil {
return ChatCompletionResponse{}, err
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// You may want to decode OpenRouter's error JSON here for better messages
return ChatCompletionResponse{}, fmt.Errorf("openrouter status %d", resp.StatusCode)
return "", fmt.Errorf("openrouter status %d", resp.StatusCode)
}
var out ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return ChatCompletionResponse{}, err
}
return out, nil
}
func doLLM() {
req := ChatCompletionRequest{
Model: "openai/gpt-oss-20b:free",
Messages: []Message{
{
Role: "user",
Content: "Write a short poem about software development.",
},
},
}
ctx := context.Background()
resp, err := client.ChatCompletion(ctx, req)
if err != nil {
fmt.Println("Error:", err)
return
}
for _, choice := range resp.Choices {
fmt.Printf("Response: %s\n", choice.Message.Content)
log.Println(err)
log.Println(out)
return "", err
}
return out.Choices[0].Message.Content, nil
}

View file

@ -49,4 +49,6 @@ func Complete(ctx context.Context, prompt string) (string, error) {
}
fmt.Println()
return "", nil
}