-
Notifications
You must be signed in to change notification settings - Fork 0
/
q_learning.py
72 lines (61 loc) · 2.28 KB
/
q_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import itertools
import random
import numpy as np
def get_random_action(env):
"""
Returns a random action for the specific environment.
:param env: OpenAI Gym environment
:return: random action
"""
return env.action_space[0].sample()
def get_best_action(q_table, state):
"""
Returns the best action for the current state given the q table.
:param q_table: q table
:param state: current state
:return: best action
"""
return np.argmax(q_table[state])
def get_action(env, q_table, state, epsilon):
"""
Returns the best action following epsilon greedy policy for the current state given the q table.
:param env: OpenAI Gym environment
:param q_table: q table
:param state: current state
:param epsilon: exploration rate
:return:
"""
num_actions = env.action_space[0].n
probability = np.random.random() + epsilon / num_actions
if probability < epsilon:
action = random.randint(0, num_actions-1)
return action
else:
return get_best_action(q_table, state)
def random_q_table(min_val, max_val, size):
"""
Returns randomly initialized n-dimensional q table.
:param min_val: lower bound of values
:param max_val: upper bound of values
:param size: size of the q table
:return: n-dimensional q table
"""
return np.random.uniform(low=min_val, high=max_val, size=size)
def calculate_new_q_value(q_table, old_state, new_state, action, reward, lr=0.1, discount_factor=0.99):
"""
Calculates new q value for the current state given the new state, action and reward.
:param q_table: n-dimensional q table
:param old_state: old (current) state
:param new_state: new (next) state
:param action: action to be taken at state old_state
:param reward: reward received for performing action
:param lr: learning rate
:param discount_factor: discount factor
:return: new q value for old_state and action
"""
max_future_q = np.max(q_table[new_state])
# if isinstance(old_state, tuple):
current_q = q_table[old_state + (action,)]
# else:
# current_q = q_table[old_state, action]
return (1 - lr) * current_q + lr * (reward + discount_factor * max_future_q)