OpenAI Compatible API#

SGLang provides an OpenAI compatible API for smooth transition from OpenAI services. Full reference of the API is available at OpenAI API Reference.

This tutorial covers these popular APIs:

Chat Completions#

Usage#

Similar to send_request.ipynb, we can send a chat completion request to SGLang server with OpenAI API format.

[1]:
from sglang.utils import (
    execute_shell_command,
    wait_for_server,
    terminate_process,
    print_highlight,
)

server_process = execute_shell_command(
    command="python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0"
)

wait_for_server("http://localhost:30000")
[2024-11-01 07:53:11] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, host='0.0.0.0', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=507753914, constrained_json_whitespace_pattern=None, decode_log_interval=40, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, watchdog_timeout=600, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_flashinfer=False, disable_flashinfer_sampling=False, disable_radix_cache=False, disable_regex_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_penalizer=False, disable_nan_detection=False, enable_overlap_schedule=False, enable_mixed_chunk=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_p2p_check=False, triton_attention_reduce_in_fp32=False, num_continuous_decode_steps=1)
[2024-11-01 07:53:27 TP0] Init torch distributed begin.
[2024-11-01 07:53:27 TP0] Load weight begin. avail mem=78.59 GB
[2024-11-01 07:53:28 TP0] lm_eval is not installed, GPTQ may not be usable
INFO 11-01 07:53:29 weight_utils.py:243] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.22it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.09it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.08it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.46it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.31it/s]

[2024-11-01 07:53:32 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=63.50 GB
[2024-11-01 07:53:32 TP0] Memory pool end. avail mem=8.37 GB
[2024-11-01 07:53:32 TP0] Capture cuda graph begin. This can take up to several minutes.
[2024-11-01 07:53:41 TP0] max_total_num_tokens=442913, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2024-11-01 07:53:41] INFO:     Started server process [1240000]
[2024-11-01 07:53:41] INFO:     Waiting for application startup.
[2024-11-01 07:53:41] INFO:     Application startup complete.
[2024-11-01 07:53:41] INFO:     Uvicorn running on http://0.0.0.0:30000 (Press CTRL+C to quit)
[2024-11-01 07:53:42] INFO:     127.0.0.1:54538 - "GET /v1/models HTTP/1.1" 200 OK
[2024-11-01 07:53:42] INFO:     127.0.0.1:54554 - "GET /get_model_info HTTP/1.1" 200 OK
[2024-11-01 07:53:42 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:53:42] INFO:     127.0.0.1:54568 - "POST /generate HTTP/1.1" 200 OK
[2024-11-01 07:53:42] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
[2]:
import openai

client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "system", "content": "You are a helpful AI assistant"},
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2024-11-01 07:53:48 TP0] Prefill batch. #new-seq: 1, #new-token: 48, #cached-token: 1, cache hit rate: 1.79%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:53:48 TP0] Decode batch. #running-req: 1, #token: 82, token usage: 0.00, gen throughput (token/s): 5.53, #queue-req: 0
[2024-11-01 07:53:48] INFO:     127.0.0.1:46400 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='da60d77fcf2c441daf77b8cef540eaf2', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. **Country:** Japan\n**Capital:** Tokyo\n\n2. **Country:** Australia\n**Capital:** Canberra\n\n3. **Country:** Brazil\n**Capital:** Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None), matched_stop=128009)], created=1730447628, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=46, prompt_tokens=49, total_tokens=95, completion_tokens_details=None, prompt_tokens_details=None))

Parameters#

The chat completions API accepts OpenAI Chat Completions API’s parameters. Refer to OpenAI Chat Completions API for more details.

Here is an example of a detailed chat completion request:

[3]:
response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {
            "role": "system",
            "content": "You are a knowledgeable historian who provides concise responses.",
        },
        {"role": "user", "content": "Tell me about ancient Rome"},
        {
            "role": "assistant",
            "content": "Ancient Rome was a civilization centered in Italy.",
        },
        {"role": "user", "content": "What were their major achievements?"},
    ],
    temperature=0.3,  # Lower temperature for more focused responses
    max_tokens=128,  # Reasonable length for a concise response
    top_p=0.95,  # Slightly higher for better fluency
    presence_penalty=0.2,  # Mild penalty to avoid repetition
    frequency_penalty=0.2,  # Mild penalty for more natural language
    n=1,  # Single response is usually more stable
    seed=42,  # Keep for reproducibility
)

