Quick Start: Sending Requests#
This notebook provides a quick-start guide to use SGLang in chat completions after installation.
For Vision Language Models, see OpenAI APIs - Vision.
For Embedding Models, see OpenAI APIs - Embedding and Encode (embedding model).
For Reward Models, see Classify (reward model).
Launch A Server#
This code block is equivalent to executing
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 30000 --host 0.0.0.0
in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the OpenAI-compatible APIs.
[1]:
from sglang.utils import (
execute_shell_command,
wait_for_server,
terminate_process,
print_highlight,
)
server_process = execute_shell_command(
"""
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 30000 --host 0.0.0.0
"""
)
wait_for_server("http://localhost:30000")
[2024-11-18 02:55:14] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, host='0.0.0.0', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=341598201, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_penalizer=False, enable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
[2024-11-18 02:55:29 TP0] Init torch distributed begin.
[2024-11-18 02:55:30 TP0] Load weight begin. avail mem=78.59 GB
[2024-11-18 02:55:30 TP0] lm_eval is not installed, GPTQ may not be usable
[2024-11-18 02:55:30 TP0] Ignore import error when loading sglang.srt.models.phi3_small. cannot import name 'maybe_prefix' from 'vllm.model_executor.models.utils' (/actions-runner/_work/_tool/Python/3.9.20/x64/lib/python3.9/site-packages/vllm/model_executor/models/utils.py)
INFO 11-18 02:55:31 weight_utils.py:243] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:02, 1.25it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.12it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.07it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.45it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.31it/s]
[2024-11-18 02:55:34 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=63.50 GB
[2024-11-18 02:55:34 TP0] Memory pool end. avail mem=8.37 GB
[2024-11-18 02:55:34 TP0] Capture cuda graph begin. This can take up to several minutes.
[2024-11-18 02:55:42 TP0] max_total_num_tokens=442913, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2024-11-18 02:55:42] INFO: Started server process [1263277]
[2024-11-18 02:55:42] INFO: Waiting for application startup.
[2024-11-18 02:55:42] INFO: Application startup complete.
[2024-11-18 02:55:42] INFO: Uvicorn running on http://0.0.0.0:30000 (Press CTRL+C to quit)
[2024-11-18 02:55:43] INFO: 127.0.0.1:46714 - "GET /v1/models HTTP/1.1" 200 OK
[2024-11-18 02:55:43] INFO: 127.0.0.1:46716 - "GET /get_model_info HTTP/1.1" 200 OK
[2024-11-18 02:55:43 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-18 02:55:43] INFO: 127.0.0.1:46718 - "POST /generate HTTP/1.1" 200 OK
[2024-11-18 02:55:43] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
Using cURL#
[2]:
import subprocess, json
curl_command = """
curl -s http://localhost:30000/v1/chat/completions \
-d '{"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "messages": [{"role": "user", "content": "What is the capital of France?"}]}'
"""
response = json.loads(subprocess.check_output(curl_command, shell=True))
print_highlight(response)
[2024-11-18 02:55:48 TP0] Prefill batch. #new-seq: 1, #new-token: 41, #cached-token: 1, cache hit rate: 2.04%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-18 02:55:48] INFO: 127.0.0.1:46730 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{'id': '5470683742c242b1943846a4fabcbb95', 'object': 'chat.completion', 'created': 1731898548, 'model': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The capital of France is Paris.'}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 128009}], 'usage': {'prompt_tokens': 42, 'total_tokens': 50, 'completion_tokens': 8, 'prompt_tokens_details': None}}
Using Python Requests#
[3]:
import requests
url = "http://localhost:30000/v1/chat/completions"
data = {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
}
response = requests.post(url, json=data)
print_highlight(response.json())
[2024-11-18 02:55:48 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 41, cache hit rate: 46.15%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-18 02:55:48] INFO: 127.0.0.1:46738 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{'id': 'fb6978fc5b3e471082a992d54fdfc13e', 'object': 'chat.completion', 'created': 1731898548, 'model': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The capital of France is Paris.'}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 128009}], 'usage': {'prompt_tokens': 42, 'total_tokens': 50, 'completion_tokens': 8, 'prompt_tokens_details': None}}
Using OpenAI Python Client#
[4]:
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(response)
[2024-11-18 02:55:49 TP0] Prefill batch. #new-seq: 1, #new-token: 13, #cached-token: 30, cache hit rate: 53.73%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-18 02:55:49 TP0] Decode batch. #running-req: 1, #token: 62, token usage: 0.00, gen throughput (token/s): 6.02, #queue-req: 0
[2024-11-18 02:55:49] INFO: 127.0.0.1:46754 - "POST /v1/chat/completions HTTP/1.1" 200 OK
ChatCompletion(id='d4d0585e094447a0ae8332611ed58f58', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. Country: Japan\n Capital: Tokyo\n\n2. Country: Australia\n Capital: Canberra\n\n3. Country: Brazil\n Capital: Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None), matched_stop=128009)], created=1731898549, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=43, prompt_tokens=43, total_tokens=86, completion_tokens_details=None, prompt_tokens_details=None))
Streaming#
[5]:
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
# Use stream=True for streaming responses
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
stream=True,
)
# Handle the streaming output
for chunk in response:
if chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content, end="", flush=True)
[2024-11-18 02:55:49] INFO: 127.0.0.1:46756 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2024-11-18 02:55:49 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 42, cache hit rate: 64.41%, token usage: 0.00, #running-req: 0, #queue-req: 0
Here are 3 countries and their capitals:
1. Country: Japan
Capital[2024-11-18 02:55:49 TP0] Decode batch. #running-req: 1, #token: 60, token usage: 0.00, gen throughput (token/s): 110.51, #queue-req: 0
: Tokyo
2. Country: Australia
Capital: Canberra
3. Country: Brazil
Capital: Brasília
Using Native Generation APIs#
You can also use the native /generate
endpoint with requests, which provides more flexiblity. An API reference is available at Sampling Parameters.
[6]:
import requests
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print_highlight(response.json())
[2024-11-18 02:55:49 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 3, cache hit rate: 63.93%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-18 02:55:49 TP0] Decode batch. #running-req: 1, #token: 21, token usage: 0.00, gen throughput (token/s): 132.55, #queue-req: 0
[2024-11-18 02:55:49] INFO: 127.0.0.1:46768 - "POST /generate HTTP/1.1" 200 OK
{'text': ' a city of romance, art, fashion, and history. Paris is a must-visit destination for anyone who loves culture, architecture, and cuisine. From the', 'meta_info': {'prompt_tokens': 6, 'completion_tokens': 32, 'completion_tokens_wo_jump_forward': 32, 'cached_tokens': 3, 'finish_reason': {'type': 'length', 'length': 32}, 'id': '6d1e347fd3cc4f9c8675dcf8a8a5bf42'}}
Streaming#
[7]:
import requests, json
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"]
print(output[prev:], end="", flush=True)
prev = len(output)
[2024-11-18 02:55:49] INFO: 127.0.0.1:46786 - "POST /generate HTTP/1.1" 200 OK
[2024-11-18 02:55:49 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 5, cache hit rate: 64.55%, token usage: 0.00, #running-req: 0, #queue-req: 0
a city of romance, art, fashion, and history. Paris is a must-visit destination for anyone who loves culture[2024-11-18 02:55:50 TP0] Decode batch. #running-req: 1, #token: 30, token usage: 0.00, gen throughput (token/s): 131.41, #queue-req: 0
, architecture, and cuisine. From the
[8]:
terminate_process(server_process)