Obtain ~2x speed-up in LLM inference with Medusa-1 on Amazon SageMaker AI

This weblog submit is co-written with Moran beladev, Manos Stergiadis, and Ilya Gusev from Reserving.com.
Large language models (LLMs) have revolutionized the sector of pure language processing with their capacity to know and generate humanlike textual content. Educated on broad, generic datasets spanning a variety of subjects and domains, LLMs use their parametric information to carry out more and more complicated and versatile duties throughout a number of enterprise use circumstances. Moreover, firms are more and more investing assets in customizing LLMs via few-shot studying and fine-tuning to optimize their efficiency for specialised purposes.
Nonetheless, the spectacular efficiency of LLMs comes at the price of important computational necessities, pushed by their massive variety of parameters and autoregressive decoding course of which is sequential in nature. This mixture makes reaching low latency a problem to be used circumstances resembling real-time textual content completion, simultaneous translation, or conversational voice assistants, the place subsecond response instances are important.
Researchers developed Medusa, a framework to hurry up LLM inference by including additional heads to foretell a number of tokens concurrently. This submit demonstrates how you can use Medusa-1, the primary model of the framework, to hurry up an LLM by fine-tuning it on Amazon SageMaker AI and confirms the velocity up with deployment and a easy load check. Medusa-1 achieves an inference speedup of round two instances with out sacrificing mannequin high quality, with the precise enchancment various primarily based on mannequin measurement and knowledge used. On this submit, we show its effectiveness with a 1.8 instances speedup noticed on a pattern dataset.
Introduction to Medusa and its advantages for LLM inference velocity
LLMs generate textual content in a sequential method, which includes autoregressive sampling, with every new token conditional on the earlier ones. Producing Okay tokens necessitates Okay sequential executions of the mannequin. This token-by-token processing introduces an inherent latency and computational overhead as a result of the mannequin must carry out a separate ahead move for every new token within the output sequence. The next diagram from Role-Play with Large Language Models illustrates this circulate.
Speculative decoding tackles this problem through the use of a smaller, quicker draft mannequin to generate a number of potential token continuations in parallel, that are then verified by a bigger, extra correct goal mannequin. This parallelization accelerates textual content era whereas sustaining the standard of the goal mannequin as a result of the verification activity is quicker than autoregressive token era. For an in depth clarification of the idea, consult with the paper Accelerating Large Language Model Decoding with Speculative Sampling. The speculative decoding method will be carried out utilizing the inference optimization toolkit on Amazon SageMaker Jumpstart.
The paper Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads launched Medusa as an alternative choice to speculative decoding. As a substitute of including a separate draft mannequin, it provides additional decoding heads to the LLM that generate candidate continuations concurrently. These candidates are then evaluated in parallel utilizing a tree-based consideration mechanism. This parallel processing reduces the variety of sequential steps wanted, resulting in quicker inference instances. The principle benefit of Medusa over speculative decoding is that it eliminates the necessity to purchase and keep a separate draft mannequin whereas reaching increased speedups. For instance, when examined on the MT-Bench dataset, the paper reviews that Medusa-2 (the second model of Medusa) accelerates inference time by 2.8 instances. This outperforms speculative decoding, which solely manages to hurry up inference time by 1.5 instances on the identical dataset.
The Medusa framework at present helps Llama and Mistral fashions. Though it provides important velocity enhancements, it does include a reminiscence trade-off (much like speculative decoding). As an example, including 5 Medusa heads to the 7-billion-parameter Mistral mannequin will increase the entire parameter depend by 750 million (150 million per head), which implies these extra parameters should be saved in GPU reminiscence, resulting in a better reminiscence requirement. Nonetheless, most often, this improve doesn’t necessitate switching to a better GPU reminiscence occasion. For instance, you may nonetheless use an ml.g5.4xlarge
occasion with 24 GB of GPU reminiscence to host your 7-billion-parameter Llama or Mistral mannequin with additional Medusa heads.
Coaching Medusa heads requires extra growth time and computational assets, which must be factored into mission planning and useful resource allocation. One other vital limitation to say is that the present framework, when deployed on an Amazon SageMaker AI endpoint, solely helps a batch measurement of 1—a configuration usually used for low-latency purposes.
The next diagram from the unique Medusa paper authors’ FasterDecoding repository provides a visible Medusa framework overview.
There are two predominant variants of Medusa:
- Medusa-1 – Requires a two-stage strategy the place you first fine-tune your LLM after which add Medusa heads and prepare them on high of your frozen fine-tuned LLM
- Medusa-2 – Launched later as an enchancment, fine-tunes each the extra heads and the spine LLM parameters collectively, enabling probably even additional latency speedups
The Medusa paper reviews that throughout fashions of various sizes, you may obtain inference speedups of round two instances for Medusa-1 and round 3 times for Medusa-2. With Medusa-1, the predictions are an identical to these of the initially fine-tuned LLM. In distinction, with Medusa-2, we would observe barely totally different outcomes in comparison with easy fine-tuning of the LLM as a result of each the heads and the spine LLM parameters are up to date collectively. On this submit, we give attention to Medusa-1.
Answer overview
We cowl the next steps in our answer:
- Stipulations
- Load and put together the dataset
- High-quality-tune an LLM utilizing a SageMaker AI coaching job
- Practice Medusa heads on high of a frozen fine-tuned LLM utilizing a SageMaker AI coaching job
- Deploy the fine-tuned LLM with Medusa heads on a SageMaker AI endpoint
- Show LLM inference speedup
By following this answer, you may speed up LLM inference in your purposes, resulting in quicker response instances and improved consumer expertise.
Stipulations
To construct the answer your self, there are the next conditions:
Load and put together the dataset
Now that you’ve got cloned the GitHub repository and opened the medusa_1_train.ipynb
pocket book, you’ll load and put together the dataset within the pocket book. We encourage you to learn this submit whereas operating the code within the pocket book. For this submit, we use a dataset known as sql-create-context, which accommodates samples of pure language directions, schema definitions and the corresponding SQL question. It accommodates 78,577 examples of pure language queries, SQL CREATE TABLE statements, and SQL queries answering the query utilizing the CREATE assertion as context. For demonstration functions, we choose 3,000 samples and break up them into prepare, validation, and check units.
You should run the “Load and put together the dataset” part of the medusa_1_train.ipynb
to organize the dataset for fine-tuning. We additionally included an information exploration script to investigate the size of enter and output tokens. After knowledge exploration, we put together the prepare, validation, and check units and add them to Amazon Simple Storage Service (Amazon S3).
High-quality-tune an LLM utilizing SageMaker AI coaching job
We use the Zephyr 7B β mannequin as our spine LLM. Zephyr is a sequence of language fashions educated to behave as useful assistants, and Zephyr 7B β is a fine-tuned model of Mistral-7B-v0.1, educated on a mixture of publicly obtainable and artificial datasets utilizing Direct Preference Optimization.
To launch a SageMaker AI coaching job, we have to use the PyTorch or Hugging Face estimator. SageMaker AI begins and manages all the required Amazon Elastic Compute Cloud (Amazon EC2) cases for us, provides the suitable containers, downloads knowledge from our S3 bucket to the container and uploads and runs the required coaching script, in our case fine_tune_llm.py
. We choose the hyperparameters primarily based on the QLoRA paper, however we encourage you to experiment with your individual mixtures. To expedite the execution of this code, we set the variety of epochs to 1. Nonetheless, for higher outcomes, it’s typically beneficial to set the variety of epochs to at the least 2 or 3.
When our coaching job has accomplished efficiently after roughly 1 hour, we will use the fine-tuned mannequin artifact for the subsequent step, coaching the Medusa heads on high of it. To visualise the coaching metrics in Tensorboard, you may observe the steerage on this documentation: Load and visualize output tensors using the TensorBoard application
Practice Medusa heads on high of frozen fine-tuned LLM utilizing a SageMaker AI coaching job
For coaching Medusa heads, we will reuse the features beforehand talked about to launch the coaching job. We chosen hyperparameters primarily based on a mixture of what the Medusa paper reported and what we discovered to be finest performing after just a few experiments. We set the variety of Medusa heads to five and used the 8-bit AdamW optimizer, as beneficial by the paper. For simplicity, we maintained a relentless studying price of 1e-4 with a relentless scheduler, much like the earlier fine-tuning step. Though the paper recommends an elevated studying price and a cosine scheduler, we discovered that our chosen mixture of hyperparameters carried out effectively on this dataset. Nonetheless, we encourage you to experiment with your individual hyperparameter settings to probably obtain even higher outcomes.
We discovered that after 3 epochs, the analysis lack of Medusa heads was converging, which will be noticed within the TensorBoard graph within the following picture.
Apart from the hyperparameters, the primary distinction is that we move train_medusa_heads.py
because the coaching entrypoint, the place we first add Medusa heads, then freeze the fine-tuned LLM, and we create customized MedusaSFTTrainer class, which is a subclass of the transformers SFTTrainer.
Within the add_medusa_heads()
operate, we add the residual blocks of the Medusa heads, and in addition override the ahead move for our mannequin to ensure to not prepare the frozen spine LLM:
After the mannequin coaching is completed (which takes 1 hour), we put together the mannequin artefacts for deployment and add it to Amazon S3. Your last mannequin artifact accommodates each the unique fine-tuned mannequin from the earlier step underneath the base-model
prefix and the educated Medusa heads in a file named medusa_heads.safetensors
.
Deploy the fine-tuned LLM with Medusa heads on a SageMaker AI endpoint
The Medusa framework is supported by the Text Generation Inference (TGI) server. After coaching the LLM with Medusa heads, we deploy it to a SageMaker AI real-time endpoint utilizing the Hugging Face Inference Container arrange with TGI.
First, we create a SageMaker AI HuggingFaceModel object after which deploy the mannequin to an endpoint with the next operate:
We deploy three LLMs on three SageMaker AI endpoints:
- Base LLM which isn’t fine-tuned
- The LLM that we fine-tuned
- The fine-tuned LLM that additionally has educated Medusa heads
You’ll be able to deploy the three fashions in parallel through the use of a operate that we included within the pocket book, or you may deploy the fashions one after the other by operating the code under:
After the standing for every endpoint turns into InService
, which ought to take round quarter-hour, we will invoke them for inference. We ship the next enter:
We are able to observe the next responses:
- The bottom LLM response accommodates additional phrases that aren’t wanted:
- The fine-tuned LLM response is improved considerably, and accommodates solely the required output:
- The fine-tuned LLM with educated Medusa heads supplies the very same response because the fine-tuned mannequin, demonstrating that Medusa-1, by design, maintains the output (high quality) of the unique mannequin:
Show LLM inference speedup
To measure the inference velocity enhancements, we evaluate the response instances of the deployed fine-tuned LLM and the fine-tuned LLM with Medusa heads on 450 check observations with the next code:
First, we run predictions utilizing the fine-tuned LLM:
Then, we run predictions utilizing the fine-tuned LLM with Medusa heads:
The prediction runs ought to take round 8 and 4 minutes respectively. We are able to observe that the common latency decreased from 950 to 530 milliseconds, which is an enchancment of 1.8 instances. You’ll be able to obtain even increased enhancements in case your dataset accommodates longer inputs and outputs. In our dataset, we solely had a median of 18 enter tokens and 30 output tokens.
We need to as soon as once more spotlight that, with this method, the output high quality is absolutely maintained, and all of the prediction outputs are the identical. The mannequin responses for the check set of 450 observations are the identical for each with Medusa heads and with out Medusa heads:
You may discover in your run that just a few observations aren’t precisely matching, and also you may get a 99% match on account of small errors in floating level operations attributable to optimizations on GPUs.
Cleanup
On the finish of this experiment, don’t overlook to delete the SageMaker AI endpoints you created:
Conclusion
On this submit, we demonstrated how you can fine-tune and deploy an LLM with Medusa heads utilizing the Medusa-1 method on Amazon SageMaker AI to speed up LLM inference. By utilizing this framework and SageMaker AI scalable infrastructure, we confirmed how you can obtain as much as twofold speedups in LLM inference whereas sustaining mannequin high quality. This answer is especially helpful for purposes requiring low-latency textual content era, resembling customer support chat assistants, content material creation, and suggestion programs.
As a subsequent step, you may discover fine-tuning your individual LLM with Medusa heads by yourself dataset and benchmark the outcomes to your particular use case, utilizing the supplied GitHub repository.
Concerning the authors
Daniel Zagyva is a Senior ML Engineer at AWS Skilled Providers. He makes a speciality of creating scalable, production-grade machine studying options for AWS clients. His expertise extends throughout totally different areas, together with pure language processing, generative AI and machine studying operations.
Aleksandra Dokic is a Senior Knowledge Scientist at AWS Skilled Providers. She enjoys supporting clients to construct modern AI/ML options on AWS and he or she is happy about enterprise transformations via the ability of information.
Moran Beladev is a Senior ML Supervisor at Reserving.com. She is main the content material intelligence monitor which is targeted on constructing, coaching and deploying content material fashions (pc imaginative and prescient, NLP and generative AI) utilizing essentially the most superior applied sciences and fashions. Moran can be a PhD candidate, researching making use of NLP fashions on social graphs.
Manos Stergiadis is a Senior ML Scientist at Reserving.com. He makes a speciality of generative NLP and has expertise researching, implementing and deploying massive deep studying fashions at scale.
Ilya Gusev is a Senior Machine Studying Engineer at Reserving.com. He leads the event of the a number of LLM programs inside Reserving.com. His work focuses on constructing manufacturing ML programs that assist tens of millions of vacationers plan their journeys successfully.
Laurens van der Maas is a Machine Studying Engineer at AWS Skilled Providers. He works intently with clients constructing their machine studying options on AWS, makes a speciality of pure language processing, experimentation and accountable AI, and is obsessed with utilizing machine studying to drive significant change on the planet.