Obtain ~2x speed-up in LLM inference with Medusa-1 on Amazon SageMaker AI


This weblog submit is co-written with Moran beladev, Manos Stergiadis, and Ilya Gusev from Reserving.com.

Large language models (LLMs) have revolutionized the sector of pure language processing with their capacity to know and generate humanlike textual content. Educated on broad, generic datasets spanning a variety of subjects and domains, LLMs use their parametric information to carry out more and more complicated and versatile duties throughout a number of enterprise use circumstances. Moreover, firms are more and more investing assets in customizing LLMs via few-shot studying and fine-tuning to optimize their efficiency for specialised purposes.

Nonetheless, the spectacular efficiency of LLMs comes at the price of important computational necessities, pushed by their massive variety of parameters and autoregressive decoding course of which is sequential in nature. This mixture makes reaching low latency a problem to be used circumstances resembling real-time textual content completion, simultaneous translation, or conversational voice assistants, the place subsecond response instances are important.

Researchers developed Medusa, a framework to hurry up LLM inference by including additional heads to foretell a number of tokens concurrently. This submit demonstrates how you can use Medusa-1, the primary model of the framework, to hurry up an LLM by fine-tuning it on Amazon SageMaker AI and confirms the velocity up with deployment and a easy load check. Medusa-1 achieves an inference speedup of round two instances with out sacrificing mannequin high quality, with the precise enchancment various primarily based on mannequin measurement and knowledge used. On this submit, we show its effectiveness with a 1.8 instances speedup noticed on a pattern dataset.

Introduction to Medusa and its advantages for LLM inference velocity

LLMs generate textual content in a sequential method, which includes autoregressive sampling, with every new token conditional on the earlier ones. Producing Okay tokens necessitates Okay sequential executions of the mannequin. This token-by-token processing introduces an inherent latency and computational overhead as a result of the mannequin must carry out a separate ahead move for every new token within the output sequence. The next diagram from Role-Play with Large Language Models illustrates this circulate.

Autoregressive sampling overview

Speculative decoding tackles this problem through the use of a smaller, quicker draft mannequin to generate a number of potential token continuations in parallel, that are then verified by a bigger, extra correct goal mannequin. This parallelization accelerates textual content era whereas sustaining the standard of the goal mannequin as a result of the verification activity is quicker than autoregressive token era. For an in depth clarification of the idea, consult with the paper Accelerating Large Language Model Decoding with Speculative Sampling. The speculative decoding method will be carried out utilizing the inference optimization toolkit on Amazon SageMaker Jumpstart.

The paper Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads launched Medusa as an alternative choice to speculative decoding. As a substitute of including a separate draft mannequin, it provides additional decoding heads to the LLM that generate candidate continuations concurrently. These candidates are then evaluated in parallel utilizing a tree-based consideration mechanism. This parallel processing reduces the variety of sequential steps wanted, resulting in quicker inference instances. The principle benefit of Medusa over speculative decoding is that it eliminates the necessity to purchase and keep a separate draft mannequin whereas reaching increased speedups. For instance, when examined on the MT-Bench dataset, the paper reviews that Medusa-2 (the second model of Medusa) accelerates inference time by 2.8 instances. This outperforms speculative decoding, which solely manages to hurry up inference time by 1.5 instances on the identical dataset.

The Medusa framework at present helps Llama and Mistral fashions. Though it provides important velocity enhancements, it does include a reminiscence trade-off (much like speculative decoding). As an example, including 5 Medusa heads to the 7-billion-parameter Mistral mannequin will increase the entire parameter depend by 750 million (150 million per head), which implies these extra parameters should be saved in GPU reminiscence, resulting in a better reminiscence requirement. Nonetheless, most often, this improve doesn’t necessitate switching to a better GPU reminiscence occasion. For instance, you may nonetheless use an ml.g5.4xlarge occasion with 24 GB of GPU reminiscence to host your 7-billion-parameter Llama or Mistral mannequin with additional Medusa heads.

Coaching Medusa heads requires extra growth time and computational assets, which must be factored into mission planning and useful resource allocation. One other vital limitation to say is that the present framework, when deployed on an Amazon SageMaker AI endpoint, solely helps a batch measurement of 1—a configuration usually used for low-latency purposes.

The next diagram from the unique Medusa paper authors’ FasterDecoding repository provides a visible Medusa framework overview.

Medusa framework overview

