Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support parsing of SQL code blocks in Markdown files #598

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

### Features

- Add support for formatting SQL code blocks in Markdown files. Introduces a new extra install (`pipx install shandy-sqlfmt[markdownfmt]`) and CLI option (`--no-markdownfmt`) ([#593](https://github.com/tconbeer/sqlfmt/issues/593) - thank you, [@michael-the1](https://github.com/michael-the1)).

## [0.21.3] - 2024-04-25

### Bug Fixes
Expand Down
199 changes: 96 additions & 103 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ packages = [

[tool.poetry.extras]
jinjafmt = ["black"]
markdownfmt = ["mistletoe"]
sqlfmt_primer = ["gitpython"]

[tool.poetry.dependencies]
Expand All @@ -40,6 +41,7 @@ jinja2 = "^3.0"

black = { version = "*", optional = true }
gitpython = { version = "^3.1.24", optional = true }
mistletoe = { version = '*', optional = true}

[tool.poetry.group.dev.dependencies]
pre-commit = ">=2.20,<4.0"
Expand Down
38 changes: 36 additions & 2 deletions src/sqlfmt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,35 @@ def format_string(source_string: str, mode: Mode) -> str:
return result


def format_markdown_string(source_string: str, mode: Mode) -> str:
"""
Takes a Markdown string and a mode as input, returns the the Markdown string with all of its SQL code blocks formatted.
"""
if mode.no_markdownfmt:
return source_string

from mistletoe import Document
from mistletoe.block_token import BlockToken, CodeFence
from mistletoe.markdown_renderer import MarkdownRenderer
Copy link
Owner

@tconbeer tconbeer Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wrap this in try ... except ImportError and either return the source or re-raise a SqlfmtError to create a nicer error message for users who don't have the extra installed. I know right now we shouldn't hit that codepath due to the Mode property, but would be good to safeguard against that changing in the future.

Copy link
Author

@michael-the1 michael-the1 Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the try ... except guard. I went ahead and subclassed SqlfmtError since I didn't see any bare raise SqlfmtError anywhere.


def format_sql_code_blocks(token: BlockToken):
"""Walk through the AST and replace SQL code blocks with its formatted version."""
if isinstance(token, CodeFence) and token.language == "sql":
raw_text = token.children[0]
raw_text.content = format_string(raw_text.content, mode)

for child in token.children:
if isinstance(child, BlockToken):
format_sql_code_blocks(child)

with MarkdownRenderer() as renderer:
doc = Document(source_string)
format_sql_code_blocks(doc)
formatted_markdown_string = renderer.render(doc)

return formatted_markdown_string


def run(
files: Collection[Path],
mode: Mode,
Expand Down Expand Up @@ -154,7 +183,7 @@ def _get_included_paths(paths: Iterable[Path], mode: Mode) -> Set[Path]:
for p in paths:
if p == STDIN_PATH:
include_set.add(p)
elif p.is_file() and str(p).endswith(tuple(mode.SQL_EXTENSIONS)):
elif p.is_file() and str(p).endswith(mode.included_file_extensions):
include_set.add(p)
elif p.is_dir():
include_set |= _get_included_paths(p.iterdir(), mode)
Expand Down Expand Up @@ -233,10 +262,15 @@ def _format_one(path: Path, mode: Mode) -> SqlFormatResult:
"""
Runs format_string on the contents of a single file (found at path). Handles
potential user errors in formatted code, and returns a SqlfmtResult

If the file is a Markdown file, only format its SQL code blocks.
"""
source, encoding, utf_bom = _read_path_or_stdin(path, mode)
try:
formatted = format_string(source, mode)
if path.is_file() and path.suffix == ".md":
formatted = format_markdown_string(source, mode)
else:
formatted = format_string(source, mode)
return SqlFormatResult(
source_path=path,
source_string=source,
Expand Down
9 changes: 9 additions & 0 deletions src/sqlfmt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@
"or if black was already available in this environment."
),
)
@click.option(
"--no-markdownfmt",
envvar="SQLFMT_NO_MARKDOWNFMT",
is_flag=True,
help=(
"Do not format sql code blocks in markdown files. Only necessary "
"to specify this flag if sqlfmt was installed with the markdownfmt extra."
),
)
@click.option(
"-l",
"--line-length",
Expand Down
15 changes: 13 additions & 2 deletions src/sqlfmt/mode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from dataclasses import dataclass, field
from importlib.util import find_spec
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple

from sqlfmt.dialect import ClickHouse, Polyglot
from sqlfmt.exception import SqlfmtConfigError
Expand All @@ -14,7 +15,6 @@ class Mode:
report config. For more info on each option, see cli.py
"""

SQL_EXTENSIONS: List[str] = field(default_factory=lambda: [".sql", ".sql.jinja"])
dialect_name: str = "polyglot"
line_length: int = 88
check: bool = False
Expand All @@ -25,6 +25,7 @@ class Mode:
fast: bool = False
single_process: bool = False
no_jinjafmt: bool = False
no_markdownfmt: bool = False
reset_cache: bool = False
verbose: bool = False
quiet: bool = False
Expand All @@ -46,6 +47,16 @@ def __post_init__(self) -> None:
"which is not supported. Did you mean 'polyglot'?"
)

@property
def included_file_extensions(self) -> Tuple[str, ...]:
"""List of file extensions to parse.

Only parses Markdown files if mistletoe is installed and no_markdownfmt is not set.
"""
if not self.no_markdownfmt and find_spec("mistletoe"):
return (".sql", ".sql.jinja", ".md")
return (".sql", ".sql.jinja")

@property
def color(self) -> bool:
"""
Expand Down
5 changes: 5 additions & 0 deletions tests/data/fast/preformatted/007_markdown_file.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Hello

```sql
select 1
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Hello again

```python
import antigravity
```

```
SELECT 1
```
11 changes: 11 additions & 0 deletions tests/data/fast/unformatted/107_markdown_file.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Hello

```sql
SELECT 1
```
)))))__SQLFMT_OUTPUT__(((((
# Hello

```sql
select 1
```
19 changes: 19 additions & 0 deletions tests/data/preformatted/008_markdown_file.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Hello

This is some SQL code in a `sql` code block:

```sql
select 1
```

SQL code, but it's not a `sql` code block so it shouldn't get formatted:

```
SELECT 2
```

And finally, some bash code

```bash
echo "Hello, world!"
```
6 changes: 6 additions & 0 deletions tests/data/preformatted/009_markdown_fmt_off.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Hello

```sql
-- fmt: off
SELECT 1
```
11 changes: 11 additions & 0 deletions tests/data/unformatted/500_markdown_file.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Hello

```sql
SELECT 1
```
)))))__SQLFMT_OUTPUT__(((((
# Hello

```sql
select 1
```
20 changes: 19 additions & 1 deletion tests/functional_tests/test_general_formatting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from sqlfmt.api import format_string
from sqlfmt.api import format_markdown_string, format_string
from sqlfmt.mode import Mode
from tests.util import check_formatting, read_test_data

Expand Down Expand Up @@ -97,3 +97,21 @@ def test_formatting(p: str) -> None:

second_pass = format_string(actual, mode)
check_formatting(expected, second_pass, ctx=f"2nd-{p}")


@pytest.mark.parametrize(
"p",
[
michael-the1 marked this conversation as resolved.
Show resolved Hide resolved
"unformatted/500_markdown_file.md",
],
)
def test_markdown_formatting(p: str) -> None:
mode = Mode()

source, expected = read_test_data(p)
actual = format_markdown_string(source, mode)

check_formatting(expected, actual, ctx=p)

second_pass = format_markdown_string(actual, mode)
check_formatting(expected, second_pass, ctx=f"2nd-{p}")
32 changes: 29 additions & 3 deletions tests/unit_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_perform_safety_check,
_read_path_or_stdin,
_update_source_files,
format_markdown_string,
format_string,
get_matching_paths,
initialize_progress_bar,
Expand Down Expand Up @@ -47,8 +48,11 @@ def all_files(file_discovery_dir: Path) -> Set[Path]:
files = {
p / "top_level_file.sql",
p / "top_level_file.two.sql",
p / "top_level_markdown_file.md",
p / "a_directory/one_file.sql",
p / "a_directory/one_markdown_file.md",
p / "a_directory/nested_directory/another_file.sql",
p / "a_directory/nested_directory/another_markdown_file.md",
p / "a_directory/nested_directory/j2_extension.sql.jinja",
p / "a_directory/symlink_source_directory/symlink_file.sql",
p / "a_directory/symlink_target_directory/symlink_file.sql",
Expand All @@ -75,11 +79,12 @@ def test_file_discovery(
"exclude",
[
["**/*_file*"],
["**/*.sql"],
["**/*.sql", "**/*.md"],
[
"**/top*",
"**/a_directory/*",
"**/a_directory/**/another_file.sql",
"**/a_directory/**/another_markdown_file.md",
"**/a_directory/**/symlink_file.sql",
],
],
Expand All @@ -95,7 +100,10 @@ def test_file_discovery_with_excludes(
def test_file_discovery_with_abs_excludes(
file_discovery_dir: Path, sql_jinja_files: Set[Path]
) -> None:
exclude = [str(file_discovery_dir / "**/*.sql")]
exclude = [
str(file_discovery_dir / "**/*.sql"),
str(file_discovery_dir / "**/*.md"),
]
mode = Mode(exclude=exclude, exclude_root=None)
res = get_matching_paths(file_discovery_dir.iterdir(), mode)
assert res == sql_jinja_files
Expand All @@ -114,7 +122,7 @@ def test_file_discovery_with_invalid_excludes(
def test_file_discovery_with_excludes_no_root(
file_discovery_dir: Path, all_files: Set[Path], sql_jinja_files: Set[Path]
) -> None:
mode = Mode(exclude=["**/*.sql"], exclude_root=None)
mode = Mode(exclude=["**/*.sql", "**/*.md"], exclude_root=None)

# relative to here, excludes shouldn't do anything.
cwd = os.getcwd()
Expand Down Expand Up @@ -142,6 +150,12 @@ def test_format_empty_string(all_output_modes: Mode) -> None:
assert expected == actual


def test_format_markdown_empty_string(all_output_modes: Mode) -> None:
source = expected = ""
actual = format_markdown_string(source, all_output_modes)
assert expected == actual


@pytest.mark.parametrize(
"source,exception",
[
Expand Down Expand Up @@ -356,6 +370,18 @@ def print_dot(_: Any) -> None:
assert "." * expected_dots in captured.out


def test_run_no_markdownfmt_mode(unformatted_files: List[Path]) -> None:
unformatted_markdown_files = [
file for file in unformatted_files if file.suffix == ".md"
]
mode = Mode(no_markdownfmt=True)
report = run(files=unformatted_markdown_files, mode=mode)
assert report.number_changed == 0
assert report.number_unchanged == len(unformatted_markdown_files)
assert report.number_errored == 0
assert not any([res.from_cache for res in report.results])


def test_initialize_progress_bar(default_mode: Mode) -> None:
total = 100
progress_bar, progress_callback = initialize_progress_bar(
Expand Down
8 changes: 3 additions & 5 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_preformatted_short_lines_env(
)
assert results.exit_code == 0
print(results.stderr)
assert "5 files formatted" in results.stderr
assert "6 files formatted" in results.stderr

# test that CLI flag overrides ENV VAR
args = f"{preformatted_dir.as_posix()} -l 88 --check"
Expand All @@ -105,7 +105,7 @@ def test_preformatted_short_lines_env(
)
assert results.exit_code == 0
print(results.stderr)
assert "6 files passed formatting check" in results.stderr
assert "8 files passed formatting check" in results.stderr


def test_unformatted_check(sqlfmt_runner: CliRunner, unformatted_dir: Path) -> None:
Expand Down Expand Up @@ -153,9 +153,7 @@ def test_preformatted_config_file(
def test_preformatted_exclude_all(
sqlfmt_runner: CliRunner, preformatted_dir: Path
) -> None:
args = (
f"{preformatted_dir.as_posix()} --exclude {preformatted_dir.as_posix()}/*.sql"
)
args = f"{preformatted_dir.as_posix()} --exclude {preformatted_dir.as_posix()}/*.sql --exclude {preformatted_dir.as_posix()}/*.md"
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
assert results.exit_code == 0
assert results.stderr.startswith("0 files left unchanged")
Expand Down
2 changes: 1 addition & 1 deletion tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def check_formatting(expected: str, actual: str, ctx: str = "") -> None:

def discover_test_files(relpaths: Iterable[Union[str, Path]]) -> Iterator[Path]:
for p in [BASE_DIR / p for p in relpaths]:
if p.is_file() and p.suffix == ".sql":
if p.is_file() and p.suffix in [".sql", ".md"]:
yield p
elif p.is_dir():
yield from (discover_test_files(p.iterdir()))
Expand Down
Loading