Native APIs#

Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:

  • /generate (text generation model)

  • /get_model_info

  • /get_server_info

  • /health

  • /health_generate

  • /flush_cache

  • /update_weights

  • /encode(embedding model)

  • /classify(reward model)

We mainly use requests to test these APIs in the following examples. You can also use curl.

Launch A Server#

[1]:
from sglang.utils import (
    execute_shell_command,
    wait_for_server,
    terminate_process,
    print_highlight,
)

import requests

server_process = execute_shell_command(
    """
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010
"""
)

wait_for_server("http://localhost:30010")
[2024-12-04 19:26:11] server_args=ServerArgs(model_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.2-1B-Instruct', chat_template=None, is_embedding=False, revision=None, host='127.0.0.1', port=30010, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, tp_size=1, stream_interval=1, random_seed=816517830, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
[2024-12-04 19:26:26 TP0] Init torch distributed begin.
[2024-12-04 19:26:26 TP0] Load weight begin. avail mem=78.59 GB
[2024-12-04 19:26:27 TP0] lm_eval is not installed, GPTQ may not be usable
[2024-12-04 19:26:27 TP0] Using model weights format ['*.safetensors']
[2024-12-04 19:26:27 TP0] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.94it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.94it/s]

[2024-12-04 19:26:28 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=76.17 GB
[2024-12-04 19:26:28 TP0] Memory pool end. avail mem=7.42 GB
[2024-12-04 19:26:28 TP0] Capture cuda graph begin. This can take up to several minutes.
[2024-12-04 19:26:35 TP0] Capture cuda graph end. Time elapsed: 7.12 s
[2024-12-04 19:26:35 TP0] max_total_num_tokens=2186821, max_prefill_tokens=16384, max_running_requests=4097, context_len=131072
[2024-12-04 19:26:35] INFO:     Started server process [801041]
[2024-12-04 19:26:35] INFO:     Waiting for application startup.
[2024-12-04 19:26:35] INFO:     Application startup complete.
[2024-12-04 19:26:35] INFO:     Uvicorn running on http://127.0.0.1:30010 (Press CTRL+C to quit)
[2024-12-04 19:26:36] INFO:     127.0.0.1:37814 - "GET /v1/models HTTP/1.1" 200 OK
[2024-12-04 19:26:36] INFO:     127.0.0.1:37824 - "GET /get_model_info HTTP/1.1" 200 OK
[2024-12-04 19:26:36 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-12-04 19:26:37] INFO:     127.0.0.1:37840 - "POST /generate HTTP/1.1" 200 OK
[2024-12-04 19:26:37] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.

Generate (text generation model)#

Generate completions. This is similar to the /v1/completions in OpenAI API. Detailed parameters can be found in the sampling parameters.

[2]:
url = "http://localhost:30010/generate"
data = {"text": "What is the capital of France?"}

response = requests.post(url, json=data)
print_highlight(response.json())
[2024-12-04 19:26:41 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 1, cache hit rate: 6.67%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-12-04 19:26:41 TP0] Decode batch. #running-req: 1, #token: 41, token usage: 0.00, gen throughput (token/s): 7.41, #queue-req: 0
[2024-12-04 19:26:41] INFO:     127.0.0.1:37856 - "POST /generate HTTP/1.1" 200 OK
{'text': ' Paris\nThe capital of France is indeed Paris, a city known for its rich history, art, fashion, and culture. Would you like to learn more about Paris or an alternative city in France?', 'meta_info': {'prompt_tokens': 8, 'completion_tokens': 41, 'completion_tokens_wo_jump_forward': 41, 'cached_tokens': 1, 'finish_reason': {'type': 'stop', 'matched': 128009}, 'id': '66cfb6515350409eb77f0a5a17e4f5dd'}}

Get Model Info#

Get the information of the model.

  • model_path: The path/name of the model.

  • is_generation: Whether the model is used as generation model or embedding model.

  • tokenizer_path: The path/name of the tokenizer.

[3]:
url = "http://localhost:30010/get_model_info"

response = requests.get(url)
response_json = response.json()
print_highlight(response_json)
assert response_json["model_path"] == "meta-llama/Llama-3.2-1B-Instruct"
assert response_json["is_generation"] is True
assert response_json["tokenizer_path"] == "meta-llama/Llama-3.2-1B-Instruct"
assert response_json.keys() == {"model_path", "is_generation", "tokenizer_path"}
[2024-12-04 19:26:41] INFO:     127.0.0.1:37870 - "GET /get_model_info HTTP/1.1" 200 OK
{'model_path': 'meta-llama/Llama-3.2-1B-Instruct', 'tokenizer_path': 'meta-llama/Llama-3.2-1B-Instruct', 'is_generation': True}

Get Server Info#

Gets the server information including CLI arguments, token limits, and memory pool sizes. - Note: get_server_info merges the following deprecated endpoints: - get_server_args - get_memory_pool_size - get_max_total_num_tokens

[4]:
# get_server_info

url = "http://localhost:30010/get_server_info"

response = requests.get(url)
print_highlight(response.text)
[2024-12-04 19:26:41] INFO:     127.0.0.1:37876 - "GET /get_server_info HTTP/1.1" 200 OK
{"model_path":"meta-llama/Llama-3.2-1B-Instruct","tokenizer_path":"meta-llama/Llama-3.2-1B-Instruct","tokenizer_mode":"auto","skip_tokenizer_init":false,"load_format":"auto","trust_remote_code":false,"dtype":"auto","kv_cache_dtype":"auto","quantization":null,"context_length":null,"device":"cuda","served_model_name":"meta-llama/Llama-3.2-1B-Instruct","chat_template":null,"is_embedding":false,"revision":null,"host":"127.0.0.1","port":30010,"mem_fraction_static":0.88,"max_running_requests":null,"max_total_tokens":null,"chunked_prefill_size":8192,"max_prefill_tokens":16384,"schedule_policy":"lpm","schedule_conservativeness":1.0,"cpu_offload_gb":0,"tp_size":1,"stream_interval":1,"random_seed":816517830,"constrained_json_whitespace_pattern":null,"watchdog_timeout":300,"download_dir":null,"base_gpu_id":0,"log_level":"info","log_level_http":null,"log_requests":false,"show_time_cost":false,"enable_metrics":false,"decode_log_interval":40,"api_key":null,"file_storage_pth":"SGLang_storage","enable_cache_report":false,"dp_size":1,"load_balance_method":"round_robin","dist_init_addr":null,"nnodes":1,"node_rank":0,"json_model_override_args":"{}","enable_double_sparsity":false,"ds_channel_config_path":null,"ds_heavy_channel_num":32,"ds_heavy_token_num":256,"ds_heavy_channel_type":"qk","ds_sparse_decode_threshold":4096,"lora_paths":null,"max_loras_per_batch":8,"attention_backend":"flashinfer","sampling_backend":"flashinfer","grammar_backend":"outlines","disable_radix_cache":false,"disable_jump_forward":false,"disable_cuda_graph":false,"disable_cuda_graph_padding":false,"disable_outlines_disk_cache":false,"disable_custom_all_reduce":false,"disable_mla":false,"disable_overlap_schedule":false,"enable_mixed_chunk":false,"enable_dp_attention":false,"enable_torch_compile":false,"torch_compile_max_bs":32,"cuda_graph_max_bs":160,"torchao_config":"","enable_nan_detection":false,"enable_p2p_check":false,"triton_attention_reduce_in_fp32":false,"num_continuous_decode_steps":1,"delete_ckpt_after_loading":false,"status":"ready","max_total_num_tokens":2186821,"version":"0.4.0"}

Health Check#

  • /health: Check the health of the server.

  • /health_generate: Check the health of the server by generating one token.

[5]:
url = "http://localhost:30010/health_generate"

response = requests.get(url)
print_highlight(response.text)
[2024-12-04 19:26:41 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 0, cache hit rate: 6.25%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-12-04 19:26:41] INFO:     127.0.0.1:37878 - "GET /health_generate HTTP/1.1" 200 OK
[6]:
url = "http://localhost:30010/health"

response = requests.get(url)
print_highlight(response.text)
[2024-12-04 19:26:41] INFO:     127.0.0.1:37884 - "GET /health HTTP/1.1" 200 OK

Flush Cache#

Flush the radix cache. It will be automatically triggered when the model weights are updated by the /update_weights API.

[7]:
# flush cache

url = "http://localhost:30010/flush_cache"

response = requests.post(url)
print_highlight(response.text)
[2024-12-04 19:26:41] INFO:     127.0.0.1:37892 - "POST /flush_cache HTTP/1.1" 200 OK
Cache flushed.
Please check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)
[2024-12-04 19:26:41 TP0] Cache flushed successfully!

