Skip to content

Commit

Permalink
Use arro3 for dictionary encoding in apply_categorical_cmap (#601)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Aug 21, 2024
1 parent 213f26b commit 07187a0
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 80 deletions.
61 changes: 37 additions & 24 deletions lonboard/colormap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union

import numpy as np
from arro3.compute import dictionary_encode
from arro3.core import (
Array,
ChunkedArray,
DataType,
dictionary_dictionary,
dictionary_indices,
)
from arro3.core.types import ArrowArrayExportable, ArrowStreamExportable

if TYPE_CHECKING:
import matplotlib as mpl
Expand Down Expand Up @@ -130,7 +139,14 @@ def apply_continuous_cmap(


def apply_categorical_cmap(
values: Union[NDArray, pd.Series, pa.Array, pa.ChunkedArray],
values: Union[
NDArray,
pd.Series,
pa.Array,
pa.ChunkedArray,
ArrowArrayExportable,
ArrowStreamExportable,
],
cmap: DiscreteColormap,
*,
alpha: Optional[int] = None,
Expand Down Expand Up @@ -163,30 +179,27 @@ def apply_categorical_cmap(
dimension will have a length of either `3` if `alpha` is `None`, or `4` is
each color has an alpha value.
"""
if isinstance(values, np.ndarray):
values = Array.from_numpy(values)

try:
import pyarrow as pa
import pyarrow.compute as pc
except ImportError as e:
raise ImportError(
"pyarrow required for apply_categorical_cmap.\n"
"Run `pip install pyarrow`."
) from e

# Import from PyCapsule interface
if hasattr(values, "__arrow_c_array__"):
values = pa.array(values)
elif hasattr(values, "__arrow_c_stream__"):
values = pa.chunked_array(values)

# Construct from non-arrow data
if not isinstance(values, (pa.Array, pa.ChunkedArray)):
values = pa.array(values)

if not pa.types.is_dictionary(values.type):
values = pc.dictionary_encode(values)
import pandas as pd

if isinstance(values, pd.Series):
values = Array.from_numpy(values)
except ImportError:
pass

values = ChunkedArray(values)

if not DataType.is_dictionary(values.type):
values = ChunkedArray(dictionary_encode(values))

dictionary = ChunkedArray(dictionary_dictionary(values))
indices = ChunkedArray(dictionary_indices(values))

# Build lookup table
lut = np.zeros((len(values.dictionary), 4), dtype=np.uint8)
lut = np.zeros((len(dictionary), 4), dtype=np.uint8)
if alpha is not None:
assert isinstance(alpha, int), "alpha must be an integer"
assert 0 <= alpha <= 255, "alpha must be between 0-255 (inclusive)."
Expand All @@ -195,7 +208,7 @@ def apply_categorical_cmap(
else:
lut[:, 3] = 255

for i, key in enumerate(values.dictionary):
for i, key in enumerate(dictionary):
color = cmap[key.as_py()]
if len(color) == 3:
lut[i, :3] = color
Expand All @@ -206,7 +219,7 @@ def apply_categorical_cmap(
"Expected color to be 3 or 4 values representing RGB or RGBA."
)

colors = lut[values.indices]
colors = lut[indices]

# If the alpha values are all 255, don't serialize
if (colors[:, 3] == 255).all():
Expand Down
92 changes: 46 additions & 46 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.8"
anywidget = "^0.9.0"
arro3-core = "^0.3.0-beta.1"
arro3-io = "^0.3.0-beta.1"
arro3-compute = "^0.3.0-beta.1"
arro3-core = "^0.3.0-beta.2"
arro3-io = "^0.3.0-beta.2"
arro3-compute = "^0.3.0-beta.2"
ipywidgets = ">=7.6.0"
numpy = ">=1.14"
# The same version pin as geopandas
Expand Down
12 changes: 5 additions & 7 deletions tests/test_colormap.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import pytest
from arro3.core import Array, DataType

from lonboard.colormap import apply_categorical_cmap


def test_discrete_cmap():
pd = pytest.importorskip("pandas")

values = ["red", "green", "blue", "blue", "red"]
df = pd.DataFrame({"val": values})
str_values = ["red", "green", "blue", "blue", "red"]
values = Array(str_values, type=DataType.string())
cmap = {
"red": [255, 0, 0],
"green": [0, 255, 0],
"blue": [0, 0, 255],
}
colors = apply_categorical_cmap(df["val"], cmap)
colors = apply_categorical_cmap(values, cmap)

for i, val in enumerate(values):
for i, val in enumerate(str_values):
assert list(colors[i]) == cmap[val]

0 comments on commit 07187a0

Please sign in to comment.