Skip to content

Commit

Permalink
Add custom timestamp accessor serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Sep 27, 2024
1 parent 6addb2e commit b346810
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 15 deletions.
72 changes: 64 additions & 8 deletions lonboard/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,28 @@

import math
from io import BytesIO
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
from arro3.core import Array, ChunkedArray, RecordBatch, Table
from typing import TYPE_CHECKING, List, Optional, Union, overload

import arro3.compute as ac
from arro3.core import (
Array,
ChunkedArray,
DataType,
RecordBatch,
Scalar,
Table,
list_array,
list_flatten,
list_offsets,
)
from traitlets import TraitError

from lonboard._constants import MIN_INTEGER_FLOAT32
from lonboard.models import ViewState

if TYPE_CHECKING:
from numpy.typing import NDArray
from lonboard._layer import BaseArrowLayer
from lonboard.experimental._layer import TripsLayer


DEFAULT_PARQUET_COMPRESSION = "ZSTD"
Expand Down Expand Up @@ -91,7 +103,20 @@ def serialize_pyarrow_column(
return serialize_table_to_parquet(pyarrow_table, max_chunksize=max_chunksize)


def serialize_accessor(data: Union[List[int], Tuple[int], NDArray[np.uint8]], obj):
@overload
def serialize_accessor(
data: ChunkedArray,
obj: BaseArrowLayer,
) -> List[bytes]: ...
@overload
def serialize_accessor(
data: Union[str, int, float, list, tuple, bytes],
obj: BaseArrowLayer,
) -> Union[str, int, float, list, tuple, bytes]: ...
def serialize_accessor(
data: Union[str, int, float, list, tuple, bytes, ChunkedArray],
obj: BaseArrowLayer,
):
if data is None:
return None

Expand All @@ -100,12 +125,12 @@ def serialize_accessor(data: Union[List[int], Tuple[int], NDArray[np.uint8]], ob
if isinstance(data, (str, int, float, list, tuple, bytes)):
return data

assert isinstance(data, (ChunkedArray, Array))
assert isinstance(data, ChunkedArray)
validate_accessor_length_matches_table(data, obj.table)
return serialize_pyarrow_column(data, max_chunksize=obj._rows_per_chunk)


def serialize_table(data, obj):
def serialize_table(data: Table, obj: BaseArrowLayer):
assert isinstance(data, Table), "expected Arrow table"
return serialize_table_to_parquet(data, max_chunksize=obj._rows_per_chunk)

Expand Down Expand Up @@ -135,5 +160,36 @@ def serialize_view_state(data: Optional[ViewState], obj):
return data._asdict()


# timestamps = layer.get_timestamps
def serialize_timestamp_accessor(
timestamps: ChunkedArray, obj: TripsLayer
) -> List[bytes]:
"""
Subtract off min timestamp to fit into f32 integer range.
Then cast to float32.
"""
# Cast to int64 type
timestamps = timestamps.cast(DataType.list(DataType.int64()))

min_timestamp = ac.min(list_flatten(timestamps))
start_offset_adjustment = Scalar(
MIN_INTEGER_FLOAT32 - min_timestamp.as_py(), type=DataType.int64()
)

list_offsets_iter = list_offsets(timestamps)
inner_values_iter = list_flatten(timestamps)

offsetted_chunks = []
for offsets, inner_values in zip(list_offsets_iter, inner_values_iter):
offsetted_values = ac.add(inner_values, start_offset_adjustment)
f32_values = offsetted_values.cast(DataType.int64()).cast(DataType.float32())
offsetted_chunks.append(list_array(offsets, f32_values))

f32_timestamps_col = ChunkedArray(offsetted_chunks)
return serialize_accessor(f32_timestamps_col, obj)


ACCESSOR_SERIALIZATION = {"to_json": serialize_accessor}
TIMESTAMP_ACCESSOR_SERIALIZATION = {"to_json": serialize_timestamp_accessor}
TABLE_SERIALIZATION = {"to_json": serialize_table}
12 changes: 5 additions & 7 deletions lonboard/experimental/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
import warnings
from typing import TYPE_CHECKING, Any, Tuple
from typing import TYPE_CHECKING, Any

import arro3.compute as ac
from arro3.core import (
Expand All @@ -18,7 +18,7 @@
from traitlets.traitlets import TraitType

from lonboard._constants import MAX_INTEGER_FLOAT32, MIN_INTEGER_FLOAT32
from lonboard._serialization import ACCESSOR_SERIALIZATION
from lonboard._serialization import TIMESTAMP_ACCESSOR_SERIALIZATION
from lonboard._utils import get_geometry_column_index
from lonboard.traits import FixedErrorTraitType

Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.tag(sync=True, **ACCESSOR_SERIALIZATION)
self.tag(sync=True, **TIMESTAMP_ACCESSOR_SERIALIZATION)

def reduce_precision(
self, obj: BaseArrowLayer, value: ChunkedArray
Expand Down Expand Up @@ -189,7 +189,7 @@ def validate_timestamp_offsets(self, obj: BaseArrowLayer, value: ChunkedArray):
info="timestamp array's offsets to match geometry array's offsets.",
)

def validate(self, obj: BaseArrowLayer, value) -> Tuple[Scalar, ChunkedArray]:
def validate(self, obj: BaseArrowLayer, value) -> ChunkedArray:
if hasattr(value, "__arrow_c_array__"):
value = ChunkedArray([Array.from_arrow(value)])
elif hasattr(value, "__arrow_c_stream__"):
Expand All @@ -210,6 +210,4 @@ def validate(self, obj: BaseArrowLayer, value) -> Tuple[Scalar, ChunkedArray]:
value = self.reduce_precision(obj, value)
value = value.rechunk(max_chunksize=obj._rows_per_chunk)
self.validate_timestamp_offsets(obj, value)

min_timestamp = ac.min(list_flatten(value))
return min_timestamp, value
return value

0 comments on commit b346810

Please sign in to comment.