Skip to content

Commit

Permalink
Implemented --schema-multi, closes #791
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Feb 27, 2025
1 parent e20bd45 commit 7e819c2
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/help.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ Options:
Attachment with explicit mimetype
-o, --option <TEXT TEXT>... key/value options for the model
--schema TEXT JSON schema, filepath or ID
--schema-multi TEXT JSON schema to use for multiple results
-t, --template TEXT Template to use
-p, --param <TEXT TEXT>... Parameters for template
--no-stream Do not stream output
Expand Down
5 changes: 5 additions & 0 deletions docs/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,9 @@ To return multiple items matching your schema, use the `--schema-multi` option.
```bash
llm --schema-multi 'name,description,fave_toy' 'invent 3 dogs'
```
Using this option a simpler version of the New York Times example above is the following:
```bash
curl https://www.nytimes.com/ | uvx strip-tags | llm --schema-multi 'headline, summary' | jq
```

The Python utility function `llm.utils.build_json_schema(schema)` can be used to convert this syntax into the equivalent JSON schema dictionary.
16 changes: 16 additions & 0 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def cli():
help="key/value options for the model",
)
@schema_option
@click.option(
"--schema-multi",
help="JSON schema to use for multiple results",
)
@click.option("-t", "--template", help="Template to use")
@click.option(
"-p",
Expand Down Expand Up @@ -246,6 +250,7 @@ def prompt(
attachment_types,
options,
schema_input,
schema_multi,
template,
param,
no_stream,
Expand Down Expand Up @@ -295,8 +300,19 @@ def prompt(
db = sqlite_utils.Database(log_path)
migrate(db)

if schema_multi:
schema_input = schema_multi

schema = resolve_schema_input(db, schema_input)

if schema_multi:
# Convert that schema into multiple "items" of the same schema
schema = {
"type": "object",
"properties": {"items": {"type": "array", "items": schema}},
"required": ["items"],
}

model_aliases = get_model_aliases()

def read_prompt():
Expand Down
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def execute(self, prompt, stream, response, conversation):
break
except IndexError:
break
response.set_usage(input=len(prompt.prompt.split()), output=len(gathered))
response.set_usage(
input=len((prompt.prompt or "").split()), output=len(gathered)
)


class MockKeyModel(llm.KeyModel):
Expand Down Expand Up @@ -120,7 +122,9 @@ async def execute(self, prompt, stream, response, conversation):
break
except IndexError:
break
response.set_usage(input=len(prompt.prompt.split()), output=len(gathered))
response.set_usage(
input=len((prompt.prompt or "").split()), output=len(gathered)
)


class EmbedDemo(llm.EmbeddingModel):
Expand Down
43 changes: 36 additions & 7 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,24 +616,53 @@ def test_schema_via_cli(mock_model, tmpdir, monkeypatch, use_filename):
assert result2.exit_code == 0


def test_schema_using_dsl(mock_model, tmpdir, monkeypatch):
@pytest.mark.parametrize(
"args,expected",
(
(
["--schema", "name, age int"],
{
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
"required": ["name", "age"],
},
),
(
["--schema-multi", "name, age int"],
{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
"required": ["name", "age"],
},
}
},
"required": ["items"],
},
),
),
)
def test_schema_using_dsl(mock_model, tmpdir, monkeypatch, args, expected):
user_path = tmpdir / "user"
mock_model.enqueue([json.dumps(dog)])
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
runner = CliRunner()
result = runner.invoke(
cli,
["--schema", "name, age int", "prompt", "-m", "mock"],
["prompt", "-m", "mock"] + args,
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.output == '{"name": "Cleo", "age": 10}\n'
rows = list(sqlite_utils.Database(str(user_path / "logs.db"))["schemas"].rows)
assert json.loads(rows[0]["content"]) == {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
"required": ["name", "age"],
}
assert json.loads(rows[0]["content"]) == expected


@pytest.mark.asyncio
Expand Down

0 comments on commit 7e819c2

Please sign in to comment.