Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] README revamp and documentation initialization #42

Merged
merged 1 commit into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/documentation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Build Docs

on:
push:
branches:
- main

jobs:
test_linux:
name: Deploy Docs
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
with:
submodules: recursive

- name: Configuring build Environment
run: |
sudo apt-get update
python -m pip install -U pip wheel

- name: Setup Ruby
uses: ruby/setup-ruby@v1
with:
ruby-version: '3.0'

- name: Installing dependencies
run: |
python -m pip install -r docs/requirements.txt
gem install jekyll jekyll-remote-theme

- name: Deploying on GitHub Pages
if: github.ref == 'refs/heads/main'
run: |
git remote set-url origin https://x-access-token:${{ secrets.MLC_GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY
git config --global user.email "mlc-gh-actions-bot@nomail"
git config --global user.name "mlc-gh-actions-bot"
./scripts/gh_deploy_site.sh
199 changes: 47 additions & 152 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,152 +1,47 @@
## XGrammar

Cross-platform Near-zero Overhead Grammar-guided Generation for LLMs

- G1: Universal: support any common tokenizer, and common grammar
- G2: Efficient: Grammar should not cause additional burden for generation
- G3: Cross-platform: pure C++ impl, portable for every platform, construct E2E pipeline on every platform
- G4: Easy to understand and maintain

This project is under active development.

### Compile and Install

```bash
# install requirements
sudo apt install cmake
python3 -m pip install ninja pybind11 torch

# build XGrammar core and Python bindings
# see scripts/config.cmake for configuration options
mkdir build
cd build
# specify your own CUDA architecture
cmake .. -G Ninja -DXGRAMMAR_CUDA_ARCHITECTURES=89
ninja

# install Python package
cd ../python
python3 -m pip install .

# optional: add the python directory to PATH
echo "export PATH=\$PATH:$(pwd)" >> ~/.bashrc
```

### Python Usage Guide

#### Step 1:Construction of grammar

```python
from xgrammar import BNFGrammar, BuiltinGrammar
from pydantic import BaseModel

# Method 1: provide a GBNF grammar string
# For specification, see https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
gbnf_grammar = """
root ::= (expr "=" term "\n")+
expr ::= term ([-+*/] term)*
term ::= num | "(" expr ")"
num ::= [0-9]+
"""

gbnf_grammar = BNFGrammar(gbnf_grammar)

# Method 2: unconstrained JSON
json_grammar = BuiltinGrammar.json()

# Method 3: provide a Pydantic model
class Person(BaseModel):
name: str
age: int
json_schema_pydantic = BuiltinGrammar.json_schema(Person)

# Method 4: provide a JSON schema string
person_schema = {
"title": "Person",
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer",
}
},
"required": ["name", "age"]
}
json_schema_str = BuiltinGrammar.json_schema(json.dumps(person_schema))
```

#### Step 2: Compiling grammars
The compilation is multi-threaded and cached for every grammar.

```python
from xgrammar import TokenizerInfo, CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
from transformers import AutoTokenizer

# 1. Convert huggingface tokenizer to TokenizerInfo (once per model)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
```

Method 1: Use CachedGrammarCompiler to avoid compile grammar multiple times
```python
# 2. Construct CachedGrammarCompiler (once per model)
compiler = CachedGrammarCompiler(tokenizer_info, max_threads=8)

# 3. Fetch CompiledGrammar and construct GrammarMatcher (once per request)
compiled_grammar = compiler.compile_json_schema(json_schema_str)
matcher = GrammarMatcher(compiled_grammar)
```

Method 2: Compile grammar directly
```python
# 2. Construct CompiledGrammar directly (once per grammar)
compiled_grammar = CompiledGrammar(grammar, tokenizer_info, max_threads=8)

# 3. Construct GrammarMatcher (once per request)
matcher = GrammarMatcher(compiled_grammar)
```

#### Step 3: Grammar-guided generation

