Skip to content

Commit

Permalink
feat: introduce log axes
Browse files Browse the repository at this point in the history
* feat: add _detect_axis_type to PlotlyGtk

Add unit tests for new function

Update IDE files and pyproject.toml

* fix: replace uses of axis["type"] with axis["_type"]

* fix: log axis ticks

* test: update tests in demo/broadway

Ensure STDERR is empty

* build: include demo tests in test report

* ci(flake8): use flake8-pyproject

* fix: numpy version

* style: run formatters

* fix: fix log axis ticks

* fix: test

* fix: fix log ticks
  • Loading branch information
slaclau authored Sep 15, 2024
1 parent d2875ef commit eb08bcf
Show file tree
Hide file tree
Showing 14 changed files with 443 additions and 98 deletions.
1 change: 1 addition & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions .idea/plotly-gtk.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions .idea/watcherTasks.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
[project]
name = "plotly-gtk"
dynamic = ["version", "readme"]
dependencies = ["numpy", "pygobject", "pandas"]
dependencies = ["numpy<2.0", "pygobject", "pandas"]

[build-system]
requires = ["setuptools>=64", "setuptools_scm>=8", "setuptools_scm_custom"]
build-backend = "setuptools.build_meta"

[project.optional-dependencies]
doc = ["pylint", "sphinx", "sphinx-pyreverse", "pydata-sphinx-theme", "sphinx-rtd-theme", "plotly"]
lint = ["black", "isort", "flake8", "flake8-pylint", "flake8-json", "flake8-bugbear", "mypy", "mypy-json-report", "pygobject-stubs", "deptry"]
lint = ["black", "isort", "flake8", "flake8-pyproject", "flake8-pylint", "flake8-json", "flake8-bugbear", "mypy", "mypy-json-report", "pygobject-stubs", "deptry"]
test = ["plotly", "pytest", "pillow", "selenium","pytest-subtests", "pytest-cov", "pytest-html"]
dev = ["plotly-gtk[doc, lint, test]"]

Expand All @@ -26,6 +26,7 @@ version_scheme = "[{tag}?{distance}==0:{next_tag}][.dev{distance}?{distance}>0]"
local_scheme = "[+{node}?{distance}>0][[+?{distance}==0:.]d{node_date}?{dirty}==True]"

[tool.mypy]
exclude = "test"
# Disallow dynamic typing
disallow_any_unimported = true
# disallow_any_expr = true
Expand Down
3 changes: 3 additions & 0 deletions src/plotly_gtk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import logging