print_highlight(response.choices[0].message.content)
[2024-11-01 07:53:48 TP0] Prefill batch. #new-seq: 1, #new-token: 48, #cached-token: 28, cache hit rate: 21.97%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:53:48 TP0] Decode batch. #running-req: 1, #token: 104, token usage: 0.00, gen throughput (token/s): 108.87, #queue-req: 0
[2024-11-01 07:53:49 TP0] Decode batch. #running-req: 1, #token: 144, token usage: 0.00, gen throughput (token/s): 132.15, #queue-req: 0
[2024-11-01 07:53:49 TP0] Decode batch. #running-req: 1, #token: 184, token usage: 0.00, gen throughput (token/s): 131.53, #queue-req: 0
[2024-11-01 07:53:49] INFO:     127.0.0.1:46400 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Ancient Rome's major achievements include:

1. **Engineering and Architecture**: Developed the arch, dome, aqueducts, roads (e.g., Appian Way), and public buildings like the Colosseum and Pantheon.
2. **Law and Governance**: Established the Twelve Tables (450 BCE), which formed the basis of Roman law, and created the concept of citizenship.
3. **Military Conquests**: Expanded from Italy to a vast empire, conquering much of Europe, North Africa, and parts of Asia.
4. **Language and Literature**: Developed Latin, which became a language of government, law, and literature,

Streaming mode is also supported

[4]:
stream = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[{"role": "user", "content": "Say this is a test"}],
    stream=True,
)
for chunk in stream:
    if chunk.choices[0].delta.content is not None:
        print(chunk.choices[0].delta.content, end="")
[2024-11-01 07:53:49] INFO:     127.0.0.1:46400 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2024-11-01 07:53:49 TP0] Prefill batch. #new-seq: 1, #new-token: 15, #cached-token: 25, cache hit rate: 31.40%, token usage: 0.00, #running-req: 0, #queue-req: 0
This is only a test

Completions#

Usage#

Completions API is similar to Chat Completions API, but without the messages parameter.

[5]:
response = client.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    prompt="List 3 countries and their capitals.",
    temperature=0,
    max_tokens=64,
    n=1,
    stop=None,
)

print_highlight(f"Response: {response}")
[2024-11-01 07:53:49 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 1, cache hit rate: 30.39%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:53:49 TP0] Decode batch. #running-req: 1, #token: 25, token usage: 0.00, gen throughput (token/s): 114.64, #queue-req: 0
[2024-11-01 07:53:50 TP0] Decode batch. #running-req: 1, #token: 65, token usage: 0.00, gen throughput (token/s): 142.05, #queue-req: 0
[2024-11-01 07:53:50] INFO:     127.0.0.1:46400 - "POST /v1/completions HTTP/1.1" 200 OK
Response: Completion(id='fccce55bb5f8476d8e052f017ad9e759', choices=[CompletionChoice(finish_reason='length', index=0, logprobs=None, text=' 1. 2. 3.\n1. United States - Washington D.C. 2. Japan - Tokyo 3. Australia - Canberra\nList 3 countries and their capitals. 1. 2. 3.\n1. China - Beijing 2. Brazil - Bras', matched_stop=None)], created=1730447630, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='text_completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=64, prompt_tokens=9, total_tokens=73, completion_tokens_details=None, prompt_tokens_details=None))

Parameters#

The completions API accepts OpenAI Completions API’s parameters. Refer to OpenAI Completions API for more details.

Here is an example of a detailed completions request:

[6]:
response = client.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    prompt="Write a short story about a space explorer.",
    temperature=0.7,  # Moderate temperature for creative writing
    max_tokens=150,  # Longer response for a story
    top_p=0.9,  # Balanced diversity in word choice
    stop=["\n\n", "THE END"],  # Multiple stop sequences
    presence_penalty=0.3,  # Encourage novel elements
    frequency_penalty=0.3,  # Reduce repetitive phrases
    n=1,  # Generate one completion
    seed=123,  # For reproducible results
)

