Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add_search_dataset_endpoint #16

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ docs/_build/

# PyCharm
.idea/

# Parquet files
timeseries/
48 changes: 37 additions & 11 deletions fishnet_cod/api/routers/datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from fastapi import APIRouter, Query
from typing import Optional, List, Union
import asyncio
from typing import Annotated, Awaitable, List, Optional, Union

Expand Down Expand Up @@ -43,21 +45,33 @@
async def get_datasets(
view_as: Optional[str] = None,
by: Optional[str] = None,
search: Optional[str] = None,
page: int = 1,
page_size: int = 20,
) -> List[DatasetResponse]:
) -> dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a Pydantic type for the API docs

"""
Get all datasets. Returns a list of tuples of datasets and their permission status for the given `view_as` user.
If `view_as` is not given, the permission status will be `none` for all datasets.
If `by` is given, it will return all datasets owned by that user.
"""
dataset_resp: Union[PageableRequest[Dataset], PageableResponse[Dataset]]

if by:
dataset_resp = Dataset.filter(owner=by)
else:
dataset_resp = Dataset.fetch_objects()
datasets = await dataset_resp.page(page=page, page_size=page_size)

all_datasets = await dataset_resp.all()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a very costly operation, only use it if you cannot avoid it


if search:
all_datasets = [dataset for dataset in all_datasets if search.lower() in dataset.name.lower()]

total_count = len(all_datasets)

start = (page - 1) * page_size
end = start + page_size
datasets = all_datasets[start:end]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the .page operation instead of fetching all datasets


async def fetch_dataset_with_permissions(dataset):
ts_ids = [ts_id for ts_id in dataset.timeseriesIDs]
permissions = await fetch_permissions(dataset.item_hash, ts_ids, view_as)
Expand All @@ -67,10 +81,15 @@ async def fetch_dataset_with_permissions(dataset):
)

if view_as:
return await asyncio.gather(*[fetch_dataset_with_permissions(dataset) for dataset in datasets])
data = await asyncio.gather(*[fetch_dataset_with_permissions(dataset) for dataset in datasets])
else:
return [DatasetResponse(**dataset.dict(), permission_status=None) for dataset in datasets]
data = [DatasetResponse(**dataset.dict(), permission_status=None) for dataset in datasets]

return {
"total": total_count,
"page": page,
"data": data
}

@router.put("")
async def upload_dataset(
Expand Down Expand Up @@ -124,7 +143,8 @@ async def get_dataset(dataset_id: str, view_as: Optional[str] = None) -> Dataset

return DatasetResponse(
**dataset.dict(),
permission_status=get_dataset_permission_status(dataset, permissions) if view_as else None,
permission_status=get_dataset_permission_status(
dataset, permissions) if view_as else None,
)


Expand All @@ -138,11 +158,13 @@ async def get_dataset_permissions(dataset_id: str) -> List[Permission]:
raise HTTPException(status_code=404, detail="No Dataset found")
ts_ids = [ts_id for ts_id in dataset.timeseriesIDs]
matched_permission_records = [
Permission.filter(timeseriesID=ts_id, status=PermissionStatus.GRANTED).all()
Permission.filter(timeseriesID=ts_id,
status=PermissionStatus.GRANTED).all()
for ts_id in ts_ids
] + [Permission.filter(datasetID=dataset_id, status=PermissionStatus.GRANTED).all()]
records = await asyncio.gather(*matched_permission_records)
permission_records = [element for row in records for element in row if element]
permission_records = [
element for row in records for element in row if element]

return permission_records

Expand All @@ -167,7 +189,8 @@ async def get_dataset_metaplex_dataset(dataset_id: str) -> FungibleAssetStandard
attributes=[
Attribute(trait_type="Owner", value=dataset.owner),
Attribute(trait_type="Last Updated", value=dataset.timestamp),
Attribute(trait_type="Columns", value=str(len(dataset.timeseriesIDs))),
Attribute(trait_type="Columns", value=str(
len(dataset.timeseriesIDs))),
],
)

Expand Down Expand Up @@ -198,7 +221,8 @@ async def upload_dataset_with_timeseries(
dataset = await upload_dataset(dataset_req=req.dataset, user=user)
return UploadDatasetTimeseriesResponse(
dataset=dataset,
timeseries=[ts for ts in timeseries if not isinstance(ts, BaseException)],
timeseries=[ts for ts in timeseries if not isinstance(
ts, BaseException)],
)


Expand Down Expand Up @@ -252,7 +276,8 @@ async def does_dataset_cost(dataset_id: str) -> bool:
Depends(
ConditionalJWTWalletAuth(
jwt_credentials_manager,
lambda request: does_dataset_cost(request.path_params["dataset_id"]),
lambda request: does_dataset_cost(
request.path_params["dataset_id"]),
)
),
]
Expand All @@ -272,7 +297,8 @@ async def get_dataset_timeseries_csv(

timeseries = await Timeseries.fetch(dataset.timeseriesIDs).all()

df = get_harmonized_timeseries_df(timeseries, column_names=ColumnNameType.name)
df = get_harmonized_timeseries_df(
timeseries, column_names=ColumnNameType.name)

if user:
await check_access(timeseries, user)
Expand Down