54 lines
1.2 KiB
Go
54 lines
1.2 KiB
Go
package yzma
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/hybridgroup/yzma/pkg/llama"
|
|
)
|
|
|
|
func Complete(ctx context.Context, prompt string) (string, error) {
|
|
|
|
// TODO: derive from something?
|
|
libPath := "/nix/store/jml3vhvay9yy94qj8bmmhbf2dhx6q2n1-llama-cpp-7356/lib"
|
|
modelFile := "./SmolLM-135M.Q2_K.gguf"
|
|
responseLength := int32(128)
|
|
|
|
llama.Load(libPath)
|
|
llama.LogSet(llama.LogSilent())
|
|
llama.Init()
|
|
|
|
model, _ := llama.ModelLoadFromFile(modelFile, llama.ModelDefaultParams())
|
|
lctx, _ := llama.InitFromModel(model, llama.ContextDefaultParams())
|
|
|
|
vocab := llama.ModelGetVocab(model)
|
|
|
|
// get tokens from the prompt
|
|
tokens := llama.Tokenize(vocab, prompt, true, false)
|
|
|
|
batch := llama.BatchGetOne(tokens)
|
|
|
|
sampler := llama.SamplerChainInit(llama.SamplerChainDefaultParams())
|
|
llama.SamplerChainAdd(sampler, llama.SamplerInitGreedy())
|
|
|
|
for pos := int32(0); pos < responseLength; pos += batch.NTokens {
|
|
llama.Decode(lctx, batch)
|
|
token := llama.SamplerSample(sampler, lctx, -1)
|
|
|
|
if llama.VocabIsEOG(vocab, token) {
|
|
fmt.Println()
|
|
break
|
|
}
|
|
|
|
buf := make([]byte, 36)
|
|
len := llama.TokenToPiece(vocab, token, buf, 0, true)
|
|
|
|
fmt.Print(string(buf[:len]))
|
|
|
|
batch = llama.BatchGetOne([]llama.Token{token})
|
|
}
|
|
|
|
fmt.Println()
|
|
|
|
return "", nil
|
|
}
|