Update Weights From Disk#

Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.

SGLang support update_weights_from_disk API for continuous evaluation during training (save checkpoint to disk and update weights from disk).

[8]:
# successful update with same architecture and size

url = "http://localhost:30010/update_weights_from_disk"
data = {"model_path": "meta-llama/Llama-3.2-1B"}

response = requests.post(url, json=data)
print_highlight(response.text)
assert response.json()["success"] is True
assert response.json()["message"] == "Succeeded to update model weights."
assert response.json().keys() == {"success", "message"}
[2024-12-04 19:26:41 TP0] Update engine weights online from disk begin. avail mem=4.87 GB
[2024-12-04 19:26:41 TP0] Using model weights format ['*.safetensors']
[2024-12-04 19:26:41 TP0] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.30it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.30it/s]

[2024-12-04 19:26:41 TP0] Update weights end.
[2024-12-04 19:26:41 TP0] Cache flushed successfully!
[2024-12-04 19:26:41] INFO:     127.0.0.1:37896 - "POST /update_weights_from_disk HTTP/1.1" 200 OK
{"success":true,"message":"Succeeded to update model weights."}
[9]:
# failed update with different parameter size or wrong name

url = "http://localhost:30010/update_weights_from_disk"
data = {"model_path": "meta-llama/Llama-3.2-1B-wrong"}