There are two predominant variants of Medusa:

  1. Medusa-1 – Requires a two-stage strategy the place you first fine-tune your LLM after which add Medusa heads and prepare them on high of your frozen fine-tuned LLM
  2. Medusa-2 – Launched later as an enchancment, fine-tunes each the extra heads and the spine LLM parameters collectively, enabling probably even additional latency speedups

The Medusa paper reviews that throughout fashions of various sizes, you may obtain inference speedups of round two instances for Medusa-1 and round 3 times for Medusa-2. With Medusa-1, the predictions are an identical to these of the initially fine-tuned LLM. In distinction, with Medusa-2, we would observe barely totally different outcomes in comparison with easy fine-tuning of the LLM as a result of each the heads and the spine LLM parameters are up to date collectively. On this submit, we give attention to Medusa-1.

Answer overview

We cowl the next steps in our answer:

  • Stipulations
  • Load and put together the dataset
  • High-quality-tune an LLM utilizing a SageMaker AI coaching job
  • Practice Medusa heads on high of a frozen fine-tuned LLM utilizing a SageMaker AI coaching job
  • Deploy the fine-tuned LLM with Medusa heads on a SageMaker AI endpoint
  • Show LLM inference speedup

By following this answer, you may speed up LLM inference in your purposes, resulting in quicker response instances and improved consumer expertise.

Stipulations

To construct the answer your self, there are the next conditions:

Load and put together the dataset

Now that you’ve got cloned the GitHub repository and opened the medusa_1_train.ipynb pocket book, you’ll load and put together the dataset within the pocket book. We encourage you to learn this submit whereas operating the code within the pocket book. For this submit, we use a dataset known as sql-create-context, which accommodates samples of pure language directions, schema definitions and the corresponding SQL question. It accommodates 78,577 examples of pure language queries, SQL CREATE TABLE statements, and SQL queries answering the query utilizing the CREATE assertion as context. For demonstration functions, we choose 3,000 samples and break up them into prepare, validation, and check units.

You should run the “Load and put together the dataset” part of the medusa_1_train.ipynb to organize the dataset for fine-tuning. We additionally included an information exploration script to investigate the size of enter and output tokens. After knowledge exploration, we put together the prepare, validation, and check units and add them to Amazon Simple Storage Service (Amazon S3).

High-quality-tune an LLM utilizing SageMaker AI coaching job

We use the Zephyr 7B β mannequin as our spine LLM. Zephyr is a sequence of language fashions educated to behave as useful assistants, and Zephyr 7B β is a fine-tuned model of Mistral-7B-v0.1, educated on a mixture of publicly obtainable and artificial datasets utilizing Direct Preference Optimization.

To launch a SageMaker AI coaching job, we have to use the PyTorch or Hugging Face estimator. SageMaker AI begins and manages all the required Amazon Elastic Compute Cloud (Amazon EC2) cases for us, provides the suitable containers, downloads knowledge from our S3 bucket to the container and uploads and runs the required coaching script, in our case fine_tune_llm.py. We choose the hyperparameters primarily based on the QLoRA paper, however we encourage you to experiment with your individual mixtures. To expedite the execution of this code, we set the variety of epochs to 1. Nonetheless, for higher outcomes, it’s typically beneficial to set the variety of epochs to at the least 2 or 3.

from sagemaker.pytorch.estimator import PyTorch
from sagemaker.debugger import TensorBoardOutputConfig
import time
import os

def get_current_time():
    return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

def create_estimator(hyperparameters_dict, job_name, position, sess, train_scipt_path):
    metric=[
        {"Name": "loss", "Regex": r"'loss':s*([0-9.]+)"},
        {"Title": "epoch", "Regex": r"'epoch':s*([0-9.]+)"},
    ]

    tensorboard_s3_output_path = os.path.be a part of(
       "s3://", sess.default_bucket(), job_name, 'tensorboard'
    )
    print("Tensorboard output path:", tensorboard_s3_output_path)

    tensorboard_output_config = TensorBoardOutputConfig(
        s3_output_path=tensorboard_s3_output_path,
        container_local_output_path=hyperparameters_dict['logging_dir']
    )
    estimator = PyTorch(
        sagemaker_session    = sess,
        entry_point          = train_scipt_path,    # prepare script
        source_dir="prepare",      # listing which incorporates all of the recordsdata wanted for coaching
        instance_type="ml.g5.4xlarge",   # cases kind used for the coaching job, "local_gpu" for native mode
        metric_definitions   = metric,
        instance_count       = 1,                 # the variety of cases used for coaching
        position                 = position,              # Iam position utilized in coaching job to entry AWS ressources, e.g. S3
        volume_size          = 300,               # the scale of the EBS quantity in GB
        framework_version      = '2.1.0',             # the pytorch_version model used within the coaching job
        py_version           = 'py310',           # the python model used within the coaching job
        hyperparameters      =  hyperparameters_dict,  # the hyperparameters handed to the coaching job
        disable_output_compression = True,        # not compress output to avoid wasting coaching time and price
        tensorboard_output_config = tensorboard_output_config
    )
    return estimator
    
