Skip to content

Commit

Permalink
feat: implement functionality required for plotly autoshift
Browse files Browse the repository at this point in the history
PR: #13
  • Loading branch information
slaclau committed Nov 7, 2024
1 parent de5d013 commit ad4e3d3
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 66 deletions.
84 changes: 22 additions & 62 deletions src/plotly_gtk/_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
gi.require_version("Gtk", "4.0")
from gi.repository import ( # pylint: disable=wrong-import-position,wrong-import-order
Gtk,
Pango,
PangoCairo,
)

Expand Down Expand Up @@ -168,7 +169,6 @@ def _draw_ticks(
if "tickfont" in self.layout[axis]:
font_dict = update_dict(font_dict, self.layout[axis]["tickfont"])


context.set_source_rgb(*parse_color(font_dict["color"]))
font = parse_font(font_dict)
layout = PangoCairo.create_layout(context)
Expand All @@ -187,7 +187,7 @@ def _draw_ticks(
)
else:
y_pos = self.layout["_margin"]["t"] + (
1 - self.layout[axis]["position"]
1 - self.layout[axis]["_position"]
) * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"])
x_pos, _ = self._calc_pos(x, [], width, height, axis, None)

Expand Down Expand Up @@ -215,9 +215,17 @@ def _draw_ticks(
x_pos, y_pos = self._calc_pos(
x, y, width, height, xaxis, axis, ignore_log_x=True
)
x_pos += self.layout[axis]["_shift"]
else:
x_pos = self.layout["_margin"]["l"] + self.layout[axis]["position"] * (
width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]
x_pos = (
self.layout["_margin"]["l"]
+ self.layout[axis]["_position"]
* (
width
- self.layout["_margin"]["l"]
- self.layout["_margin"]["r"]
)
+ self.layout[axis]["_shift"]
)
_, y_pos = self._calc_pos([], y, width, height, None, axis)

Expand All @@ -244,40 +252,8 @@ def _draw_axis(self, context, width, height, axis):
context.set_source_rgb(*parse_color("green"))

axis_letter = axis[0 : axis.find("axis")]
overlaying_axis = (
(
self.layout[axis]["overlaying"][0]
+ "axis"
+ self.layout[axis]["overlaying"][1:]
)
if "overlaying" in self.layout[axis]
else ""
)
anchor_axis = (
"free"
if "anchor" not in self.layout[axis]
or self.layout[axis]["anchor"] == "free"
else (
self.layout[axis]["anchor"][0]
+ "axis"
+ self.layout[axis]["anchor"][1:]
)
)
domain = (
self.layout[axis]["domain"]
if "overlaying" not in self.layout[axis]
else self.layout[overlaying_axis]["domain"]
)
position = (
self.layout[axis]["position"]
if "anchor" not in self.layout[axis] or anchor_axis == "free"
else (
self.layout[anchor_axis]["domain"][0]
if self.layout[axis]["side"] == "left"
or self.layout[axis]["side"] == "bottom"
else self.layout[anchor_axis]["domain"][-1]
)
)
domain = self.layout[axis]["_domain"]
position = self.layout[axis]["_position"]

if axis_letter == "x":
context.move_to(
Expand All @@ -300,15 +276,17 @@ def _draw_axis(self, context, width, height, axis):
context.move_to(
self.layout["_margin"]["l"]
+ position
* (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]),
* (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"])
+ self.layout[axis]["_shift"],
self.layout["_margin"]["t"]
+ (1 - domain[0])
* (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]),
)
context.line_to(
self.layout["_margin"]["l"]
+ position
* (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]),
* (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"])
+ self.layout[axis]["_shift"],
self.layout["_margin"]["t"]
+ (1 - domain[-1])
* (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]),
Expand Down Expand Up @@ -346,16 +324,7 @@ def _calc_pos(
y_pos = []

if xaxis is not None:
x_overlaying_axis = (
(xaxis["overlaying"][0] + "axis" + xaxis["overlaying"][1:])
if "overlaying" in xaxis
else ""
)
xdomain = (
xaxis["domain"]
if "overlaying" not in xaxis
else self.layout[x_overlaying_axis]["domain"]
)
xdomain = xaxis["_domain"]
xaxis_start = (
xdomain[0]
* (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"])
Expand All @@ -374,24 +343,15 @@ def _calc_pos(
) + xaxis_start

if yaxis is not None:
y_overlaying_axis = (
(yaxis["overlaying"][0] + "axis" + yaxis["overlaying"][1:])
if "overlaying" in yaxis
else ""
)
ydomain = (
yaxis["domain"]
if "overlaying" not in yaxis
else self.layout[y_overlaying_axis]["domain"]
)
ydomain = yaxis["_domain"]
yaxis_start = (
-(ydomain[0])
-ydomain[0]
* (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"])
+ height
- self.layout["_margin"]["b"]
)
yaxis_end = (
-(ydomain[-1])
-ydomain[-1]
* (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"])
+ height
- self.layout["_margin"]["b"]
Expand Down
126 changes: 126 additions & 0 deletions src/plotly_gtk/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def update(self, fig: dict[str, plotly_types.Data | plotly_types.Layout]):
A dictionary representing a plotly figure
"""
self._update_ranges()
self._update_positions_and_domains()
for overlay in self.overlays:
self.remove_overlay(overlay)
self.overlays = []
Expand Down Expand Up @@ -606,6 +607,131 @@ def _update_layout(self):
self.layout = update_dict(template, self.layout)
self.layout = update_dict(defaults, self.layout)

def _update_positions_and_domains(self):
axes = [k for k in self.layout if "axis" in k]
axes_order = []
overlayed = []
for axis in axes:
if "overlaying" in self.layout[axis]:
ax = self.layout[axis]["overlaying"]
overlayed.append(ax[0] + "axis" + ax[1:])
overlayed = set(overlayed)
for axis in overlayed:
overlayed_by = [
k
for k in axes
if "overlaying" in self.layout[k]
and self.layout[k]["overlaying"] == axis.replace("axis", "")
]
left = [
k
for k in overlayed_by
if "side" in self.layout[k]
and self.layout[k]["side"] == "left"
or "side" not in self.layout[k]
]
right = [
k
for k in overlayed_by
if "side" in self.layout[k] and self.layout[k]["side"] == "right"
]
right = sorted(right)
left = sorted(left)

if "side" in self.layout[axis] and self.layout[axis]["side"] == "right":
right = [axis] + right
else:
left = [axis] + left

for side in [left, right]:
self.layout[side[0]]["_overlaying"] = ""
for i in range(1, len(side)):
self.layout[side[i]]["_overlaying"] = side[i - 1]

axes_order += left
axes_order += right

other_axes = set(axes) - set(axes_order)

axes_order = sorted(list(other_axes)) + axes_order

for axis in axes_order:
if "linecolor" not in self.layout[axis]:
continue
overlaying_axis = (
self.layout[axis]["_overlaying"]
if "overlaying" in self.layout[axis]
else ""
)
anchor_axis = (
"free"
if "anchor" not in self.layout[axis]
or self.layout[axis]["anchor"] == "free"
else (
self.layout[axis]["anchor"][0]
+ "axis"
+ self.layout[axis]["anchor"][1:]
)
)
domain = (
self.layout[axis]["domain"]
if "overlaying" not in self.layout[axis] or overlaying_axis == ""
else self.layout[overlaying_axis]["domain"]
)
position = (
self.layout[overlaying_axis]["_position"]
if "autoshift" in self.layout[axis]
and self.layout[axis]["autoshift"]
and anchor_axis == "free"
else (
self.layout[axis]["position"]
if "anchor" not in self.layout[axis] or anchor_axis == "free"
else (
self.layout[anchor_axis]["domain"][0]
if self.layout[axis]["side"] == "left"
or self.layout[axis]["side"] == "bottom"
else self.layout[anchor_axis]["domain"][-1]
)
)
)
self.layout[axis]["_domain"] = domain
self.layout[axis]["_position"] = position

if "autoshift" in self.layout[axis] and self.layout[axis]["autoshift"]:
shift = (
self.layout[axis]["shift"]
if "shift" in self.layout[axis]
else 3 if self.layout[axis]["side"] == "right" else -3
)
font_extra = 0
tickfont = update_dict(
self.layout["font"], self.layout[axis]["tickfont"]
)
tickfont = parse_font(tickfont)
ctx = self.get_pango_context()
layout = Pango.Layout(ctx)
layout.set_font_description(tickfont)

metrics = ctx.get_metrics(tickfont)
font_height = (
metrics.get_ascent() + metrics.get_descent()
) / Pango.SCALE

for tick in self.layout[overlaying_axis]["_ticktext"]:
layout.set_text(tick)
font_extra = max(layout.get_pixel_size()[0], font_extra)
autoshift = (
font_extra
if self.layout[axis]["side"] == "right"
else -font_extra
- self.layout[axis]["title"]["standoff"]
- font_height
) + self.layout[overlaying_axis]["_shift"]
else:
shift = 0
autoshift = 0
self.layout[axis]["_shift"] = shift + autoshift

@staticmethod
def _detect_axis_type(data):
if any(isinstance(i, list) or isinstance(i, np.ndarray) for i in data):
Expand Down
41 changes: 39 additions & 2 deletions src/plotly_gtk/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"two_y_axes",
"multiple_y_axes_subplots",
"multiple_axes",
# "autoshift",
"autoshift",
# "shift_by_pixels",
# "syncticks",
]
Expand Down Expand Up @@ -253,6 +253,43 @@ def _get_multiple_axes_test_figure(reference):
title_text="multiple y-axes example",
width=800,
)
elif reference == "autoshift":
fig = go.Figure()

fig.add_trace(go.Scatter(x=[1, 2, 3], y=[4, 5, 6], name="yaxis data"))

fig.add_trace(go.Scatter(x=[2, 3, 4], y=[40, 50, 60], name="yaxis2 data", yaxis="y2"))

fig.add_trace(
go.Scatter(x=[4, 5, 6], y=[1000, 2000, 3000], name="yaxis3 data", yaxis="y3")
)

fig.add_trace(
go.Scatter(x=[3, 4, 5], y=[400, 500, 600], name="yaxis4 data", yaxis="y4")
)

fig.update_layout(
xaxis=dict(domain=[0.25, 0.75]),
yaxis=dict(
title="yaxis title",
),
yaxis2=dict(
title="yaxis2 title",
overlaying="y",
side="right",
),
yaxis3=dict(title="yaxis3 title", anchor="free", overlaying="y", autoshift=True),
yaxis4=dict(
title="yaxis4 title",
anchor="free",
overlaying="y",
autoshift=True,
),
)

fig.update_layout(
title_text="Shifting y-axes with autoshift",
)
else:
return
return fig
Expand All @@ -263,7 +300,7 @@ def test(app):
paned = Gtk.Paned()
window.set_content(paned)

fig = get_test_figure("multiple_axes")
fig = get_test_figure("autoshift")
print(fig)
# print(fig["layout"]["template"])

Expand Down
4 changes: 2 additions & 2 deletions src/plotly_gtk/widgets/axis_title.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, plot: "PlotlyGtk", axis: Spec, axis_name: str):
else (axis["anchor"][0] + "axis" + axis["anchor"][1:])
)
position = (
axis["position"]
axis["_position"]
if "anchor" not in axis or anchor_axis == "free"
else (
plot.layout[anchor_axis]["domain"][0]
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self, plot: "PlotlyGtk", axis: Spec, axis_name: str):
-standoff - ticklen - font_extra - x_size_error
if axis["side"] == "left"
else standoff + ticklen + font_extra + x_size_error
)
) + axis["_shift"]
yoffset = 0
angle = 270 # angle = 270 if axis["side"] == "left" else 90
self.set_orientation(Gtk.Orientation.VERTICAL)
Expand Down

0 comments on commit ad4e3d3

Please sign in to comment.