Skip to content

Commit

Permalink
Merge pull request #3657 from vitaly-krugl/nup2354-tmregion-capnp
Browse files Browse the repository at this point in the history
NUP-2354 Implement serialization for TMRegion
  • Loading branch information
vitaly-krugl authored Jun 2, 2017
2 parents 6d798f4 + 20ba811 commit 28789ff
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 20 deletions.
19 changes: 17 additions & 2 deletions src/nupic/algorithms/backtracking_tm_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self,
permanenceDec=0.10,
permanenceMax=1.0,
activationThreshold=12,
predictedSegmentDecrement=0,
predictedSegmentDecrement=0.0,
maxSegmentsPerCell=255,
maxSynapsesPerSegment=255,
globalDecay=0.10,
Expand Down Expand Up @@ -82,6 +82,21 @@ def __init__(self,
self.infActiveState = {"t": None}


@classmethod
def read(cls, proto):
"""
Intercepts TemporalMemory deserialization request in order to initialize
`self.infActiveState`
@param proto (DynamicStructBuilder) Proto object
@return (TemporalMemory) TemporalMemory shim instance
"""
tm = super(TMShimMixin, cls).read(proto)
tm.infActiveState = {"t": None}
return tm


def compute(self, bottomUpInput, enableLearn, computeInfOutput=None):
"""
(From `backtracking_tm.py`)
Expand Down Expand Up @@ -171,7 +186,7 @@ def __init__(self,
permanenceDec=0.10,
permanenceMax=1.0,
activationThreshold=12,
predictedSegmentDecrement=0,
predictedSegmentDecrement=0.0,
maxSegmentsPerCell=255,
maxSynapsesPerSegment=255,
globalDecay=0.10,
Expand Down
19 changes: 19 additions & 0 deletions src/nupic/regions/tm_region.capnp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@0xb9d11462f08c1dee;

using import "/nupic/proto/TemporalMemoryProto.capnp".TemporalMemoryProto;

# Next ID: 11
struct TMRegionProto {
temporalImp @0 :Text;
temporalMemory @1 :TemporalMemoryProto;
columnCount @2 :UInt32;
inputWidth @3 :UInt32;
cellsPerColumn @4 :UInt32;
learningMode @5 :Bool;
inferenceMode @6 :Bool;
anomalyMode @7 :Bool;
topDownMode @8 :Bool;
computePredictedActiveCellIndices @9 :Bool;
orColumnOutputs @10 :Bool;
}

95 changes: 78 additions & 17 deletions src/nupic/regions/tm_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,23 @@

import numpy
import os

try:
import capnp
except ImportError:
capnp = None

from nupic.bindings.regions.PyRegion import PyRegion

from nupic.algorithms import (anomaly, backtracking_tm, backtracking_tm_cpp,
backtracking_tm_shim)
if capnp:
from nupic.regions.tm_region_capnp import TMRegionProto

from nupic.support import getArgumentDescriptions



gDefaultTemporalImp = 'py'


Expand Down Expand Up @@ -183,7 +194,7 @@ def getConstraints(arg):
cells per column must also be specified and the output size of the region
should be set the same as columnCount""",
accessMode='Read',
dataType='UInt32',
dataType='Bool',
count=1,
constraints='bool'),

