From c49a926a2fb9ac3e8b8c7d3593577b5444343e9e Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Wed, 28 Feb 2024 06:35:34 -0600 Subject: [PATCH] Move table shape into data explorer get_state request and test pandas state requests (posit-dev/positron-python#393) * Move table shape into get_state request and test pandas state requests * Handle parameter change, cleaning * fix pyright * Rename shape again * Dictify state result --- .../positron_ipykernel/data_explorer.py | 33 ++++--- .../positron_ipykernel/data_explorer_comm.py | 36 +++++--- .../tests/test_data_explorer.py | 88 ++++++++++++++----- 3 files changed, 114 insertions(+), 43 deletions(-) diff --git a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer.py b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer.py index 76ee2e00d3c..2190258a3f0 100644 --- a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer.py +++ b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer.py @@ -23,7 +23,6 @@ from .access_keys import decode_access_key from .data_explorer_comm import ( - BackendState, ColumnFilter, ColumnFilterCompareOp, ColumnSchema, @@ -42,6 +41,8 @@ SetSortColumnsRequest, TableData, TableSchema, + TableShape, + TableState, ) from .positron_comm import CommMessage, PositronComm from .third_party import pd_ @@ -128,7 +129,7 @@ def get_column_profile(self, request: GetColumnProfileRequest): return self._get_column_profile(request.params.profile_type, request.params.column_index) def get_state(self, request: GetStateRequest): - return self._get_state() + return self._get_state().dict() def _get_schema(self, column_start: int, num_columns: int) -> TableSchema: raise NotImplementedError @@ -154,7 +155,7 @@ def _get_column_profile( ) -> None: raise NotImplementedError - def _get_state(self) -> BackendState: + def _get_state(self) -> TableState: raise NotImplementedError @@ -185,6 +186,7 @@ class PandasView(DataExplorerTableView): "float64": "number", "mixed-integer": "number", "mixed-integer-float": "number", + "mixed": "unknown", "decimal": "number", "complex": "number", "categorical": "categorical", @@ -291,11 +293,7 @@ def _get_schema(self, column_start: int, num_columns: int) -> TableSchema: ) column_schemas.append(col_schema) - return TableSchema( - columns=column_schemas, - num_rows=self.table.shape[0], - total_num_columns=self.table.shape[1], - ) + return TableSchema(columns=column_schemas) def _get_data_values( self, row_start: int, num_rows: int, column_indices: Sequence[int] @@ -420,8 +418,12 @@ def _get_column_profile( ) -> None: pass - def _get_state(self) -> BackendState: - return BackendState(filters=self.filters, sort_keys=self.sort_keys) + def _get_state(self) -> TableState: + return TableState( + table_shape=TableShape(num_rows=self.table.shape[0], num_columns=self.table.shape[1]), + filters=self.filters, + sort_keys=self.sort_keys, + ) COMPARE_OPS = { @@ -503,7 +505,13 @@ def shutdown(self) -> None: for comm_id in list(self.comms.keys()): self._close_explorer(comm_id) - def register_table(self, table, title, variable_path=None, comm_id=None): + def register_table( + self, + table, + title, + variable_path: Optional[List[str]] = None, + comm_id=None, + ): """ Set up a new comm and data explorer table query wrapper to handle requests and manage state. @@ -552,6 +560,9 @@ def close_callback(msg): base_comm.on_close(close_callback) if variable_path is not None: + if not isinstance(variable_path, list): + raise ValueError(variable_path) + key = tuple(variable_path) self.comm_id_to_path[comm_id] = key diff --git a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer_comm.py b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer_comm.py index 843424cda02..25d1133d4d9 100644 --- a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer_comm.py +++ b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer_comm.py @@ -113,14 +113,6 @@ class TableSchema(BaseModel): description="Schema for each column in the table", ) - num_rows: int = Field( - description="Numbers of rows in the unfiltered dataset", - ) - - total_num_columns: int = Field( - description="Total number of columns in the unfiltered dataset", - ) - class TableData(BaseModel): """ @@ -211,11 +203,15 @@ class FreqtableCounts(BaseModel): ) -class BackendState(BaseModel): +class TableState(BaseModel): """ - The current backend state + The current backend table state """ + table_shape: TableShape = Field( + description="Provides number of rows and columns in table", + ) + filters: List[ColumnFilter] = Field( description="The set of currently applied filters", ) @@ -225,6 +221,20 @@ class BackendState(BaseModel): ) +class TableShape(BaseModel): + """ + Provides number of rows and columns in table + """ + + num_rows: int = Field( + description="Numbers of rows in the unfiltered dataset", + ) + + num_columns: int = Field( + description="Number of columns in the unfiltered dataset", + ) + + class ColumnSchema(BaseModel): """ Schema for a column in a table @@ -548,7 +558,7 @@ class GetColumnProfileRequest(BaseModel): class GetStateRequest(BaseModel): """ - Request the current backend state (applied filters and sort columns) + Request the current table state (applied filters and sort columns) """ method: Literal[DataExplorerBackendRequest.GetState] = Field( @@ -606,7 +616,9 @@ class SchemaUpdateParams(BaseModel): FreqtableCounts.update_forward_refs() -BackendState.update_forward_refs() +TableState.update_forward_refs() + +TableShape.update_forward_refs() ColumnSchema.update_forward_refs() diff --git a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/tests/test_data_explorer.py b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/tests/test_data_explorer.py index 453229da5c8..8dc657ca69d 100644 --- a/extensions/positron-python/pythonFiles/positron/positron_ipykernel/tests/test_data_explorer.py +++ b/extensions/positron-python/pythonFiles/positron/positron_ipykernel/tests/test_data_explorer.py @@ -11,7 +11,12 @@ from ..access_keys import encode_access_key from .._vendor.pydantic import BaseModel from ..data_explorer import COMPARE_OPS, DataExplorerService -from ..data_explorer_comm import ColumnSchema, ColumnSortKey, FilterResult +from ..data_explorer_comm import ( + ColumnFilter, + ColumnSchema, + ColumnSortKey, + FilterResult, +) from .conftest import DummyComm, PositronShell from .utils import json_rpc_notification, json_rpc_request, json_rpc_response @@ -58,6 +63,17 @@ def get_last_message(de_service: DataExplorerService, comm_id: str): # Test basic service functionality +class MyData: + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __repr__(self): + return repr(self.value) + + SIMPLE_PANDAS_DF = pd.DataFrame( { "a": [1, 2, 3, 4, 5], @@ -73,6 +89,7 @@ def get_last_message(de_service: DataExplorerService, comm_id: str): "2024-01-05 00:00:00", ] ), + "f": [None, MyData(5), MyData(-1), None, None], } ) @@ -216,6 +233,7 @@ def _check_update_variable(name, update_type="schema", discard_state=True): # Do a simple update and make sure that sort keys are preserved x_comm_id = list(de_service.path_to_comm_ids[path_x])[0] + x_sort_keys = [{"column_index": 0, "ascending": True}] msg = json_rpc_request( "set_sort_columns", params={"sort_keys": [{"column_index": 0, "ascending": True}]}, @@ -227,9 +245,15 @@ def _check_update_variable(name, update_type="schema", discard_state=True): _check_update_variable("x", update_type="data") tv = de_service.table_views[x_comm_id] - assert tv.sort_keys == [ColumnSortKey(column_index=0, ascending=True)] + assert tv.sort_keys == [ColumnSortKey(**k) for k in x_sort_keys] assert tv._need_recompute + pf = PandasFixture(de_service) + new_state = pf.get_state("x") + assert new_state["table_shape"]["num_rows"] == 5 + assert new_state["table_shape"]["num_columns"] == 1 + assert new_state["sort_keys"] == [ColumnSortKey(**k) for k in x_sort_keys] + # Execute code that triggers an update event for big_x because it's large shell.run_cell("print('hello world')") _check_update_variable("big_x", update_type="data") @@ -281,17 +305,30 @@ def test_shutdown(de_service: DataExplorerService): class PandasFixture: def __init__(self, de_service: DataExplorerService): self.de_service = de_service - self._table_ids = {} self.register_table("simple", SIMPLE_PANDAS_DF) def register_table(self, table_name: str, table): comm_id = guid() - self.de_service.register_table(table, table_name, comm_id=comm_id) - self._table_ids[table_name] = comm_id + + paths = self.de_service.get_paths_for_variable(table_name) + for path in paths: + for old_comm_id in list(self.de_service.path_to_comm_ids[path]): + self.de_service._close_explorer(old_comm_id) + + self.de_service.register_table( + table, + table_name, + comm_id=comm_id, + variable_path=[encode_access_key(table_name)], + ) def do_json_rpc(self, table_name, method, **params): - comm_id = self._table_ids[table_name] + paths = self.de_service.get_paths_for_variable(table_name) + assert len(paths) == 1 + + comm_id = list(self.de_service.path_to_comm_ids[paths[0]])[0] + request = json_rpc_request( method, params=params, @@ -313,6 +350,9 @@ def get_schema(self, table_name, start_index, num_columns): num_columns=num_columns, ) + def get_state(self, table_name): + return self.do_json_rpc(table_name, "get_state") + def get_data_values(self, table_name, **params): return self.do_json_rpc(table_name, "get_data_values", **params) @@ -372,10 +412,26 @@ def _wrap_json(model: Type[BaseModel], data: JsonRecords): return [model(**d).dict() for d in data] +def test_pandas_get_state(pandas_fixture: PandasFixture): + result = pandas_fixture.get_state("simple") + assert result["table_shape"]["num_rows"] == 5 + assert result["table_shape"]["num_columns"] == 6 + + sort_keys = [ + {"column_index": 0, "ascending": True}, + {"column_index": 1, "ascending": False}, + ] + filters = [_compare_filter(0, ">", 0), _compare_filter(0, "<", 5)] + pandas_fixture.set_sort_columns("simple", sort_keys=sort_keys) + pandas_fixture.set_column_filters("simple", filters=filters) + + result = pandas_fixture.get_state("simple") + assert result["sort_keys"] == sort_keys + assert result["filters"] == [ColumnFilter(**f) for f in filters] + + def test_pandas_get_schema(pandas_fixture: PandasFixture): result = pandas_fixture.get_schema("simple", 0, 100) - assert result["num_rows"] == 5 - assert result["total_num_columns"] == 5 full_schema = [ { @@ -403,20 +459,15 @@ def test_pandas_get_schema(pandas_fixture: PandasFixture): "type_name": "datetime64[ns]", "type_display": "datetime", }, + {"column_name": "f", "type_name": "mixed", "type_display": "unknown"}, ] assert result["columns"] == _wrap_json(ColumnSchema, full_schema) result = pandas_fixture.get_schema("simple", 2, 100) - assert result["num_rows"] == 5 - assert result["total_num_columns"] == 5 - assert result["columns"] == _wrap_json(ColumnSchema, full_schema[2:]) - result = pandas_fixture.get_schema("simple", 5, 100) - assert result["num_rows"] == 5 - assert result["total_num_columns"] == 5 - + result = pandas_fixture.get_schema("simple", 6, 100) assert result["columns"] == [] # Make a really big schema @@ -426,13 +477,9 @@ def test_pandas_get_schema(pandas_fixture: PandasFixture): pandas_fixture.register_table(bigger_name, bigger_df) result = pandas_fixture.get_schema(bigger_name, 0, 100) - assert result["num_rows"] == 5 - assert result["total_num_columns"] == 500 assert result["columns"] == _wrap_json(ColumnSchema, bigger_schema[:100]) result = pandas_fixture.get_schema(bigger_name, 10, 10) - assert result["num_rows"] == 5 - assert result["total_num_columns"] == 500 assert result["columns"] == _wrap_json(ColumnSchema, bigger_schema[10:20]) @@ -466,7 +513,7 @@ def test_pandas_get_data_values(pandas_fixture: PandasFixture): "simple", row_start_index=0, num_rows=20, - column_indices=list(range(5)), + column_indices=list(range(6)), ) # TODO: pandas pads all values to fixed width, do we want to do @@ -483,6 +530,7 @@ def test_pandas_get_data_values(pandas_fixture: PandasFixture): "2024-01-04 00:00:00", "2024-01-05 00:00:00", ], + ["None", "5", "-1", "None", "None"], ] assert _trim_whitespace(result["columns"]) == expected_columns