logging.basicConfig()
37 changes: 20 additions & 17 deletions src/plotly_gtk/_chart.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains a private class to handle plotting for :class:`plotly_gtk.chart.PlotlyGTK`."""
"""Contains a private class to handle plotting for
:class:`plotly_gtk.chart.PlotlyGTK`."""

import gi
import numpy as np
Expand All @@ -23,8 +24,7 @@ def __init__(self, fig: dict):
self.set_draw_func(self._on_draw)

def update(self, fig: dict[str, plotly_types.Data, plotly_types.Layout]):
"""
Update the plot with a new figure.
"""Update the plot with a new figure.
Parameters
----------
Expand Down Expand Up @@ -85,16 +85,15 @@ def _draw_gridlines(self, context, width, height, axis):
if "_range" not in self.layout[axis]:
return
if axis.startswith("x"):
self.layout[axis]["_ticksobject"].update_length(
(width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"])
* (self.layout[axis]["domain"][-1] - self.layout[axis]["domain"][0])
)
self.layout[axis]["_ticksobject"].length = (
width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]
) * (self.layout[axis]["domain"][-1] - self.layout[axis]["domain"][0])

self.layout[axis]["_ticksobject"].calculate()
else:
self.layout[axis]["_ticksobject"].update_length(
(height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"])
* (self.layout[axis]["domain"][-1] - self.layout[axis]["domain"][0])
)
self.layout[axis]["_ticksobject"].length = (
height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]
) * (self.layout[axis]["domain"][-1] - self.layout[axis]["domain"][0])
self.layout[axis]["_ticksobject"].calculate()
if "anchor" in self.layout[axis] and self.layout[axis]["anchor"] != "free":
anchor = (
Expand Down Expand Up @@ -180,7 +179,9 @@ def _draw_ticks(
+ self.layout[axis]["anchor"][1:]
)
y = self.layout[yaxis]["_range"][0]
x_pos, y_pos = self._calc_pos(x, y, width, height, axis, yaxis)
x_pos, y_pos = self._calc_pos(
x, y, width, height, axis, yaxis, ignore_log_y=True
)
else:
y_pos = self.layout["_margin"]["t"] + (
1 - self.layout[axis]["position"]
Expand All @@ -189,7 +190,7 @@ def _draw_ticks(

for tick, text in zip(x_pos, ticktext):
context.move_to(tick, y_pos)
layout.set_text(text)
layout.set_markup(text)
layout_size = layout.get_pixel_size()
context.rel_move_to(-layout_size[0] / 2, 0)
PangoCairo.show_layout(context, layout)
Expand All @@ -203,7 +204,9 @@ def _draw_ticks(
+ self.layout[axis]["anchor"][1:]
)
x = self.layout[xaxis]["_range"][0]
x_pos, y_pos = self._calc_pos(x, y, width, height, xaxis, axis)
x_pos, y_pos = self._calc_pos(
x, y, width, height, xaxis, axis, ignore_log_x=True
)
else:
x_pos = self.layout["_margin"]["l"] + self.layout[axis]["position"] * (
width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]
Expand All @@ -212,7 +215,7 @@ def _draw_ticks(

for tick, text in zip(y_pos, ticktext):
context.move_to(x_pos, tick)
layout.set_text(text)
layout.set_markup(text)
layout_size = layout.get_pixel_size()
context.rel_move_to(-layout_size[0], -layout_size[1] / 2)
PangoCairo.show_layout(context, layout)
Expand Down Expand Up @@ -263,9 +266,9 @@ def _calc_pos(
)

if log_x and not ignore_log_x:
x = np.log(x)
x = np.log10(x)
if log_y and not ignore_log_y:
y = np.log(y)
y = np.log10(y)

x_pos = []
y_pos = []
Expand Down
134 changes: 114 additions & 20 deletions src/plotly_gtk/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
:class:`plotly.graph_objects.Figure` using GTK."""

import datetime
import numbers
from datetime import timezone
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd

from plotly_gtk._chart import _PlotlyGtk
Expand Down Expand Up @@ -85,7 +87,7 @@ def _update_ranges(self):
for axis in axes:
if "autorange" in self.layout[axis]:
autorange = self.layout[axis]["autorange"]
elif "range" in self.layout[axis] and len(self.layout[axis]["range"] == 2):
if "range" in self.layout[axis] and len(self.layout[axis]["range"]) == 2:
autorange = False
else:
autorange = True
Expand Down Expand Up @@ -130,6 +132,16 @@ def _update_ranges(self):
_range[0] = _range[0] - 1
_range[-1] = _range[1] + 1
self.layout[axis]["_range"] = _range
else:
if self.layout[axis]["type"] == "log":
self.layout[axis]["_range"] = np.array(
[
10 ** self.layout[axis]["range"][0],
10 ** self.layout[axis]["range"][-1],
]
)
else:
self.layout[axis]["_range"] = np.array(self.layout[axis]["range"])

