# Exercise 1 - Python 2 
# Nicholas Merritt
# 17.04.13

# This is a SAMPLE solution, not a perfect solution! 
# Any mistakes, comments are appreciated.

import re
from string import whitespace

class Tree:
    # Not specified in the assignment. Required for the tests to work correctly.
    def __eq__(self, other):
        return self.label == other.label and self.children == other.children
        
    def __init__(self, label, children):
        self.label = label
        self.children = children

    def __iter__(self):
        yield self
        for child in self.children:
            for node in child:
                yield node
            #yield from iter(child) # Python version > 3.3 only

    # Not required for assignment.
    def postorder(self):
        for child in self.children:
            for node in child.postorder():
                yield node
        yield self

    def __str__(self):
        if self.children:
            return "("+self.label+ " " + " ".join([str(child) for child in self.children]) + ")"
        return self.label

    def __repr__(self):
        return "Tree('"+self.label+"', [" + ", ".join([repr(child) for child in self.children]) + "])"

    def leaves(self):
        if self.children:
            for child in self.children:
                # Replaced, yield from (>Python 3.3 only) with for... yield
                for leaf in child.leaves():
                    yield leaf
                #yield from child.leaves()
        else:
            yield self

    def depth(self):
        if(len(self.children) > 0):
            subtree_depths = []
            for child in self.children:
                subtree_depths.append(child.depth())
            return max(subtree_depths)+1
        else:
            return 0        

class ParserError(Exception):
    pass

def trees(tokens):
    tree_lists = [[]]
    roots = []
    open_par = 0
    
    for t in tokens:
        if t == '(': # new tree
            open_par+=1
            try:
                root = next(tokens)
                if(root == '(' or root ==  ')'):
                    raise ParserError("PARENTESIS AS LABEL")
            except StopIteration:
                raise ParserError("Missing label or unbalanced parenthesis")
            roots.append(root)
            tree_lists.append([])
            
        elif t == ')': # finished a tree
            if open_par <= 0:
                raise ParserError("Unbalanced parenthesis")
            open_par-=1
            current = Tree(roots.pop(), tree_lists.pop())
            tree_lists[-1].append(current)
            
        else: # this is a leaf
            current = Tree(t, [])
            tree_lists[-1].append(current)

    if open_par != 0:
        raise ParserError("Unbalanced parenthesis")
        
    for t in tree_lists[-1]: # iterator, in case it's a forest.  
        yield t


class Parser:
    def __init__(self, tokens):
        self.tokens = tokens    
        self._trees = trees(tokens)

    def __iter__(self):
        return self
    def __next__(self):        
        return next(self._trees)
        
def tokenize(string):
    buf = ""
    for char in string:
        if char == '(' or char == ')':  
            if len(buf) > 0:
                yield buf
                buf = ""
            yield char
        elif char in whitespace:
            if len(buf) > 0:
                yield buf
                buf = ""
        else:
            buf+=char
            
    
def count_nonterminals(filename):
    nonterm = {}
    
    with open(filename) as f:
        tokens = tokenize(f.read())
    parser = Parser(tokens)
    for tree in parser:
        nodes = list(tree)
        leaves = list(tree.leaves())
        for node in nodes:
            if node not in leaves:
                nonterm[node.label] = nonterm.get(node.label, 0) + 1
    return nonterm    

# Run this file to test on the file 'baeume.txt'
if __name__ == '__main__':
    for (nonterm, freq) in count_nonterminals('baeume.txt').items():
        print(nonterm, ":", freq)
