Skip to content

Commit

Permalink
Initial tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hannahkim24 committed Jun 1, 2022
1 parent f6bbdc7 commit 9bf4e69
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 45 deletions.
188 changes: 157 additions & 31 deletions examples/4-20_sort.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"outputs": [],
"source": [
"%load_ext autoreload\n",
"\n",
"%autoreload 2"
]
},
Expand All @@ -19,7 +20,8 @@
"import meerkat as mk\n",
"import os\n",
"import torch\n",
"import numpy as np"
"import numpy as np\n",
"import pandas as pd"
]
},
{
Expand All @@ -34,44 +36,44 @@
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text (NumpyArrayColumn)</th>\n",
" <th>image (NumpyArrayColumn)</th>\n",
" <th>text (TensorColumn)</th>\n",
" <th>image (PandasSeriesColumn)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>tensor(1)</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>tensor(5)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>tensor(1)</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>tensor(4)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3</td>\n",
" <td>tensor(3)</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>4</td>\n",
" <td>tensor(4)</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>4</td>\n",
" <td>tensor(4)</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
Expand All @@ -88,8 +90,8 @@
],
"source": [
"dp2 = mk.DataPanel({\n",
" 'text': mk.NumpyArrayColumn([1,5,1,4,3,4,4]), \n",
" 'image': mk.NumpyArrayColumn([9,4,0,4,0,2,1]), \n",
" 'text': mk.TensorColumn([1,5,1,4,3,4,4]), \n",
" 'image': mk.PandasSeriesColumn([9,4,0,4,0,2,1]), \n",
" \n",
"}) \n",
"\n",
Expand All @@ -98,55 +100,63 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'meerkat.columns.numpy_column.NumpyArrayColumn'>\n",
"<class 'meerkat.columns.numpy_column.NumpyArrayColumn'>\n"
]
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text (NumpyArrayColumn)</th>\n",
" <th>image (NumpyArrayColumn)</th>\n",
" <th>text (NumpyArrayColumn)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand All @@ -155,15 +165,74 @@
"DataPanel(nrows: 7, ncols: 2)"
]
},
"execution_count": 4,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dp2.sort(by=[\"image\", \"text\"])\n",
"\n",
"# have to convert back to original type\n"
"dp2.sort(by=[\"text\", \"image\"], ascending=True)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>(NumpyArrayColumn)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"NumpyArrayColumn(array([1, 5, 1, 4, 3, 4, 4]))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dp2[\"text\"]\n"
]
},
{
Expand All @@ -175,25 +244,82 @@
"test2 = mk.PandasSeriesColumn([3,2, 1])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"test2 = mk.DataPanel({\n",
" 'tensor': mk.TensorColumn([3,1,2]), \n",
" 'pandas': mk.PandasSeriesColumn([5,4,6]), \n",
" }) \n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"test1 = test2.sort(by=[\"tensor\"])\n",
"test3 = test2.sort(by=[\"pandas\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tensor (TensorColumn)</th>\n",
" <th>pandas (PandasSeriesColumn)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>tensor(1)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>tensor(2)</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>tensor(3)</td>\n",
" <td>5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"(3,)"
"DataPanel(nrows: 3, ncols: 2)"
]
},
"execution_count": 8,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test2.shape"
"test1\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
46 changes: 32 additions & 14 deletions meerkat/datapanel.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,38 +762,56 @@ def sort(
"""

pandaColType = '<class \'meerkat.columns.pandas_column.PandasSeriesColumn\'>'
numpyColType = '<class \'meerkat.columns.numpy_column.NumpyArrayColumn\'>'
tensorColType = '<class \'meerkat.columns.tensor_column.TensorColumn\'>'
panda_col_type = '<class \'meerkat.columns.pandas_column.PandasSeriesColumn\'>'
numpy_col_type = '<class \'meerkat.columns.numpy_column.NumpyArrayColumn\'>'
tensor_col_type = '<class \'meerkat.columns.tensor_column.TensorColumn\'>'

keys = []


if len(by) > 1:
if len(by) > 1: # Sort with multiple column
tensor_col_name = "none"
panda_col_name = "none"
for col in by:
currColType = str(type(self[col]))
if currColType == pandaColType:

curr_col_type = str(type(self[col]))
print(curr_col_type)
# Convert all columns to numpy type
if curr_col_type == panda_col_type:
self[col] = self[col].values
panda_col_name = col

if currColType == tensorColType:
if curr_col_type == tensor_col_type:
self[col] = self[col].numpy()
tensor_col_name = col

for col in by[::-1]:
keys.append(self[col])

# Sort numpy columns
sorted_indices = np.lexsort(keys = keys)

# !! This doesn't update self!!
self = self.lz[sorted_indices]


# Convert columns to original types
if tensor_col_name != "none":
self[tensor_col_name] = torch.from_numpy(np.array(self[tensor_col_name]))
if panda_col_name != "none":
self[panda_col_name] = pd.Series(np.array(self[panda_col_name]))

else:
currColType = str(type(self[by[0]]))
if currColType == pandaColType:
else: # Sort with single column
curr_col_type = str(type(self[by[0]]))
if curr_col_type == panda_col_type:
sorted_indices = PandasSeriesColumn.argsort(self[by[0]], ascending=ascending, kind=kind)
if currColType == numpyColType:
if curr_col_type == numpy_col_type:
sorted_indices = NumpyArrayColumn.argsort(self[by[0]], ascending=ascending, kind=kind)
if currColType == tensorColType:
if curr_col_type == tensor_col_type:
sorted_indices = TensorColumn.argsort(self[by[0]], ascending=ascending, kind=kind)
self = self.lz[sorted_indices]

return self.lz[sorted_indices]

return self


def items(self):
Expand Down
Loading

0 comments on commit 9bf4e69

Please sign in to comment.