from votingData import *
from pyrobot.brain.conx import *
from random import *
from os import system
from time import sleep

votingInputs = []
votingOutputs = []
votingLabels = []
validationSet = []
validationLabels = []

def convertInput(data):
    if data =='y':
        return 1
    elif data == 'n':
        return 0
    elif data == '?':
        return 0.5
    else:
        return None

def convertOutput(party):
    if party == 'democrat':
        return [1, 0]
    elif party == 'republican':
        return [0, 1]
    else:
        return None

def createLabel(values):
    label = ''
    for v in values:
        if v == 'republican':
            party = 'Rep\n'
        elif v == 'democrat':
            party = 'Dem\n'
        else:
            label += str(v)
    return label + party 

def processData(data):
    for example in data:
        l = example.values()
        l.pop() #remove the id
        votingLabels.append(createLabel(l))
        party = l.pop(12)
        votingInputs.append(map(convertInput, l))
        votingOutputs.append(convertOutput(party))

def createValidationSet(inputs, outputs, percentage):
    count = int(percentage * len(inputs))
    for i in range(count):
        index = randint(0, len(inputs)-1)
        validationLabels.append(votingLabels.pop(index))
        d = {}
        d['input'] = inputs.pop(index)
        d['output'] = outputs.pop(index)
        validationSet.append(d)
    print "Cross validation corpus created of length", len(validationSet)

# create the network
n = Network()

# add layers in the order they will be connected
n.add(Layer('input', 16))
n.add(Layer('hidden', 8))
n.add(Layer('output', 2))     
n.connect('input', 'hidden')
n.connect('hidden', 'output')

# set learning parameters
n.setEpsilon(0.1)
n.setMomentum(0.9)
n.setTolerance(0.2)

# provide training patterns (inputs and outputs)
processData(voting)
createValidationSet(votingInputs, votingOutputs, .15)
n.crossValidationCorpus = validationSet
n.setInputs(votingInputs)
n.setOutputs(votingOutputs)

# learn
f1 = open("trainError", "w")
f2 = open("validateError", "w")
mostCorrect = 0
count = 0
for epoch in range(200):
    (trainError, trainCorrect, trainTotal) = n.sweep()
    f1.write(str(epoch)+" "+str(trainError)+"\n")
    print "Train   ", epoch, trainError, trainCorrect/float(trainTotal)
    (validateError, validateCorrect, validateTotal) = n.sweepCrossValidation()
    f2.write(str(epoch)+" "+str(validateError)+"\n")
    correct = validateCorrect/float(validateTotal)
    print "Validate", epoch, validateError, correct
    if correct > mostCorrect:
        mostCorrect = correct
        print "**************************Improved correctness: ", correct
        n.saveWeightsToFile("voting.wts")
        count = 0
    else:
        count += 1
        if count > 30:
            print "Performance dropping on validation set, training ended."
            print "Best performance on validation set:", mostCorrect
            break
f1.close()
f2.close()

# save data for analysis
print "Saving hidden layer representations for analysis"
n.loadWeightsFromFile('voting.wts')
n.setLearning(0)
n.logLayer('hidden','votingHiddens')
n.sweepCrossValidation()
n.closeLog('hidden')
f = open('votingLabels', 'w')
for label in validationLabels:
    f.write(label)
f.close()
system('/usr/local/pyrobot/tools/cluster/cluster votingHiddens votingLabels')
