Skip to content

Commit

Permalink
update: fix type
Browse files Browse the repository at this point in the history
  • Loading branch information
chanwutk committed Apr 16, 2024
1 parent 4412f6e commit 7d6b781
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 19 deletions.
5 changes: 3 additions & 2 deletions spatialyze/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from .video_processor.types import Float33

if TYPE_CHECKING:
from psycopg2 import connection as Connection
from psycopg2 import cursor as Cursor
from psycopg2._psycopg import connection as Connection
from psycopg2._psycopg import cursor as Cursor

from .predicate import PredicateNode

Expand Down Expand Up @@ -313,6 +313,7 @@ def predicate(self, predicate: "PredicateNode", temporal: bool = True):
def sql(self, query: str) -> pd.DataFrame:
results, cursor = self.execute_and_cursor(query)
description = cursor.description
assert description is not None
cursor.close()
return pd.DataFrame(results, columns=[d.name for d in description])

Expand Down
4 changes: 2 additions & 2 deletions spatialyze/predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def __or__(self, other):
],
)

def __eq__(self, other):
def __eq__(self, other): # pyright: ignore [reportIncompatibleMethodOverride]
other = wrap_literal(other)
return CompOpNode(self, "eq", other)

def __ne__(self, other):
def __ne__(self, other): # pyright: ignore [reportIncompatibleMethodOverride]
other = wrap_literal(other)
return CompOpNode(self, "ne", other)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@ def _r(acc: "tuple[int, list[tuple[int, int]]]", frames: int):
return None, {self.classname(): metadata}
except BaseException:
_, output = DecodeFrame()._run(payload)
return None, {self.classname(): DecodeFrame.get(output)}
images = DecodeFrame.get(output)
assert images is not None
return None, {self.classname(): images}
2 changes: 1 addition & 1 deletion spatialyze/video_processor/stages/depth_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def eval_all(self, input_images: "list[npt.NDArray | None]"):
# Load image and preprocess
input_image = pil.fromarray(im[:, :, [2, 1, 0]])
original_width, original_height = input_image.size
input_image = input_image.resize((self.feed_width, self.feed_height), pil.LANCZOS)
input_image = input_image.resize((self.feed_width, self.feed_height), pil.Resampling.LANCZOS)
input_image = transforms.ToTensor()(input_image).unsqueeze(0)

# PREDICTION
Expand Down
12 changes: 6 additions & 6 deletions spatialyze/video_processor/stages/detection_2d/ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(self, df_annotations: "pd.DataFrame"):
self.class_to_id = {c: i for i, c in enumerate(self.id_to_classes)}

def _run(self, payload: "Payload"):
metadata: "list[Metadatum | None]" = []
metadata: list[Metadatum] = []
dimension = payload.video.dimension
for i, cc in enumerate(payload.video._camera_configs):
fid = cc.frame_id
Expand Down Expand Up @@ -222,10 +222,10 @@ def _run(self, payload: "Payload"):
if len(tensor) == 0:
metadata.append(Metadatum(torch.Tensor([]), yolo_classes, []))
else:
metadata.append(
Metadatum(
torch.Tensor(tensor), yolo_classes, [DetectionId(i, _id) for _id in ids]
)
)
metadata.append(Metadatum(
torch.Tensor(tensor),
yolo_classes,
[DetectionId(i, _id) for _id in ids],
))

return None, {self.classname(): metadata}
12 changes: 6 additions & 6 deletions spatialyze/video_processor/stages/detection_3d/ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, df_annotations: "pd.DataFrame"):
self.class_to_id = {c: i for i, c in enumerate(self.id_to_classes)}

def _run(self, payload: "Payload"):
metadata: "list[Metadatum | None]" = []
metadata: list[Metadatum] = []
dimension = payload.video.dimension
for i, cc in enumerate(payload.video._camera_configs):
fid = cc.frame_id
Expand Down Expand Up @@ -79,10 +79,10 @@ def _run(self, payload: "Payload"):
if len(tensor) == 0:
metadata.append(Metadatum(torch.Tensor([]), yolo_classes, []))
else:
metadata.append(
Metadatum(
torch.Tensor(tensor), yolo_classes, [DetectionId(i, _id) for _id in ids]
)
)
metadata.append(Metadatum(
torch.Tensor(tensor),
yolo_classes,
[DetectionId(i, _id) for _id in ids],
))

return None, {self.classname(): metadata}
4 changes: 3 additions & 1 deletion spatialyze/video_processor/stages/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,6 @@ def tqdm(cls, iterable: "Iterable[_T2]", *args, **kwargs) -> "Iterable[_T2]":
def _get_classnames(cls: "type") -> "list[str]":
if cls == Stage:
return []
return [*_get_classnames(cls.__base__), cls.__name__]
base = cls.__base__
assert base is not None
return [*_get_classnames(base), cls.__name__]

0 comments on commit 7d6b781

Please sign in to comment.