Skip to content

Commit

Permalink
Merge pull request hail-is#33 from ammekk/color-scale-discrete
Browse files Browse the repository at this point in the history
Color scales and position arguements to bar charts
  • Loading branch information
johnc1231 authored Jan 20, 2022
2 parents 9c94ddf + deae2cc commit a970b3d
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 140 deletions.
17 changes: 17 additions & 0 deletions hail/python/hail/ir/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
from .ir import Coalesce, ApplyUnaryPrimOp, FalseIR

import hail as hl


def check_scale_continuity(scale, dtype, aes_key):
if scale.is_discrete() and not is_discrete_type(dtype):
raise ValueError(f"Aesthetic {aes_key} has continuous dtype but non continuous scale")
if not scale.is_discrete() and is_discrete_type(dtype):
raise ValueError(f"Aesthetic {aes_key} has non continuous dtype but continuous scale")


def is_continuous_type(dtype):
return dtype in [hl.tint32, hl.tint64, hl.tfloat32, hl.tfloat64]


def is_discrete_type(dtype):
return not is_continuous_type(dtype)


def filter_predicate_with_keep(ir_pred, keep):
return Coalesce(ir_pred if keep else ApplyUnaryPrimOp('!', ir_pred), FalseIR())
Expand Down
8 changes: 6 additions & 2 deletions hail/python/hail/plot2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .geoms import geom_line, geom_point, geom_text, geom_bar, geom_histogram, geom_hline, geom_func, geom_vline, geom_tile
from .labels import ggtitle, xlab, ylab
from .scale import scale_x_continuous, scale_y_continuous, scale_x_discrete, scale_y_discrete, scale_x_genomic, \
scale_x_log10, scale_y_log10, scale_x_reverse, scale_y_reverse
scale_x_log10, scale_y_log10, scale_x_reverse, scale_y_reverse, scale_color_discrete, scale_color_identity,\
scale_color_continuous

__all__ = [
"aes",
Expand All @@ -30,5 +31,8 @@
"scale_x_log10",
"scale_y_log10",
"scale_x_reverse",
"scale_y_reverse"
"scale_y_reverse",
"scale_color_continuous",
"scale_color_identity",
"scale_color_discrete"
]
216 changes: 87 additions & 129 deletions hail/python/hail/plot2/geoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .utils import categorical_strings_to_colors, continuous_nums_to_colors
import hail as hl

from ..ir.utils import is_continuous_type


class FigureAttribute:
pass
Expand All @@ -31,33 +33,30 @@ def __init__(self, aes, color):

def apply_to_fig(self, parent, agg_result, fig_so_far):

def plot_one_color(one_color_data, color, legend_name):
def plot_group(data, color=None):
scatter_args = {
"x": [element["x"] for element in one_color_data],
"y": [element["y"] for element in one_color_data],
"x": [element["x"] for element in data],
"y": [element["y"] for element in data],
"mode": "lines",
"name": legend_name,
"line_color": color

}
if "color_legend" in data[0]:
scatter_args["name"] = data[0]["color_legend"]

if "tooltip" in parent.aes or "tooltip" in self.aes:
scatter_args["hovertext"] = [element["tooltip"] for element in one_color_data]
scatter_args["hovertext"] = [element["tooltip"] for element in data]
fig_so_far.add_scatter(**scatter_args)

if self.color is not None:
plot_one_color(agg_result, self.color, None)
plot_group(agg_result, self.color)
elif "color" in parent.aes or "color" in self.aes:
if isinstance(agg_result[0]["color"], int):
# Should show colors in continuous scale.
raise ValueError("Do not currently support continuous color changing of lines")
else:
categorical_strings = set([element["color"] for element in agg_result])
unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent)

for category in categorical_strings:
filtered_data = [element for element in agg_result if element["color"] == category]
plot_one_color(filtered_data, unique_color_mapping[category], category)
groups = set([element["group"] for element in agg_result])
for group in groups:
just_one_group = [element for element in agg_result if element["group"] == group]
plot_group(just_one_group, just_one_group[0]["color"])
else:
plot_one_color(agg_result, "black", None)
plot_group(agg_result, "black")

@abc.abstractmethod
def get_stat(self):
Expand Down Expand Up @@ -112,15 +111,20 @@ def make_agg(self, mapping):
# aesthetics in geom_struct are just pointers to x_expr, that's fine.
# Maybe I just do a `take(1) for every field of parent_struct and geom_struct?
# Or better, a collect_as_set where I error if size is greater than 1?
return hl.agg.group_by(mapping["x"], hl.struct(count=hl.agg.count(), other=hl.agg.take(mapping.drop("x"), 1)))
#group by all discrete variables and x
discrete_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys()
if not is_continuous_type(mapping[aes_key].dtype)}
discrete_variables["x"] = mapping["x"]
return hl.agg.group_by(hl.struct(**discrete_variables), hl.agg.count())


