Skip to content

Commit

Permalink
naive baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
chanwutk committed Aug 10, 2023
1 parent 58e79ea commit d62fc76
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
5 changes: 4 additions & 1 deletion playground/run-ablation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,10 @@
" cache=ss_cache,\n",
" ))\n",
"\n",
" pipeline.add_filter(FromTracking2DAndRoad())\n",
" if geo_depth:\n",
" pipeline.add_filter(FromTracking2DAndRoad())\n",
" else:\n",
" pipeline.add_filter(FromTracking3DAndDepth())\n",
"\n",
" # Segment Trajectory\n",
" # pipeline.add_filter(FromTracking3D())\n",
Expand Down
8 changes: 6 additions & 2 deletions spatialyze/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
FromTracking2DAndRoad,
)
from .video_processor.stages.tracking_3d.tracking_3d import Metadatum as T3DMetadatum
from .video_processor.stages.depth_estimation import DepthEstimation
from .video_processor.stages.detection_3d.from_detection_2d_and_depth import FromDetection2DAndDepth
from .video_processor.stages.tracking_3d.from_tracking_2d_and_depth import FromTracking2DAndDepth
from .video_processor.utils.format_trajectory import format_trajectory
from .video_processor.utils.get_tracks import get_tracks
from .video_processor.utils.insert_trajectory import insert_trajectory
Expand Down Expand Up @@ -107,14 +110,15 @@ def _execute(world: "World", optimization=True):
InView(distance=50, predicate=world.predicates),
YoloDetection(),
objtypes_filter,
FromDetection2DAndRoad(),
DepthEstimation(),
FromDetection2DAndDepth(),
*(
[DetectionEstimation()]
if all(t in ["car", "truck"] for t in objtypes_filter.types)
else []
),
StrongSORT(),
FromTracking2DAndRoad(),
FromTracking2DAndDepth(),
]
)
else:
Expand Down
8 changes: 4 additions & 4 deletions tests/workflow/test_simple_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def test_simple_workflow():

objects, trackings = _execute(world, optimization=False)

# with open(os.path.join(OUTPUT_DIR, 'simple-workflow-trackings.json'), 'w') as f:
# json.dump(trackings, f, indent=1, cls=MetadataJSONEncoder)
with open(os.path.join(OUTPUT_DIR, 'simple-workflow-trackings.json'), 'w') as f:
json.dump(trackings, f, indent=1, cls=MetadataJSONEncoder)

with open(os.path.join(OUTPUT_DIR, 'simple-workflow-trackings.json'), 'r') as f:
trackings_groundtruth = json.load(f)
Expand All @@ -79,8 +79,8 @@ def test_simple_workflow():
assert p.object_type == g['object_type'], (p.object_type, g['object_type'])
assert str(p.timestamp) == g['timestamp'], (p.timestamp, g['timestamp'])

# with open(os.path.join(OUTPUT_DIR, 'simple-workflow-objects.json'), 'w') as f:
# json.dump(objects, f, indent=1)
with open(os.path.join(OUTPUT_DIR, 'simple-workflow-objects.json'), 'w') as f:
json.dump(objects, f, indent=1)

with open(os.path.join(OUTPUT_DIR, 'simple-workflow-objects.json'), 'r') as f:
objects_groundtruth = json.load(f)
Expand Down

0 comments on commit d62fc76

Please sign in to comment.