from pyro.vision import *

def normalizeImage(data):
    """
    This function takes a grayscale image with values between 0 and a 
    maximum of 255, and normalizes the values between 0 and 1.
    """
    maxValue = float(max(data))
    for i in range(len(data)):
        data[i] = data[i]/maxValue
    
def getImages(directory, imagesFilenames):
    """
    This function can be used to read in a series pgm images from
    files into a format that is suitable for neural network training.
    The directory parameter should be the location of the pgm files.
    The imagesFilenames parameter should be the name of a file that
    contains image file names, one per line, that can be found within
    the given directory. 
    """
    inputs = []
    names = open(imagesFilenames, "r")
    while 1:
        name = names.readline().strip()
        if len(name) == 0: break
        image = PyroImage(depth=1)
        image.loadFromFile(directory+name)
        normalizeImage(image.data)
        inputs.append(image.data)
    return inputs

def weightsToImage(weights, index, filename):
    """
    This function takes the hidden weights from a conx connection,
    the index for a particular hidden unit, and a filename.
    Use the method getWeights('input','hidden') to obtain the
    weights in the formart required by this function.
    It then converts the weights into a pgm image and stores it in
    the filename.  This image can then be viewed using the linux
    command xv.  
    """
    grayscales = []
    for i in range(len(weights)):
        grayscales.append(weights[i][index])
    maxVal = max(grayscales)
    minVal = min(grayscales)
    dist = maxVal - minVal
    for i in range(len(weights)):
        grayscales[i] = int(round(((grayscales[i]-minVal)/dist)*255))
    image = PyroImage(width=32, height=30, depth=1)
    image.data = grayscales
    image.saveToFile(filename)