def listify(self, agg_result):
unflattened_items = agg_result.items()
res = []
for x, agg_result in unflattened_items:
other_list = agg_result.other
assert len(other_list) == 1
new_struct = hl.Struct(x=x, y=agg_result.count, **other_list[0])
for discrete_variables, count in unflattened_items:
arg_dict = {key: value for key, value in discrete_variables.items()}
arg_dict["y"] = count
new_struct = hl.Struct(**arg_dict)
res.append(new_struct)
return res

Expand Down Expand Up @@ -155,50 +159,32 @@ def __init__(self, aes, color=None):
self.color = color

def apply_to_fig(self, parent, agg_result, fig_so_far):
def plot_one_color(one_color_data, color, legend_name):
scatter_args = {
"x": [element["x"] for element in one_color_data],
"y": [element["y"] for element in one_color_data],
"mode": "markers",
"marker_color": color,
"name": legend_name
}
if "size" in parent.aes or "size" in self.aes:
scatter_args["marker_size"] = [element["size"] for element in one_color_data]
if "tooltip" in parent.aes or "tooltip" in self.aes:
scatter_args["hovertext"] = [element["tooltip"] for element in one_color_data]
fig_so_far.add_scatter(**scatter_args)

def plot_continuous_color(data, colors):
def plot_group(data, color=None):
scatter_args = {
"x": [element["x"] for element in data],
"y": [element["y"] for element in data],
"mode": "markers",
"marker_color": colors
"marker_color": color if color is not None else [element["color"] for element in data]
}

if "color_legend" in data[0]:
scatter_args["name"] = data[0]["color_legend"]

if "size" in parent.aes or "size" in self.aes:
scatter_args["marker_size"] = [element["size"] for element in data]
if "tooltip" in parent.aes or "tooltip" in self.aes:
scatter_args["hovertext"] = [element["tooltip"] for element in data]
fig_so_far.add_scatter(**scatter_args)

if self.color is not None:
plot_one_color(agg_result, self.color, None)
plot_group(agg_result, self.color)
elif "color" in parent.aes or "color" in self.aes:
if isinstance(agg_result[0]["color"], int):
# Should show colors in continuous scale.
input_color_nums = [element["color"] for element in agg_result]
color_mapping = continuous_nums_to_colors(input_color_nums, parent.continuous_color_scale)
plot_continuous_color(agg_result, color_mapping)

else:
categorical_strings = set([element["color"] for element in agg_result])
unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent)

for category in categorical_strings:
filtered_data = [element for element in agg_result if element["color"] == category]
plot_one_color(filtered_data, unique_color_mapping[category], category)
groups = set([element["group"] for element in agg_result])
for group in groups:
just_one_group = [element for element in agg_result if element["group"] == group]
plot_group(just_one_group)
else:
plot_one_color(agg_result, "black", None)
plot_group(agg_result, "black")

def get_stat(self):
return StatIdentity()
Expand Down Expand Up @@ -232,50 +218,34 @@ def __init__(self, aes, color=None):
self.color = color

def apply_to_fig(self, parent, agg_result, fig_so_far):
def plot_one_color(one_color_data, color, legend_name):
def plot_group(data, color=None):
scatter_args = {
"x": [element["x"] for element in one_color_data],
"y": [element["y"] for element in one_color_data],
"text": [element["label"] for element in one_color_data],
"x": [element["x"] for element in data],
"y": [element["y"] for element in data],
"text": [element["label"] for element in data],
"mode": "text",
"name": legend_name,
"textfont_color": color
}

if "size" in parent.aes or "size" in self.aes:
scatter_args["textfont_size"] = [element["size"] for element in one_color_data]
if "tooltip" in parent.aes or "tooltip" in self.aes:
scatter_args["hovertext"] = [element["tooltip"] for element in one_color_data]
fig_so_far.add_scatter(**scatter_args)
if "color_legend" in data[0]:
scatter_args["name"] = data[0]["color_legend"]

def plot_continuous_color(data, colors):
scatter_args = {
"x": [element["x"] for element in data],
"y": [element["y"] for element in data],
"mode": "markers",
"marker_color": colors
}
if "tooltip" in parent.aes or "tooltip" in self.aes:
scatter_args["hovertext"] = [element["tooltip"] for element in data]

if "size" in parent.aes or "size" in self.aes:
scatter_args["marker_size"] = [element["size"] for element in data]
fig_so_far.add_scatter(**scatter_args)

if self.color is not None:
plot_one_color(agg_result, self.color, None)
plot_group(agg_result, self.color)
elif "color" in parent.aes or "color" in self.aes:
if isinstance(agg_result[0]["color"], int):
# Should show colors in continuous scale.
input_color_nums = [element["color"] for element in agg_result]
color_mapping = continuous_nums_to_colors(input_color_nums, parent.continuous_color_scale)
plot_continuous_color(agg_result, color_mapping)

