Skip to content

Commit

Permalink
refactor: Update graph traversal method names and heuristics
Browse files Browse the repository at this point in the history
  • Loading branch information
SverreNystad committed Aug 19, 2024
1 parent b2024af commit db197c7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
19 changes: 9 additions & 10 deletions backend/graphtraversal/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" This module contains the factory functions for the graph traversal algorithms """

from typing import Callable
from graphtraversal.algorithms.heuristics import euclidean_distance, manhattan_distance
from graphtraversal.algorithms.uninformed_search import (
Expand All @@ -10,38 +11,36 @@


graph_traversal_function_map: dict[str, Pathfinder] = {
"a star": AStarPathfinder(),
"breadth first search": BFSPathfinder(),
"depth first search": DFSPathfinder(),
"A* (A Star)": AStarPathfinder(),
"Breadth First Search (BFS)": BFSPathfinder(),
"Depth First Search (DFS)": DFSPathfinder(),
}


def get_pathfinder(graph_method_name: str) -> Pathfinder:
"""
This function returns the pathfinder for the given graph method name
"""
print(f"graph_method_name in factory: {graph_method_name}")
return graph_traversal_function_map[graph_method_name]


def get_graph_traversal_methods() -> list[str]:
"""
This function returns a list of all available graph traversal methods
"""
method_names: list[str] = list(graph_traversal_function_map.keys())
return method_names
return list(graph_traversal_function_map.keys())


def get_heuristics(graph_method_name: str) -> list[str]:
"""
This function returns a list of heuristics that are valid for the given graph method name
"""
match graph_method_name.lower():
case "a star":
return ["manhattan", "euclidean"]
case "breadth first search":
case "a* (a star)":
return ["Manhattan", "Euclidean"]
case "breadth first search (bfs)":
return []
case "depth first search":
case "depth first search (dfs)":
return []
case _:
raise ValueError(f"Invalid graph method name: {graph_method_name}")
Expand Down
19 changes: 13 additions & 6 deletions backend/graphtraversal/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from rest_framework.parsers import JSONParser

from graphtraversal.algorithms.pathfinder import Pathfinder
from graphtraversal.serializers import GraphHeuristicsSerializer, GraphTraversalMethodSerializer, GraphTraversalSerializer
from graphtraversal.serializers import (
GraphHeuristicsSerializer,
GraphTraversalMethodSerializer,
GraphTraversalSerializer,
)
from graphtraversal.map import Map, RestMap, Node, Position
from graphtraversal.factory import (
get_heuristic_function,
Expand All @@ -16,6 +20,7 @@
get_pathfinder,
)


@swagger_auto_schema(
method="post",
request_body=GraphTraversalSerializer,
Expand Down Expand Up @@ -61,7 +66,7 @@ def post_graph_traversal(request):
)

# Parse data
algorithm: str = (algorithm.lower()).strip()
algorithm: str = (algorithm).strip()

start_point: Node = Node(
Position(int(start_point.get("x", 0)), int(start_point.get("y", 0))),
Expand Down Expand Up @@ -122,32 +127,34 @@ def post_graph_traversal(request):
else:
return Response(status=status.HTTP_400_BAD_REQUEST)


@swagger_auto_schema(
method="get",
operation_summary="Get all legal graph traversal methods",
operation_description="This endpoint retrieves all legal graph traversal methods that this service provides.",
responses={200: GraphTraversalMethodSerializer },
responses={200: GraphTraversalMethodSerializer},
)
@api_view(["GET"])
def fetch_graph_traversal_methods(request):
return Response(status=status.HTTP_200_OK, data=get_graph_traversal_methods())


@swagger_auto_schema(
method="post",
operation_summary="Get all legal heuristics for a given graph traversal method",
operation_description="This endpoint retrieves all legal heuristics for a given graph traversal method that this service provides.",
request_body=GraphTraversalMethodSerializer,
responses={200: GraphHeuristicsSerializer },
responses={200: GraphHeuristicsSerializer},
)
@api_view(["POST"])
def fetch_graph_heuristics(request):

serializer = GraphTraversalMethodSerializer(data=request.data)
if serializer.is_valid():
heuristic = request.data.get("method", None)
if heuristic is None:
return Response(status=status.HTTP_400_BAD_REQUEST)

heuristics: list[str] = get_heuristics(heuristic)
return Response(status=status.HTTP_200_OK, data=heuristics)

Expand Down

0 comments on commit db197c7

Please sign in to comment.