diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 5b270fb..9a62bc6 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -17,6 +17,48 @@ FILENAME_FIELD = [FILENAME, FIELD] CONCAT_FIELDS = FIELD_SEPARATOR.join(FILENAME_FIELD) +SPLIT_ANCHORS = [ + "color", + "data", + "plot_height", + "plot_width", + "shape", + "stroke_dash", + "title", + "tooltip", + "x_label", + "y_label", + "zoom_and_pan", +] +OPTIONAL_ANCHORS = [ + "color", + "column", + "group_by_x", + "group_by_y", + "group_by", + "pivot_field", + "plot_height", + "plot_width", + "row", + "shape", + "stroke_dash", + "tooltip", + "zoom_and_pan", +] +OPTIONAL_ANCHOR_RANGES: Dict[str, Union[List[str], List[List[int]]]] = { + "stroke_dash": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]], + "color": [ + "#945dd6", + "#13adc7", + "#f46837", + "#48bb78", + "#4299e1", + "#ed8936", + "#f56565", + ], + "shape": ["circle", "square", "triangle", "diamond"], +} + class VegaRenderer(Renderer): """Renderer for vega plots.""" @@ -46,25 +88,6 @@ def __init__(self, datapoints: List, name: str, **properties): self.properties.get("template", None), self.properties.get("template_dir", None), ) - self._optional_anchor_ranges: Dict[ - str, - Union[ - List[str], - List[List[int]], - ], - ] = { - "stroke_dash": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]], - "color": [ - "#945dd6", - "#13adc7", - "#f46837", - "#48bb78", - "#4299e1", - "#ed8936", - "#f56565", - ], - "shape": ["circle", "square", "triangle", "diamond"], - } self._split_content: Dict[str, str] = {} @@ -126,19 +149,7 @@ def get_partial_filled_template(self): Returns a partially filled template along with the split out anchor content """ content = self.get_filled_template( - split_anchors=[ - "color", - "data", - "plot_height", - "plot_width", - "shape", - "stroke_dash", - "title", - "tooltip", - "x_label", - "y_label", - "zoom_and_pan", - ], + split_anchors=SPLIT_ANCHORS, strict=True, ) return content, {"anchor_definitions": self._split_content} @@ -206,23 +217,7 @@ def get_revs(self): def _process_optional_anchors(self, split_anchors: List[str]): optional_anchors = [ - anchor - for anchor in [ - "color", - "column", - "group_by_x", - "group_by_y", - "group_by", - "pivot_field", - "plot_height", - "plot_width", - "row", - "shape", - "stroke_dash", - "tooltip", - "zoom_and_pan", - ] - if self.template.has_anchor(anchor) + anchor for anchor in OPTIONAL_ANCHORS if self.template.has_anchor(anchor) ] if not optional_anchors: return None @@ -443,7 +438,7 @@ def _get_optional_anchor_mapping( name: str, domain: List[str], ): - full_range_values: List[Any] = self._optional_anchor_ranges.get(name, []) + full_range_values: List[Any] = OPTIONAL_ANCHOR_RANGES.get(name, []) anchor_range_values = full_range_values.copy() anchor_range = [] @@ -454,6 +449,7 @@ def _get_optional_anchor_mapping( anchor_range.append(range_value) legend = ( + # fix stroke dash and shape legend entry appearance (use empty shapes) {"legend": {"symbolFillColor": "transparent", "symbolStrokeColor": "grey"}} if name != "color" else {} diff --git a/tests/test_vega.py b/tests/test_vega.py index 0e3a26c..9c6682c 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -3,7 +3,7 @@ import pytest -from dvc_render.vega import BadTemplateError, VegaRenderer +from dvc_render.vega import OPTIONAL_ANCHOR_RANGES, BadTemplateError, VegaRenderer from dvc_render.vega_templates import NoFieldInDataError, Template # pylint: disable=missing-function-docstring, C1803, C0302 @@ -339,7 +339,10 @@ def test_fill_anchor_in_string(tmp_dir): ["rev", "acc", "step", "filename"], { "field": "filename", - "scale": {"domain": ["test", "train"], "range": [[1, 0], [8, 8]]}, + "scale": { + "domain": ["test", "train"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, "legend": { "symbolFillColor": "transparent", "symbolStrokeColor": "grey", @@ -388,7 +391,10 @@ def test_fill_anchor_in_string(tmp_dir): ["rev", "dvc_inferred_y_value", "step", "field"], { "field": "field", - "scale": {"domain": ["acc", "acc_norm"], "range": [[1, 0], [8, 8]]}, + "scale": { + "domain": ["acc", "acc_norm"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, "legend": { "symbolFillColor": "transparent", "symbolStrokeColor": "grey", @@ -454,7 +460,7 @@ def test_fill_anchor_in_string(tmp_dir): "field": "filename::field", "scale": { "domain": ["test::acc", "test::acc_norm", "train::acc"], - "range": [[1, 0], [8, 8], [8, 4]], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:3], }, "legend": { "symbolFillColor": "transparent", @@ -492,7 +498,7 @@ def test_optional_anchors_linear( assert plot_content["data"]["values"] == expected_datapoints assert plot_content["encoding"]["color"] == { "field": "rev", - "scale": {"domain": ["B"], "range": ["#945dd6"]}, + "scale": {"domain": ["B"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:1]}, } assert plot_content["encoding"]["strokeDash"] == stroke_dash_encoding assert plot_content["layer"][3]["transform"][0]["calculate"] == pivot_field @@ -763,7 +769,7 @@ def test_optional_anchors_confusion( }, "scale": { "domain": ["test", "train"], - "range": ["circle", "square"], + "range": OPTIONAL_ANCHOR_RANGES["shape"][0:2], }, }, [ @@ -831,7 +837,7 @@ def test_optional_anchors_confusion( }, "scale": { "domain": ["test_acc", "train_acc"], - "range": ["circle", "square"], + "range": OPTIONAL_ANCHOR_RANGES["shape"][0:2], }, }, [ @@ -895,7 +901,7 @@ def test_optional_anchors_confusion( }, "scale": { "domain": ["test::test_acc", "train::train_acc"], - "range": ["circle", "square"], + "range": OPTIONAL_ANCHOR_RANGES["shape"][0:2], }, }, [ @@ -932,7 +938,7 @@ def test_optional_anchors_scatter( assert plot_content["data"]["values"] == expected_datapoints assert plot_content["encoding"]["color"] == { "field": "rev", - "scale": {"domain": ["B", "C"], "range": ["#945dd6", "#13adc7"]}, + "scale": {"domain": ["B", "C"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:2]}, } assert plot_content["encoding"]["shape"] == shape_encoding assert plot_content["encoding"]["tooltip"] == tooltip_encoding @@ -996,7 +1002,10 @@ def test_optional_anchors_scatter( ["rev", "acc", "step", "field"], { "field": "filename", - "scale": {"domain": ["test", "train"], "range": [[1, 0], [8, 8]]}, + "scale": { + "domain": ["test", "train"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, "legend": { "symbolFillColor": "transparent", "symbolStrokeColor": "grey", @@ -1029,7 +1038,10 @@ def test_optional_anchors_scatter( ["rev", "dvc_inferred_y_value", "step", "field"], { "field": "field", - "scale": {"domain": ["acc", "acc_norm"], "range": [[1, 0], [8, 8]]}, + "scale": { + "domain": ["acc", "acc_norm"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, "legend": { "symbolFillColor": "transparent", "symbolStrokeColor": "grey", @@ -1072,7 +1084,7 @@ def test_optional_anchors_scatter( "field": "filename::field", "scale": { "domain": ["test::acc", "test::acc_norm", "train::acc"], - "range": [[1, 0], [8, 8], [8, 4]], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:3], }, "legend": { "symbolFillColor": "transparent", @@ -1103,7 +1115,7 @@ def test_partial_filled_template( expected_split = { Template.anchor("color"): { "field": "rev", - "scale": {"domain": ["B"], "range": ["#945dd6"]}, + "scale": {"domain": ["B"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:1]}, }, Template.anchor("data"): _get_expected_datapoints(datapoints, expected_dp_keys), Template.anchor("plot_height"): 300,