print_highlight(f"Response: {response}")
[2024-11-01 07:53:50 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 1, cache hit rate: 29.32%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:53:50 TP0] Decode batch. #running-req: 1, #token: 43, token usage: 0.00, gen throughput (token/s): 123.81, #queue-req: 0
[2024-11-01 07:53:50 TP0] Decode batch. #running-req: 1, #token: 83, token usage: 0.00, gen throughput (token/s): 133.46, #queue-req: 0
[2024-11-01 07:53:51 TP0] Decode batch. #running-req: 1, #token: 123, token usage: 0.00, gen throughput (token/s): 132.27, #queue-req: 0
[2024-11-01 07:53:51] INFO:     127.0.0.1:46400 - "POST /v1/completions HTTP/1.1" 200 OK
Response: Completion(id='104040efc23e4f56ba0390b76aacb745', choices=[CompletionChoice(finish_reason='length', index=0, logprobs=None, text="\xa0\nAstrid had always been fascinated by the vastness of space. As a child, she spent countless hours gazing up at the stars, dreaming of the day she could explore them for herself. Now, as a seasoned space explorer, she had finally achieved her goal.\nAstrid's ship, the Aurora, had been traveling through the galaxy for months, and she had seen some incredible sights along the way. From the swirling purple clouds of a distant gas giant to the gleaming silver surface of a moon that shone like a beacon in the darkness, Astrid had experienced it all.\nBut her mission was not just about sightseeing – it was about discovery. Astrid was on a quest to find a new habitable planet,", matched_stop=None)], created=1730447631, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='text_completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=150, prompt_tokens=10, total_tokens=160, completion_tokens_details=None, prompt_tokens_details=None))

Batches#

We have implemented the batches API for chat completions and completions. You can upload your requests in jsonl files, create a batch job, and retrieve the results when the batch job is completed (which takes longer but costs less).

The batches APIs are:

  • batches

  • batches/{batch_id}/cancel

  • batches/{batch_id}

Here is an example of a batch job for chat completions, completions are similar.

[7]:
import json
import time
from openai import OpenAI

client = OpenAI(base_url="http://127.0.0.1:30000/v1", api_key="None")

requests = [
    {
        "custom_id": "request-1",
        "method": "POST",
        "url": "/chat/completions",
        "body": {
            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
            "messages": [
                {"role": "user", "content": "Tell me a joke about programming"}
            ],
            "max_tokens": 50,
        },
    },
    {
        "custom_id": "request-2",
        "method": "POST",
        "url": "/chat/completions",
        "body": {
            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
            "messages": [{"role": "user", "content": "What is Python?"}],
            "max_tokens": 50,
        },
    },
]

input_file_path = "batch_requests.jsonl"

with open(input_file_path, "w") as f:
    for req in requests:
        f.write(json.dumps(req) + "\n")

with open(input_file_path, "rb") as f:
    file_response = client.files.create(file=f, purpose="batch")

batch_response = client.batches.create(
    input_file_id=file_response.id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
)

print_highlight(f"Batch job created with ID: {batch_response.id}")
[2024-11-01 07:53:51] INFO:     127.0.0.1:46410 - "POST /v1/files HTTP/1.1" 200 OK
[2024-11-01 07:53:51] INFO:     127.0.0.1:46410 - "POST /v1/batches HTTP/1.1" 200 OK
[2024-11-01 07:53:51 TP0] Prefill batch. #new-seq: 2, #new-token: 20, #cached-token: 60, cache hit rate: 42.80%, token usage: 0.00, #running-req: 0, #queue-req: 0
Batch job created with ID: batch_ce6ad18e-0eda-495b-8c1b-4cf73603d6b1
[8]:
while batch_response.status not in ["completed", "failed", "cancelled"]:
    time.sleep(3)
    print(f"Batch job status: {batch_response.status}...trying again in 3 seconds...")
    batch_response = client.batches.retrieve(batch_response.id)

if batch_response.status == "completed":
    print("Batch job completed successfully!")
    print(f"Request counts: {batch_response.request_counts}")

    result_file_id = batch_response.output_file_id
    file_response = client.files.content(result_file_id)
    result_content = file_response.read().decode("utf-8")

    results = [
        json.loads(line) for line in result_content.split("\n") if line.strip() != ""
    ]

    for result in results:
        print_highlight(f"Request {result['custom_id']}:")
        print_highlight(f"Response: {result['response']}")

    print_highlight("Cleaning up files...")
    # Only delete the result file ID since file_response is just content
    client.files.delete(result_file_id)
else:
    print_highlight(f"Batch job failed with status: {batch_response.status}")
    if hasattr(batch_response, "errors"):
        print_highlight(f"Errors: {batch_response.errors}")
[2024-11-01 07:53:51 TP0] Decode batch. #running-req: 2, #token: 58, token usage: 0.00, gen throughput (token/s): 106.14, #queue-req: 0
[2024-11-01 07:53:51 TP0] Decode batch. #running-req: 1, #token: 83, token usage: 0.00, gen throughput (token/s): 166.70, #queue-req: 0
Batch job status: validating...trying again in 3 seconds...
[2024-11-01 07:53:54] INFO:     127.0.0.1:46410 - "GET /v1/batches/batch_ce6ad18e-0eda-495b-8c1b-4cf73603d6b1 HTTP/1.1" 200 OK
Batch job completed successfully!
Request counts: BatchRequestCounts(completed=2, failed=0, total=2)
[2024-11-01 07:53:54] INFO:     127.0.0.1:46410 - "GET /v1/files/backend_result_file-a3eef415-2825-439f-8bb6-c8bdda89e193/content HTTP/1.1" 200 OK
Request request-1:
Response: {'status_code': 200, 'request_id': 'request-1', 'body': {'id': 'request-1', 'object': 'chat.completion', 'created': 1730447631, 'model': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'choices': {'index': 0, 'message': {'role': 'assistant', 'content': 'Why do programmers prefer dark mode?\n\nBecause light attracts bugs.'}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 128009}, 'usage': {'prompt_tokens': 41, 'completion_tokens': 13, 'total_tokens': 54}, 'system_fingerprint': None}}
Request request-2:
Response: {'status_code': 200, 'request_id': 'request-2', 'body': {'id': 'request-2', 'object': 'chat.completion', 'created': 1730447631, 'model': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'choices': {'index': 0, 'message': {'role': 'assistant', 'content': '**What is Python?**\n\nPython is a high-level, interpreted programming language that is widely used for various purposes such as web development, scientific computing, data analysis, artificial intelligence, and more. It was created in the late 1980s by'}, 'logprobs': None, 'finish_reason': 'length', 'matched_stop': None}, 'usage': {'prompt_tokens': 39, 'completion_tokens': 50, 'total_tokens': 89}, 'system_fingerprint': None}}
Cleaning up files...
[2024-11-01 07:53:54] INFO:     127.0.0.1:46410 - "DELETE /v1/files/backend_result_file-a3eef415-2825-439f-8bb6-c8bdda89e193 HTTP/1.1" 200 OK

It takes a while to complete the batch job. You can use these two APIs to retrieve the batch job status or cancel the batch job.

  1. batches/{batch_id}: Retrieve the batch job status.

  2. batches/{batch_id}/cancel: Cancel the batch job.

Here is an example to check the batch job status.

[9]:
import json
import time
from openai import OpenAI

client = OpenAI(base_url="http://127.0.0.1:30000/v1", api_key="None")

requests = []
for i in range(100):
    requests.append(
        {
            "custom_id": f"request-{i}",
            "method": "POST",
            "url": "/chat/completions",
            "body": {
                "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
                "messages": [
                    {
                        "role": "system",
                        "content": f"{i}: You are a helpful AI assistant",
                    },
                    {
                        "role": "user",
                        "content": "Write a detailed story about topic. Make it very long.",
                    },
                ],
                "max_tokens": 500,
            },
        }
    )

input_file_path = "batch_requests.jsonl"
with open(input_file_path, "w") as f:
    for req in requests:
        f.write(json.dumps(req) + "\n")

with open(input_file_path, "rb") as f:
    uploaded_file = client.files.create(file=f, purpose="batch")

batch_job = client.batches.create(
    input_file_id=uploaded_file.id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
)

print_highlight(f"Created batch job with ID: {batch_job.id}")
print_highlight(f"Initial status: {batch_job.status}")

time.sleep(10)

max_checks = 5
for i in range(max_checks):
    batch_details = client.batches.retrieve(batch_id=batch_job.id)

    print_highlight(
        f"Batch job details (check {i+1} / {max_checks}) // ID: {batch_details.id} // Status: {batch_details.status} // Created at: {batch_details.created_at} // Input file ID: {batch_details.input_file_id} // Output file ID: {batch_details.output_file_id}"
    )
    print_highlight(
        f"<strong>Request counts: Total: {batch_details.request_counts.total} // Completed: {batch_details.request_counts.completed} // Failed: {batch_details.request_counts.failed}</strong>"
    )

    time.sleep(3)
[2024-11-01 07:53:54] INFO:     127.0.0.1:46420 - "POST /v1/files HTTP/1.1" 200 OK
[2024-11-01 07:53:54] INFO:     127.0.0.1:46420 - "POST /v1/batches HTTP/1.1" 200 OK
Created batch job with ID: batch_91ed089c-f0e4-40e5-bda9-38848eb741d1
Initial status: validating
[2024-11-01 07:53:54 TP0] Prefill batch. #new-seq: 23, #new-token: 690, #cached-token: 575, cache hit rate: 44.99%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:53:54 TP0] Prefill batch. #new-seq: 77, #new-token: 2310, #cached-token: 1925, cache hit rate: 45.33%, token usage: 0.00, #running-req: 23, #queue-req: 0
[2024-11-01 07:53:55 TP0] Decode batch. #running-req: 100, #token: 6525, token usage: 0.01, gen throughput (token/s): 1076.66, #queue-req: 0
[2024-11-01 07:53:55 TP0] Decode batch. #running-req: 100, #token: 10525, token usage: 0.02, gen throughput (token/s): 10714.16, #queue-req: 0
[2024-11-01 07:53:55 TP0] Decode batch. #running-req: 100, #token: 14525, token usage: 0.03, gen throughput (token/s): 10497.37, #queue-req: 0
[2024-11-01 07:53:56 TP0] Decode batch. #running-req: 100, #token: 18525, token usage: 0.04, gen throughput (token/s): 10265.91, #queue-req: 0
[2024-11-01 07:53:56 TP0] Decode batch. #running-req: 100, #token: 22525, token usage: 0.05, gen throughput (token/s): 10053.58, #queue-req: 0
[2024-11-01 07:53:56 TP0] Decode batch. #running-req: 100, #token: 26525, token usage: 0.06, gen throughput (token/s): 9815.04, #queue-req: 0
[2024-11-01 07:53:57 TP0] Decode batch. #running-req: 100, #token: 30525, token usage: 0.07, gen throughput (token/s): 9636.97, #queue-req: 0
[2024-11-01 07:53:57 TP0] Decode batch. #running-req: 100, #token: 34525, token usage: 0.08, gen throughput (token/s): 9448.08, #queue-req: 0
[2024-11-01 07:53:58 TP0] Decode batch. #running-req: 100, #token: 38525, token usage: 0.09, gen throughput (token/s): 9259.08, #queue-req: 0
[2024-11-01 07:53:58 TP0] Decode batch. #running-req: 100, #token: 42525, token usage: 0.10, gen throughput (token/s): 9076.97, #queue-req: 0
[2024-11-01 07:53:59 TP0] Decode batch. #running-req: 100, #token: 46525, token usage: 0.11, gen throughput (token/s): 8904.75, #queue-req: 0
[2024-11-01 07:53:59 TP0] Decode batch. #running-req: 100, #token: 50525, token usage: 0.11, gen throughput (token/s): 8729.19, #queue-req: 0
[2024-11-01 07:54:04] INFO:     127.0.0.1:53702 - "GET /v1/batches/batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 HTTP/1.1" 200 OK
Batch job details (check 1 / 5) // ID: batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 // Status: completed // Created at: 1730447634 // Input file ID: backend_input_file-f4c920b6-dd3e-44a2-8983-4d3b76ddf189 // Output file ID: backend_result_file-dad4734a-d8eb-425b-b580-698f543957c7
Request counts: Total: 100 // Completed: 100 // Failed: 0
[2024-11-01 07:54:07] INFO:     127.0.0.1:53702 - "GET /v1/batches/batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 HTTP/1.1" 200 OK
Batch job details (check 2 / 5) // ID: batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 // Status: completed // Created at: 1730447634 // Input file ID: backend_input_file-f4c920b6-dd3e-44a2-8983-4d3b76ddf189 // Output file ID: backend_result_file-dad4734a-d8eb-425b-b580-698f543957c7
Request counts: Total: 100 // Completed: 100 // Failed: 0
[2024-11-01 07:54:10] INFO:     127.0.0.1:53702 - "GET /v1/batches/batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 HTTP/1.1" 200 OK
Batch job details (check 3 / 5) // ID: batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 // Status: completed // Created at: 1730447634 // Input file ID: backend_input_file-f4c920b6-dd3e-44a2-8983-4d3b76ddf189 // Output file ID: backend_result_file-dad4734a-d8eb-425b-b580-698f543957c7
Request counts: Total: 100 // Completed: 100 // Failed: 0
[2024-11-01 07:54:13] INFO:     127.0.0.1:53702 - "GET /v1/batches/batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 HTTP/1.1" 200 OK
Batch job details (check 4 / 5) // ID: batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 // Status: completed // Created at: 1730447634 // Input file ID: backend_input_file-f4c920b6-dd3e-44a2-8983-4d3b76ddf189 // Output file ID: backend_result_file-dad4734a-d8eb-425b-b580-698f543957c7
Request counts: Total: 100 // Completed: 100 // Failed: 0
[2024-11-01 07:54:16] INFO:     127.0.0.1:53702 - "GET /v1/batches/batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 HTTP/1.1" 200 OK
Batch job details (check 5 / 5) // ID: batch_91ed089c-f0e4-40e5-bda9-38848eb741d1 // Status: completed // Created at: 1730447634 // Input file ID: backend_input_file-f4c920b6-dd3e-44a2-8983-4d3b76ddf189 // Output file ID: backend_result_file-dad4734a-d8eb-425b-b580-698f543957c7
Request counts: Total: 100 // Completed: 100 // Failed: 0

Here is an example to cancel a batch job.

[10]:
import json
import time
from openai import OpenAI
import os

client = OpenAI(base_url="http://127.0.0.1:30000/v1", api_key="None")

requests = []
for i in range(500):
    requests.append(
        {
            "custom_id": f"request-{i}",
            "method": "POST",
            "url": "/chat/completions",
            "body": {
                "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
                "messages": [
                    {
                        "role": "system",
                        "content": f"{i}: You are a helpful AI assistant",
                    },
                    {
                        "role": "user",
                        "content": "Write a detailed story about topic. Make it very long.",
                    },
                ],
                "max_tokens": 500,
            },
        }
    )

input_file_path = "batch_requests.jsonl"
with open(input_file_path, "w") as f:
    for req in requests:
        f.write(json.dumps(req) + "\n")

with open(input_file_path, "rb") as f:
    uploaded_file = client.files.create(file=f, purpose="batch")

batch_job = client.batches.create(
    input_file_id=uploaded_file.id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
)

print_highlight(f"Created batch job with ID: {batch_job.id}")
print_highlight(f"Initial status: {batch_job.status}")

time.sleep(10)

try:
    cancelled_job = client.batches.cancel(batch_id=batch_job.id)
    print_highlight(f"Cancellation initiated. Status: {cancelled_job.status}")
    assert cancelled_job.status == "cancelling"

    # Monitor the cancellation process
    while cancelled_job.status not in ["failed", "cancelled"]:
        time.sleep(3)
        cancelled_job = client.batches.retrieve(batch_job.id)
        print_highlight(f"Current status: {cancelled_job.status}")

    # Verify final status
    assert cancelled_job.status == "cancelled"
    print_highlight("Batch job successfully cancelled")

except Exception as e:
    print_highlight(f"Error during cancellation: {e}")
    raise e

finally:
    try:
        del_response = client.files.delete(uploaded_file.id)
        if del_response.deleted:
            print_highlight("Successfully cleaned up input file")
        if os.path.exists(input_file_path):
            os.remove(input_file_path)
            print_highlight("Successfully deleted local batch_requests.jsonl file")
    except Exception as e:
        print_highlight(f"Error cleaning up: {e}")
        raise e
[2024-11-01 07:54:19] INFO:     127.0.0.1:60802 - "POST /v1/files HTTP/1.1" 200 OK
[2024-11-01 07:54:19] INFO:     127.0.0.1:60802 - "POST /v1/batches HTTP/1.1" 200 OK
Created batch job with ID: batch_ee93a6f4-3c9b-4884-a564-40725b15103b
Initial status: validating
[2024-11-01 07:54:19 TP0] Prefill batch. #new-seq: 42, #new-token: 42, #cached-token: 2268, cache hit rate: 60.44%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-01 07:54:19 TP0] Prefill batch. #new-seq: 330, #new-token: 8192, #cached-token: 9932, cache hit rate: 56.54%, token usage: 0.01, #running-req: 42, #queue-req: 128
[2024-11-01 07:54:20 TP0] Prefill batch. #new-seq: 129, #new-token: 3866, #cached-token: 3229, cache hit rate: 54.19%, token usage: 0.03, #running-req: 371, #queue-req: 1
[2024-11-01 07:54:20 TP0] Decode batch. #running-req: 500, #token: 23025, token usage: 0.05, gen throughput (token/s): 492.95, #queue-req: 0
[2024-11-01 07:54:21 TP0] Decode batch. #running-req: 500, #token: 43025, token usage: 0.10, gen throughput (token/s): 23992.40, #queue-req: 0
[2024-11-01 07:54:22 TP0] Decode batch. #running-req: 500, #token: 63025, token usage: 0.14, gen throughput (token/s): 22891.17, #queue-req: 0
[2024-11-01 07:54:23 TP0] Decode batch. #running-req: 500, #token: 83025, token usage: 0.19, gen throughput (token/s): 21879.73, #queue-req: 0
[2024-11-01 07:54:24 TP0] Decode batch. #running-req: 500, #token: 103025, token usage: 0.23, gen throughput (token/s): 20940.17, #queue-req: 0
[2024-11-01 07:54:25 TP0] Decode batch. #running-req: 500, #token: 123025, token usage: 0.28, gen throughput (token/s): 20076.86, #queue-req: 0
[2024-11-01 07:54:26 TP0] Decode batch. #running-req: 500, #token: 143025, token usage: 0.32, gen throughput (token/s): 19329.24, #queue-req: 0
[2024-11-01 07:54:27 TP0] Decode batch. #running-req: 500, #token: 163025, token usage: 0.37, gen throughput (token/s): 18630.48, #queue-req: 0
[2024-11-01 07:54:28 TP0] Decode batch. #running-req: 500, #token: 183025, token usage: 0.41, gen throughput (token/s): 17953.34, #queue-req: 0
[2024-11-01 07:54:29 TP0] Decode batch. #running-req: 500, #token: 203025, token usage: 0.46, gen throughput (token/s): 17335.25, #queue-req: 0
[2024-11-01 07:54:29] INFO:     127.0.0.1:43472 - "POST /v1/batches/batch_ee93a6f4-3c9b-4884-a564-40725b15103b/cancel HTTP/1.1" 200 OK
Cancellation initiated. Status: cancelling
[2024-11-01 07:54:32] INFO:     127.0.0.1:43472 - "GET /v1/batches/batch_ee93a6f4-3c9b-4884-a564-40725b15103b HTTP/1.1" 200 OK
Current status: cancelled
Batch job successfully cancelled
[2024-11-01 07:54:32] INFO:     127.0.0.1:43472 - "DELETE /v1/files/backend_input_file-a43aefb0-5138-4be1-a631-24cabb70eef5 HTTP/1.1" 200 OK
Successfully cleaned up input file
Successfully deleted local batch_requests.jsonl file
[11]:
terminate_process(server_process)