response = requests.post(url, json=data)
response_json = response.json()
print_highlight(response_json)
assert response_json["success"] is False
assert response_json["message"] == (
    "Failed to get weights iterator: "
    "meta-llama/Llama-3.2-1B-wrong"
    " (repository not found)."
)
[2024-12-04 19:26:41 TP0] Update engine weights online from disk begin. avail mem=4.87 GB
[2024-12-04 19:26:42 TP0] Failed to get weights iterator: meta-llama/Llama-3.2-1B-wrong (repository not found).
[2024-12-04 19:26:42] INFO:     127.0.0.1:37912 - "POST /update_weights_from_disk HTTP/1.1" 400 Bad Request
{'success': False, 'message': 'Failed to get weights iterator: meta-llama/Llama-3.2-1B-wrong (repository not found).'}

Encode (embedding model)#

Encode text into embeddings. Note that this API is only available for embedding models and will raise an error for generation models. Therefore, we launch a new server to server an embedding model.

[10]:
terminate_process(server_process)

embedding_process = execute_shell_command(
    """
python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \
    --port 30020 --host 0.0.0.0 --is-embedding
"""
)

wait_for_server("http://localhost:30020")
[2024-12-04 19:26:54] server_args=ServerArgs(model_path='Alibaba-NLP/gte-Qwen2-7B-instruct', tokenizer_path='Alibaba-NLP/gte-Qwen2-7B-instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='Alibaba-NLP/gte-Qwen2-7B-instruct', chat_template=None, is_embedding=True, revision=None, host='0.0.0.0', port=30020, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, tp_size=1, stream_interval=1, random_seed=1067949106, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
[2024-12-04 19:26:59] Downcasting torch.float32 to torch.float16.
[2024-12-04 19:27:09 TP0] Downcasting torch.float32 to torch.float16.
[2024-12-04 19:27:10 TP0] Overlap scheduler is disabled for embedding models.
[2024-12-04 19:27:10 TP0] Downcasting torch.float32 to torch.float16.
[2024-12-04 19:27:10 TP0] Init torch distributed begin.
[2024-12-04 19:27:10 TP0] Load weight begin. avail mem=78.59 GB
[2024-12-04 19:27:10 TP0] lm_eval is not installed, GPTQ may not be usable
[2024-12-04 19:27:11 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  14% Completed | 1/7 [00:01<00:10,  1.70s/it]
Loading safetensors checkpoint shards:  29% Completed | 2/7 [00:03<00:08,  1.71s/it]
Loading safetensors checkpoint shards:  43% Completed | 3/7 [00:05<00:07,  1.80s/it]
Loading safetensors checkpoint shards:  57% Completed | 4/7 [00:07<00:05,  1.94s/it]
Loading safetensors checkpoint shards:  71% Completed | 5/7 [00:09<00:04,  2.05s/it]
Loading safetensors checkpoint shards:  86% Completed | 6/7 [00:11<00:01,  1.92s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:12<00:00,  1.64s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:12<00:00,  1.78s/it]

[2024-12-04 19:27:23 TP0] Load weight end. type=Qwen2ForCausalLM, dtype=torch.float16, avail mem=64.18 GB
[2024-12-04 19:27:23 TP0] Memory pool end. avail mem=7.43 GB
[2024-12-04 19:27:24 TP0] max_total_num_tokens=1025173, max_prefill_tokens=16384, max_running_requests=4005, context_len=131072
[2024-12-04 19:27:24] INFO:     Started server process [801879]
[2024-12-04 19:27:24] INFO:     Waiting for application startup.
[2024-12-04 19:27:24] INFO:     Application startup complete.
[2024-12-04 19:27:24] INFO:     Uvicorn running on http://0.0.0.0:30020 (Press CTRL+C to quit)
[2024-12-04 19:27:25] INFO:     127.0.0.1:45768 - "GET /v1/models HTTP/1.1" 200 OK
[2024-12-04 19:27:25] INFO:     127.0.0.1:45774 - "GET /get_model_info HTTP/1.1" 200 OK
[2024-12-04 19:27:25 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-12-04 19:27:26] INFO:     127.0.0.1:45784 - "POST /encode HTTP/1.1" 200 OK
[2024-12-04 19:27:26] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
[11]:
# successful encode for embedding model

url = "http://localhost:30020/encode"
data = {"model": "Alibaba-NLP/gte-Qwen2-7B-instruct", "text": "Once upon a time"}

response = requests.post(url, json=data)
response_json = response.json()
print_highlight(f"Text embedding (first 10): {response_json['embedding'][:10]}")
[2024-12-04 19:27:30 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-12-04 19:27:30] INFO:     127.0.0.1:45798 - "POST /encode HTTP/1.1" 200 OK
Text embedding (first 10): [0.00830841064453125, 0.0006804466247558594, -0.00807952880859375, -0.000682830810546875, 0.01438140869140625, -0.009002685546875, 0.01239013671875, 0.0020999908447265625, 0.006214141845703125, -0.0030345916748046875]

Classify (reward model)#

SGLang Runtime also supports reward models. Here we use a reward model to classify the quality of pairwise generations.

[12]:
terminate_process(embedding_process)

# Note that SGLang now treats embedding models and reward models as the same type of models.
# This will be updated in the future.

reward_process = execute_shell_command(
    """
python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --port 30030 --host 0.0.0.0 --is-embedding
"""
)

wait_for_server("http://localhost:30030")
[2024-12-04 19:27:40] server_args=ServerArgs(model_path='Skywork/Skywork-Reward-Llama-3.1-8B-v0.2', tokenizer_path='Skywork/Skywork-Reward-Llama-3.1-8B-v0.2', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='Skywork/Skywork-Reward-Llama-3.1-8B-v0.2', chat_template=None, is_embedding=True, revision=None, host='0.0.0.0', port=30030, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, tp_size=1, stream_interval=1, random_seed=909662746, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
[2024-12-04 19:27:56 TP0] Overlap scheduler is disabled for embedding models.
[2024-12-04 19:27:56 TP0] Init torch distributed begin.
[2024-12-04 19:27:56 TP0] Load weight begin. avail mem=78.59 GB
[2024-12-04 19:27:56 TP0] lm_eval is not installed, GPTQ may not be usable
[2024-12-04 19:27:57 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.16it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.06it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.08it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.64it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.38it/s]

[2024-12-04 19:28:00 TP0] Load weight end. type=LlamaForSequenceClassification, dtype=torch.bfloat16, avail mem=64.48 GB
[2024-12-04 19:28:00 TP0] Memory pool end. avail mem=8.35 GB
[2024-12-04 19:28:00 TP0] max_total_num_tokens=450929, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2024-12-04 19:28:00] INFO:     Started server process [802709]
[2024-12-04 19:28:00] INFO:     Waiting for application startup.
[2024-12-04 19:28:00] INFO:     Application startup complete.
[2024-12-04 19:28:00] INFO:     Uvicorn running on http://0.0.0.0:30030 (Press CTRL+C to quit)
[2024-12-04 19:28:01] INFO:     127.0.0.1:47736 - "GET /v1/models HTTP/1.1" 200 OK
[2024-12-04 19:28:01] INFO:     127.0.0.1:47752 - "GET /get_model_info HTTP/1.1" 200 OK
[2024-12-04 19:28:01 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-12-04 19:28:02] INFO:     127.0.0.1:47758 - "POST /encode HTTP/1.1" 200 OK
[2024-12-04 19:28:02] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
[13]:
from transformers import AutoTokenizer

PROMPT = (
    "What is the range of the numeric output of a sigmoid node in a neural network?"
)

RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."

CONVS = [
    [{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}],
    [{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}],
]

tokenizer = AutoTokenizer.from_pretrained("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2")
prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)

url = "http://localhost:30030/classify"
data = {"model": "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", "text": prompts}

responses = requests.post(url, json=data).json()
for response in responses:
    print_highlight(f"reward: {response['embedding'][0]}")
[2024-12-04 19:28:12 TP0] Prefill batch. #new-seq: 2, #new-token: 136, #cached-token: 2, cache hit rate: 1.38%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-12-04 19:28:12] INFO:     127.0.0.1:49378 - "POST /classify HTTP/1.1" 200 OK
reward: -24.25
reward: 1.15625
[14]:
terminate_process(reward_process)