Vision Language Model#

SGLang supports vision language models in the same way as completion models. Here are some example models:

Launch A Server#

The following code is equivalent to running this in the shell:

python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \
 --port=30010 --chat-template=llama_3_vision

Remember to add --chat-template=llama_3_vision to specify the vision chat template, otherwise the server only supports text.

[1]:
from sglang.utils import (
    execute_shell_command,
    wait_for_server,
    terminate_process,
    print_highlight,
)

embedding_process = execute_shell_command(
    """
    python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \
        --port=30010 --chat-template=llama_3_vision

"""
)

wait_for_server("http://localhost:30010")
[2024-11-01 07:55:46] server_args=ServerArgs(model_path='meta-llama/Llama-3.2-11B-Vision-Instruct', tokenizer_path='meta-llama/Llama-3.2-11B-Vision-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.2-11B-Vision-Instruct', chat_template='llama_3_vision', is_embedding=False, host='127.0.0.1', port=30010, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=606882579, constrained_json_whitespace_pattern=None, decode_log_interval=40, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, watchdog_timeout=600, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_flashinfer=False, disable_flashinfer_sampling=False, disable_radix_cache=False, disable_regex_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_penalizer=False, disable_nan_detection=False, enable_overlap_schedule=False, enable_mixed_chunk=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_p2p_check=False, triton_attention_reduce_in_fp32=False, num_continuous_decode_steps=1)
[2024-11-01 07:55:53] Use chat template for the OpenAI-compatible API server: llama_3_vision
[2024-11-01 07:56:03 TP0] Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models.
[2024-11-01 07:56:03 TP0] Init torch distributed begin.
[2024-11-01 07:56:04 TP0] Load weight begin. avail mem=78.59 GB
[2024-11-01 07:56:04 TP0] lm_eval is not installed, GPTQ may not be usable
INFO 11-01 07:56:04 weight_utils.py:243] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:01<00:04,  1.20s/it]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:02<00:03,  1.26s/it]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [00:03<00:02,  1.28s/it]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [00:05<00:01,  1.31s/it]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:05<00:00,  1.07s/it]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:05<00:00,  1.16s/it]

[2024-11-01 07:56:10 TP0] Load weight end. type=MllamaForConditionalGeneration, dtype=torch.bfloat16, avail mem=58.43 GB
[2024-11-01 07:56:10 TP0] Memory pool end. avail mem=11.80 GB
[2024-11-01 07:56:10 TP0] Capture cuda graph begin. This can take up to several minutes.
[2024-11-01 07:56:21 TP0] max_total_num_tokens=298440, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2024-11-01 07:56:22] INFO:     Started server process [1241850]
[2024-11-01 07:56:22] INFO:     Waiting for application startup.
[2024-11-01 07:56:22] INFO:     Application startup complete.
[2024-11-01 07:56:22] INFO:     Uvicorn running on http://127.0.0.1:30010 (Press CTRL+C to quit)
[2024-11-01 07:56:22] INFO:     127.0.0.1:44614 - "GET /v1/models HTTP/1.1" 200 OK
[2024-11-01 07:56:23] INFO:     127.0.0.1:44616 - "GET /get_model_info HTTP/1.1" 200 OK
[2024-11-01 07:56:23 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:56:23] INFO:     127.0.0.1:44620 - "POST /generate HTTP/1.1" 200 OK
[2024-11-01 07:56:23] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.

Use Curl#

[2]:
import subprocess, json, os

curl_command = """
curl http://localhost:30010/v1/chat/completions \
  -H "Content-Type: application/json" \
  -H "Authorization: Bearer None" \
  -d '{
    "model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
    "messages": [
      {
        "role": "user",
        "content": [
          {
            "type": "text",
            "text": "What’s in this image?"
          },
          {
            "type": "image_url",
            "image_url": {
              "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
            }
          }
        ]
      }
    ],
    "max_tokens": 300
  }'
"""

