Skip to content

Commit

Permalink
Switch from raw text generation to tomlkit
Browse files Browse the repository at this point in the history
I was hitting a few edge cases:
- The package name "ruamel.yaml" needs to be escaped with quotes
- Inline tables use {key = value} syntax rather than JSON's {key: value}.
  • Loading branch information
maresb committed Sep 15, 2024
1 parent 07a9048 commit d0c94da
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 82 deletions.
4 changes: 2 additions & 2 deletions conda_lock/conda_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1978,11 +1978,11 @@ def do_render_lockspec(
required_categories=required_categories if filter_categories else None,
)
if "pixi.toml" in kinds:
pixi_toml_lines = render_pixi_toml(
pixi_toml = render_pixi_toml(
lock_spec=lock_spec, with_cuda=with_cuda, project_name=pixi_project_name
)
if stdout:
print("\n".join(pixi_toml_lines))
print(pixi_toml.as_string())
else:
raise NotImplementedError("Only stdout is supported at the moment.")

Expand Down
178 changes: 98 additions & 80 deletions conda_lock/export_lock_spec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import warnings

from collections import defaultdict
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union

from tomlkit import TOMLDocument, comment, document, inline_table, item, table
from tomlkit.items import InlineTable, Table

from conda_lock._export_lock_spec_compute_platform_indep import (
unify_platform_independent_deps,
Expand All @@ -19,12 +21,12 @@
class TomlTableKey(NamedTuple):
"""Represents a key in a pixi.toml table.
It can be rendered into a TOML header line using `toml_header_line`.
It can be rendered into a TOML header using `toml_header_sequence`.
>>> toml_header_line(
>>> toml_header_sequence(
... TomlTableKey(category="dev", platform="linux-64", manager="pip")
... )
'[feature.dev.target.linux-64.pypi-dependencies]'
['feature', 'dev', 'target', 'linux-64', 'pypi-dependencies']
"""

category: str
Expand All @@ -37,8 +39,18 @@ def render_pixi_toml(
lock_spec: LockSpecification,
project_name: Optional[str] = None,
with_cuda: Optional[str] = None,
) -> List[str]:
"""Render a pixi.toml from a LockSpecification as a list of lines."""
) -> TOMLDocument:
"""Render a pixi.toml from a LockSpecification as a tomlkit TOMLDocument."""
pixi_toml = document()
for line in (
"This file was generated by conda-lock for the pixi environment manager.",
"For more information, see <https://github.com/conda/conda-lock> "
"and <https://pixi.sh>.",
"Source files:",
*(f"- {src_file}" for src_file in lock_spec.sources),
):
pixi_toml.add(comment(line))

if project_name is None:
project_name = "project-name-placeholder"
all_platforms = lock_spec.dependencies.keys()
Expand All @@ -52,25 +64,15 @@ def render_pixi_toml(
# is a conflict.
raise ValueError("Cannot have both 'main' and 'default' as categories/extras")

# The header block consists of explanatory comments and a list of source files.
lines: List[str] = [
"# This file was generated by conda-lock for the pixi environment manager.",
"# For more information, see <https://github.com/conda/conda-lock> "
"and <https://pixi.sh>.",
"# Source files:",
]
lines.extend(f"# - {source}" for source in lock_spec.sources)

# The project table
lines.extend(
[
"",
"[project]",
f'name = "{project_name}"',
f"platforms = {json.dumps(sorted(all_platforms))}",
f"channels = {json.dumps([channel.url for channel in lock_spec.channels])}",
"",
]
pixi_toml.add(
"project",
item(
dict(
name=project_name,
platforms=list(all_platforms),
channels=[channel.url for channel in lock_spec.channels],
)
),
)
if len(lock_spec.channels) == 0:
warnings.warn(
Expand All @@ -87,51 +89,59 @@ def render_pixi_toml(
# The dependency tables
arranged_deps = arrange_for_toml(lock_spec)
for key, deps_by_name in arranged_deps.items():
lines.append(toml_header_line(key))
lines.extend(toml_dependency_line(dep) for name, dep in deps_by_name.items())
lines.append("")
header_sequence: List[str] = toml_header_sequence(key)
inner_dict = {
name: toml_dependency_value(dep) for name, dep in deps_by_name.items()
}
# Construct the outer dictionary by nesting the inner dictionary within
# the header sequence.
outer_dict = inner_dict
for header in reversed(header_sequence):
outer_dict = {header: outer_dict}
pixi_toml.update(outer_dict)

# The environments table
if len(all_categories) > 1:
lines.extend(define_environments(all_categories))
pixi_toml.add("environments", toml_environments_table(all_categories))

# The system requirements table
if with_cuda:
lines.extend(["[system-requirements]", f'cuda = "{with_cuda}"', ""])
return lines
pixi_toml.add("system-requirements", item(dict(cuda=with_cuda)))

return pixi_toml


def toml_dependency_line(dep: Dependency) -> str:
def toml_dependency_value(dep: Dependency) -> Union[str, InlineTable]:
"""Render a conda-lock Dependency as a pixi.toml line as VersionSpec or matchspec.
The result is suitable for the values used in the `dependencies` or
`pypi-dependencies` tables of a pixi TOML file.
>>> toml_dependency_line(VersionedDependency(name="numpy", version="2.1.1"))
'numpy = "2.1.1"'
>>> toml_dependency_value(VersionedDependency(name="numpy", version="2.1.1"))
'2.1.1'
>>> toml_dependency_line(VersionedDependency(name="numpy", version=""))
'numpy = "*"'
>>> toml_dependency_value(VersionedDependency(name="numpy", version=""))
'*'
>>> toml_dependency_line(
>>> toml_dependency_value(
... VersionedDependency(
... name="numpy",
... version="2.1.1",
... conda_channel="conda-forge",
... build="py313h4bf6692_0"
... )
... )
'numpy = {"version": "2.1.1", "build": "py313h4bf6692_0", "channel": "conda-forge"}'
{'version': '2.1.1', 'build': 'py313h4bf6692_0', 'channel': 'conda-forge'}
>>> toml_dependency_line(
>>> toml_dependency_value(
... VersionedDependency(
... name="xarray",
... version="",
... extras=["io", "parallel"],
... manager="pip",
... )
... )
'xarray = {"version": "*", "extras": ["io", "parallel"]}'
{'version': '*', 'extras': ['io', 'parallel']}
"""
matchspec: Dict[str, Any] = {}
if isinstance(dep, VersionedDependency):
Expand All @@ -149,9 +159,9 @@ def toml_dependency_line(dep: Dependency) -> str:
matchspec["extras"] = dep.extras
if len(matchspec) == 1:
# Use the simpler VersionSpec format if there's only a version.
return f'{dep.name} = "{matchspec["version"]}"'
return matchspec["version"]
else:
return dep.name + " = " + json.dumps(matchspec)
return _dict_to_inline_table(matchspec)
elif isinstance(dep, URLDependency):
raise NotImplementedError(f"URL not yet supported in {dep}")
elif isinstance(dep, VCSDependency):
Expand All @@ -160,6 +170,13 @@ def toml_dependency_line(dep: Dependency) -> str:
raise ValueError(f"Unknown dependency type {dep}")


def _dict_to_inline_table(d: Dict[str, Any]) -> InlineTable:
"""Convert a dictionary to a TOML inline table."""
table = inline_table()
table.update(d)
return table


def arrange_for_toml(
lock_spec: LockSpecification,
) -> Dict[TomlTableKey, Dict[str, Dependency]]:
Expand Down Expand Up @@ -235,58 +252,61 @@ def toml_ordering(item: Tuple[TomlTableKey, dict]) -> Tuple[str, str, str]:
return category, platform, key.manager


def toml_header_line(key: TomlTableKey) -> str:
def toml_header_sequence(key: TomlTableKey) -> List[str]:
"""Generates a TOML header based on the dependency type, platform, and manager.
>>> toml_header_line(
>>> toml_header_sequence(
... TomlTableKey(category="main", platform=None, manager="conda")
... )
'[dependencies]'
['dependencies']
>>> toml_header_line(
>>> toml_header_sequence(
... TomlTableKey(category="main", platform="linux-64", manager="conda")
... )
'[target.linux-64.dependencies]'
['target', 'linux-64', 'dependencies']
>>> toml_header_line(TomlTableKey(category="main", platform=None, manager="pip"))
'[pypi-dependencies]'
>>> toml_header_sequence(
... TomlTableKey(category="main", platform=None, manager="pip")
... )
['pypi-dependencies']
>>> toml_header_line(
>>> toml_header_sequence(
... TomlTableKey(category="main", platform="linux-64", manager="pip")
... )
'[target.linux-64.pypi-dependencies]'
['target', 'linux-64', 'pypi-dependencies']
>>> toml_header_line(TomlTableKey(category="dev", platform=None, manager="conda"))
'[feature.dev.dependencies]'
>>> toml_header_sequence(
... TomlTableKey(category="dev", platform=None, manager="conda")
... )
['feature', 'dev', 'dependencies']
>>> toml_header_line(
>>> toml_header_sequence(
... TomlTableKey(category="dev", platform="linux-64", manager="conda")
... )
'[feature.dev.target.linux-64.dependencies]'
['feature', 'dev', 'target', 'linux-64', 'dependencies']
>>> toml_header_line(TomlTableKey(category="dev", platform=None, manager="pip"))
'[feature.dev.pypi-dependencies]'
>>> toml_header_sequence(TomlTableKey(category="dev", platform=None, manager="pip"))
['feature', 'dev', 'pypi-dependencies']
>>> toml_header_line(
>>> toml_header_sequence(
... TomlTableKey(category="dev", platform="linux-64", manager="pip")
... )
'[feature.dev.target.linux-64.pypi-dependencies]'
['feature', 'dev', 'target', 'linux-64', 'pypi-dependencies']
"""
parts = []
if key.category not in ["main", "default"]:
parts.extend(["feature", key.category])
if key.platform:
parts.extend(["target", key.platform])
parts.append("dependencies" if key.manager == "conda" else "pypi-dependencies")
return "[" + ".".join(parts) + "]"
return parts


def define_environments(all_categories: Set[str]) -> List[str]:
def toml_environments_table(all_categories: Set[str]) -> Table:
r"""Define the environments section of a pixi.toml file.
>>> lines = define_environments({"main", "dev", "docs"})
>>> print("\n".join(lines))
[environments]
>>> environments_table = toml_environments_table({"main", "dev", "docs"})
>>> print(environments_table.as_string())
# Redefine the default environment to include all categories.
default = ["dev", "docs"]
# Define a minimal environment with only the default feature.
Expand All @@ -300,10 +320,11 @@ def define_environments(all_categories: Set[str]) -> List[str]:
if len(non_default_categories) == 0:
raise ValueError("Expected at least one non-default category")

lines = []
lines.append("[environments]")
lines.append("# Redefine the default environment to include all categories.")
lines.append("default = " + json.dumps(non_default_categories))
environments = table()
environments.add(
comment("Redefine the default environment to include all categories.")
)
environments.add("default", non_default_categories)

MINIMAL_ENVIRONMENT_NAMES = ["minimal", "prod", "main"]
minimal_category_name = next(
Expand All @@ -316,15 +337,12 @@ def define_environments(all_categories: Set[str]) -> List[str]:
+ "' are already defined. Skipping."
)
else:
lines.extend(
[
"# Define a minimal environment with only the default feature.",
f"{minimal_category_name} = []",
]
environments.add(
comment("Define a minimal environment with only the default feature.")
)
lines.append("# Create an environment for each feature.")
lines.extend(
f"{category} = " + json.dumps([category]) for category in non_default_categories
)
lines.append("")
return lines
environments.add(minimal_category_name, [])

environments.add(comment("Create an environment for each feature."))
for category in non_default_categories:
environments.add(category, [category])
return environments

0 comments on commit d0c94da

Please sign in to comment.