A Survey of Speculative Decoding Techniques in LLM Inference

omnivore inference high-priority

Read on Omnivore | Read Original

Highlights

one full forward pass of the model results in the generation of a single token. This is a highly inefficient use of the available compute resources of the GPU or the accelerator chip, and speculative decoding solves this by enabling the prediction of multiple tokens in a single forward pass. ⤴️

It would be prudent if once the weights are moved into the accelerator, multiple tokens could be predicted in parallel so that the available GPU compute was maximally utilized. However, the autoregressive nature of the LLM architecture prohibits this from happening because at every step of the next token generation, the model needs the previously generated tokens as part of its context. ⤴️

The underlying motivation behind speculative decoding is to improve the GPU resource utilization and in the process improve LLM inference throughput by predicting multiple tokens in parallel.

This is inspired by the concept of speculative execution in traditional processors (CPUs). ⤴️

The speculative decoding paper proposes the use of a smaller more compute efficient model for predicting the next few tokens. This smaller model is called the speculator model, but you may also see it being referred to as the draft model in some places. ⤴️

the same prefix is fed to the base LLM and the next token is predicted. If the probability of the token as generated by the speculator is less than or equal to the probability of the same token as generated by the base LLM, then it is accepted, otherwise it is rejected. The intuition being that we don’t want to accept any token that the speculator generates with high probability but the base LLM is highly unlikely to produce that token. ⤴️

2024 ICML paper.

Instead of using a separate speculator model, the Medusa paper proposes adding multiple prediction heads on top of the base LLM. These prediction heads are added in the form of a single feed forward layer on top of the last hidden layer in the base LLM.

If there are k prediction heads, the prediction head at the first position predicts the next token, the prediction head at the second position predicts the token after that, and similarly, the prediction head at the kth position predicts the token k steps ahead. ⤴️

Medusa model consists of multiple prediction heads which are used to predict the next k tokens. However, these heads receive the same input, which is the embedding vector as produced by the last hidden layer of the base model for predicting the current token. The further in the future a token is making a prediction, the further it moves away from the ground reality because it has no information about ⤴️

employing a smaller, faster speculator model to propose candidate tokens, speculative decoding enables the main LLM to process multiple tokens in parallel, which significantly improves the token throughput and saves compute cost. ⤴️

The original speculative decoding method relied on a separate draft model, introducing challenges in training, integration, and potential distribution mismatches. The Medusa architecture addressed these limitations by incorporating multiple prediction heads directly within the base LLM, eliminating the need for a separate speculator model. ⤴️

However, that is not the end of the story. Even though Medusa improves upon the original speculative decoding design, it also has some limitations which were outlined in the IBM report. They suggest some interesting modifications, such as using a multi-stage MLP and employing multi-candidate decoding for faster verification of candidate tokens. ⤴️