-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcts.go
94 lines (80 loc) · 1.64 KB
/
mcts.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package main
import (
"math"
// "math/rand"
)
type MCNode struct {
state State
visits int
utility int
parent *MCNode
children []MCNode
}
type State struct {
}
type Action struct {
}
func selectNode(node MCNode) MCNode {
if node.visits == 0 {
return node
}
for _, child := range node.children {
if child.visits == 0 {
return child
}
}
score := float64(0)
result := node
for _, child := range node.children {
newscore := selectfn(child)
if newscore > score {
score = newscore
result = child
}
}
return selectNode(result)
}
func expand(node MCNode) {
actions := findLegals(node.state)
for _, action := range actions {
newState := findNextState(node.state, action)
newNode := MCNode{newState, 0, 0, &node, []MCNode{}}
node.children = append(node.children, newNode)
}
}
var roles = []int{}
func simulate(state State) int {
if findTerminal(state) {
return findReward(state)
}
newState := state
for _, role := range roles {
_ = role
options := findLegals(newState)
best := r.Intn(len(options))
newState = findNextState(newState, options[best])
}
return simulate(newState)
}
func backPropagate(node MCNode, score int) {
node.visits++
node.utility += score
if node.parent != nil {
backPropagate(*node.parent, score)
}
}
func findReward(state State) int {
return 0
}
func findLegals(state State) []Action {
return []Action{}
}
func findNextState(state State, action Action) State {
return State{}
}
func selectfn(node MCNode) float64 {
return float64(node.utility) + 2.0*math.Sqrt(math.Log(float64(node.parent.visits))/float64(node.visits))
}
func findTerminal(state State) bool {
return false
}