package yzma import ( "context" "fmt" "github.com/hybridgroup/yzma/pkg/llama" ) func Complete(ctx context.Context, prompt string) (string, error) { // TODO: derive from something? libPath := "/nix/store/jml3vhvay9yy94qj8bmmhbf2dhx6q2n1-llama-cpp-7356/lib" modelFile := "./SmolLM-135M.Q2_K.gguf" responseLength := int32(128) llama.Load(libPath) llama.LogSet(llama.LogSilent()) llama.Init() model, _ := llama.ModelLoadFromFile(modelFile, llama.ModelDefaultParams()) lctx, _ := llama.InitFromModel(model, llama.ContextDefaultParams()) vocab := llama.ModelGetVocab(model) // get tokens from the prompt tokens := llama.Tokenize(vocab, prompt, true, false) batch := llama.BatchGetOne(tokens) sampler := llama.SamplerChainInit(llama.SamplerChainDefaultParams()) llama.SamplerChainAdd(sampler, llama.SamplerInitGreedy()) for pos := int32(0); pos < responseLength; pos += batch.NTokens { llama.Decode(lctx, batch) token := llama.SamplerSample(sampler, lctx, -1) if llama.VocabIsEOG(vocab, token) { fmt.Println() break } buf := make([]byte, 36) len := llama.TokenToPiece(vocab, token, buf, 0, true) fmt.Print(string(buf[:len])) batch = llama.BatchGetOne([]llama.Token{token}) } fmt.Println() return "", nil }