Optimized for Dark Mode

The website is optimized for dark mode to enhance your user experience. Switch to dark mode to enjoy it.

Faster LLMs with Quantization - How to get faster inference times with quantization

Comparison of different quantization techniques and their impact on inference time. And how to optimize inference time with quantizations using vllm.


Table of Contents

Exploring the inference speed of quantized models.
Created with my custom Flux LoRA on Replicate

tl;dr - What is the Post About?

  • Quantization significantly improves inference speed by reducing precision while maintaining most of the model's performance.
  • A comparison of AWQ, GPTQ, and BF16 quantization techniques using Meta's Llama 3.1 8B model and the ifeval benchmark shows that AWQ and GPTQ have nearly identical throughput, processing ~3x more requests per second than the full-precision BF16 model.
  • Optimizing GPTQ with vllm settings (e.g., enabling chunked prefill and increasing max sequences) resulted in 2-3x faster token throughput compared to the default configuration on the vllm benchmark troughput. These optimizations can also be applied to AWQ models.
  • Swapping of requests due to exceeding the kv-cache leads to massive performance drops.
  • Higher num sequences maximizes throughput and GPU utilization but increases latency, making it unsuitable for realtime or chat use cases.
  • Consumer GPUs like RTX 3090s are heavily memory-bound due to slower GDDR VRAM bandwidth (<1 TB/s). In contrast, HBM-based GPUs (e.g., A100, H100) have a higher bandwidth (~1.9 TB/s).

About the Series

This is the second part of the small local LLMs-Series created for the talk "Energy Efficiency in AI: Use of Quantized Language Models" at the CNCF Sustainability Week Stuttgart 2024 event on the 10. October 2024. A big thank you to Red Hat for hosting such a great event and giving me the opportunity to speak on this topic.

On the 26th of February 2025 the second Cloud Native Stuttgart will take place. You can sign up for the event here.

Introduction

In this part of the series, we will focus on the inference speed of quantized models, comparing various quantization techniques. For the comparison, I will use the vllm library with Meta's Llama 3.1 8B model. The evaluation will include BF16 precision, as well as AWQ and GPTQ quantization methods. While the focus of inference optimization will be on the GPTQ variant, similar techniques can also be applied to AWQ quantization.

All inference tests will be run on my setup, which includes two RTX 3090 GPUs. However, only one GPU will be used for the tests.

Benchmark on the Ifeval inference times

As part of the first series of the talk, I have already shown the inference accuracy for the Llama 3.1 8B model. Therefore I used the Ifeval benchmark.

Comparable to before I served the vllm openAI endpoint with the following commands:

# For the standard model (setting the hf_token necessary)
vllm serve meta-llama/Llama-3.1-8B --max-model-len 8096 --gpu-memory-utilization 0.8
# For the AWQ quantized model
vllm serve hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 --max-model-len 8096 --gpu-memory-utilization 0.8 -- quantization awq
# for the GPTQ quantized model
vllm serve hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4 --max-model-len 8096 --gpu-memory-utilization 0.8 --quantization gptq

After that I ran the Ifeval benchmark with the lm_eval library.

lm_eval --model local-chat-completions \
        --tasks ifeval \
        --model_args model="<model_name>",base_url=http://localhost:8000/v1/chat/completions,num_concurrent=32,max_retries=3,tokenized_requests=False  \
        --apply_chat_template \
        --fewshot_as_multiturn \
        --output_path="./llama3.1_<quantization>.jsonl"

More Information about the Ifeval benchmark can be found here. I limited the maximum number of requests to 32. This number was chosen somewhat randomly and hasn't been optimized. Without this limit, the benchmark fails because the OpenAI-compatible server may drop requests. While 32 requests is on the lower side, the total throughput for the quantized models could be higher since there is more space available for the KV cache.

We can examine the requests per second, as shown in the following diagram:

Loading...

Overall, the inference times for the AWQ and GPTQ models are very similar. In comparison, the full-precision model handles only one-third of the requests per second achieved by the quantized models.

Benchmark Serving with vllm

Another way to test throughput is by using the benchmark_serving.py script from the vllm project. This script utilizes the ShareGPT dataset, which consists of real-world chat samples. These samples vary significantly and are available on Hugging Face. The benchmarking script automatically loads and filters the dataset. You can view the filtering process here.

benchmarks/benchmark_serving.py#L89
def sample_sharegpt_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int, None]]:
    # Load the dataset.
    with open(dataset_path, encoding='utf-8') as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Only keep the first two turns of each conversation.
    dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
 
    # Shuffle the dataset.
    random.shuffle(dataset)
 
    # Filter out sequences that are too long or too short
    filtered_dataset: List[Tuple[str, int, int]] = []
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break
 
        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
        prompt_len = len(prompt_token_ids)
        output_len = len(completion_token_ids
                         ) if fixed_output_len is None else fixed_output_len
        if prompt_len < 4 or (fixed_output_len is None and output_len < 4):
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len, None))
 
    return filtered_dataset

The dataset is filtered to include only conversations with at least two turns, retaining just the first two turns of each conversation. Sequences that are too short (fewer than 4 tokens) or too long (a prompt exceeding 1024 tokens or a combined prompt and completion exceeding 2048 tokens) are removed. While this setup is closer to real-world scenarios, the chats are likely shorter than typical real-world conversations.

With this data we are getting the following results:

Loading...

In the optimized version, I increased GPU utilization by enabling chunked prefill with --enable-chunked-prefill, and set --max-num-seqs 380. This resulted in approximately 90% KV cache usage and improved throughput. The request throughput for the GPTQ model increased by 2x.

Loading...

Conclusion

Quantization can significantly speed up inference times, with the AWQ and GPTQ techniques delivering similar speeds. The GPTQ model was also further optimized with the vllm library, resulting in a 2-3x faster inference time. These optimization settings can be applied to AWQ quantization as well.

This optimization is focused on maximizing throughput and is not suitable for real-time applications, as the latency to the first token will be too high. It is a tradeoff between latency and throughput. The improved inference speed is achieved through greater use of the KV cache and higher GPU utilization. Since GDDR-based VRAM is slower than HBM-based VRAM, increasing utilization and batch size can lead to higher throughput. For reference, consumer GPUs like the RTX 3090 have a limited bandwidth of less than 1 TB/s, whereas HBM-based GPUs such as the A100 or H100 offer bandwidths of approximately 1.9 TB/s. This makes consumer GPUs particularly memory-bound.

Additionally, quantizing the KV cache reduces VRAM usage, but it led to a decrease in throughput on the RTX 3090. More modern GPUs with FP8 support could potentially improve throughput. In other use cases inference with SGLang was significantly faster, particularly when KV cache hits occurred due to requests having similar starting text.