From 32d1aabb767abd1a20c3d678f97986f075a39218 Mon Sep 17 00:00:00 2001 From: ammekk Date: Tue, 18 Jan 2022 16:47:40 -0500 Subject: [PATCH 1/5] added color scales --- hail/python/hail/plot2/__init__.py | 8 +++-- hail/python/hail/plot2/geoms.py | 48 ++++++++++-------------------- hail/python/hail/plot2/ggplot.py | 25 ++++++++++++---- hail/python/hail/plot2/scale.py | 47 ++++++++++++++++++++++++++++- 4 files changed, 87 insertions(+), 41 deletions(-) diff --git a/hail/python/hail/plot2/__init__.py b/hail/python/hail/plot2/__init__.py index 36020f437dd..6466dd13cfe 100644 --- a/hail/python/hail/plot2/__init__.py +++ b/hail/python/hail/plot2/__init__.py @@ -4,7 +4,8 @@ from .geoms import geom_line, geom_point, geom_text, geom_bar, geom_histogram, geom_hline, geom_func, geom_vline, geom_tile from .labels import ggtitle, xlab, ylab from .scale import scale_x_continuous, scale_y_continuous, scale_x_discrete, scale_y_discrete, scale_x_genomic, \ - scale_x_log10, scale_y_log10, scale_x_reverse, scale_y_reverse + scale_x_log10, scale_y_log10, scale_x_reverse, scale_y_reverse, scale_color_discrete, scale_color_identity,\ + scale_color_continuous __all__ = [ "aes", @@ -30,5 +31,8 @@ "scale_x_log10", "scale_y_log10", "scale_x_reverse", - "scale_y_reverse" + "scale_y_reverse", + "scale_color_continuous", + "scale_color_identity", + "scale_color_discrete" ] \ No newline at end of file diff --git a/hail/python/hail/plot2/geoms.py b/hail/python/hail/plot2/geoms.py index 6de6c769457..01ad0fa264c 100644 --- a/hail/python/hail/plot2/geoms.py +++ b/hail/python/hail/plot2/geoms.py @@ -112,7 +112,9 @@ def make_agg(self, mapping): # aesthetics in geom_struct are just pointers to x_expr, that's fine. # Maybe I just do a `take(1) for every field of parent_struct and geom_struct? # Or better, a collect_as_set where I error if size is greater than 1? - return hl.agg.group_by(mapping["x"], hl.struct(count=hl.agg.count(), other=hl.agg.take(mapping.drop("x"), 1))) + + #go through mappings grab all non continuous fields + return hl.agg.group_by(mapping.select("x", "color", ), hl.struct(count=hl.agg.count(), other=hl.agg.take(mapping.drop("x"), 1))) def listify(self, agg_result): unflattened_items = agg_result.items() @@ -155,50 +157,32 @@ def __init__(self, aes, color=None): self.color = color def apply_to_fig(self, parent, agg_result, fig_so_far): - def plot_one_color(one_color_data, color, legend_name): - scatter_args = { - "x": [element["x"] for element in one_color_data], - "y": [element["y"] for element in one_color_data], - "mode": "markers", - "marker_color": color, - "name": legend_name - } - if "size" in parent.aes or "size" in self.aes: - scatter_args["marker_size"] = [element["size"] for element in one_color_data] - if "tooltip" in parent.aes or "tooltip" in self.aes: - scatter_args["hovertext"] = [element["tooltip"] for element in one_color_data] - fig_so_far.add_scatter(**scatter_args) - - def plot_continuous_color(data, colors): + def plot_group(data, color=None): scatter_args = { "x": [element["x"] for element in data], "y": [element["y"] for element in data], "mode": "markers", - "marker_color": colors + "marker_color": color if color is not None else [element["color"] for element in data] } + if "color_legend" in data[0]: + scatter_args["name"] = data[0]["color_legend"] + if "size" in parent.aes or "size" in self.aes: scatter_args["marker_size"] = [element["size"] for element in data] + if "tooltip" in parent.aes or "tooltip" in self.aes: + scatter_args["hovertext"] = [element["tooltip"] for element in data] fig_so_far.add_scatter(**scatter_args) if self.color is not None: - plot_one_color(agg_result, self.color, None) + plot_group(agg_result, self.color) elif "color" in parent.aes or "color" in self.aes: - if isinstance(agg_result[0]["color"], int): - # Should show colors in continuous scale. - input_color_nums = [element["color"] for element in agg_result] - color_mapping = continuous_nums_to_colors(input_color_nums, parent.continuous_color_scale) - plot_continuous_color(agg_result, color_mapping) - - else: - categorical_strings = set([element["color"] for element in agg_result]) - unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent) - - for category in categorical_strings: - filtered_data = [element for element in agg_result if element["color"] == category] - plot_one_color(filtered_data, unique_color_mapping[category], category) + groups = set([element["group"] for element in agg_result]) + for group in groups: + just_one_group = [element for element in agg_result if element["group"] == group] + plot_group(just_one_group) else: - plot_one_color(agg_result, "black", None) + plot_group(agg_result, "black") def get_stat(self): return StatIdentity() diff --git a/hail/python/hail/plot2/ggplot.py b/hail/python/hail/plot2/ggplot.py index 503fcb5892b..3da3b2635da 100644 --- a/hail/python/hail/plot2/ggplot.py +++ b/hail/python/hail/plot2/ggplot.py @@ -9,7 +9,8 @@ from .coord_cartesian import CoordCartesian from .geoms import Geom, FigureAttribute from .labels import Labels -from .scale import Scale, ScaleContinuous, ScaleDiscrete, scale_x_continuous, scale_x_genomic, scale_y_continuous, scale_x_discrete, scale_y_discrete +from .scale import Scale, ScaleContinuous, ScaleDiscrete, scale_x_continuous, scale_x_genomic, scale_y_continuous, \ + scale_x_discrete, scale_y_discrete, scale_color_discrete, scale_color_continuous from .aes import Aesthetic, aes @@ -79,6 +80,10 @@ def is_genomic_type(dtype): raise ValueError("Don't yet support y axis genomic") else: self.scales["y"] = scale_y_discrete() + elif aesthetic_str == "color" and not is_continuous: + self.scales["color"] = scale_color_discrete() + elif aesthetic_str == "color" and is_continuous: + self.scales["color"] = scale_color_continuous() else: if is_continuous: self.scales[aesthetic_str] = ScaleContinuous(aesthetic_str) @@ -120,18 +125,26 @@ def render(self): for geom, (label, agg_result) in zip(self.geoms, aggregated.items()): listified_agg_result = labels_to_stats[label].listify(agg_result) - # Ok, need to identify every possible value of every discrete scale. if listified_agg_result: - relevant_aesthetics = [scale_name for scale_name in list(listified_agg_result[0]) if scale_name in self.scales and self.scales[scale_name].is_discrete()] - subsetted_to_relevant = tuple([one_struct.select(*relevant_aesthetics) for one_struct in listified_agg_result]) + # apply local scales here + # grab list of aesthetic names, loop through, grab scale associated with it, called new function with + # each scale + relevant_aesthetics =\ + [scale_name for scale_name in list(listified_agg_result[0]) if scale_name in self.scales] + for relevant_aesthetic in relevant_aesthetics: + listified_agg_result =\ + self.scales[relevant_aesthetic].transform_data_local(listified_agg_result, self) + + # Ok, need to identify every possible value of every discrete scale. + discrete_aesthetics = [scale_name for scale_name in relevant_aesthetics if self.scales[scale_name].is_discrete()] + subsetted_to_discrete = tuple([one_struct.select(*discrete_aesthetics) for one_struct in listified_agg_result]) group_counter = 0 - def increment_and_return_old(): nonlocal group_counter group_counter += 1 return group_counter - 1 numberer = defaultdict(increment_and_return_old) - numbering = [numberer[data_tuple] for data_tuple in subsetted_to_relevant] + numbering = [numberer[data_tuple] for data_tuple in subsetted_to_discrete] numbered_result = [] for s, i in zip(listified_agg_result, numbering): diff --git a/hail/python/hail/plot2/scale.py b/hail/python/hail/plot2/scale.py index 43749db9ecd..22302e93f8e 100644 --- a/hail/python/hail/plot2/scale.py +++ b/hail/python/hail/plot2/scale.py @@ -3,6 +3,8 @@ from hail import get_reference +from .utils import categorical_strings_to_colors, continuous_nums_to_colors + class Scale(FigureAttribute): def __init__(self, aesthetic_name): @@ -12,6 +14,9 @@ def __init__(self, aesthetic_name): def transform_data(self, field_expr): pass + def transform_data_local(self, data, parent): + return data + @abc.abstractmethod def is_discrete(self): pass @@ -123,6 +128,34 @@ def is_discrete(self): return True +class ScaleColorDiscrete(ScaleDiscrete): + def transform_data_local(self, data, parent): + categorical_strings = set([element["color"] for element in data]) + unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent) + + updated_data = [] + for category in categorical_strings: + for data_entry in data: + if data_entry["color"] == category: + updated_data.append(data_entry.annotate(color=unique_color_mapping[category], color_legend=category)) + return updated_data + + +class ScaleColorContinuous(ScaleContinuous): + def transform_data_local(self, data, parent): + color_list = [element["color"] for element in data] + color_mapping = continuous_nums_to_colors(color_list, parent.continuous_color_scale) + updated_data = [] + for data_idx, data_entry in enumerate(data): + updated_data.append(data_entry.annotate(color=color_mapping[data_idx], color_legend=data_entry["color"])) + + return updated_data + +class ScaleColorDiscreteIdentity(ScaleDiscrete): + def transform_data_local(self, data, parent): + return data + + def scale_x_log10(): return PositionScaleContinuous("x", transformation="log10") @@ -156,4 +189,16 @@ def scale_y_discrete(name=None, breaks=None, labels=None): def scale_x_genomic(reference_genome, name=None): - return PositionScaleGenomic("x", reference_genome, name=name) \ No newline at end of file + return PositionScaleGenomic("x", reference_genome, name=name) + + +def scale_color_discrete(): + return ScaleColorDiscrete("color") + + +def scale_color_continuous(): + return ScaleColorContinuous("color") + + +def scale_color_identity(): + return ScaleColorDiscreteIdentity("color") From 25214005f063875673c4e5c740b7ae0b6cec1080 Mon Sep 17 00:00:00 2001 From: ammekk Date: Wed, 19 Jan 2022 13:02:28 -0500 Subject: [PATCH 2/5] working on changing geoms to plot_group --- hail/python/hail/plot2/geoms.py | 112 ++++++++++++++------------------ hail/python/hail/plot2/scale.py | 2 + 2 files changed, 52 insertions(+), 62 deletions(-) diff --git a/hail/python/hail/plot2/geoms.py b/hail/python/hail/plot2/geoms.py index 01ad0fa264c..581942ae8e7 100644 --- a/hail/python/hail/plot2/geoms.py +++ b/hail/python/hail/plot2/geoms.py @@ -31,33 +31,35 @@ def __init__(self, aes, color): def apply_to_fig(self, parent, agg_result, fig_so_far): - def plot_one_color(one_color_data, color, legend_name): + def plot_group(data, color=None): scatter_args = { - "x": [element["x"] for element in one_color_data], - "y": [element["y"] for element in one_color_data], + "x": [element["x"] for element in data], + "y": [element["y"] for element in data], "mode": "lines", - "name": legend_name, "line_color": color + } + if "color_legend" in data[0]: + scatter_args["name"] = data[0]["color_legend"] + if "tooltip" in parent.aes or "tooltip" in self.aes: - scatter_args["hovertext"] = [element["tooltip"] for element in one_color_data] + scatter_args["hovertext"] = [element["tooltip"] for element in data] fig_so_far.add_scatter(**scatter_args) if self.color is not None: - plot_one_color(agg_result, self.color, None) + plot_group(agg_result, self.color) elif "color" in parent.aes or "color" in self.aes: if isinstance(agg_result[0]["color"], int): # Should show colors in continuous scale. raise ValueError("Do not currently support continuous color changing of lines") + # Groups messed up so that if there is a break in group, line not continuous else: - categorical_strings = set([element["color"] for element in agg_result]) - unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent) - - for category in categorical_strings: - filtered_data = [element for element in agg_result if element["color"] == category] - plot_one_color(filtered_data, unique_color_mapping[category], category) + groups = set([element["group"] for element in agg_result]) + for group in groups: + just_one_group = [element for element in agg_result if element["group"] == group] + plot_group(just_one_group, just_one_group[0]["color"]) else: - plot_one_color(agg_result, "black", None) + plot_group(agg_result, "black") @abc.abstractmethod def get_stat(self): @@ -112,9 +114,7 @@ def make_agg(self, mapping): # aesthetics in geom_struct are just pointers to x_expr, that's fine. # Maybe I just do a `take(1) for every field of parent_struct and geom_struct? # Or better, a collect_as_set where I error if size is greater than 1? - - #go through mappings grab all non continuous fields - return hl.agg.group_by(mapping.select("x", "color", ), hl.struct(count=hl.agg.count(), other=hl.agg.take(mapping.drop("x"), 1))) + return hl.agg.group_by(mapping.select("x", "color"), hl.struct(count=hl.agg.count(), other=hl.agg.take(mapping.drop("x"), 1))) def listify(self, agg_result): unflattened_items = agg_result.items() @@ -162,7 +162,7 @@ def plot_group(data, color=None): "x": [element["x"] for element in data], "y": [element["y"] for element in data], "mode": "markers", - "marker_color": color if color is not None else [element["color"] for element in data] + "marker_color": color } if "color_legend" in data[0]: @@ -180,7 +180,7 @@ def plot_group(data, color=None): groups = set([element["group"] for element in agg_result]) for group in groups: just_one_group = [element for element in agg_result if element["group"] == group] - plot_group(just_one_group) + plot_group(just_one_group, just_one_group[0]["color"]) else: plot_group(agg_result, "black") @@ -216,50 +216,34 @@ def __init__(self, aes, color=None): self.color = color def apply_to_fig(self, parent, agg_result, fig_so_far): - def plot_one_color(one_color_data, color, legend_name): + def plot_group(data, color=None): scatter_args = { - "x": [element["x"] for element in one_color_data], - "y": [element["y"] for element in one_color_data], - "text": [element["label"] for element in one_color_data], + "x": [element["x"] for element in data], + "y": [element["y"] for element in data], + "text": [element["label"] for element in data], "mode": "text", - "name": legend_name, "textfont_color": color } - if "size" in parent.aes or "size" in self.aes: - scatter_args["textfont_size"] = [element["size"] for element in one_color_data] - if "tooltip" in parent.aes or "tooltip" in self.aes: - scatter_args["hovertext"] = [element["tooltip"] for element in one_color_data] - fig_so_far.add_scatter(**scatter_args) + if "color_legend" in data[0]: + scatter_args["name"] = data[0]["color_legend"] - def plot_continuous_color(data, colors): - scatter_args = { - "x": [element["x"] for element in data], - "y": [element["y"] for element in data], - "mode": "markers", - "marker_color": colors - } + if "tooltip" in parent.aes or "tooltip" in self.aes: + scatter_args["hovertext"] = [element["tooltip"] for element in data] if "size" in parent.aes or "size" in self.aes: scatter_args["marker_size"] = [element["size"] for element in data] fig_so_far.add_scatter(**scatter_args) if self.color is not None: - plot_one_color(agg_result, self.color, None) + plot_group(agg_result, self.color) elif "color" in parent.aes or "color" in self.aes: - if isinstance(agg_result[0]["color"], int): - # Should show colors in continuous scale. - input_color_nums = [element["color"] for element in agg_result] - color_mapping = continuous_nums_to_colors(input_color_nums, parent.continuous_color_scale) - plot_continuous_color(agg_result, color_mapping) - - else: - categorical_strings = set([element["color"] for element in agg_result]) - unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent) - - for category in categorical_strings: - filtered_data = [element for element in agg_result if element["color"] == category] - plot_one_color(filtered_data, unique_color_mapping[category], category) + groups = set([element["group"] for element in agg_result]) + for group in groups: + just_one_group = [element for element in agg_result if element["group"] == group] + plot_group(just_one_group, just_one_group[0]["color"]) + else: + plot_group(agg_result, "black") def get_stat(self): return StatIdentity() @@ -276,23 +260,27 @@ def __init__(self, aes, color=None): self.color = color def apply_to_fig(self, parent, agg_result, fig_so_far): - item_list = agg_result + def plot_group(data, color=None): + bar_args = { + "x": [element["x"] for element in data], + "y": [element["y"] for element in data], + "marker_color": color + } + if "color_legend" in data[0]: + bar_args["name"] = data[0]["color_legend"] - if self.color: - color = self.color + fig_so_far.add_bar(**bar_args) + + if self.color is not None: + plot_group(agg_result, self.color) elif "color" in parent.aes or "color" in self.aes: - categorical_strings = set([item["color"] for item in item_list]) - color_mapping = categorical_strings_to_colors(categorical_strings, parent) - color = [color_mapping[item["color"]] for item in item_list] + groups = set([element["group"] for element in agg_result]) + for group in groups: + just_one_group = [element for element in agg_result if element["group"] == group] + plot_group(just_one_group, just_one_group[0]["color"]) else: - color = "black" + plot_group(agg_result, "black") - bar_args = { - "x": [item["x"] for item in item_list], - "y": [item["y"] for item in item_list], - "marker_color": color - } - fig_so_far.add_bar(**bar_args) def get_stat(self): return StatCount() diff --git a/hail/python/hail/plot2/scale.py b/hail/python/hail/plot2/scale.py index 22302e93f8e..087b1e5d3e1 100644 --- a/hail/python/hail/plot2/scale.py +++ b/hail/python/hail/plot2/scale.py @@ -151,6 +151,8 @@ def transform_data_local(self, data, parent): return updated_data + +# Legend names messed up for scale color identity class ScaleColorDiscreteIdentity(ScaleDiscrete): def transform_data_local(self, data, parent): return data From 8feb2d20f364c71bd35233ed2d6a3d856c7506db Mon Sep 17 00:00:00 2001 From: ammekk Date: Wed, 19 Jan 2022 13:12:20 -0500 Subject: [PATCH 3/5] debugging --- hail/python/hail/plot2/geoms.py | 6 +++++- hail/python/hail/plot2/ggplot.py | 1 - hail/python/hail/plot2/scale.py | 1 - 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/hail/python/hail/plot2/geoms.py b/hail/python/hail/plot2/geoms.py index 581942ae8e7..024524d159d 100644 --- a/hail/python/hail/plot2/geoms.py +++ b/hail/python/hail/plot2/geoms.py @@ -114,7 +114,7 @@ def make_agg(self, mapping): # aesthetics in geom_struct are just pointers to x_expr, that's fine. # Maybe I just do a `take(1) for every field of parent_struct and geom_struct? # Or better, a collect_as_set where I error if size is greater than 1? - return hl.agg.group_by(mapping.select("x", "color"), hl.struct(count=hl.agg.count(), other=hl.agg.take(mapping.drop("x"), 1))) + return hl.agg.group_by(mapping["x"], hl.struct(count=hl.agg.count(), other=hl.agg.take(mapping.drop("x"), 1))) def listify(self, agg_result): unflattened_items = agg_result.items() @@ -273,10 +273,14 @@ def plot_group(data, color=None): if self.color is not None: plot_group(agg_result, self.color) + + # continuous colors is working differently than in johnc-plotly, its because all bars are showing + # up in the same group elif "color" in parent.aes or "color" in self.aes: groups = set([element["group"] for element in agg_result]) for group in groups: just_one_group = [element for element in agg_result if element["group"] == group] + import pdb; pdb.set_trace() plot_group(just_one_group, just_one_group[0]["color"]) else: plot_group(agg_result, "black") diff --git a/hail/python/hail/plot2/ggplot.py b/hail/python/hail/plot2/ggplot.py index 3da3b2635da..fde50e4886a 100644 --- a/hail/python/hail/plot2/ggplot.py +++ b/hail/python/hail/plot2/ggplot.py @@ -134,7 +134,6 @@ def render(self): for relevant_aesthetic in relevant_aesthetics: listified_agg_result =\ self.scales[relevant_aesthetic].transform_data_local(listified_agg_result, self) - # Ok, need to identify every possible value of every discrete scale. discrete_aesthetics = [scale_name for scale_name in relevant_aesthetics if self.scales[scale_name].is_discrete()] subsetted_to_discrete = tuple([one_struct.select(*discrete_aesthetics) for one_struct in listified_agg_result]) diff --git a/hail/python/hail/plot2/scale.py b/hail/python/hail/plot2/scale.py index 087b1e5d3e1..1263b34e544 100644 --- a/hail/python/hail/plot2/scale.py +++ b/hail/python/hail/plot2/scale.py @@ -148,7 +148,6 @@ def transform_data_local(self, data, parent): updated_data = [] for data_idx, data_entry in enumerate(data): updated_data.append(data_entry.annotate(color=color_mapping[data_idx], color_legend=data_entry["color"])) - return updated_data From a4fa12883105d1dd2e9d5ba7423499e5febc834e Mon Sep 17 00:00:00 2001 From: ammekk Date: Thu, 20 Jan 2022 10:34:53 -0500 Subject: [PATCH 4/5] added scale checks --- hail/python/hail/ir/utils.py | 13 +++++++ hail/python/hail/plot2/geoms.py | 62 +++++++++----------------------- hail/python/hail/plot2/ggplot.py | 17 +++++++-- hail/python/hail/plot2/scale.py | 30 +++++++++++++--- 4 files changed, 69 insertions(+), 53 deletions(-) diff --git a/hail/python/hail/ir/utils.py b/hail/python/hail/ir/utils.py index 0c445beacb6..fc382864ac4 100644 --- a/hail/python/hail/ir/utils.py +++ b/hail/python/hail/ir/utils.py @@ -1,5 +1,18 @@ from .ir import Coalesce, ApplyUnaryPrimOp, FalseIR +import hail as hl + + +def check_scale_continuity(scale, dtype, aes_key): + if scale.is_discrete() and is_continuous_type(dtype): + raise ValueError(f"Aesthetic {aes_key} has continuous dtype but non continuous scale") + if not scale.is_discrete() and not is_continuous_type(dtype): + raise ValueError(f"Aesthetic {aes_key} has non continuous dtype but continuous scale") + + +def is_continuous_type(dtype): + return dtype in [hl.tint32, hl.tint64, hl.float32, hl.float64] + def filter_predicate_with_keep(ir_pred, keep): return Coalesce(ir_pred if keep else ApplyUnaryPrimOp('!', ir_pred), FalseIR()) diff --git a/hail/python/hail/plot2/geoms.py b/hail/python/hail/plot2/geoms.py index 024524d159d..e88a1e5996b 100644 --- a/hail/python/hail/plot2/geoms.py +++ b/hail/python/hail/plot2/geoms.py @@ -49,15 +49,10 @@ def plot_group(data, color=None): if self.color is not None: plot_group(agg_result, self.color) elif "color" in parent.aes or "color" in self.aes: - if isinstance(agg_result[0]["color"], int): - # Should show colors in continuous scale. - raise ValueError("Do not currently support continuous color changing of lines") - # Groups messed up so that if there is a break in group, line not continuous - else: - groups = set([element["group"] for element in agg_result]) - for group in groups: - just_one_group = [element for element in agg_result if element["group"] == group] - plot_group(just_one_group, just_one_group[0]["color"]) + groups = set([element["group"] for element in agg_result]) + for group in groups: + just_one_group = [element for element in agg_result if element["group"] == group] + plot_group(just_one_group, just_one_group[0]["color"]) else: plot_group(agg_result, "black") @@ -162,7 +157,7 @@ def plot_group(data, color=None): "x": [element["x"] for element in data], "y": [element["y"] for element in data], "mode": "markers", - "marker_color": color + "marker_color": color if color is not None else [element["color"] for element in data] } if "color_legend" in data[0]: @@ -180,7 +175,7 @@ def plot_group(data, color=None): groups = set([element["group"] for element in agg_result]) for group in groups: just_one_group = [element for element in agg_result if element["group"] == group] - plot_group(just_one_group, just_one_group[0]["color"]) + plot_group(just_one_group) else: plot_group(agg_result, "black") @@ -264,7 +259,7 @@ def plot_group(data, color=None): bar_args = { "x": [element["x"] for element in data], "y": [element["y"] for element in data], - "marker_color": color + "marker_color": color if color is not None else [element["color"] for element in data] } if "color_legend" in data[0]: bar_args["name"] = data[0]["color_legend"] @@ -274,14 +269,11 @@ def plot_group(data, color=None): if self.color is not None: plot_group(agg_result, self.color) - # continuous colors is working differently than in johnc-plotly, its because all bars are showing - # up in the same group elif "color" in parent.aes or "color" in self.aes: groups = set([element["group"] for element in agg_result]) for group in groups: just_one_group = [element for element in agg_result if element["group"] == group] - import pdb; pdb.set_trace() - plot_group(just_one_group, just_one_group[0]["color"]) + plot_group(just_one_group) else: plot_group(agg_result, "black") @@ -394,8 +386,8 @@ def __init__(self, aes): self.aes = aes def apply_to_fig(self, parent, agg_result, fig_so_far): - def plot_continuous_color(agg_results, colors): - for idx, row in enumerate(agg_results): + def plot_group(data, fill=None): + for idx, row in enumerate(data): x_center = row['x'] y_center = row['y'] width = row['width'] @@ -405,36 +397,16 @@ def plot_continuous_color(agg_results, colors): x_right = x_center + width / 2 y_up = y_center + height / 2 y_down = y_center - height / 2 - fig_so_far.add_shape(type="rect", x0=x_left, y0=y_down, x1=x_right, y1=y_up, fillcolor=colors[idx], opacity=alpha) - - def plot_one_color(agg_results, color): - for idx, row in enumerate(agg_results): - x_center = row['x'] - y_center = row['y'] - width = row['width'] - height = row['height'] - alpha= row.get('alpha', 1.0) - x_left = x_center - width / 2 - x_right = x_center + width / 2 - y_up = y_center + height / 2 - y_down = y_center - height / 2 - fig_so_far.add_shape(type="rect", x0=x_left, y0=y_down, x1=x_right, y1=y_up, fillcolor=color, opacity=alpha) + fill = fill if fill is not None else row['fill'] + fig_so_far.add_shape(type="rect", x0=x_left, y0=y_down, x1=x_right, y1=y_up, fillcolor=fill, opacity=alpha) if "fill" in parent.aes or "fill" in self.aes: - if isinstance(agg_result[0]["fill"], int): - input_color_nums = [element["fill"] for element in agg_result] - color_mapping = continuous_nums_to_colors(input_color_nums, parent.continuous_color_scale) - plot_continuous_color(agg_result, color_mapping) - else: - categorical_strings = set([element["fill"] for element in agg_result]) - unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent) - - for category in categorical_strings: - filtered_data = [element for element in agg_result if element["fill"] == category] - plot_one_color(filtered_data, unique_color_mapping[category]) - + groups = set([element["group"] for element in agg_result]) + for group in groups: + just_one_group = [element for element in agg_result if element["group"] == group] + plot_group(just_one_group) else: - plot_one_color(agg_result, "black") + plot_group(agg_result, "black") def get_stat(self): return StatIdentity() diff --git a/hail/python/hail/plot2/ggplot.py b/hail/python/hail/plot2/ggplot.py index fde50e4886a..c601a66f328 100644 --- a/hail/python/hail/plot2/ggplot.py +++ b/hail/python/hail/plot2/ggplot.py @@ -10,8 +10,10 @@ from .geoms import Geom, FigureAttribute from .labels import Labels from .scale import Scale, ScaleContinuous, ScaleDiscrete, scale_x_continuous, scale_x_genomic, scale_y_continuous, \ - scale_x_discrete, scale_y_discrete, scale_color_discrete, scale_color_continuous + scale_x_discrete, scale_y_discrete, scale_color_discrete, scale_color_continuous, scale_fill_discrete, \ + scale_fill_continuous from .aes import Aesthetic, aes +from ..ir.utils import is_continuous_type, check_scale_continuity class GGPlot: @@ -55,8 +57,6 @@ def __add__(self, other): return copied def add_default_scales(self, aesthetic): - def is_continuous_type(dtype): - return dtype in [hl.tint32, hl.tint64, hl.float32, hl.float64] def is_genomic_type(dtype): return isinstance(dtype, hl.tlocus) @@ -84,6 +84,10 @@ def is_genomic_type(dtype): self.scales["color"] = scale_color_discrete() elif aesthetic_str == "color" and is_continuous: self.scales["color"] = scale_color_continuous() + elif aesthetic_str == "fill" and not is_continuous: + self.scales["fill"] = scale_fill_discrete() + elif aesthetic_str == "fill" and is_continuous: + self.scales["fill"] = scale_fill_continuous() else: if is_continuous: self.scales[aesthetic_str] = ScaleContinuous(aesthetic_str) @@ -94,6 +98,12 @@ def copy(self): return GGPlot(self.ht, self.aes, self.geoms[:], self.labels, self.coord_cartesian, self.scales, self.discrete_color_scale, self.continuous_color_scale) + def verify_scales(self): + for geom_idx, geom in enumerate(self.geoms): + aesthetic_dict = geom.aes.properties + for aes_key in aesthetic_dict.keys(): + check_scale_continuity(self.scales[aes_key], aesthetic_dict[aes_key].dtype, aes_key) + def render(self): fields_to_select = {"figure_mapping": hl.struct(**self.aes)} for geom_idx, geom in enumerate(self.geoms): @@ -101,6 +111,7 @@ def render(self): fields_to_select[label] = hl.struct(**geom.aes.properties) selected = self.ht.select(**fields_to_select) + self.verify_scales() aggregators = {} labels_to_stats = {} diff --git a/hail/python/hail/plot2/scale.py b/hail/python/hail/plot2/scale.py index 1263b34e544..e24bdb4c1ad 100644 --- a/hail/python/hail/plot2/scale.py +++ b/hail/python/hail/plot2/scale.py @@ -130,24 +130,32 @@ def is_discrete(self): class ScaleColorDiscrete(ScaleDiscrete): def transform_data_local(self, data, parent): - categorical_strings = set([element["color"] for element in data]) + categorical_strings = set([element[self.aesthetic_name] for element in data]) unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent) updated_data = [] for category in categorical_strings: for data_entry in data: - if data_entry["color"] == category: - updated_data.append(data_entry.annotate(color=unique_color_mapping[category], color_legend=category)) + if data_entry[self.aesthetic_name] == category: + annotate_args = { + self.aesthetic_name: unique_color_mapping[category], + "color_legend": category + } + updated_data.append(data_entry.annotate(**annotate_args)) return updated_data class ScaleColorContinuous(ScaleContinuous): def transform_data_local(self, data, parent): - color_list = [element["color"] for element in data] + color_list = [element[self.aesthetic_name] for element in data] color_mapping = continuous_nums_to_colors(color_list, parent.continuous_color_scale) updated_data = [] for data_idx, data_entry in enumerate(data): - updated_data.append(data_entry.annotate(color=color_mapping[data_idx], color_legend=data_entry["color"])) + annotate_args = { + self.aesthetic_name: color_mapping[data_idx], + "color_legend": data_entry[self.aesthetic_name] + } + updated_data.append(data_entry.annotate(**annotate_args)) return updated_data @@ -203,3 +211,15 @@ def scale_color_continuous(): def scale_color_identity(): return ScaleColorDiscreteIdentity("color") + + +def scale_fill_discrete(): + return ScaleColorDiscrete("fill") + + +def scale_fill_continuous(): + return ScaleColorContinuous("fill") + + +def scale_fill_identity(): + return ScaleColorDiscreteIdentity("fill") From deae2cc310c69f5a341d96816de7e70dcedb25b7 Mon Sep 17 00:00:00 2001 From: ammekk Date: Thu, 20 Jan 2022 14:00:25 -0500 Subject: [PATCH 5/5] added postition arguments dodge and stack to bar charts --- hail/python/hail/ir/utils.py | 10 +++++++--- hail/python/hail/plot2/geoms.py | 26 ++++++++++++++++++-------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/hail/python/hail/ir/utils.py b/hail/python/hail/ir/utils.py index fc382864ac4..9aa3ad6e147 100644 --- a/hail/python/hail/ir/utils.py +++ b/hail/python/hail/ir/utils.py @@ -4,14 +4,18 @@ def check_scale_continuity(scale, dtype, aes_key): - if scale.is_discrete() and is_continuous_type(dtype): + if scale.is_discrete() and not is_discrete_type(dtype): raise ValueError(f"Aesthetic {aes_key} has continuous dtype but non continuous scale") - if not scale.is_discrete() and not is_continuous_type(dtype): + if not scale.is_discrete() and is_discrete_type(dtype): raise ValueError(f"Aesthetic {aes_key} has non continuous dtype but continuous scale") def is_continuous_type(dtype): - return dtype in [hl.tint32, hl.tint64, hl.float32, hl.float64] + return dtype in [hl.tint32, hl.tint64, hl.tfloat32, hl.tfloat64] + + +def is_discrete_type(dtype): + return not is_continuous_type(dtype) def filter_predicate_with_keep(ir_pred, keep): diff --git a/hail/python/hail/plot2/geoms.py b/hail/python/hail/plot2/geoms.py index e88a1e5996b..89d6087b1a8 100644 --- a/hail/python/hail/plot2/geoms.py +++ b/hail/python/hail/plot2/geoms.py @@ -5,6 +5,8 @@ from .utils import categorical_strings_to_colors, continuous_nums_to_colors import hail as hl +from ..ir.utils import is_continuous_type + class FigureAttribute: pass @@ -109,15 +111,20 @@ def make_agg(self, mapping): # aesthetics in geom_struct are just pointers to x_expr, that's fine. # Maybe I just do a `take(1) for every field of parent_struct and geom_struct? # Or better, a collect_as_set where I error if size is greater than 1? - return hl.agg.group_by(mapping["x"], hl.struct(count=hl.agg.count(), other=hl.agg.take(mapping.drop("x"), 1))) + #group by all discrete variables and x + discrete_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() + if not is_continuous_type(mapping[aes_key].dtype)} + discrete_variables["x"] = mapping["x"] + return hl.agg.group_by(hl.struct(**discrete_variables), hl.agg.count()) + def listify(self, agg_result): unflattened_items = agg_result.items() res = [] - for x, agg_result in unflattened_items: - other_list = agg_result.other - assert len(other_list) == 1 - new_struct = hl.Struct(x=x, y=agg_result.count, **other_list[0]) + for discrete_variables, count in unflattened_items: + arg_dict = {key: value for key, value in discrete_variables.items()} + arg_dict["y"] = count + new_struct = hl.Struct(**arg_dict) res.append(new_struct) return res @@ -250,9 +257,10 @@ def geom_text(mapping=aes(), *, color=None): class GeomBar(Geom): - def __init__(self, aes, color=None): + def __init__(self, aes, color=None, position="stack"): super().__init__(aes) self.color = color + self.position = position def apply_to_fig(self, parent, agg_result, fig_so_far): def plot_group(data, color=None): @@ -277,13 +285,15 @@ def plot_group(data, color=None): else: plot_group(agg_result, "black") + ggplot_to_plotly = {'dodge': 'group', 'stack': 'stack'} + fig_so_far.update_layout(barmode=ggplot_to_plotly[self.position]) def get_stat(self): return StatCount() -def geom_bar(mapping=aes(), *, color=None): - return GeomBar(mapping, color=color) +def geom_bar(mapping=aes(), *, color=None, position="stack"): + return GeomBar(mapping, color=color, position=position) class GeomHistogram(Geom):