llama_ros: llama.cpp for ROS 2
Loading...
Searching...
No Matches
llama.hpp
Go to the documentation of this file.
1// MIT License
2//
3// Copyright (c) 2023 Miguel Ángel González Santamarta
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in
13// all copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23#ifndef LLAMA_ROS__LLAMA_HPP
24#define LLAMA_ROS__LLAMA_HPP
25
26#include <functional>
27#include <memory>
28#include <mutex>
29#include <string>
30#include <unordered_map>
31#include <vector>
32
33#include "common.h"
34#include "json.hpp"
35#include "llama.h"
36#include "sampling.h"
37
39
40namespace llama_ros {
41
42// llama structs
43struct TokenProb {
44 llama_token token;
46};
47
48struct LoRA {
49 int id;
50 std::string path;
51 float scale;
52};
53
55 std::vector<TokenProb> probs;
56 llama_token token;
57};
58
66
68 std::vector<CompletionOutput> completions;
70};
71
73 std::vector<float> embeddings;
74 int32_t n_tokens;
75};
76
77struct Metadata {
78 struct GeneralInfo {
79 std::string architecture;
81 uint32_t alignment;
82
83 std::string name;
84 std::string author;
85 std::string version;
86 std::string organization;
87
88 std::string basename;
89 std::string finetune;
90 std::string description;
91 std::string quantized_by;
92 std::string size_label;
93
94 std::string license;
95 std::string license_name;
96 std::string license_link;
97
98 std::string url;
99 std::string repo_url;
100 std::string doi;
101 std::string uuid;
102
103 std::vector<std::string> tags;
104 std::vector<std::string> languages;
105 std::vector<std::string> datasets;
106 std::string file_type;
107 };
108
110 uint64_t head_count;
112
115
118
119 uint32_t key_length;
120 uint32_t value_length;
121 };
122
132
148
150 std::string model;
151
152 uint32_t bos_token_id;
153 uint32_t eos_token_id;
158
159 std::string chat_template;
160 };
161
165};
166
167using GenerateResponseCallback = std::function<void(struct CompletionOutput)>;
168
169class Llama {
170
171public:
172 Llama(const struct common_params &params, std::string system_prompt = "",
173 bool initial_reset = true);
174 virtual ~Llama();
175
176 std::vector<llama_token> tokenize(const std::string &text, bool add_bos,
177 bool special = false);
178 std::string detokenize(const std::vector<llama_token> &tokens);
179
180 virtual void reset();
181 void cancel();
182
183 std::string format_chat_prompt(std::vector<struct common_chat_msg> chat_msgs,
184 bool add_ass);
185 std::vector<struct LoRA> list_loras();
186 void update_loras(std::vector<struct LoRA> loras);
187
188 std::vector<llama_token>
189 truncate_tokens(const std::vector<llama_token> &tokens, int limit_size,
190 bool add_eos = true);
191 struct EmbeddingsOuput generate_embeddings(const std::string &input_prompt,
192 int normalization = 2);
193 struct EmbeddingsOuput
194 generate_embeddings(const std::vector<llama_token> &tokens,
195 int normalization = 2);
196 float rank_document(const std::string &query, const std::string &document);
197 std::vector<float> rank_documents(const std::string &query,
198 const std::vector<std::string> &documents);
199
200 struct ResponseOutput
201 generate_response(const std::string &input_prompt,
202 struct common_params_sampling sparams,
203 GenerateResponseCallback callbakc = nullptr,
204 std::vector<std::string> stop = {});
205 struct ResponseOutput
206 generate_response(const std::string &input_prompt,
207 GenerateResponseCallback callbakc = nullptr,
208 std::vector<std::string> stop = {});
209
210 const struct llama_context *get_ctx() { return this->ctx; }
211 const struct llama_model *get_model() { return this->model; }
212 const struct llama_vocab *get_vocab() {
213 return llama_model_get_vocab(this->model);
214 }
215
216 int get_n_ctx() { return llama_n_ctx(this->ctx); }
217 int get_n_ctx_train() { return llama_model_n_ctx_train(this->model); }
218 int get_n_embd() { return llama_model_n_embd(this->model); }
219 int get_n_vocab() { return llama_vocab_n_tokens(this->get_vocab()); }
220
221 std::string get_metadata(const std::string &key, size_t size);
222 std::string get_metadata(const std::string &model_name,
223 const std::string &key, size_t size);
224 int get_int_metadata(const std::string &key, size_t size);
225 int get_int_metadata(const std::string &model_name, const std::string &key,
226 size_t size);
227 float get_float_metadata(const std::string &key, size_t size);
228 float get_float_metadata(const std::string &model_name,
229 const std::string &key, size_t size);
230 struct Metadata get_metadata();
231
232 bool is_embedding() { return this->params.embedding; }
233 bool is_reranking() { return this->params.reranking; }
234
235 bool add_bos_token() { return llama_vocab_get_add_bos(this->get_vocab()); }
236 bool is_eog() {
237 return llama_vocab_is_eog(this->get_vocab(),
238 common_sampler_last(this->sampler));
239 }
240 llama_token get_token_eos() { return llama_vocab_eos(this->get_vocab()); }
241 llama_token get_token_bos() { return llama_vocab_bos(this->get_vocab()); }
242 llama_token get_token_sep() { return llama_vocab_sep(this->get_vocab()); }
243
244protected:
245 struct common_params params;
246
247 // model
248 struct common_init_result llama_init;
249 struct llama_context *ctx;
250 struct llama_model *model;
251 std::vector<common_adapter_lora_info> lora_adapters;
252 struct common_sampler *sampler;
253 struct ggml_threadpool *threadpool;
254 struct ggml_threadpool *threadpool_batch;
255
256 // aux
257 std::string system_prompt;
260 std::vector<llama_token> prompt_tokens;
261
262 // eval
263 int32_t n_past;
264 int32_t n_consumed;
265 int32_t ga_i;
266
267 virtual void load_prompt(const std::string &input_prompt, bool add_pfx,
268 bool add_sfx);
269
271 find_stop(std::vector<struct CompletionOutput> completion_result_list,
272 std::vector<std::string> stopping_words);
274 find_stop_word(std::vector<struct CompletionOutput> completion_result_list,
275 std::string stopping_word);
276
277 bool eval_system_prompt();
278 virtual bool eval_prompt();
279 bool eval_prompt(std::vector<llama_token> prompt_tokens);
280 bool eval_token(llama_token token);
281 bool eval(std::vector<llama_token> tokens);
282 virtual bool eval(struct llama_batch batch);
283
284 std::vector<struct TokenProb> get_probs();
285 struct CompletionOutput sample();
286
287private:
288 // lock
289 std::recursive_mutex mutex;
290};
291
292} // namespace llama_ros
293
294#endif
StopType find_stop(std::vector< struct CompletionOutput > completion_result_list, std::vector< std::string > stopping_words)
Definition llama.cpp:798
std::string detokenize(const std::vector< llama_token > &tokens)
Definition llama.cpp:389
int get_n_embd()
Definition llama.hpp:218
int get_n_ctx()
Definition llama.hpp:216
int32_t n_consumed
Definition llama.hpp:264
std::string system_prompt
Definition llama.hpp:257
int32_t ga_i
Definition llama.hpp:265
struct Metadata get_metadata()
Definition llama.cpp:235
llama_token get_token_eos()
Definition llama.hpp:240
std::vector< common_adapter_lora_info > lora_adapters
Definition llama.hpp:251
Llama(const struct common_params &params, std::string system_prompt="", bool initial_reset=true)
Definition llama.cpp:37
struct EmbeddingsOuput generate_embeddings(const std::string &input_prompt, int normalization=2)
Definition llama.cpp:476
int get_int_metadata(const std::string &key, size_t size)
Definition llama.cpp:213
StopType find_stop_word(std::vector< struct CompletionOutput > completion_result_list, std::string stopping_word)
Definition llama.cpp:871
const struct llama_vocab * get_vocab()
Definition llama.hpp:212
struct common_sampler * sampler
Definition llama.hpp:252
std::vector< llama_token > tokenize(const std::string &text, bool add_bos, bool special=false)
Definition llama.cpp:383
int32_t n_past
Definition llama.hpp:263
std::recursive_mutex mutex
Definition llama.hpp:289
void cancel()
Definition llama.cpp:401
bool is_embedding()
Definition llama.hpp:232
const struct llama_model * get_model()
Definition llama.hpp:211
struct ggml_threadpool * threadpool_batch
Definition llama.hpp:254
std::vector< float > rank_documents(const std::string &query, const std::vector< std::string > &documents)
Definition llama.cpp:542
llama_token get_token_bos()
Definition llama.hpp:241
int get_n_vocab()
Definition llama.hpp:219
virtual void reset()
Definition llama.cpp:160
float get_float_metadata(const std::string &key, size_t size)
Definition llama.cpp:224
bool canceled
Definition llama.hpp:258
std::vector< llama_token > truncate_tokens(const std::vector< llama_token > &tokens, int limit_size, bool add_eos=true)
Definition llama.cpp:485
struct llama_model * model
Definition llama.hpp:250
virtual void load_prompt(const std::string &input_prompt, bool add_pfx, bool add_sfx)
Definition llama.cpp:746
std::vector< llama_token > prompt_tokens
Definition llama.hpp:260
struct common_init_result llama_init
Definition llama.hpp:248
bool eval_system_prompt()
Definition llama.cpp:905
bool is_reranking()
Definition llama.hpp:233
struct common_params params
Definition llama.hpp:245
virtual bool eval_prompt()
Definition llama.cpp:920
std::vector< struct LoRA > list_loras()
Definition llama.cpp:576
bool add_bos_token()
Definition llama.hpp:235
bool eval(std::vector< llama_token > tokens)
Definition llama.cpp:952
bool eval_token(llama_token token)
Definition llama.cpp:948
virtual ~Llama()
Definition llama.cpp:134
llama_utils::Spinner spinner
Definition llama.hpp:259
void update_loras(std::vector< struct LoRA > loras)
Definition llama.cpp:596
struct ResponseOutput generate_response(const std::string &input_prompt, struct common_params_sampling sparams, GenerateResponseCallback callbakc=nullptr, std::vector< std::string > stop={})
Definition llama.cpp:642
llama_token get_token_sep()
Definition llama.hpp:242
const struct llama_context * get_ctx()
Definition llama.hpp:210
int get_n_ctx_train()
Definition llama.hpp:217
std::vector< struct TokenProb > get_probs()
Definition llama.cpp:1043
struct llama_context * ctx
Definition llama.hpp:249
struct CompletionOutput sample()
Definition llama.cpp:1060
bool is_eog()
Definition llama.hpp:236
std::string format_chat_prompt(std::vector< struct common_chat_msg > chat_msgs, bool add_ass)
Definition llama.cpp:566
struct ggml_threadpool * threadpool
Definition llama.hpp:253
float rank_document(const std::string &query, const std::string &document)
Definition llama.cpp:509
Definition spinner.hpp:31
Definition llama.hpp:40
std::function< void(struct CompletionOutput)> GenerateResponseCallback
Definition llama.hpp:167
StopType
Definition llama.hpp:59
@ CANCEL
Definition llama.hpp:63
@ PARTIAL_STOP
Definition llama.hpp:62
@ ABORT
Definition llama.hpp:64
@ NO_STOP
Definition llama.hpp:60
@ FULL_STOP
Definition llama.hpp:61
Definition llama.hpp:54
llama_token token
Definition llama.hpp:56
std::vector< TokenProb > probs
Definition llama.hpp:55
Definition llama.hpp:72
std::vector< float > embeddings
Definition llama.hpp:73
int32_t n_tokens
Definition llama.hpp:74
Definition llama.hpp:48
float scale
Definition llama.hpp:51
std::string path
Definition llama.hpp:50
int id
Definition llama.hpp:49
Definition llama.hpp:109
float layer_norm_rms_epsilon
Definition llama.hpp:117
uint32_t value_length
Definition llama.hpp:120
float max_alibi_bias
Definition llama.hpp:113
uint64_t head_count_kv
Definition llama.hpp:111
float clamp_kqv
Definition llama.hpp:114
uint32_t key_length
Definition llama.hpp:119
float layer_norm_epsilon
Definition llama.hpp:116
uint64_t head_count
Definition llama.hpp:110
Definition llama.hpp:78
std::vector< std::string > tags
Definition llama.hpp:103
std::vector< std::string > languages
Definition llama.hpp:104
std::string size_label
Definition llama.hpp:92
std::string file_type
Definition llama.hpp:106
std::string finetune
Definition llama.hpp:89
std::string architecture
Definition llama.hpp:79
uint32_t quantization_version
Definition llama.hpp:80
std::string url
Definition llama.hpp:98
std::string repo_url
Definition llama.hpp:99
std::string basename
Definition llama.hpp:88
std::string doi
Definition llama.hpp:100
std::string organization
Definition llama.hpp:86
std::vector< std::string > datasets
Definition llama.hpp:105
std::string version
Definition llama.hpp:85
std::string quantized_by
Definition llama.hpp:91
std::string license_link
Definition llama.hpp:96
std::string license_name
Definition llama.hpp:95
std::string uuid
Definition llama.hpp:101
uint32_t alignment
Definition llama.hpp:81
std::string author
Definition llama.hpp:84
std::string license
Definition llama.hpp:94
std::string description
Definition llama.hpp:90
std::string name
Definition llama.hpp:83
Definition llama.hpp:133
uint64_t embedding_length
Definition llama.hpp:135
RoPEInfo rope
Definition llama.hpp:146
uint64_t context_length
Definition llama.hpp:134
uint32_t expert_count
Definition llama.hpp:142
bool use_parallel_residual
Definition llama.hpp:139
uint32_t expert_used_count
Definition llama.hpp:143
AttentionInfo attention
Definition llama.hpp:145
uint64_t block_count
Definition llama.hpp:136
std::string tensor_data_layout
Definition llama.hpp:140
uint64_t feed_forward_length
Definition llama.hpp:137
Definition llama.hpp:123
std::string scaling_type
Definition llama.hpp:127
float freq_base
Definition llama.hpp:125
float scaling_factor
Definition llama.hpp:128
uint32_t scaling_original_context_length
Definition llama.hpp:129
uint64_t dimension_count
Definition llama.hpp:124
bool scaling_finetuned
Definition llama.hpp:130
Definition llama.hpp:149
uint32_t eos_token_id
Definition llama.hpp:153
bool add_bos_token
Definition llama.hpp:157
uint32_t bos_token_id
Definition llama.hpp:152
uint32_t padding_token_id
Definition llama.hpp:155
std::string model
Definition llama.hpp:150
uint32_t separator_token_id
Definition llama.hpp:156
std::string chat_template
Definition llama.hpp:159
uint32_t unknown_token_id
Definition llama.hpp:154
Definition llama.hpp:77
ModelInfo model
Definition llama.hpp:163
GeneralInfo general
Definition llama.hpp:162
TokenizerInfo tokenizer
Definition llama.hpp:164
Definition llama.hpp:67
std::vector< CompletionOutput > completions
Definition llama.hpp:68
StopType stop
Definition llama.hpp:69
Definition llama.hpp:43
llama_token token
Definition llama.hpp:44
float probability
Definition llama.hpp:45