# hyperparameters, that are handed into the coaching job
sft_hyperparameters = {
  ### SCRIPT PARAMETERS ###
  'train_dataset_path': '/decide/ml/enter/knowledge/prepare/train_dataset.json', # path the place sagemaker will save coaching dataset
  'eval_dataset_path': '/decide/ml/enter/knowledge/eval/eval_dataset.json', # path the place sagemaker will save analysis dataset
  'model_id': model_id,
  'max_seq_len': 256,                               # max sequence size for mannequin and packing of the dataset
  'use_qlora': True,                                 # use QLoRA mannequin
  ### TRAINING PARAMETERS ###
  'num_train_epochs': 1,                             # variety of coaching epochs
  'per_device_train_batch_size': 1,                  # batch measurement per machine throughout coaching
  'gradient_accumulation_steps': 16,                  # variety of steps earlier than performing a backward/replace move
  'gradient_checkpointing': True,                    # use gradient checkpointing to avoid wasting reminiscence
  'optim': "adamw_8bit",                             # use fused adamw 8bit optimizer
  'logging_steps': 15,                               # log each 10 steps
  'save_strategy': "steps",                          # save checkpoint each epoch
  'save_steps': 15,
  'save_total_limit': 2,
  'eval_strategy': "steps",
  'eval_steps': 15,
  'learning_rate': 1e-4,                             # studying price, primarily based on QLoRA paper
  'bf16': True,                                      # use bfloat16 precision
  'max_grad_norm': 10,                              # max gradient norm primarily based on QLoRA paper
  'warmup_ratio': 0.03,                              # warmup ratio primarily based on QLoRA paper
  'lr_scheduler_type': "fixed",                   # use fixed studying price scheduler
  'output_dir': '/decide/ml/checkpoints/',              # Non permanent output listing for mannequin checkpoints
  'merge_adapters': True,                            # merge LoRA adapters into mannequin for simpler deployment
  'report_to': "tensorboard",                        # report metrics to tensorboard
  'logging_dir': "/decide/ml/output/tensorboard"        # tensorboard logging listing
}
 
sft_job_name = f"sft-qlora-text-to-sql-{get_current_time()}"
knowledge = {
    'prepare': train_dataset_path,
    'eval': eval_dataset_path
}

sft_estimator = create_estimator(sft_hyperparameters, sft_job_name, position, sess, "fine_tune_llm.py")

sft_estimator.match(job_name=sft_job_name, inputs=knowledge, wait=False)

When our coaching job has accomplished efficiently after roughly 1 hour, we will use the fine-tuned mannequin artifact for the subsequent step, coaching the Medusa heads on high of it. To visualise the coaching metrics in Tensorboard, you may observe the steerage on this documentation: Load and visualize output tensors using the TensorBoard application

Practice Medusa heads on high of frozen fine-tuned LLM utilizing a SageMaker AI coaching job

For coaching Medusa heads, we will reuse the features beforehand talked about to launch the coaching job. We chosen hyperparameters primarily based on a mixture of what the Medusa paper reported and what we discovered to be finest performing after just a few experiments. We set the variety of Medusa heads to five and used the 8-bit AdamW optimizer, as beneficial by the paper. For simplicity, we maintained a relentless studying price of 1e-4 with a relentless scheduler, much like the earlier fine-tuning step. Though the paper recommends an elevated studying price and a cosine scheduler, we discovered that our chosen mixture of hyperparameters carried out effectively on this dataset. Nonetheless, we encourage you to experiment with your individual hyperparameter settings to probably obtain even higher outcomes.

