-
Notifications
You must be signed in to change notification settings - Fork 0
/
Genome.pde
473 lines (397 loc) · 18.9 KB
/
Genome.pde
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
class Genome {
ArrayList<connectionGene> genes = new ArrayList<connectionGene>();//a list of connections between nodes which represent the NN
ArrayList<Node> nodes = new ArrayList<Node>();//list of nodes
int inputs;
int outputs;
int layers =2;
int nextNode = 0;
int biasNode;
ArrayList<Node> network = new ArrayList<Node>();//a list of the nodes in the order that they need to be considered in the NN
//---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Genome(int in, int out) {
//set input number and output number
inputs = in;
outputs = out;
//create input nodes
for (int i = 0; i<inputs; i++) {
nodes.add(new Node(i));
nextNode ++;
nodes.get(i).layer =0;
}
//create output nodes
for (int i = 0; i < outputs; i++) {
nodes.add(new Node(i+inputs));
nodes.get(i+inputs).layer = 1;
nextNode++;
}
nodes.add(new Node(nextNode));//bias node
biasNode = nextNode;
nextNode++;
nodes.get(biasNode).layer = 0;
}
//-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//returns the node with a matching number
//sometimes the nodes will not be in order
Node getNode(int nodeNumber) {
for (int i = 0; i < nodes.size(); i++) {
if (nodes.get(i).number == nodeNumber) {
return nodes.get(i);
}
}
return null;
}
//---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//adds the conenctions going out of a node to that node so that it can acess the next node during feeding forward
void connectNodes() {
for (int i = 0; i< nodes.size(); i++) {//clear the connections
nodes.get(i).outputConnections.clear();
}
for (int i = 0; i < genes.size(); i++) {//for each connectionGene
genes.get(i).fromNode.outputConnections.add(genes.get(i));//add it to node
}
}
//---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//feeding in input values into the NN and returning output array
float[] feedForward(float[] inputValues) {
//set the outputs of the input nodes
for (int i =0; i < inputs; i++) {
nodes.get(i).outputValue = inputValues[i];
}
nodes.get(biasNode).outputValue = 1;//output of bias is 1
for (int i = 0; i< network.size(); i++) {//for each node in the network engage it(see node class for what this does)
network.get(i).engage();
}
//the outputs are nodes[inputs] to nodes [inputs+outputs-1]
float[] outs = new float[outputs];
for (int i = 0; i < outputs; i++) {
outs[i] = nodes.get(inputs + i).outputValue;
}
for (int i = 0; i < nodes.size(); i++) {//reset all the nodes for the next feed forward
nodes.get(i).inputSum = 0;
}
return outs;
}
//----------------------------------------------------------------------------------------------------------------------------------------
//sets up the NN as a list of nodes in order to be engaged
void generateNetwork() {
connectNodes();
network = new ArrayList<Node>();
//for each layer add the node in that layer, since layers cannot connect to themselves there is no need to order the nodes within a layer
for (int l = 0; l< layers; l++) {//for each layer
for (int i = 0; i< nodes.size(); i++) {//for each node
if (nodes.get(i).layer == l) {//if that node is in that layer
network.add(nodes.get(i));
}
}
}
}
//-----------------------------------------------------------------------------------------------------------------------------------------
//mutate the NN by adding a new node
//it does this by picking a random connection and disabling it then 2 new connections are added
//1 between the input node of the disabled connection and the new node
//and the other between the new node and the output of the disabled connection
void addNode(ArrayList<connectionHistory> innovationHistory) {
//pick a random connection to create a node between
if (genes.size() ==0) {
addConnection(innovationHistory);
return;
}
int randomConnection = floor(random(genes.size()));
while (genes.get(randomConnection).fromNode == nodes.get(biasNode) && genes.size() !=1 ) {//dont disconnect bias
randomConnection = floor(random(genes.size()));
}
genes.get(randomConnection).enabled = false;//disable it
int newNodeNo = nextNode;
nodes.add(new Node(newNodeNo));
nextNode ++;
//add a new connection to the new node with a weight of 1
int connectionInnovationNumber = getInnovationNumber(innovationHistory, genes.get(randomConnection).fromNode, getNode(newNodeNo));
genes.add(new connectionGene(genes.get(randomConnection).fromNode, getNode(newNodeNo), 1, connectionInnovationNumber));
connectionInnovationNumber = getInnovationNumber(innovationHistory, getNode(newNodeNo), genes.get(randomConnection).toNode);
//add a new connection from the new node with a weight the same as the disabled connection
genes.add(new connectionGene(getNode(newNodeNo), genes.get(randomConnection).toNode, genes.get(randomConnection).weight, connectionInnovationNumber));
getNode(newNodeNo).layer = genes.get(randomConnection).fromNode.layer +1;
connectionInnovationNumber = getInnovationNumber(innovationHistory, nodes.get(biasNode), getNode(newNodeNo));
//connect the bias to the new node with a weight of 0
genes.add(new connectionGene(nodes.get(biasNode), getNode(newNodeNo), 0, connectionInnovationNumber));
//if the layer of the new node is equal to the layer of the output node of the old connection then a new layer needs to be created
//more accurately the layer numbers of all layers equal to or greater than this new node need to be incrimented
if (getNode(newNodeNo).layer == genes.get(randomConnection).toNode.layer) {
for (int i = 0; i< nodes.size() -1; i++) {//dont include this newest node
if (nodes.get(i).layer >= getNode(newNodeNo).layer) {
nodes.get(i).layer ++;
}
}
layers ++;
}
connectNodes();
}
//------------------------------------------------------------------------------------------------------------------
//adds a connection between 2 nodes which aren't currently connected
void addConnection(ArrayList<connectionHistory> innovationHistory) {
//cannot add a connection to a fully connected network
if (fullyConnected()) {
println("connection failed");
return;
}
//get random nodes
int randomNode1 = floor(random(nodes.size()));
int randomNode2 = floor(random(nodes.size()));
while (randomConnectionNodesAreShit(randomNode1, randomNode2)) {//while the random nodes are no good
//get new ones
randomNode1 = floor(random(nodes.size()));
randomNode2 = floor(random(nodes.size()));
}
int temp;
if (nodes.get(randomNode1).layer > nodes.get(randomNode2).layer) {//if the first random node is after the second then switch
temp =randomNode2 ;
randomNode2 = randomNode1;
randomNode1 = temp;
}
//get the innovation number of the connection
//this will be a new number if no identical genome has mutated in the same way
int connectionInnovationNumber = getInnovationNumber(innovationHistory, nodes.get(randomNode1), nodes.get(randomNode2));
//add the connection with a random array
genes.add(new connectionGene(nodes.get(randomNode1), nodes.get(randomNode2), random(-1, 1), connectionInnovationNumber));//changed this so if error here
connectNodes();
}
//-------------------------------------------------------------------------------------------------------------------------------------------
boolean randomConnectionNodesAreShit(int r1, int r2) {
if (nodes.get(r1).layer == nodes.get(r2).layer) return true; // if the nodes are in the same layer
if (nodes.get(r1).isConnectedTo(nodes.get(r2))) return true; //if the nodes are already connected
return false;
}
//-------------------------------------------------------------------------------------------------------------------------------------------
//returns the innovation number for the new mutation
//if this mutation has never been seen before then it will be given a new unique innovation number
//if this mutation matches a previous mutation then it will be given the same innovation number as the previous one
int getInnovationNumber(ArrayList<connectionHistory> innovationHistory, Node from, Node to) {
boolean isNew = true;
int connectionInnovationNumber = nextConnectionNo;
for (int i = 0; i < innovationHistory.size(); i++) {//for each previous mutation
if (innovationHistory.get(i).matches(this, from, to)) {//if match found
isNew = false;//its not a new mutation
connectionInnovationNumber = innovationHistory.get(i).innovationNumber; //set the innovation number as the innovation number of the match
break;
}
}
if (isNew) {//if the mutation is new then create an arrayList of integers representing the current state of the genome
ArrayList<Integer> innoNumbers = new ArrayList<Integer>();
for (int i = 0; i< genes.size(); i++) {//set the innovation numbers
innoNumbers.add(genes.get(i).innovationNo);
}
//then add this mutation to the innovationHistory
innovationHistory.add(new connectionHistory(from.number, to.number, connectionInnovationNumber, innoNumbers));
nextConnectionNo++;
}
return connectionInnovationNumber;
}
//----------------------------------------------------------------------------------------------------------------------------------------
//returns whether the network is fully connected or not
boolean fullyConnected() {
int maxConnections = 0;
int[] nodesInLayers = new int[layers];//array which stored the amount of nodes in each layer
//populate array
for (int i =0; i< nodes.size(); i++) {
nodesInLayers[nodes.get(i).layer] +=1;
}
//for each layer the maximum amount of connections is the number in this layer * the number of nodes infront of it
//so lets add the max for each layer together and then we will get the maximum amount of connections in the network
for (int i = 0; i < layers-1; i++) {
int nodesInFront = 0;
for (int j = i+1; j < layers; j++) {//for each layer infront of this layer
nodesInFront += nodesInLayers[j];//add up nodes
}
maxConnections += nodesInLayers[i] * nodesInFront;
}
if (maxConnections == genes.size()) {//if the number of connections is equal to the max number of connections possible then it is full
return true;
}
return false;
}
//-------------------------------------------------------------------------------------------------------------------------------
//mutates the genome
void mutate(ArrayList<connectionHistory> innovationHistory) {
if (genes.size() ==0) {
addConnection(innovationHistory);
}
float rand1 = random(1);
if (rand1<0.8) { // 80% of the time mutate weights
for (int i = 0; i< genes.size(); i++) {
genes.get(i).mutateWeight();
}
}
//5% of the time add a new connection
float rand2 = random(1);
if (rand2<0.08) {
addConnection(innovationHistory);
}
//1% of the time add a node
float rand3 = random(1);
if (rand3<0.02) {
addNode(innovationHistory);
}
}
//---------------------------------------------------------------------------------------------------------------------------------
//called when this Genome is better that the other parent
Genome crossover(Genome parent2) {
Genome child = new Genome(inputs, outputs, true);
child.genes.clear();
child.nodes.clear();
child.layers = layers;
child.nextNode = nextNode;
child.biasNode = biasNode;
ArrayList<connectionGene> childGenes = new ArrayList<connectionGene>();//list of genes to be inherrited form the parents
ArrayList<Boolean> isEnabled = new ArrayList<Boolean>();
//all inherrited genes
for (int i = 0; i< genes.size(); i++) {
boolean setEnabled = true;//is this node in the chlid going to be enabled
int parent2gene = matchingGene(parent2, genes.get(i).innovationNo);
if (parent2gene != -1) {//if the genes match
if (!genes.get(i).enabled || !parent2.genes.get(parent2gene).enabled) {//if either of the matching genes are disabled
if (random(1) < 0.75) {//75% of the time disabel the childs gene
setEnabled = false;
}
}
float rand = random(1);
if (rand<0.5) {
childGenes.add(genes.get(i));
//get gene from this fucker
} else {
//get gene from parent2
childGenes.add(parent2.genes.get(parent2gene));
}
} else {//disjoint or excess gene
childGenes.add(genes.get(i));
setEnabled = genes.get(i).enabled;
}
isEnabled.add(setEnabled);
}
//since all excess and disjoint genes are inherrited from the more fit parent (this Genome) the childs structure is no different from this parent | with exception of dormant connections being enabled but this wont effect nodes
//so all the nodes can be inherrited from this parent
for (int i = 0; i < nodes.size(); i++) {
child.nodes.add(nodes.get(i).clone());
}
//clone all the connections so that they connect the childs new nodes
for ( int i =0; i<childGenes.size(); i++) {
child.genes.add(childGenes.get(i).clone(child.getNode(childGenes.get(i).fromNode.number), child.getNode(childGenes.get(i).toNode.number)));
child.genes.get(i).enabled = isEnabled.get(i);
}
child.connectNodes();
return child;
}
//----------------------------------------------------------------------------------------------------------------------------------------
//create an empty genome
Genome(int in, int out, boolean crossover) {
//set input number and output number
inputs = in;
outputs = out;
}
//----------------------------------------------------------------------------------------------------------------------------------------
//returns whether or not there is a gene matching the input innovation number in the input genome
int matchingGene(Genome parent2, int innovationNumber) {
for (int i =0; i < parent2.genes.size(); i++) {
if (parent2.genes.get(i).innovationNo == innovationNumber) {
return i;
}
}
return -1; //no matching gene found
}
//----------------------------------------------------------------------------------------------------------------------------------------
//prints out info about the genome to the console
void printGenome() {
println("Print genome layers:", layers);
println("bias node: " + biasNode);
println("nodes");
for (int i = 0; i < nodes.size(); i++) {
print(nodes.get(i).number + ",");
}
println("Genes");
for (int i = 0; i < genes.size(); i++) {//for each connectionGene
println("gene " + genes.get(i).innovationNo, "From node " + genes.get(i).fromNode.number, "To node " + genes.get(i).toNode.number,
"is enabled " +genes.get(i).enabled, "from layer " + genes.get(i).fromNode.layer, "to layer " + genes.get(i).toNode.layer, "weight: " + genes.get(i).weight);
}
println();
}
//----------------------------------------------------------------------------------------------------------------------------------------
//returns a copy of this genome
Genome clone() {
Genome clone = new Genome(inputs, outputs, true);
for (int i = 0; i < nodes.size(); i++) {//copy nodes
clone.nodes.add(nodes.get(i).clone());
}
//copy all the connections so that they connect the clone new nodes
for ( int i =0; i<genes.size(); i++) {//copy genes
clone.genes.add(genes.get(i).clone(clone.getNode(genes.get(i).fromNode.number), clone.getNode(genes.get(i).toNode.number)));
}
clone.layers = layers;
clone.nextNode = nextNode;
clone.biasNode = biasNode;
clone.connectNodes();
return clone;
}
//----------------------------------------------------------------------------------------------------------------------------------------
//draw the genome on the screen
void drawGenome(int startX, int startY, int w, int h) {
//i know its ugly but it works (and is not that important) so I'm not going to mess with it
ArrayList<ArrayList<Node>> allNodes = new ArrayList<ArrayList<Node>>();
ArrayList<PVector> nodePoses = new ArrayList<PVector>();
ArrayList<Integer> nodeNumbers= new ArrayList<Integer>();
//get the positions on the screen that each node is supposed to be in
//split the nodes into layers
for (int i = 0; i< layers; i++) {
ArrayList<Node> temp = new ArrayList<Node>();
for (int j = 0; j< nodes.size(); j++) {//for each node
if (nodes.get(j).layer == i ) {//check if it is in this layer
temp.add(nodes.get(j)); //add it to this layer
}
}
allNodes.add(temp);//add this layer to all nodes
}
//for each layer add the position of the node on the screen to the node posses arraylist
for (int i = 0; i < layers; i++) {
fill(255, 0, 0);
float x = startX + (float)((i)*w)/(float)(layers-1);
for (int j = 0; j< allNodes.get(i).size(); j++) {//for the position in the layer
float y = startY + ((float)(j + 1.0) * h)/(float)(allNodes.get(i).size() + 1.0);
nodePoses.add(new PVector(x, y));
nodeNumbers.add(allNodes.get(i).get(j).number);
if(i == layers -1){
println(i,j,x,y);
}
}
}
//draw connections
stroke(0);
strokeWeight(2);
for (int i = 0; i< genes.size(); i++) {
if (genes.get(i).enabled) {
stroke(0);
} else {
stroke(100);
}
PVector from;
PVector to;
from = nodePoses.get(nodeNumbers.indexOf(genes.get(i).fromNode.number));
to = nodePoses.get(nodeNumbers.indexOf(genes.get(i).toNode.number));
if (genes.get(i).weight > 0) {
stroke(255, 0, 0);
} else {
stroke(0, 0, 255);
}
strokeWeight(map(abs(genes.get(i).weight), 0, 1, 0, 5));
line(from.x, from.y, to.x, to.y);
}
//draw nodes last so they appear ontop of the connection lines
for (int i = 0; i < nodePoses.size(); i++) {
fill(255);
stroke(0);
strokeWeight(1);
ellipse(nodePoses.get(i).x, nodePoses.get(i).y, 20, 20);
textSize(10);
fill(0);
textAlign(CENTER, CENTER);
text(nodeNumbers.get(i), nodePoses.get(i).x, nodePoses.get(i).y);
}
}
}