-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrule_generators.py
110 lines (89 loc) · 3.4 KB
/
rule_generators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
class RandomRuleGeneratorMonkey:
"""
An iterable to just generate a random int from 0 - 11 indicating
which feature should be the next rule
It follows statistics from the monkey version of the task,
where there's 50 50 chance of an intra vs extra dimensional shift
"""
def __init__(self, seed, num_rules=12, num_dims=3):
self.rng = np.random.default_rng(seed)
self.num_rules = num_rules
self.num_dims = num_dims
if not num_rules % num_dims == 0:
raise ValueError(f"number of rules {num_rules} not divisible by number of dimensions {num_dims}")
self.num_rules_per_dim = num_rules // num_dims
self.rule = self.rng.integers(0, num_rules, 1)[0]
self.dimension = self.rule // self.num_rules_per_dim
def __iter__(self):
return self
def __next__(self):
"""
randomly generates integer from 0 - 11
"""
shift_type = self.rng.choice([0,1])
# If 0, intra dimensional shift
# If 1, extra dimensional shift
if shift_type==0:
self.dimension = self.dimension
else:
dimensions = np.arange(self.num_dims)
dimensions = dimensions[dimensions!=self.dimension]
self.dimension = self.rng.choice(dimensions)
features = np.arange(self.num_rules)
features = features[features//self.num_rules_per_dim==self.dimension]
features = features[features!=self.rule]
self.rule = self.rng.choice(features)
return self.rule
class RandomRuleGeneratorHuman:
"""
An iterable to just generate a random int from 0 - 11 indicating
which feature should be the next rule
It follows statistics from the human version of the task
"""
def __init__(self, seed, num_rules=12, num_dims=3):
self.rng = np.random.default_rng(seed)
self.rule = self.rng.integers(0, num_rules, 1)[0]
self.num_rules=num_rules
self.num_dims=num_dims
def __iter__(self):
return self
def __next__(self):
"""
randomly generates integer from 0 - 11
"""
features = np.arange(self.num_rules)
features = features[features!=self.rule]
self.rule = self.rng.choice(features)
return self.rule
class RandomRuleGeneratorValidRules:
"""
An iterable to just generate a random int from the list of valid rules you give it.
It follows statistics from the human version of the task
"""
def __init__(self, seed, valid_rules, num_rules=12, num_dims=3):
self.rng = np.random.default_rng(seed)
self.valid_rules = valid_rules
self.rule = self.rng.choice(self.valid_rules)
self.num_rules=num_rules
self.num_dims=num_dims
def __iter__(self):
return self
def __next__(self):
"""
randomly generates integer from 0 - 11
"""
next_possible_rules = self.valid_rules[self.valid_rules!=self.rule]
self.rule = self.rng.choice(next_possible_rules)
return self.rule
class ConstantRuleGenerator:
"""
A rule generator which only ever stays on the same rule, that's specified on initialization
"""
def __init__(self, num_rules, rule):
self.num_rules = num_rules
self.rule = rule
def __iter__(self):
return self
def __next__(self):
return self.rule