response = json.loads(subprocess.check_output(curl_command, shell=True))
print_highlight(response)
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   559    0     0  100   559      0    174  0:00:03  0:00:03 --:--:--   174
/actions-runner/_work/_tool/Python/3.9.20/x64/lib/python3.9/site-packages/torch/storage.py:414: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  return torch.load(io.BytesIO(b))
[2024-11-01 07:56:31 TP0] Prefill batch. #new-seq: 1, #new-token: 6463, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
100   559    0     0  100   559      0    132  0:00:04  0:00:04 --:--:--   132
[2024-11-01 07:56:32 TP0] Decode batch. #running-req: 1, #token: 6496, token usage: 0.02, gen throughput (token/s): 3.66, #queue-req: 0
[2024-11-01 07:56:32 TP0] Decode batch. #running-req: 1, #token: 6536, token usage: 0.02, gen throughput (token/s): 99.00, #queue-req: 0
100   559    0     0  100   559      0    107  0:00:05  0:00:05 --:--:--   107
[2024-11-01 07:56:33] INFO:     127.0.0.1:44624 - "POST /v1/chat/completions HTTP/1.1" 200 OK
100  1434  100   875  100   559    160    102  0:00:05  0:00:05 --:--:--   269
{'id': 'f335ea7b994a4692b62883941d2932e4', 'object': 'chat.completion', 'created': 1730447793, 'model': 'meta-llama/Llama-3.2-11B-Vision-Instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The image depicts a serene and peaceful landscape with a wooden boardwalk leading through a field of tall grass and trees. The boardwalk is made of light-colored wood and is surrounded by lush green grass on either side. In the background, there are trees and bushes that add to the natural beauty of the scene. The sky above is blue with some clouds, creating a sense of depth and atmosphere. The overall effect is one of tranquility and calmness, inviting the viewer to step into the idyllic setting.'}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 128009}], 'usage': {'prompt_tokens': 6463, 'total_tokens': 6567, 'completion_tokens': 104, 'prompt_tokens_details': None}}

Using OpenAI Compatible API#

[3]:
import base64, requests
from openai import OpenAI

client = OpenAI(base_url="http://localhost:30010/v1", api_key="None")


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def download_image(image_url, image_path):
    response = requests.get(image_url)
    response.raise_for_status()
    with open(image_path, "wb") as f:
        f.write(response.content)


image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
image_path = "boardwalk.jpeg"
download_image(image_url, image_path)

base64_image = encode_image(image_path)

response = client.chat.completions.create(
    model="meta-llama/Llama-3.2-11B-Vision-Instruct",
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "What is in this image?",
                },
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
                },
            ],
        }
    ],
    max_tokens=300,
)

print_highlight(response.choices[0].message.content)
[2024-11-01 07:56:39 TP0] Prefill batch. #new-seq: 1, #new-token: 6463, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:56:39 TP0] Decode batch. #running-req: 1, #token: 6473, token usage: 0.02, gen throughput (token/s): 6.15, #queue-req: 0
[2024-11-01 07:56:39] INFO:     127.0.0.1:47312 - "POST /v1/chat/completions HTTP/1.1" 200 OK
This image depicts a serene landscape with a boardwalk meandering through a lush grassy field, surrounded by tall green grass and trees in the distance.

Multiple Images Input#

[4]:
from openai import OpenAI

client = OpenAI(base_url="http://localhost:30010/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Llama-3.2-11B-Vision-Instruct",
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Are there any differences between these two images?",
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
                    },
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
                    },
                },
            ],
        }
    ],
    max_tokens=300,
)
print(response.choices[0])
[2024-11-01 07:56:40 TP0] Prefill batch. #new-seq: 1, #new-token: 12871, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:56:40 TP0] Decode batch. #running-req: 1, #token: 12891, token usage: 0.04, gen throughput (token/s): 24.88, #queue-req: 0
[2024-11-01 07:56:41 TP0] Decode batch. #running-req: 1, #token: 12931, token usage: 0.04, gen throughput (token/s): 104.23, #queue-req: 0
[2024-11-01 07:56:41] INFO:     127.0.0.1:47314 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='The two images depict a serene and peaceful natural setting, with the first featuring a serene lake and the second, a boardwalk in a field. Both images share a sense of tranquility and calmness, inviting the viewer to step into the natural world. The lake image evokes feelings of relaxation and tranquility, while the boardwalk image suggests a sense of adventure and exploration. Both images evoke a sense of connection to nature and the great outdoors.', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None), matched_stop=128009)
[5]:
terminate_process(embedding_process)
os.remove(image_path)

Chat Template#

As mentioned before, if you do not specify a vision model’s chat-template, the server uses Hugging Face’s default template, which only supports text.

You can add your custom chat template by referring to the custom chat template.

We list popular vision models with their chat templates: