CS75 Spring 2011
Project 2b: Expression Parser with AST


from nltk import Tree
from pprint import pprint
from formatAST import *
from scanner import *
import sys

class ExpressionParser:
    def __init__(self, filename, debug=False):
        self.scanner = LexicalAnalyzer(filename)
        self.debug = debug
        self.token = None
        self.value = None
        self.ast = ['expressions']

    def match(self, tokenType):
        if tokenType == self.token:
            if self.debug:
                print "MATCH:", self.token,
                if self.value == None:
                    print
                else:
                    print self.value
            self.token, self.value = self.scanner.getToken()
        elif self.token == 'err':
            self.emitError(self.value)
        else:
            self.emitError("expected %s, but found %s" % \
                             (tokenType, self.token))

    def emitError(self, msg):
        print "Error on line %d: %s" % (self.scanner.getLineNum(), msg)
        print self.scanner.getFile().getLineText()
        raise SyntaxError

    def parse(self):
        self.token, self.value = self.scanner.getToken()
        while True:
            try:
                expTree = self.E0()
                self.ast.append(expTree)
                self.match('semi')
            except SyntaxError:
                print "Unrecoverable error, parser exiting"
                sys.exit(-1)
            if self.token == 'done':
                break
        print "Parse was successful"
        return self.ast

    def E0(self):
        if self.debug:
            print "E0"
        return self.E0rest(self.E1())

    def E0rest(self, expr):
        if self.debug:
            print "E1"
        if self.token == 'assign':
            self.match('assign')
            return ['assign', expr, self.E0()]
        else:
            return expr
            
    def E1(self):
        if self.debug:
            print "E1"
        return self.E1rest(self.E2())

    def E1rest(self, expr):
        if self.debug:
            print "E1rest"
        if self.token == 'sub':
            self.match('sub')
            return self.E1rest(['sub', expr, self.E2()])
        elif self.token == 'add':
            self.match('add')
            return self.E1rest(['add', expr, self.E2()])
        else:
            return expr
            
    def E2(self):
        if self.debug:
            print "E2"
        return self.E2rest(self.E3())
        
    def E2rest(self, expr):
        if self.debug:
            print "E2rest"
        if self.token == 'mul':
            self.match('mul')
            return self.E2rest(['mul', expr, self.E3()])
        elif self.token == 'div':
            self.match('div')
            return self.E2rest(['div', expr, self.E3()])
        else:
            return expr
            
    def E3(self):
        if self.debug:
            print "E3"
        if self.token == 'sub':
            self.match('sub')
            return ['sub', self.E3()]
        else:
            return self.E4()

    def E4(self):
        if self.debug:
            print "E4"
        if self.token == 'lparen':
            self.match('lparen')
            expr = self.E0()
            self.match('rparen')
            return expr
        elif self.token == 'id':
            idName = self.value
            self.match('id')
            return ['id', idName]
        elif self.token == 'num':
            numValue = self.value
            self.match('num')
            return ['num', numValue]
        else:
            self.emitError("invalid expression")


def main():
    if len(sys.argv) == 1:
        print "Error: must provide filename to parse"
        return
    p = ExpressionParser(sys.argv[1])
    
    # Parser returns the AST as a list of lists
    listAST = p.parse()
    pprint(listAST, indent=2)

    # Convert the AST into a string with a tuple of tuples
    stringAST = formatAST(listAST)
    print stringAST

    # Draw the converted AST
    pic = Tree.parse(stringAST)
    pic.draw()
    

if __name__ == '__main__':
    main()