# hyperparameters, that are handed into the coaching job
medusa_hyperparameters = {
  ### SCRIPT PARAMETERS ###
  'train_dataset_path': '/decide/ml/enter/knowledge/prepare/train_dataset.json', # path the place sagemaker will save coaching dataset
  'eval_dataset_path': '/decide/ml/enter/knowledge/eval/eval_dataset.json', # path the place sagemaker will save analysis dataset
  'model_path': '/decide/ml/enter/knowledge/fine-tuned-model/',
  'max_seq_len': 256,                               # max sequence size for mannequin and packing of the dataset
  'medusa_num_heads': 5,
  ### TRAINING PARAMETERS ###
  'num_train_epochs': 3,                             # variety of coaching epochs
  'per_device_train_batch_size': 1,                  # batch measurement per machine throughout coaching
  'gradient_accumulation_steps': 16,                  # variety of steps earlier than performing a backward/replace move
  'gradient_checkpointing': True,                    # use gradient checkpointing to avoid wasting reminiscence
  'optim': "adamw_8bit",                             # use fused adamw 8bit optimizer
  'logging_steps': 15,                               # log each 10 steps
  'save_strategy': "steps",                          # save checkpoint each epoch
  'save_steps': 15,
  'save_total_limit':2,
  'eval_strategy': "steps",
  'eval_steps': 15,
  'learning_rate': 1e-4,                             # studying price
  'bf16': True,                                      # use bfloat16 precision
  'max_grad_norm': 10,                              # max gradient norm primarily based on QLoRA paper
  'warmup_ratio': 0.03,                              # warmup ratio primarily based on QLoRA paper
  'lr_scheduler_type': "fixed",                   # use fixed studying price scheduler
  'output_dir': '/decide/ml/checkpoints/',              # Non permanent output listing for mannequin checkpoints
  'report_to': "tensorboard",                        # report metrics to tensorboard
  'logging_dir': "/decide/ml/output/tensorboard"        # tensorboard logging listing
}

medusa_train_job_name = f"medusa-text-to-sql-{get_current_time()}"
knowledge = {
    'prepare': train_dataset_path,
    'eval': eval_dataset_path,
    'fine-tuned-model': fine_tuned_model_path
}

medusa_estimator = create_estimator(medusa_hyperparameters, medusa_train_job_name, position, sess, "train_medusa_heads.py")

medusa_estimator.match(job_name=medusa_train_job_name, inputs=knowledge, wait=False)

We discovered that after 3 epochs, the analysis lack of Medusa heads was converging, which will be noticed within the TensorBoard graph within the following picture.

TensorBoard graph showing the evaluation loss during Medusa heads training

Apart from the hyperparameters, the primary distinction is that we move train_medusa_heads.py because the coaching entrypoint, the place we first add Medusa heads, then freeze the fine-tuned LLM, and we create customized MedusaSFTTrainer class, which is a subclass of the transformers SFTTrainer.

# Add medusa heads and freeze base mannequin
add_medusa_heads(
    mannequin,
    medusa_num_heads=script_args.medusa_num_heads,
)
freeze_layers(mannequin)
mannequin.config.torch_dtype = torch_dtype
mannequin.config.use_cache = False

logger.data("Completed loading mannequin and medusa heads")

