For everything that could have been,
一切本都可实现,
At least we took the ride,
但至少我们经历了这一程,
There is no relief in bitterness,
无论怎么做都不会解脱,
Might as well let it die,
所以还是让它随风而逝吧。
voidsoftmax(float* x, int size){ // find max value (for numerical stability) float max_val = x[0]; for (int i = 1; i < size; i++) { if (x[i] > max_val) { max_val = x[i]; } } // exp and sum float sum = 0.0f; for (int i = 0; i < size; i++) { x[i] = expf(x[i] - max_val); sum += x[i]; } // normalize for (int i = 0; i < size; i++) { x[i] /= sum; } }
voidmatmul(float* xout, float* x, float* w, int n, int d){ // W (d,n) @ x (n,) -> xout (d,) // by far the most amount of time is spent inside this little function int i; #pragma omp parallel for private(i) for (i = 0; i < d; i++) { float val = 0.0f; for (int j = 0; j < n; j++) { val += w[i * n + j] * x[j]; } xout[i] = val; } }
// build the Transformer via the model .bin file Transformer transformer; build_transformer(&transformer, checkpoint_path); if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // override to ~max length
// build the Tokenizer via the tokenizer .bin file Tokenizer tokenizer; build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
typedefstruct { Config config; // the hyperparameters of the architecture (the blueprint) TransformerWeights weights; // the weights of the model RunState state; // buffers for the "wave" of activations in the forward pass // some more state needed to properly clean up the memory mapping (sigh) int fd; // file descriptor for memory mapping float* data; // memory mapped data pointer ssize_t file_size; // size of the checkpoint file in bytes } Transformer;
typedefstruct { int dim; // transformer dimension int hidden_dim; // for ffn layers int n_layers; // number of layers int n_heads; // number of query heads int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) int vocab_size; // vocabulary size, usually 256 (byte-level) int seq_len; // max sequence length } Config;
// encode the (string) prompt into tokens sequence int num_prompt_tokens = 0; int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); if (num_prompt_tokens < 1) { fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); exit(EXIT_FAILURE); }
// start the main loop long start = 0; // used to time our code, only initialized after first iteration int next; // will store the next token in the sequence int token = prompt_tokens[0]; // kick off with the first token in the prompt int pos = 0; // position in the sequence while (pos < steps) {
// forward the transformer to get logits for the next token float* logits = forward(transformer, token, pos);
// advance the state machine if (pos < num_prompt_tokens - 1) { // if we are still processing the input prompt, force the next prompt token next = prompt_tokens[pos + 1]; } else { // otherwise sample the next token from the logits next = sample(sampler, logits); } pos++;
// data-dependent terminating condition: the BOS (=1) token delimits sequences if (next == 1) { break; }
// print the token as string, decode it with the Tokenizer object char* piece = decode(tokenizer, token, next); safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes fflush(stdout); token = next;
// init the timer here because the first iteration can be slower if (start == 0) { start = time_in_ms(); } } printf("\n");
// report achieved tok/s (pos-1 because the timer starts after first iteration) if (pos > 1) { long end = time_in_ms(); fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); }
float* forward(Transformer* transformer, int token, int pos){
// a few convenience variables Config* p = &transformer->config; TransformerWeights* w = &transformer->weights; RunState* s = &transformer->state; float *x = s->x; int dim = p->dim; int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery int hidden_dim = p->hidden_dim; int head_size = dim / p->n_heads;
// copy the token embedding into x float* content_row = w->token_embedding_table + token * dim; memcpy(x, content_row, dim*sizeof(*x));
// forward all the layers for(unsignedlonglong l = 0; l < p->n_layers; l++) { //... }
// final rmsnorm rmsnorm(x, x, w->rms_final_weight, dim);
// multihead attention. iterate over all heads int h; #pragma omp parallel for private(h) for (h = 0; h < p->n_heads; h++) { // get the query vector for this head float* q = s->q + h * head_size; // attention scores for this head float* att = s->att + h * p->seq_len; // iterate over all timesteps, including the current one for (int t = 0; t <= pos; t++) { // get the key vector for this head and at this timestep float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size; // calculate the attention score as the dot product of q and k float score = 0.0f; for (int i = 0; i < head_size; i++) { score += q[i] * k[i]; } score /= sqrtf(head_size); // save the score to the attention buffer att[t] = score; }
// softmax the scores to get attention weights, from 0..pos inclusively softmax(att, pos + 1);
// weighted sum of the values, store back into xb float* xb = s->xb + h * head_size; memset(xb, 0, head_size * sizeof(float)); for (int t = 0; t <= pos; t++) { // get the value vector for this head and at this timestep float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size; // get the attention weight for this timestep float a = att[t]; // accumulate the weighted value into xb for (int i = 0; i < head_size; i++) { xb[i] += a * v[i]; } } }
分头计算完毕,xb就是所有头的concat。再来个全连接+残差,Attention阶段就结束了。
321 322 323 324 325 326 327
// final matmul to get the output of the attention matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
// residual connection back into x for (int i = 0; i < dim; i++) { x[i] += s->xb2[i]; }
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) // first calculate self.w1(x) and self.w3(x) matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim); matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
// SwiGLU non-linearity for (int i = 0; i < hidden_dim; i++) { float val = s->hb[i]; // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid val *= (1.0f / (1.0f + expf(-val))); // elementwise multiply with w3(x) val *= s->hb2[i]; s->hb[i] = val; }
// final matmul to get the output of the ffn matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
// residual connection for (int i = 0; i < dim; i++) { x[i] += s->xb[i]; } }