# An example showing memory in a sRAAM

from pyrobot.brain.conx import *

def decompressDone(pattern, delta=0.05):
    '''
    Checks if the given pattern is close enough (based on delta) to
    the empty pattern (all 0.5s).  
    '''
    for value in pattern:
        if not(value >= (0.5-delta) and value <= (0.5 + delta)):
            return False
    return True

def discretize(pattern):
    '''
    Given a pattern of real values between 0 and 1, returns a bit
    pattern.
    '''
    bits = []
    for value in pattern:
        if value < 0.5:
            bits.append(0)
        else:
            bits.append(1)
    return bits

def display(pattern):
    '''
    Prints a pattern in a shortened format.  
    '''
    for value in pattern:
        print "%.2f" % value,
    print 

# Create network:

raam = SRN()
raam.setSequenceType("random-segmented")
raam.setPatterns({"john"  : [0, 0, 0, 1],
                  "likes" : [0, 0, 1, 0],
                  "mary"  : [0, 1, 0, 0],
                  "is" : [1, 0, 0, 0],
                  })

size = len(raam.getPattern("john"))
raam.addSRNLayers(size, size * 2, size)
raam.add( Layer("outcontext", size * 2) )
raam.connect("hidden", "outcontext")

raam.associate('input', 'output')
raam.associate('context', 'outcontext')

raam.setInputs([ [ "john", "likes", "mary" ],
                 [ "mary", "likes", "john" ],
                 [ "john", "is", "john" ],
                 [ "mary", "is", "mary" ],
               ])

# Network learning parameters:

raam.setLearnDuringSequence(1) 
raam.setReportRate(10)
raam.setEpsilon(0.1) 
raam.setMomentum(0.0) 
raam.setBatch(0)

# Ending criteria:
raam.setTolerance(0.2)
raam.setStopPercent(1.0)
raam.setResetEpoch(5000) 
raam.setResetLimit(0)

# Train:
raam.train() 

# Test:
raam.setLearning(0)
# Find the encodings for each sentence
d = {}
for i in raam.loadOrder:
    datum = raam.getData(i)
    sentence = datum['input']
    key = ''
    for word in sentence:
        key += word
    (error, correct, total) = raam.step( **datum)
    encoding = raam.getLayer('hidden').getActivationsList()
    d[key] = encoding
# Test whether encodings can be properly decompressed
print "Testing encodings, should print sentence in reverse order"
raam.setLayerVerification(0)
for sentence in d.keys():
    print "decompressing ", sentence
    retval = raam.propagateFrom("hidden", hidden = d[sentence])
    while True:
        bits = discretize(retval['output'])
        word = raam.getWord(bits)
        if word == None or decompressDone(retval['outcontext']):
            break
        display(retval['output'])
        display(bits)
        print word
        retval = raam.propagateFrom("hidden", hidden = retval['outcontext'])