# Do matching
matched_to_axes = {
Expand All @@ -155,13 +167,33 @@ def _update_ranges(self):
for axis in axes:
if "_range" not in self.layout[axis]:
continue
if "type" in self.layout[axis] and self.layout[axis]["type"] == "log":
self.layout[axis]["_range"] = np.log(self.layout[axis]["_range"])
if self.layout[axis]["_type"] == "log":
self.layout[axis]["_range"] = np.log10(self.layout[axis]["_range"])
range_length = (
self.layout[axis]["_range"][-1] - self.layout[axis]["_range"][0]
)
if (
"range" in self.layout[axis]
and len(self.layout[axis]["range"]) == 2
):
range_addon = range_length * 0.001
else:
range_addon = range_length * 0.125 / 2
self.layout[axis]["_range"] = [
self.layout[axis]["_range"][0] - range_addon,
self.layout[axis]["_range"][-1] + range_addon,
]
else:
range_length = (
self.layout[axis]["_range"][-1] - self.layout[axis]["_range"][0]
)
range_addon = range_length * 0.125 / 2
if (
"range" in self.layout[axis]
and len(self.layout[axis]["range"]) == 2
):
range_addon = range_length * 0.001
else:
range_addon = range_length * 0.125 / 2
self.layout[axis]["_range"] = [
self.layout[axis]["_range"][0] - range_addon,
self.layout[axis]["_range"][-1] + range_addon,
Expand All @@ -178,14 +210,12 @@ def _update_ranges(self):

def _prepare_data(self):
for plot in self.data:
plot["x"] = [
(
x.replace(tzinfo=timezone.utc).timestamp()
if isinstance(x, datetime.datetime)
else x
)
for x in plot["x"]
]
if self._detect_axis_type(plot["x"]) == "date":
plot["x"] = np.array(plot["x"], dtype="datetime64")
plot["x"] = pd.to_datetime(plot["x"])
plot["x"] = [
x.replace(tzinfo=timezone.utc).timestamp() for x in plot["x"]
]
plots = []
for plot in self.data:
if plot["type"] in ["scatter", "scattergl"]:
Expand Down Expand Up @@ -349,20 +379,18 @@ def automargin(self):
self.queue_allocate()

def _update_layout(self):
xaxes = [
xaxes = {
trace["xaxis"].replace("x", "xaxis")
for trace in self.data
if "xaxis" in trace
]
yaxes = [
}
yaxes = {
trace["yaxis"].replace("y", "yaxis")
for trace in self.data
if "yaxis" in trace
]
xaxes.append("xaxis")
yaxes.append("yaxis")
xaxes = set(xaxes)
yaxes = set(yaxes)
}
xaxes.add("xaxis")
yaxes.add("yaxis")

template = self.layout["template"]["layout"]
defaults = dict(
Expand Down Expand Up @@ -513,14 +541,80 @@ def _update_layout(self):
),
)
for xaxis in xaxes:
if "type" not in self.layout[xaxis]:
first_plot_on_axis = [
trace
for trace in self.data
if trace["xaxis"] == xaxis.replace("axis", "")
][0]
self.layout[xaxis]["_type"] = self._detect_axis_type(
first_plot_on_axis["x"]
)
else:
self.layout[xaxis]["_type"] = self.layout[xaxis]["type"]
template[xaxis] = template["xaxis"]
defaults[xaxis] = defaults["xaxis"]
for yaxis in yaxes:
if "type" not in self.layout[yaxis]:
first_plot_on_axis = [
trace
for trace in self.data
if trace["yaxis"] == yaxis.replace("axis", "")
][0]
self.layout[yaxis]["_type"] = self._detect_axis_type(
first_plot_on_axis["y"]
)
else:
self.layout[yaxis]["_type"] = self.layout[yaxis]["type"]
template[yaxis] = template["yaxis"]
defaults[yaxis] = defaults["yaxis"]
self.layout = update_dict(template, self.layout)
self.layout = update_dict(defaults, self.layout)

@staticmethod
def _detect_axis_type(data):
if any(isinstance(i, list) or isinstance(i, np.ndarray) for i in data):
return "multicategory"
if not isinstance(data, np.ndarray):
data = np.array(data)

length = len(data)
if length >= 1000:
start = np.random.randint(0, length / 1000)
index = np.arange(start, length, length / 1000).astype(np.int32)
data = data[index]

data = set(data)

def to_type(d):
try:
d = np.datetime64(d)
return "date"
except ValueError:
if isinstance(d, numbers.Number):
return "linear"
else:
return "category"

data_types = [to_type(d) for d in data]
data_types = {d: data_types.count(d) for d in set(data_types)}
if len(data_types) == 1:
return list(data_types)[0]
if "linear" not in data_types:
if data_types["date"] > data_types["category"]:
return "date"
return "category"

if "date" in data_types and data_types["date"] > 2 * data_types["linear"]:
return "date"
if (
"category" in data_types
and data_types["category"] > 2 * data_types["linear"]
):
return "category"

return "linear"

def _draw_buttons(self):
if "updatemenus" not in self.layout:
return
Expand Down
Loading

0 comments on commit eb08bcf

Please sign in to comment.