Expand All @@ -208,41 +219,44 @@ def getConstraints(arg):
# The last group is for parameters that aren't strictly spatial or temporal
otherSpec = dict(
learningMode=dict(
description='1 if the node is learning (default 1).',
description='True if the node is learning (default True).',
accessMode='ReadWrite',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=True,
constraints='bool'),

inferenceMode=dict(
description='1 if the node is inferring (default 0).',
description='True if the node is inferring (default False).',
accessMode='ReadWrite',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=False,
constraints='bool'),

computePredictedActiveCellIndices=dict(
description='1 if active and predicted active indices should be computed',
description='True if active and predicted active indices should be computed',
accessMode='Create',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=0,
defaultValue=False,
constraints='bool'),

anomalyMode=dict(
description='1 if an anomaly score is being computed',
description='True if an anomaly score is being computed',
accessMode='Create',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=0,
defaultValue=False,
constraints='bool'),

topDownMode=dict(
description='1 if the node should do top down compute on the next call '
'to compute into topDownOut (default 0).',
description='True if the node should do top down compute on the next call '
'to compute into topDownOut (default False).',
accessMode='ReadWrite',
dataType='UInt32',
dataType='Bool',
count=1,
defaultValue=False,
constraints='bool'),

activeOutputCount=dict(
Expand Down Expand Up @@ -319,7 +333,6 @@ def __init__(self,
computePredictedActiveCellIndices=False,

**kwargs):

# Which Temporal implementation?
TemporalClass = _getTPClass(temporalImp)

Expand Down Expand Up @@ -362,7 +375,7 @@ def __init__(self,
self._fpLogTPOutput = None

# Variables set up in initInNetwork()
self._tfdr = None # FDRTemporal instance
self._tfdr = None # FDRTemporal instance


#############################################################################
Expand Down Expand Up @@ -716,7 +729,6 @@ def setParameter(self, parameterName, index, parameterValue):
automatically by PyRegion's parameter set mechanism. The ones that need
special treatment are explicitly handled here.
"""

if parameterName in self._temporalArgNames:
setattr(self._tfdr, parameterName, parameterValue)

Expand All @@ -737,6 +749,7 @@ def setParameter(self, parameterName, index, parameterValue):
else:
raise Exception('Unknown parameter: ' + parameterName)


#############################################################################
#
# Commands
Expand Down Expand Up @@ -773,6 +786,54 @@ def finishLearning(self):
#############################################################################


@staticmethod
def getProtoType():
"""Return the pycapnp proto type that the class uses for serialization."""
return TMRegionProto


def writeToProto(self, proto):
"""Write state to proto object.
proto: TMRegionProto capnproto object
"""
proto.temporalImp = self.temporalImp
proto.columnCount = self.columnCount
proto.inputWidth = self.inputWidth
proto.cellsPerColumn = self.cellsPerColumn
proto.learningMode = self.learningMode
proto.inferenceMode = self.inferenceMode
proto.anomalyMode = self.anomalyMode
proto.topDownMode = self.topDownMode
proto.computePredictedActiveCellIndices = (
self.computePredictedActiveCellIndices)
proto.orColumnOutputs = self.orColumnOutputs

self._tfdr.write(proto.temporalMemory)


@classmethod
def readFromProto(cls, proto):
"""Read state from proto object.
proto: TMRegionProto capnproto object
"""
instance = cls(proto.columnCount, proto.inputWidth, proto.cellsPerColumn)

instance.temporalImp = proto.temporalImp
instance.learningMode = proto.learningMode
instance.inferenceMode = proto.inferenceMode
instance.anomalyMode = proto.anomalyMode
instance.topDownMode = proto.topDownMode
instance.computePredictedActiveCellIndices = (
proto.computePredictedActiveCellIndices)
instance.orColumnOutputs = proto.orColumnOutputs

instance._tfdr = _getTPClass(proto.temporalImp).read(proto.temporalMemory)

return instance


def __getstate__(self):
"""
Return serializable state. This function will return a version of the
Expand Down
23 changes: 22 additions & 1 deletion tests/integration/nupic/engine/network_checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import unittest
import numpy

from nupic.regions.sp_region import SPRegion
from nupic.regions.record_sensor import RecordSensor
from nupic.regions.sp_region import SPRegion
from nupic.regions.tm_region import TMRegion

from network_creation_common import createAndRunNetwork

Expand Down Expand Up @@ -66,6 +67,26 @@ def testSPRegion(self):
"Row {0} not equal: {1} vs. {2}".format(i, result1, result2))


@unittest.skipUnless(
capnp, "pycapnp is not installed, skipping serialization test.")
def testTMRegion(self):
results1 = createAndRunNetwork(TMRegion, "bottomUpOut",
checkpointMidway=False,
temporalImp="tm_py")

results2 = createAndRunNetwork(TMRegion, "bottomUpOut",
checkpointMidway=True,
temporalImp="tm_py")

self.assertEqual(len(results1), len(results2))

for i in xrange(len(results1)):
result1 = list(results1[i].nonzero()[0])
result2 = list(results2[i].nonzero()[0])
self.assertEqual(result1, result2,
"Row {0} not equal: {1} vs. {2}".format(i, result1, result2))


def compareArrayResults(self, results1, results2):
self.assertEqual(len(results1), len(results2))

Expand Down

0 comments on commit 28789ff

Please sign in to comment.