-
Notifications
You must be signed in to change notification settings - Fork 0
/
nqueens.py
62 lines (52 loc) · 2.29 KB
/
nqueens.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from problem import Problem
class NQueensProblem(Problem):
"""The problem of placing N queens on an NxN board with none attacking
each other. A state is represented as an N-element array, where
a value of r in the c-th entry means there is a queen at column c,
row r, and a value of -1 means that the c-th column has not been
filled in yet. We fill in columns left to right.
<Node (7, 3, 0, 2, 5, 1, 6, 4)>
"""
def __init__(self, N):
super().__init__(tuple([-1] * N))
self.N = N
def actions(self, state):
"""In the leftmost empty column, try all non-conflicting rows."""
print(state.state)
if state.state[-1] != -1:
return [] # All columns filled; no successors
else:
col = state.state.index(-1)
return [row for row in range(self.N)
if not self.conflicted(state, row, col)]
def result(self, state, row):
"""Place the next queen at the given row."""
col = state.state.index(-1)
new = list(state.state[:])
new[col] = row
return tuple(new)
def conflicted(self, state, row, col):
"""Would placing a queen at (row, col) conflict with anything?"""
return any(self.conflict(row, col, state.state[c], c)
for c in range(col))
def conflict(self, row1, col1, row2, col2):
"""Would putting two queens in (row1, col1) and (row2, col2) conflict?"""
return (row1 == row2 or # same row
col1 == col2 or # same column
row1 - col1 == row2 - col2 or # same \ diagonal
row1 + col1 == row2 + col2) # same / diagonal
def is_goal(self, state):
"""Check if all columns filled, no conflicts."""
if state.state[-1] == -1:
return False
return not any(self.conflicted(state.state, state.state[col], col)
for col in range(len(state.state)))
def value(self, node):
"""Return number of conflicting queens for a given node"""
num_conflicts = 0
print(type(node))
for (r1, c1) in enumerate(node):
for (r2, c2) in enumerate(node):
if (r1, c1) != (r2, c2):
num_conflicts += self.conflict(r1, c1, r2, c2)
return num_conflicts