For single-batch generation:
```python
import torch

token_bitmask = GrammarMatcher.allocate_token_bitmask(matcher.vocab_size)
while True:
logits = LLM.inference() # logits is a tensor of shape (vocab_size,) on GPU
matcher.fill_next_token_bitmask(logits, token_bitmask)
GrammarMatcher.apply_token_bitmask_inplace(logits, token_bitmask)

prob = torch.softmax(logits, dim=-1) # get probability from logits
next_token_id = Sampler.sample(logits) # use your own sampler

matcher.accept_token(next_token_id)
if matcher.is_terminated(): # or your own termination condition
break
```

For multi-batch generation:
```python
import torch

matchers: List[GrammarMatcher] # The grammar matcher for every request
token_bitmasks = GrammarMatcher.allocate_token_bitmask(matchers[0].vocab_size, batch_size)
while True:
logits = LLM.inference() # logits is a tensor of shape (batch_size, vocab_size) on GPU
# This for loop is parallelizable using threading.Thread. But estimate the overhead in your
# engine.
for i in range(len(matchers)):
matchers[i].fill_next_token_bitmask(token_bitmasks, i)
GrammarMatcher.apply_token_bitmask_inplace(logits, token_bitmasks)

prob = torch.softmax(logits, dim=-1) # get probability from logits
next_token_ids = Sampler.sample(logits) # use your own sampler

for i in range(len(matchers)):
matchers[i].accept_token(next_token_ids[i])
if matchers[i].is_terminated(): # or your own termination condition
requests[i].terminate()
```
<div align="center" id="top">

# XGrammar

[![Documentation](https://img.shields.io/badge/docs-latest-green)](https://xgrammar.mlc.ai/docs/)
[![License](https://img.shields.io/badge/license-apache_2-blue)](https://github.com/mlc-ai/xgrammar/blob/main/LICENSE)

**Flexible, Portable and Fast Structured Generation**


[Get Started](#get-started) | [Documentation](https://xgrammar.mlc.ai/docs/) <!-- TODO: | [Blogpost](https://blog.mlc.ai/TODO) -->

</div>

## Overview

XGrammar is open-source solution for flexible, portable, and fast structured generations,
aiming at bring flexible zero-overhead structure generation everywhere.
It supports general context-free grammar to enable a broad range of structures
while bringing careful system optimizations to enable fast executions at tens of microseconds level.
XGrammar features a minimal and portable c++ backend that can be easily integrated into multiple environments and frameworks,
and is co-designed with the LLM inference engine, which enables outperformance over existing structure
generation solutions and enables zero-overhead structured generation in LLM inference.

<!--
## Key Features

TODO
WebLLM reference https://github.com/mlc-ai/web-llm/#key-features, if we want to list key features.
-->


## Get Started

Please visit our [documentation](https://xgrammar.mlc.ai/docs/) to get started with XGrammar.
- [Installation](https://xgrammar.mlc.ai/docs/installation)
- [Quick start](https://xgrammar.mlc.ai/docs/quick_start)


<!--
## Links

TODO
- [Demo App: WebLLM Chat](https://chat.webllm.ai/)
- If you want to run LLM on native runtime, check out [MLC-LLM](https://github.com/mlc-ai/mlc-llm)
- You might also be interested in [Web Stable Diffusion](https://github.com/mlc-ai/web-stable-diffusion/).
-->
1 change: 1 addition & 0 deletions docs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_build/
20 changes: 20 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= python -m sphinx
SOURCEDIR = .
BUILDDIR = _build

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
30 changes: 30 additions & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# XGrammar Documentation

The documentation was built upon [Sphinx](https://www.sphinx-doc.org/en/master/).

## Dependencies

Run the following command in this directory to install dependencies first:

```bash
pip3 install -r requirements.txt
```

## Build the Documentation

Then you can build the documentation by running:

```bash
make html
```

## View the Documentation

Run the following command to start a simple HTTP server:

```bash
cd _build/html
python3 -m http.server
```

Then you can view the documentation in your browser at `http://localhost:8000` (the port can be customized by appending ` -p PORT_NUMBER` in the python command above).
Loading