else:
categorical_strings = set([element["color"] for element in agg_result])
unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent)

for category in categorical_strings:
filtered_data = [element for element in agg_result if element["color"] == category]
plot_one_color(filtered_data, unique_color_mapping[category], category)
groups = set([element["group"] for element in agg_result])
for group in groups:
just_one_group = [element for element in agg_result if element["group"] == group]
plot_group(just_one_group, just_one_group[0]["color"])
else:
plot_group(agg_result, "black")

def get_stat(self):
return StatIdentity()
Expand All @@ -287,35 +257,43 @@ def geom_text(mapping=aes(), *, color=None):

class GeomBar(Geom):

def __init__(self, aes, color=None):
def __init__(self, aes, color=None, position="stack"):
super().__init__(aes)
self.color = color
self.position = position

def apply_to_fig(self, parent, agg_result, fig_so_far):
item_list = agg_result
def plot_group(data, color=None):
bar_args = {
"x": [element["x"] for element in data],
"y": [element["y"] for element in data],
"marker_color": color if color is not None else [element["color"] for element in data]
}
if "color_legend" in data[0]:
bar_args["name"] = data[0]["color_legend"]

fig_so_far.add_bar(**bar_args)

if self.color is not None:
plot_group(agg_result, self.color)

if self.color:
color = self.color
elif "color" in parent.aes or "color" in self.aes:
categorical_strings = set([item["color"] for item in item_list])
color_mapping = categorical_strings_to_colors(categorical_strings, parent)
color = [color_mapping[item["color"]] for item in item_list]
groups = set([element["group"] for element in agg_result])
for group in groups:
just_one_group = [element for element in agg_result if element["group"] == group]
plot_group(just_one_group)
else:
color = "black"
plot_group(agg_result, "black")

bar_args = {
"x": [item["x"] for item in item_list],
"y": [item["y"] for item in item_list],
"marker_color": color
}
fig_so_far.add_bar(**bar_args)
ggplot_to_plotly = {'dodge': 'group', 'stack': 'stack'}
fig_so_far.update_layout(barmode=ggplot_to_plotly[self.position])

def get_stat(self):
return StatCount()


def geom_bar(mapping=aes(), *, color=None):
return GeomBar(mapping, color=color)
def geom_bar(mapping=aes(), *, color=None, position="stack"):
return GeomBar(mapping, color=color, position=position)


class GeomHistogram(Geom):
Expand Down Expand Up @@ -418,21 +396,8 @@ def __init__(self, aes):
self.aes = aes

def apply_to_fig(self, parent, agg_result, fig_so_far):
def plot_continuous_color(agg_results, colors):
for idx, row in enumerate(agg_results):
x_center = row['x']
y_center = row['y']
width = row['width']
height = row['height']
alpha= row.get('alpha', 1.0)
x_left = x_center - width / 2
x_right = x_center + width / 2
y_up = y_center + height / 2
y_down = y_center - height / 2
fig_so_far.add_shape(type="rect", x0=x_left, y0=y_down, x1=x_right, y1=y_up, fillcolor=colors[idx], opacity=alpha)

def plot_one_color(agg_results, color):
for idx, row in enumerate(agg_results):
def plot_group(data, fill=None):
for idx, row in enumerate(data):
x_center = row['x']
y_center = row['y']
width = row['width']
Expand All @@ -442,23 +407,16 @@ def plot_one_color(agg_results, color):
x_right = x_center + width / 2
y_up = y_center + height / 2
y_down = y_center - height / 2
fig_so_far.add_shape(type="rect", x0=x_left, y0=y_down, x1=x_right, y1=y_up, fillcolor=color, opacity=alpha)
fill = fill if fill is not None else row['fill']
fig_so_far.add_shape(type="rect", x0=x_left, y0=y_down, x1=x_right, y1=y_up, fillcolor=fill, opacity=alpha)

if "fill" in parent.aes or "fill" in self.aes:
if isinstance(agg_result[0]["fill"], int):
input_color_nums = [element["fill"] for element in agg_result]
color_mapping = continuous_nums_to_colors(input_color_nums, parent.continuous_color_scale)
plot_continuous_color(agg_result, color_mapping)
else:
categorical_strings = set([element["fill"] for element in agg_result])
unique_color_mapping = categorical_strings_to_colors(categorical_strings, parent)

for category in categorical_strings:
filtered_data = [element for element in agg_result if element["fill"] == category]
plot_one_color(filtered_data, unique_color_mapping[category])

groups = set([element["group"] for element in agg_result])
for group in groups:
just_one_group = [element for element in agg_result if element["group"] == group]
plot_group(just_one_group)
else:
plot_one_color(agg_result, "black")
plot_group(agg_result, "black")

def get_stat(self):
return StatIdentity()
Expand Down
Loading

0 comments on commit a970b3d

Please sign in to comment.