tokenizer = AutoTokenizer.from_pretrained(script_args.model_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

################
# Coaching
################
coach = MedusaSFTTrainer(
    mannequin=mannequin,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    max_seq_length=script_args.max_seq_length,
    tokenizer=tokenizer,
    dataset_kwargs={
        "add_special_tokens": False,  # We template with particular tokens
        "append_concat_token": False,  # No want so as to add extra separator token
    },
    medusa_num_heads=script_args.medusa_num_heads,
    medusa_heads_coefficient=script_args.medusa_heads_coefficient,
    medusa_decay_coefficient=script_args.medusa_decay_coefficient,
    medusa_scheduler=script_args.medusa_scheduler,
    train_only_medusa_heads=script_args.train_only_medusa_heads,
    medusa_lr_multiplier=script_args.medusa_lr_multiplier
)
coach.prepare()

Within the add_medusa_heads() operate, we add the residual blocks of the Medusa heads, and in addition override the ahead move for our mannequin to ensure to not prepare the frozen spine LLM:

def add_medusa_heads(
    mannequin,
    medusa_num_heads,
):
    """
    Args:
        mannequin (nn.Module): The bottom language mannequin for use.
        medusa_num_heads (int, non-obligatory): Variety of extra tokens to foretell
    """
    hidden_size = mannequin.lm_head.weight.form[-1]
    vocab_size = mannequin.lm_head.weight.form[0]
    mannequin.config.medusa_num_layers = 1
    mannequin.config.medusa_num_heads = medusa_num_heads
    mannequin.medusa_num_heads = medusa_num_heads
    # Create an inventory of Medusa heads
    mannequin.medusa_heads = nn.ModuleList(
        [
            nn.Sequential(
                ResBlock(hidden_size),
                nn.Linear(hidden_size, vocab_size, bias=False),
            )
            for _ in range(medusa_num_heads)
        ]
    )

    # Guarantee medusa_head's dtype and machine align with the base_model
    mannequin.medusa_heads.to(mannequin.dtype).to(mannequin.machine)
    logger.data(f"Loading medusa heads in {str(mannequin.dtype)} to machine {mannequin.machine}")

    for i in vary(medusa_num_heads):
        # Initialize the weights of every medusa_head utilizing the bottom mannequin's weights
        mannequin.medusa_heads[i][-1].weight.knowledge[:] = mannequin.lm_head.weight.knowledge[:]

    def ahead(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Elective[torch.Tensor] = None,
        position_ids: Elective[torch.LongTensor] = None,
        past_key_values: Elective[List[torch.FloatTensor]] = None,
        inputs_embeds: Elective[torch.FloatTensor] = None,
        labels: Elective[torch.LongTensor] = None,
        use_cache: Elective[bool] = None,
        output_attentions: Elective[bool] = None,
        output_hidden_states: Elective[bool] = None,
        return_dict: Elective[bool] = None,
        train_only_medusa_heads: bool = False,
    ):
        """Ahead move of the MedusaModel.
        Returns:
            torch.Tensor: A tensor containing predictions from all Medusa heads.
            (Elective) Unique predictions from the bottom mannequin's LM head.
        """
        maybe_grad = torch.no_grad() if train_only_medusa_heads else nullcontext()
        with maybe_grad:
            outputs = self.mannequin(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            hidden_states = outputs[0]
            medusa_logits = [self.lm_head(hidden_states)]
        for i in vary(self.medusa_num_heads):
            medusa_logits.append(self.medusa_heads[i](hidden_states))
        return torch.stack(medusa_logits, dim=0)

    mannequin.ahead = sorts.MethodType(ahead, mannequin)

After the mannequin coaching is completed (which takes 1 hour), we put together the mannequin artefacts for deployment and add it to Amazon S3. Your last mannequin artifact accommodates each the unique fine-tuned mannequin from the earlier step underneath the base-model prefix and the educated Medusa heads in a file named medusa_heads.safetensors.

Deploy the fine-tuned LLM with Medusa heads on a SageMaker AI endpoint

The Medusa framework is supported by the Text Generation Inference (TGI) server. After coaching the LLM with Medusa heads, we deploy it to a SageMaker AI real-time endpoint utilizing the Hugging Face Inference Container arrange with TGI.

First, we create a SageMaker AI HuggingFaceModel object after which deploy the mannequin to an endpoint with the next operate:

import json
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri


def deploy_model(endpoint_name, instance_type, model_s3_path=None, hf_model_id=None):
    llm_image = get_huggingface_llm_image_uri(
      "huggingface",
      model="2.2.0",
      session=sess,
    )

    print(f"llm picture uri: {llm_image}")

    model_data = None
    if model_s3_path:
        model_data = {'S3DataSource': {'S3Uri': model_s3_path, 'S3DataType': 'S3Prefix', 'CompressionType': 'None'}}
        hf_model_id = "/decide/ml/mannequin"
    else:
        assert hf_model_id, "You should present both pretrained HF mannequin id, or S3 mannequin knowledge to deploy"
    config = {
      'HF_MODEL_ID': hf_model_id,  # path to the place sagemaker shops the mannequin
      'SM_NUM_GPUS': json.dumps(1),  # Variety of GPU used per reproduction
      'MAX_INPUT_LENGTH': json.dumps(1024),  # Max size of enter textual content
      'MAX_TOTAL_TOKENS': json.dumps(2048),  # Max size of the era (together with enter textual content)
    }

    llm_model = HuggingFaceModel(
      identify=endpoint_name,
      position=position,
      image_uri=llm_image,
      model_data=model_data,
      env=config
    )

    deployed_llm = llm_model.deploy(
      endpoint_name=endpoint_name,
      initial_instance_count=1,
      instance_type=instance_type,
      container_startup_health_check_timeout=300,
    )
    return deployed_llm

We deploy three LLMs on three SageMaker AI endpoints:

  1. Base LLM which isn’t fine-tuned
  2. The LLM that we fine-tuned
  3. The fine-tuned LLM that additionally has educated Medusa heads

You’ll be able to deploy the three fashions in parallel through the use of a operate that we included within the pocket book, or you may deploy the fashions one after the other by operating the code under:

base_deployed_llm = deploy_model( f"base-{get_current_time()}", instance_type="ml.g5.4xlarge", model_s3_path=None, hf_model_id=model_id )
sft_deployed_llm = deploy_model( f"sft-{get_current_time()}", instance_type="ml.g5.4xlarge", model_s3_path=fine_tuned_model_path )
medusa_deployed_llm = deploy_model( f"medusa-{get_current_time()}", instance_type="ml.g5.4xlarge", model_s3_path=medusa_trained_model_path )

After the standing for every endpoint turns into InService, which ought to take round quarter-hour, we will invoke them for inference. We ship the next enter:

“You're a textual content to SQL question translator. Customers will ask you questions in English and you'll generate a SQL question primarily based on the supplied SCHEMA. SCHEMA: CREATE TABLE table_name_32 (time VARCHAR, warmth VARCHAR, identify VARCHAR)
What was Liu Limin's time in warmth 4?“

We are able to observe the next responses:

  1. The bottom LLM response accommodates additional phrases that aren’t wanted:
“To retrieve the time of Liu Limin in warmth 4 primarily based on the supplied SCHEMA, we have to write a SQL question. Because the desk identify will not be supplied, let 's assume it is "my_table".  Assuming the desk has a major key (which isn't talked about within the SCHEMA), we will be a part of the desk with itself to check the warmth and identify of Liu Limin in warmth 4 with all of the data within the desk.  Here is the SQL question:  ``sql SELECT t1.time FROM my_table t1 JOIN my_table t2 ON t1.identify = t2.identify AND t2.warmth = 4 WHERE t1.identify="Liu Limin" AND t1.warmth <> 4; `  Clarification:  1. We're deciding on the time` column from the primary desk (t1) within the FROM clause. 2. We're becoming a member of the desk with itself (my_table t1 JOIN my_table t2) to check the identify and warmth of Liu Limin in warmth 4 with all of the data within the desk. 3. We're filtering the outcomes utilizing the WHERE clause. We're deciding on solely the data the place the identify is 'Liu Limin' and the warmth will not be equal to 4 (i.e., not warmth 4). That is to make sure that we're deciding on the time of Liu Limin in warmth 3.  Word: This question assumes that the desk has a novel major key. If the desk would not have a major key, chances are you'll want so as to add extra situations to the JOIN and WHERE clauses to make sure that we're deciding on the proper data.“

  1. The fine-tuned LLM response is improved considerably, and accommodates solely the required output:
'SELECT time FROM table_name_32 WHERE warmth = 4 AND identify = "liu limin"'

  1. The fine-tuned LLM with educated Medusa heads supplies the very same response because the fine-tuned mannequin, demonstrating that Medusa-1, by design, maintains the output (high quality) of the unique mannequin:
'SELECT time FROM table_name_32 WHERE warmth = 4 AND identify = "liu limin"'

Show LLM inference speedup

To measure the inference velocity enhancements, we evaluate the response instances of the deployed fine-tuned LLM and the fine-tuned LLM with Medusa heads on 450 check observations with the next code:

import time
import numpy as np
from tqdm import tqdm

def request(pattern, deployed_llm):
    immediate = tokenizer.apply_chat_template(pattern, tokenize=False, add_generation_prompt=True)
    outputs = deployed_llm.predict({
      "inputs": immediate,
      "parameters": {
        "max_new_tokens": 512,
        "do_sample": False,
        "return_full_text": False,
      }
    })
    return {"position": "assistant", "content material": outputs[0]["generated_text"].strip()}

def predict(deployed_llm, test_dataset):
    predicted_answers = []
    latencies = []

    for pattern in tqdm(test_dataset):
        start_time = time.time()
        predicted_answer = request(pattern["messages"][:2], deployed_llm)
        end_time = time.time()

        latency = end_time - start_time
        latencies.append(latency)
        predicted_answers.append(predicted_answer)

    # Calculate p90 and common latencies
    p90_latency = np.percentile(latencies, 90)
    avg_latency = np.imply(latencies)

    print(f"P90 Latency: {p90_latency:.2f} seconds")
    print(f"Common Latency: {avg_latency:.2f} seconds")

    return predicted_answers

First, we run predictions utilizing the fine-tuned LLM:

sft_predictions = predict(sft_deployed_llm, test_dataset)
P90 Latency: 1.28 seconds
Common Latency: 0.95 seconds

Then, we run predictions utilizing the fine-tuned LLM with Medusa heads:

medusa_predictions = predict(medusa_deployed_llm, test_dataset)
P90 Latency: 0.80 seconds
Common Latency: 0.53 seconds

The prediction runs ought to take round 8 and 4 minutes respectively. We are able to observe that the common latency decreased from 950 to 530 milliseconds, which is an enchancment of 1.8 instances. You’ll be able to obtain even increased enhancements in case your dataset accommodates longer inputs and outputs. In our dataset, we solely had a median of 18 enter tokens and 30 output tokens.

We need to as soon as once more spotlight that, with this method, the output high quality is absolutely maintained, and all of the prediction outputs are the identical. The mannequin responses for the check set of 450 observations are the identical for each with Medusa heads and with out Medusa heads:

match_percentage = sum(a["content"] == b["content"] for a, b in zip(sft_predictions, medusa_predictions)) / len(sft_predictions) * 100
print(f"Predictions with the fine-tuned mannequin with medusa heads are the identical as with out medusa heads: {match_percentage:.2f}% of check set ")

Predictions with fine-tuned mannequin with medusa heads are the identical as with out medusa heads: 100.00% of check set 

You may discover in your run that just a few observations aren’t precisely matching, and also you may get a 99% match on account of small errors in floating level operations attributable to optimizations on GPUs.

Cleanup

On the finish of this experiment, don’t overlook to delete the SageMaker AI endpoints you created:

base_deployed_llm.delete_model()
base_deployed_llm.delete_endpoint()
sft_deployed_llm.delete_model()
sft_deployed_llm.delete_endpoint()
medusa_deployed_llm.delete_model()
medusa_deployed_llm.delete_endpoint()

Conclusion

On this submit, we demonstrated how you can fine-tune and deploy an LLM with Medusa heads utilizing the Medusa-1 method on Amazon SageMaker AI to speed up LLM inference. By utilizing this framework and SageMaker AI scalable infrastructure, we confirmed how you can obtain as much as twofold speedups in LLM inference whereas sustaining mannequin high quality. This answer is especially helpful for purposes requiring low-latency textual content era, resembling customer support chat assistants, content material creation, and suggestion programs.

As a subsequent step, you may discover fine-tuning your individual LLM with Medusa heads by yourself dataset and benchmark the outcomes to your particular use case, utilizing the supplied GitHub repository.


Concerning the authors

Daniel Zagyva is a Senior ML Engineer at AWS Skilled Providers. He makes a speciality of creating scalable, production-grade machine studying options for AWS clients. His expertise extends throughout totally different areas, together with pure language processing, generative AI and machine studying operations.

Aleksandra Dokic is a Senior Knowledge Scientist at AWS Skilled Providers. She enjoys supporting clients to construct modern AI/ML options on AWS and he or she is happy about enterprise transformations via the ability of information.

Moran Beladev is a Senior ML Supervisor at Reserving.com. She is main the content material intelligence monitor which is targeted on constructing, coaching and deploying content material fashions (pc imaginative and prescient, NLP and generative AI) utilizing essentially the most superior applied sciences and fashions. Moran can be a PhD candidate, researching making use of NLP fashions on social graphs.

Manos Stergiadis is a Senior ML Scientist at Reserving.com. He makes a speciality of generative NLP and has expertise researching, implementing and deploying massive deep studying fashions at scale.

Ilya Gusev is a Senior Machine Studying Engineer at Reserving.com. He leads the event of the a number of LLM programs inside Reserving.com. His work focuses on constructing manufacturing ML programs that assist tens of millions of vacationers plan their journeys successfully.

Laurens van der Maas is a Machine Studying Engineer at AWS Skilled Providers. He works intently with clients constructing their machine studying options on AWS, makes a speciality of pure language processing, experimentation and accountable AI, and is obsessed with utilizing machine studying to drive significant change on the planet.

Leave a Reply

Your email address will not be published. Required fields are marked *