ref: 92cfd764bf08dcd464f2edf2992ca67cc847b1e9
parent: 6ab15be07c62e3d31c916084a377ff3b6ba7d0ad
author: eli <eli@owl>
date: Fri Jul 11 20:28:22 EDT 2025
github.com/karpathy/llama2.c
--- /dev/null
+++ b/llama2.c
@@ -1,0 +1,1093 @@
+/* Inference for Llama-2 Transformer model in pure C */
+
+#include <u.h>
+#include <libc.h>
+#include <stdio.h>
+//#include <stdlib.h>
+#include <ctype.h>
+//#include <time.h>
+//#include <math.h>
+//#include <string.h>
+//#include <fcntl.h>
+//#if defined _WIN32
+// #include "win.h"
+//#else
+// #include <unistd.h>
+// #include <sys/mman.h>
+//#endif
+
+#define int8_t char
+#define ssize_t uvlong
+#define size_t ulong
+#define EXIT_FAILURE "exits"
+#define exit exits
+#define O_RDONLY OREAD
+#define sqrtf sqrt
+#define expf exp
+#define powf pow
+#define cosf cos
+#define sinf sin
+
+// ----------------------------------------------------------------------------
+// Transformer model
+
+typedef struct {
+ int dim; // transformer dimension
+ int hidden_dim; // for ffn layers
+ int n_layers; // number of layers
+ int n_heads; // number of query heads
+ int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
+ int vocab_size; // vocabulary size, usually 256 (byte-level)
+ int seq_len; // max sequence length
+} Config;
+
+#define SIZEOFCONFIG 24
+
+vlong read8(int fd) {
+ vlong result;
+ int buf[8];
+ int i;
+
+ if (read(fd, buf, 8) != 8)
+ exit(EXIT_FAILURE);
+
+ result = 0;
+
+ for (i = 0; i < 8; i++)
+ result |= ((vlong)buf[i] << (i*8));
+
+ return result;
+}
+
+int read4(int fd) {
+ typedef union _result {
+ char buf[4];
+ int i;
+ } result;
+
+ result r;
+
+ if (read(fd, r.buf, 4) != 4)
+ exit(EXIT_FAILURE);
+
+ return r.i;
+}
+
+typedef struct {
+ // token embedding table
+ float* token_embedding_table; // (vocab_size, dim)
+ // weights for rmsnorms
+ float* rms_att_weight; // (layer, dim) rmsnorm weights
+ float* rms_ffn_weight; // (layer, dim)
+ // weights for matmuls. note dim == n_heads * head_size
+ float* wq; // (layer, dim, n_heads * head_size)
+ float* wk; // (layer, dim, n_kv_heads * head_size)
+ float* wv; // (layer, dim, n_kv_heads * head_size)
+ float* wo; // (layer, n_heads * head_size, dim)
+ // weights for ffn
+ float* w1; // (layer, hidden_dim, dim)
+ float* w2; // (layer, dim, hidden_dim)
+ float* w3; // (layer, hidden_dim, dim)
+ // final rmsnorm
+ float* rms_final_weight; // (dim,)
+ // (optional) classifier weights for the logits, on the last layer
+ float* wcls;
+} TransformerWeights;
+
+#define SIZEOFTRANSFORMERWEIGHTS (12*sizeof(float*))
+
+typedef struct {
+ // current wave of activations
+ float *x; // activation at current time stamp (dim,)
+ float *xb; // same, but inside a residual branch (dim,)
+ float *xb2; // an additional buffer just for convenience (dim,)
+ float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
+ float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
+ float *q; // query (dim,)
+ float *k; // key (dim,)
+ float *v; // value (dim,)
+ float *att; // buffer for scores/attention values (n_heads, seq_len)
+ float *logits; // output logits
+ // kv cache
+ float* key_cache; // (layer, seq_len, dim)
+ float* value_cache; // (layer, seq_len, dim)
+} RunState;
+
+#define SIZEOFRUNSTATE (12*sizeof(float*))
+
+typedef struct {
+ Config config; // the hyperparameters of the architecture (the blueprint)
+ TransformerWeights weights; // the weights of the model
+ RunState state; // buffers for the "wave" of activations in the forward pass
+ // some more state needed to properly clean up the memory mapping (sigh)
+ int fd; // file descriptor for memory mapping
+ float* data; // memory mapped data pointer
+ ssize_t file_size; // size of the checkpoint file in bytes
+} Transformer;
+
+#define SIZEOFTRANSFORMER (SIZEOFCONFIG+SIZEOFTRANSFORMERWEIGHTS+SIZEOFRUNSTATE+4+sizeof(float*)+sizeof(ssize_t))
+
+void malloc_run_state(RunState* s, Config* p) {
+ // we calloc instead of malloc to keep valgrind happy
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
+ s->x = calloc(p->dim, sizeof(float));
+ s->xb = calloc(p->dim, sizeof(float));
+ s->xb2 = calloc(p->dim, sizeof(float));
+ s->hb = calloc(p->hidden_dim, sizeof(float));
+ s->hb2 = calloc(p->hidden_dim, sizeof(float));
+ s->q = calloc(p->dim, sizeof(float));
+ s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
+ s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
+ s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
+ s->logits = calloc(p->vocab_size, sizeof(float));
+ // ensure all mallocs went fine
+ if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
+ || !s->key_cache || !s->value_cache || !s->att || !s->logits) {
+ fprintf(stderr, "malloc failed!\n");
+ exit(EXIT_FAILURE);
+ }
+}
+
+void free_run_state(RunState* s) {
+ free(s->x);
+ free(s->xb);
+ free(s->xb2);
+ free(s->hb);
+ free(s->hb2);
+ free(s->q);
+ free(s->att);
+ free(s->logits);
+ free(s->key_cache);
+ free(s->value_cache);
+}
+
+void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
+ int head_size = p->dim / p->n_heads;
+ // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
+ unsigned long long n_layers = p->n_layers;
+ w->token_embedding_table = ptr;
+ ptr = &ptr[p->vocab_size * p->dim];
+ w->rms_att_weight = ptr;
+ ptr += n_layers * p->dim;
+ w->wq = ptr;
+ ptr += n_layers * p->dim * (p->n_heads * head_size);
+ w->wk = ptr;
+ ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
+ w->wv = ptr;
+ ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
+ w->wo = ptr;
+ ptr += n_layers * (p->n_heads * head_size) * p->dim;
+ w->rms_ffn_weight = ptr;
+ ptr += n_layers * p->dim;
+ w->w1 = ptr;
+ ptr += n_layers * p->dim * p->hidden_dim;
+ w->w2 = ptr;
+ ptr += n_layers * p->hidden_dim * p->dim;
+ w->w3 = ptr;
+ ptr += n_layers * p->dim * p->hidden_dim;
+ w->rms_final_weight = ptr;
+ ptr += p->dim;
+ ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
+ ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)
+ w->wcls = shared_weights ? w->token_embedding_table : ptr;
+}
+
+void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
+ int* fd, float** data, ssize_t* file_size) {
+ uvlong length;
+ uvlong offset;
+ int ret;
+ int fdt;
+ Dir *dstat;
+ fdt = open(checkpoint, OREAD);
+ if (fdt < 3) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
+/*
+typedef struct {
+ int dim; // transformer dimension
+ int hidden_dim; // for ffn layers
+ int n_layers; // number of layers
+ int n_heads; // number of query heads
+ int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
+ int vocab_size; // vocabulary size, usually 256 (byte-level)
+ int seq_len; // max sequence length
+} Config;
+*/
+ // read in the config header
+ config->dim = read4(fdt);
+ config->hidden_dim = read4(fdt);
+ config->n_layers = read4(fdt);
+ config->n_heads = read4(fdt);
+ config->n_kv_heads = read4(fdt);
+ config->vocab_size = read4(fdt);
+ config->seq_len = read4(fdt);
+
+ // negative vocab size is hacky way of signaling unshared weights. bit yikes.
+ int shared_weights = config->vocab_size > 0 ? 1 : 0;
+ config->vocab_size = abs(config->vocab_size);
+ // figure out the file size
+ // fseek(file, 0, SEEK_END); // move file pointer to end of file
+ // *file_size = ftell(file); // get the file size, in bytes
+ // fclose(file);
+ close(fdt);
+ dstat = dirstat(checkpoint);
+ *file_size = dstat->length;
+ free(dstat);
+ // memory map the Transformer weights into the data pointer
+ *fd = open(checkpoint, O_RDONLY); // open in read only mode
+ if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
+// *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
+// if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
+ *data = malloc(*file_size);
+ length = *file_size;
+ offset = 0;
+ while ((ret = read(*fd, (char*)(*data) + offset, length)) > 0) {
+ length -= ret;
+ offset += ret;
+ }
+ close(*fd);
+ *fd = open(checkpoint, OREAD);
+ float* weights_ptr = (float*)((char*)(*data) + 28);
+ memory_map_weights(weights, config, weights_ptr, shared_weights);
+}
+
+void build_transformer(Transformer *t, char* checkpoint_path) {
+ // read in the Config and the Weights from the checkpoint
+ read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
+ // allocate the RunState buffers
+ malloc_run_state(&t->state, &t->config);
+}
+
+void free_transformer(Transformer* t) {
+ // close the memory mapping
+// if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
+ if (t->fd != -1) { close(t->fd); }
+ // free the RunState buffers
+ free_run_state(&t->state);
+}
+
+// ----------------------------------------------------------------------------
+// neural net blocks; the dynamics of the Transformer
+
+void rmsnorm(float* o, float* x, float* weight, int size) {
+ // calculate sum of squares
+ float ss = 0.0f;
+ for (int j = 0; j < size; j++) {
+ ss += x[j] * x[j];
+ }
+ ss /= size;
+ ss += 1e-5f;
+ ss = 1.0f / sqrtf(ss);
+ // normalize and scale
+ for (int j = 0; j < size; j++) {
+ o[j] = weight[j] * (ss * x[j]);
+ }
+}
+
+void softmax(float* x, int size) {
+ // find max value (for numerical stability)
+ float max_val = x[0];
+ for (int i = 1; i < size; i++) {
+ if (x[i] > max_val) {
+ max_val = x[i];
+ }
+ }
+ // exp and sum
+ float sum = 0.0f;
+ for (int i = 0; i < size; i++) {
+ x[i] = expf(x[i] - max_val);
+ sum += x[i];
+ }
+ // normalize
+ for (int i = 0; i < size; i++) {
+ x[i] /= sum;
+ }
+}
+
+void matmul(float* xout, float* x, float* w, int n, int d) {
+ // W (d,n) @ x (n,) -> xout (d,)
+ // by far the most amount of time is spent inside this little function
+ int i;
+ #pragma omp parallel for private(i)
+ for (i = 0; i < d; i++) {
+ float val = 0.0f;
+ for (int j = 0; j < n; j++) {
+ val += w[i * n + j] * x[j];
+ }
+ xout[i] = val;
+ }
+}
+
+float* forward(Transformer* transformer, int token, int pos) {
+ // a few convenience variables
+ Config* p = &transformer->config;
+ TransformerWeights* w = &transformer->weights;
+ RunState* s = &transformer->state;
+ float *x = s->x;
+ int dim = p->dim;
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
+ int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
+ int hidden_dim = p->hidden_dim;
+ int head_size = dim / p->n_heads;
+
+ // copy the token embedding into x
+ float* content_row = &w->token_embedding_table[token * dim];
+ memcpy(x, content_row, dim*sizeof(float));
+
+ // forward all the layers
+ for(unsigned long long l = 0; l < p->n_layers; l++) {
+
+ // attention rmsnorm
+ rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
+
+ // key and value point to the kv cache
+ int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
+ s->k = s->key_cache + loff + pos * kv_dim;
+ s->v = s->value_cache + loff + pos * kv_dim;
+
+ // qkv matmuls for this position
+ matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
+ matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
+ matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
+
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
+ for (int i = 0; i < dim; i+=2) {
+ int head_dim = i % head_size;
+ float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
+ float val = pos * freq;
+ float fcr = cosf(val);
+ float fci = sinf(val);
+ int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
+ for (int v = 0; v < rotn; v++) {
+ float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
+ float v0 = vec[i];
+ float v1 = vec[i+1];
+ vec[i] = v0 * fcr - v1 * fci;
+ vec[i+1] = v0 * fci + v1 * fcr;
+ }
+ }
+
+ // multihead attention. iterate over all heads
+ int h;
+ #pragma omp parallel for private(h)
+ for (h = 0; h < p->n_heads; h++) {
+ // get the query vector for this head
+ float* q = s->q + h * head_size;
+ // attention scores for this head
+ float* att = s->att + h * p->seq_len;
+ // iterate over all timesteps, including the current one
+ for (int t = 0; t <= pos; t++) {
+ // get the key vector for this head and at this timestep
+ float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
+ // calculate the attention score as the dot product of q and k
+ float score = 0.0f;
+ for (int i = 0; i < head_size; i++) {
+ score += q[i] * k[i];
+ }
+ score /= sqrtf(head_size);
+ // save the score to the attention buffer
+ att[t] = score;
+ }
+
+ // softmax the scores to get attention weights, from 0..pos inclusively
+ softmax(att, pos + 1);
+
+ // weighted sum of the values, store back into xb
+ float* xb = s->xb + h * head_size;
+ memset(xb, 0, head_size * 4);
+ for (int t = 0; t <= pos; t++) {
+ // get the value vector for this head and at this timestep
+ float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
+ // get the attention weight for this timestep
+ float a = att[t];
+ // accumulate the weighted value into xb
+ for (int i = 0; i < head_size; i++) {
+ xb[i] += a * v[i];
+ }
+ }
+ }
+
+ // final matmul to get the output of the attention
+ matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
+
+ // residual connection back into x
+ for (int i = 0; i < dim; i++) {
+ x[i] += s->xb2[i];
+ }
+
+ // ffn rmsnorm
+ rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
+
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
+ // first calculate self.w1(x) and self.w3(x)
+ matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
+ matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
+
+ // SwiGLU non-linearity
+ for (int i = 0; i < hidden_dim; i++) {
+ float val = s->hb[i];
+ // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
+ val *= (1.0f / (1.0f + expf(-val)));
+ // elementwise multiply with w3(x)
+ val *= s->hb2[i];
+ s->hb[i] = val;
+ }
+
+ // final matmul to get the output of the ffn
+ matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
+
+ // residual connection
+ for (int i = 0; i < dim; i++) {
+ x[i] += s->xb[i];
+ }
+ }
+
+ // final rmsnorm
+ rmsnorm(x, x, w->rms_final_weight, dim);
+
+ // classifier into logits
+ matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
+ return s->logits;
+}
+
+// ----------------------------------------------------------------------------
+// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
+
+typedef struct {
+ char *str;
+ int id;
+} TokenIndex;
+
+#define SIZEOFTOKENINDEX (sizeof(char*)+4)
+
+typedef struct {
+ char** vocab;
+ float* vocab_scores;
+ TokenIndex *sorted_vocab;
+ int vocab_size;
+ unsigned int max_token_length;
+ unsigned char byte_pieces[512]; // stores all single-byte strings
+} Tokenizer;
+
+int compare_tokens(const void *a, const void *b) {
+ return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
+}
+
+void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
+ // i should have written the vocab_size into the tokenizer file... sigh
+ t->vocab_size = vocab_size;
+ // malloc space to hold the scores and the strings
+ t->vocab = (char**)malloc(vocab_size * sizeof(char*));
+ t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
+ t->sorted_vocab = NULL; // initialized lazily
+ for (int i = 0; i < 256; i++) {
+ t->byte_pieces[i * 2] = (unsigned char)i;
+ t->byte_pieces[i * 2 + 1] = '\0';
+ }
+ // read in the file
+// FILE *file = fopen(tokenizer_path, "rb");
+ int fd = open(tokenizer_path, OREAD);
+ if (fd < 3) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
+ if (read(fd, &t->max_token_length, sizeof(int)) != sizeof(int)) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+// fprint(2, "max_token_length: %d\n", t->max_token_length);
+ int len;
+ for (int i = 0; i < vocab_size; i++) {
+ if (read(fd, t->vocab_scores + i, sizeof(float)) != sizeof(float)) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
+ if (read(fd, &len, sizeof(int)) != sizeof(int)) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+ t->vocab[i] = (char*)malloc(len + 1);
+ if (read(fd, t->vocab[i], len) != len) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+ t->vocab[i][len] = '\0'; // add the string terminating token
+ }
+ close(fd);
+
+// fprint(2, "vocab_size: %d\n", vocab_size);
+}
+
+void free_tokenizer(Tokenizer* t) {
+ for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
+ free(t->vocab);
+ free(t->vocab_scores);
+ free(t->sorted_vocab);
+}
+
+char* decode(Tokenizer* t, int prev_token, int token) {
+ char *piece = t->vocab[token];
+ // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
+ if (prev_token == 1 && piece[0] == ' ') { piece++; }
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
+ // parse this and convert and return the actual byte
+ unsigned char byte_val;
+ if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
+ piece = (char*)t->byte_pieces + byte_val * 2;
+ }
+ return piece;
+}
+
+void safe_printf(char *piece) {
+ // piece might be a raw byte token, and we only want to print printable chars or whitespace
+ // because some of the other bytes can be various control codes, backspace, etc.
+ if (piece == NULL) { return; }
+ if (piece[0] == '\0') { return; }
+ if (piece[1] == '\0') {
+ unsigned char byte_val = piece[0];
+ if (!(isprint(byte_val) || isspace(byte_val))) {
+ return; // bad byte, don't print it
+ }
+ }
+ printf("%s", piece);
+}
+
+TokenIndex *bsearch_call(TokenIndex *tok, TokenIndex *list, int n, int s, int (*comp)(const void *a, const void *b), int A, int B) {
+ int middle = A + ((B - A) / 2);
+ int result = comp(tok, &list[middle]);
+
+ if (result == 0)
+ return &list[middle];
+
+ if (A == B || A == middle)
+ return nil;
+
+ if (result > 0)
+ return bsearch_call(tok, list, n, s, comp, middle, B);
+
+ if (result < 0)
+ return bsearch_call(tok, list, n, s, comp, A, middle);
+
+ exits("bsearch");
+}
+
+void *bsearch(TokenIndex *tok, TokenIndex *list, int n, int s, int (*comp)(const void *a, const void *b)) {
+ return bsearch_call(tok, list, n, s, comp, 0, n);
+}
+
+
+int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
+ // efficiently find the perfect match for str in vocab, return its index or -1 if not found
+ TokenIndex tok = { .str = str }; // acts as the key to search for
+ TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, SIZEOFTOKENINDEX, compare_tokens);
+ return res != NULL ? res->id : -1;
+}
+
+void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
+ // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
+ if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
+
+ if (t->sorted_vocab == NULL) {
+ // lazily malloc and sort the vocabulary
+ t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
+ for (int i = 0; i < t->vocab_size; i++) {
+ t->sorted_vocab[i].str = t->vocab[i];
+ t->sorted_vocab[i].id = i;
+ }
+ qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
+ }
+
+ // create a temporary buffer that will store merge candidates of always two consecutive tokens
+ // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
+ char* str_buffer = malloc((t->max_token_length*2 +1 +2));
+ size_t str_len = 0;
+
+ // start at 0 tokens
+ *n_tokens = 0;
+
+ // add optional BOS (=1) token, if desired
+ if (bos) tokens[(*n_tokens)++] = 1;
+
+ // add_dummy_prefix is true by default
+ // so prepend a dummy prefix token to the input string, but only if text != ""
+ // TODO: pretty sure this isn't correct in the general case but I don't have the
+ // energy to read more of the sentencepiece code to figure out what it's doing
+ if (text[0] != '\0') {
+ int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
+ tokens[(*n_tokens)++] = dummy_prefix;
+ }
+
+ // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
+ // Code point ↔ UTF-8 conversion
+ // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
+ // U+0000 U+007F 0xxxxxxx
+ // U+0080 U+07FF 110xxxxx 10xxxxxx
+ // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
+ // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+
+ // process the raw (UTF-8) byte sequence of the input string
+ for (char *c = text; *c != '\0'; c++) {
+
+ // reset buffer if the current byte is ASCII or a leading byte
+ // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
+ // 0x80 is 10000000
+ // in UTF-8, all continuation bytes start with "10" in first two bits
+ // so in English this is: "if this byte is not a continuation byte"
+ if ((*c & 0xC0) != 0x80) {
+ // this byte must be either a leading byte (11...) or an ASCII char (0x...)
+ // => reset our location, as we're starting a new UTF-8 codepoint
+ str_len = 0;
+ }
+
+ // append the current byte to the buffer
+ str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
+ str_buffer[str_len] = '\0';
+
+ // while the next character is a continuation byte, continue appending
+ // but if there are too many of them, just stop to avoid overruning str_buffer size.
+ if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
+ continue;
+ }
+
+ // ok c+1 is not a continuation byte, so we've read in a full codepoint
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
+
+ if (id != -1) {
+ // we found this codepoint in vocab, add it as a token
+ tokens[(*n_tokens)++] = id;
+ } else {
+ // byte_fallback encoding: just encode each byte as a token
+ // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
+ // so the individual bytes only start at index 3
+ for (int i=0; i < str_len; i++) {
+ tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
+ }
+ }
+ str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
+ }
+
+ // merge the best consecutive pair each iteration, according the scores in vocab_scores
+ while (1) {
+ float best_score = -1e10;
+ int best_id = -1;
+ int best_idx = -1;
+
+ for (int i=0; i < (*n_tokens-1); i++) {
+ // check if we can merge the pair (tokens[i], tokens[i+1])
+ sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
+ if (id != -1 && t->vocab_scores[id] > best_score) {
+ // this merge pair exists in vocab! record its score and position
+ best_score = t->vocab_scores[id];
+ best_id = id;
+ best_idx = i;
+ }
+ }
+
+ if (best_idx == -1) {
+ break; // we couldn't find any more pairs to merge, so we're done
+ }
+
+ // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
+ tokens[best_idx] = best_id;
+ // delete token at position best_idx+1, shift the entire sequence back 1
+ for (int i = best_idx+1; i < (*n_tokens-1); i++) {
+ tokens[i] = tokens[i+1];
+ }
+ (*n_tokens)--; // token length decreased
+ }
+
+ // add optional EOS (=2) token, if desired
+ if (eos) tokens[(*n_tokens)++] = 2;
+
+ free(str_buffer);
+}
+
+// ----------------------------------------------------------------------------
+// The Sampler, which takes logits and returns a sampled token
+// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
+
+typedef struct {
+ float prob;
+ int index;
+} ProbIndex; // struct used when sorting probabilities during top-p sampling
+
+#define SIZEOFPROBINDEX 8
+
+typedef struct {
+ int vocab_size;
+ ProbIndex* probindex; // buffer used in top-p sampling
+ float temperature;
+ float topp;
+ unsigned long long rng_state;
+} Sampler;
+
+int sample_argmax(float* probabilities, int n) {
+ // return the index that has the highest probability
+ int max_i = 0;
+ float max_p = probabilities[0];
+ for (int i = 1; i < n; i++) {
+ if (probabilities[i] > max_p) {
+ max_i = i;
+ max_p = probabilities[i];
+ }
+ }
+ return max_i;
+}
+
+int sample_mult(float* probabilities, int n, float coin) {
+ // sample index from probabilities (they must sum to 1!)
+ // coin is a random number in [0, 1), usually from random_f32()
+ float cdf = 0.0f;
+ for (int i = 0; i < n; i++) {
+ cdf += probabilities[i];
+ if (coin < cdf) {
+ return i;
+ }
+ }
+ return n - 1; // in case of rounding errors
+}
+
+int compare(const void* a, const void* b) {
+ ProbIndex* a_ = (ProbIndex*) a;
+ ProbIndex* b_ = (ProbIndex*) b;
+ if (a_->prob > b_->prob) return -1;
+ if (a_->prob < b_->prob) return 1;
+ return 0;
+}
+
+int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
+ // tokens that exceed probability topp. This way we never sample tokens that
+ // have very low probabilities and are less likely to go "off the rails".
+ // coin is a random number in [0, 1), usually from random_f32()
+
+ int n0 = 0;
+ // quicksort indices in descending order of probabilities
+ // values smaller than (1 - topp) / (n - 1) cannot be part of the result
+ // so for efficiency we crop these out as candidates before sorting
+ const float cutoff = (1.0f - topp) / (n - 1);
+ for (int i = 0; i < n; i++) {
+ if (probabilities[i] >= cutoff) {
+ probindex[n0].index = i;
+ probindex[n0].prob = probabilities[i];
+ n0++;
+ }
+ }
+ qsort(probindex, n0, 8, compare);
+
+ // truncate the list where cumulative probability exceeds topp
+ float cumulative_prob = 0.0f;
+ int last_idx = n0 - 1; // in case of rounding errors consider all elements
+ for (int i = 0; i < n0; i++) {
+ cumulative_prob += probindex[i].prob;
+ if (cumulative_prob > topp) {
+ last_idx = i;
+ break; // we've exceeded topp by including last_idx
+ }
+ }
+
+ // sample from the truncated list
+ float r = coin * cumulative_prob;
+ float cdf = 0.0f;
+ for (int i = 0; i <= last_idx; i++) {
+ cdf += probindex[i].prob;
+ if (r < cdf) {
+ return probindex[i].index;
+ }
+ }
+ return probindex[last_idx].index; // in case of rounding errors
+}
+
+void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
+ sampler->vocab_size = vocab_size;
+ sampler->temperature = temperature;
+ sampler->topp = topp;
+ sampler->rng_state = rng_seed;
+ // buffer only used with nucleus sampling; may not need but it's ~small
+ sampler->probindex = malloc(sampler->vocab_size * SIZEOFPROBINDEX);
+}
+
+void free_sampler(Sampler* sampler) {
+ free(sampler->probindex);
+}
+
+unsigned int random_u32(unsigned long long *state) {
+ // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
+ *state ^= *state >> 12;
+ *state ^= *state << 25;
+ *state ^= *state >> 27;
+ return (*state * 0x2545F4914F6CDD1Dull) >> 32;
+}
+float random_f32(unsigned long long *state) { // random float32 in [0,1)
+ return (random_u32(state) >> 8) / 16777216.0f;
+}
+
+int sample(Sampler* sampler, float* logits) {
+ // sample the token given the logits and some hyperparameters
+ int next;
+ if (sampler->temperature == 0.0f) {
+ // greedy argmax sampling: take the token with the highest probability
+ next = sample_argmax(logits, sampler->vocab_size);
+ } else {
+ // apply the temperature to the logits
+ for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
+ // apply softmax to the logits to get the probabilities for next token
+ softmax(logits, sampler->vocab_size);
+ // flip a (float) coin (this is our source of entropy for sampling)
+ float coin = random_f32(&sampler->rng_state);
+ // we sample from this distribution to get the next token
+ if (sampler->topp <= 0 || sampler->topp >= 1) {
+ // simply sample from the predicted probability distribution
+ next = sample_mult(logits, sampler->vocab_size, coin);
+ } else {
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
+ next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
+ }
+ }
+ return next;
+}
+
+// ----------------------------------------------------------------------------
+// utilities: time
+
+/*long time_in_ms() {
+ // return time in milliseconds, for benchmarking the model speed
+ struct timespec time;
+ clock_gettime(CLOCK_REALTIME, &time);
+ return time.tv_sec * 1000 + time.tv_nsec / 1000000;
+} */
+
+#define time_in_ms() (nsec()/1000000)
+
+// ----------------------------------------------------------------------------
+// generation loop
+
+void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
+ char *empty_prompt = "";
+ if (prompt == NULL) { prompt = empty_prompt; }
+
+ // encode the (string) prompt into tokens sequence
+ int num_prompt_tokens = 0;
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * 4); // +3 for '\0', ?BOS, ?EOS
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
+ if (num_prompt_tokens < 1) {
+ fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
+ exit(EXIT_FAILURE);
+ }
+
+ // start the main loop
+ long start = 0; // used to time our code, only initialized after first iteration
+ int next; // will store the next token in the sequence
+ int token = prompt_tokens[0]; // kick off with the first token in the prompt
+ int pos = 0; // position in the sequence
+ while (pos < steps) {
+
+ // forward the transformer to get logits for the next token
+ float* logits = forward(transformer, token, pos);
+
+ // advance the state machine
+ if (pos < (num_prompt_tokens - 1)) {
+ // if we are still processing the input prompt, force the next prompt token
+ next = prompt_tokens[pos + 1];
+ } else {
+ // otherwise sample the next token from the logits
+ next = sample(sampler, logits);
+ }
+ pos++;
+
+ // data-dependent terminating condition: the BOS (=1) token delimits sequences
+ if (next == 1) { break; }
+
+ // print the token as string, decode it with the Tokenizer object
+ char* piece = decode(tokenizer, token, next);
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
+ fflush(stdout);
+ token = next;
+
+ // init the timer here because the first iteration can be slower
+ if (start == 0) { start = time_in_ms(); }
+ }
+ printf("\n");
+
+ // report achieved tok/s (pos-1 because the timer starts after first iteration)
+ if (pos > 1) {
+ long end = time_in_ms();
+ fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
+ }
+
+ free(prompt_tokens);
+}
+
+void read_stdin(const char* guide, char* buffer, size_t bufsize) {
+ // read a line from stdin, up to but not including \n
+ printf("%s", guide);
+ if (fgets(buffer, bufsize, stdin) != NULL) {
+ size_t len = strlen(buffer);
+ if (len > 0 && buffer[len - 1] == '\n') {
+ buffer[len - 1] = '\0'; // strip newline
+ }
+ }
+}
+
+// ----------------------------------------------------------------------------
+// chat loop
+// I manually inspected the tokens for a few chat conversations compared to
+// python reference and that seemed ok, but this was not thoroughly tested and
+// is not safely implemented, it's more a proof of concept atm.
+
+void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
+ char *cli_user_prompt, char *cli_system_prompt, int steps) {
+
+ // buffers for reading the system prompt and user prompt from stdin
+ // you'll notice they are soomewhat haphazardly and unsafely set atm
+ char system_prompt[512];
+ char user_prompt[512];
+ char rendered_prompt[1152];
+ int num_prompt_tokens = 0;
+ int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
+ int user_idx = 0;
+
+ // start the main loop
+ int8_t user_turn = 1; // user starts
+ int next = user_turn; // will store the next token in the sequence
+ int token; // stores the current token to feed into the transformer
+ int pos = 0; // position in the sequence
+ while (pos < steps) {
+
+ // when it is the user's turn to contribute tokens to the dialog...
+ if (user_turn) {
+ // get the (optional) system prompt at position 0
+ if (pos == 0) {
+ // at position 0, the user can also contribute a system prompt
+ if (cli_system_prompt == NULL) {
+ // system prompt was not passed in, attempt to get it from stdin
+ read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
+ } else {
+ // system prompt was passed in, use it
+ strcpy(system_prompt, cli_system_prompt);
+ }
+ }
+ // get the user prompt
+ if (pos == 0 && cli_user_prompt != NULL) {
+ // user prompt for position 0 was passed in, use it
+ strcpy(user_prompt, cli_user_prompt);
+ } else {
+ // otherwise get user prompt from stdin
+ read_stdin("User: ", user_prompt, sizeof(user_prompt));
+ }
+ // render user/system prompts into the Llama 2 Chat schema
+ if (pos == 0 && system_prompt[0] != '\0') {
+ char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
+ sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
+ } else {
+ char user_template[] = "[INST] %s [/INST]";
+ sprintf(rendered_prompt, user_template, user_prompt);
+ }
+ // encode the rendered prompt into tokens
+ encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
+ user_idx = 0; // reset the user index
+ user_turn = 0;
+ printf("Assistant: ");
+ }
+
+ // determine the token to pass into the transformer next
+ if (user_idx < num_prompt_tokens) {
+ // if we are still processing the input prompt, force the next prompt token
+ token = prompt_tokens[user_idx++];
+ } else {
+ // otherwise use the next token sampled from previous turn
+ token = next;
+ }
+ // EOS (=2) token ends the Assistant turn
+ if (token == 2) { user_turn = 1; }
+
+ // forward the transformer to get logits for the next token
+ float* logits = forward(transformer, token, pos);
+ next = sample(sampler, logits);
+ pos++;
+
+ if (user_idx >= num_prompt_tokens && next != 2) {
+ // the Assistant is responding, so print its output
+ char* piece = decode(tokenizer, token, next);
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
+ fflush(stdout);
+ }
+ if (next == 2) { printf("\n"); }
+ }
+ printf("\n");
+ free(prompt_tokens);
+}
+
+
+// ----------------------------------------------------------------------------
+// CLI, include only if not testing
+
+void error_usage(void) {
+ fprintf(stderr, "Usage: run <checkpoint> [options]\n");
+ fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
+ fprintf(stderr, "Options:\n");
+ fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
+ fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
+ fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
+ fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
+ fprintf(stderr, " -i <string> input prompt\n");
+ fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
+ fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
+ fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
+ exit(EXIT_FAILURE);
+}
+
+int main(int argc, char *argv[]) {
+
+ // default parameters
+ char *checkpoint_path = NULL; // e.g. out/model.bin
+ char *tokenizer_path = "tokenizer.bin";
+ float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
+ float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
+ int steps = 256; // number of steps to run for
+ char *prompt = NULL; // prompt string
+ unsigned long long rng_seed = 0; // seed rng with time by default
+ char *mode = "generate"; // generate|chat
+ char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
+
+ // poor man's C argparse so we can override the defaults above from the command line
+ if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
+ for (int i = 2; i < argc; i+=2) {
+ // do some basic validation
+ if (i + 1 >= argc) { error_usage(); } // must have arg after flag
+ if (argv[i][0] != '-') { error_usage(); } // must start with dash
+ if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
+ // read in the args
+ if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
+ else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
+ else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
+ else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
+ else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
+ else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
+ else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
+ else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
+ else { error_usage(); }
+ }
+
+ // parameter validation/overrides
+ if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
+ if (temperature < 0.0) temperature = 0.0;
+ if (topp < 0.0 || 1.0 < topp) topp = 0.9;
+ if (steps < 0) steps = 0;
+
+ // build the Transformer via the model .bin file
+ Transformer transformer;
+ build_transformer(&transformer, checkpoint_path);
+ if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // override to ~max length
+
+ // build the Tokenizer via the tokenizer .bin file
+ Tokenizer tokenizer;
+ build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
+
+ // build the Sampler
+ Sampler sampler;
+ build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
+
+ // run!
+ if (strcmp(mode, "generate") == 0) {
+ generate(&transformer, &tokenizer, &sampler, prompt, steps);
+ } else if (strcmp(mode, "chat") == 0) {
+ chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
+ } else {
+ fprintf(stderr, "unknown mode: %s\n", mode);
+ error_usage();
+ }
+
+ // memory and file handles cleanup
+ free_sampler(&sampler);
+ free_tokenizer(&tokenizer);
+ free_transformer(&transformer);
+ return 0;
+}
+
--- a/round.c
+++ /dev/null
@@ -1,34 +1,0 @@
-#include <u.h>
-#include <libc.h>
-
-float round(float in) {
- float f;
-
- f = fmod(in, 1.0);
-
- if (in > 0) {
- if (f < 0.5)
- return floor(in);
- return ceil(in);
- }
-
- if (f > -0.5)
- return ceil(in);
- return floor(in);
-}
-
-void main() {
- char buf[1024];
- int r;
- float f;
-
- r = read(0, buf, sizeof(buf));
- if (r <= 0)
- return;
-
- buf[r] = '\0';
-
- f = round(atof(buf));
-
- print("%f\n", f);
-}
--
⑨