# Tests for Exercise 1
# Python-II Course 2013

import os # Used for deleting a test file
import unittest # import testing framework
from l1_trees import * # import code to test

class TestTree(unittest.TestCase):
    def test_init(self):
        # Overly simple test to see how tests work
        # Construct test object
        label = "root"
        children = []
        t1 = Tree(label, children)

        # Test that label and children are corretly set
        self.assertEqual(label, t1.label)
        self.assertEqual(children, t1.children)

    def test_iter(self):
        left_leaf = Tree("left", [])
        right_leaf = Tree("right", [])
        root = Tree("root", [left_leaf, right_leaf])

        # Check that we can call iter on a tree
        it = iter(root)

        # Define a iterable so we can compare iter(root)
        # Note: This is pre-order traversal (ie. root first)
        ref_list = [root, left_leaf, right_leaf]

        # Check that the lists are equal, ORDER is important
        self.assertEqual(ref_list, list(it))

    def test_str(self):
        ref_str = "(S (NP Er) (VP schnarcht))"

        er_leaf = Tree("Er", [])
        np_tree = Tree("NP", [er_leaf])
        sch_leaf = Tree("schnarcht", [])
        vp_tree = Tree("VP", [sch_leaf])
        s_tree = Tree("S", [np_tree, vp_tree])

        # Test if leaves are printed without parenthesis
        self.assertEqual("Er", str(er_leaf))

        # Test full example
        self.assertEqual(ref_str, str(s_tree))

    def test_rep(self):
        ref_str = "Tree('root', [Tree('left', []),Tree('right', [])])"
        tree = Tree('root', [Tree('left', []), Tree('right', [])])

        # Check if string returned by Tree.__repr__ is equal to ref_str
        self.assertEqual(ref_str, repr(tree))

    def test_leaves(self):
        # Check that a leaf only returns itself
        root_only = Tree("root", [])
        self.assertEqual([root_only], list(root_only.leaves()))
        
        # sym:
        #   root
        #     |
        #   trunk
        # |   |   |
        # l1  l2  l3
        
        leaves = [Tree('l1', []), Tree('l2', []), Tree('l3', [])]
        trunk = Tree('trunk', leaves)
        sym = Tree('root', [trunk,])       

        # Check the full-tree
        self.assertEqual(leaves, list(sym.leaves()))

    def test_depth(self):
        root_only = Tree("root", [])
        # Check that single node should have depth 0
        self.assertEqual(0, root_only.depth())

        # asym:
        #      d0
        #    |    |
        #   d1a  d1b
        #    |
        #   d2
        
        asym = Tree("d0", [Tree("d1-a", [Tree("d2", [])]), Tree("d1-b", [])])
        self.assertEqual(2, asym.depth())

class TestParser(unittest.TestCase):
    def test_tokenize(self):
        # Test that tokenize splits up the string correctly
        expected_tokens = [
            '(', 'S', '(', 'NP', 'Peter', ')',
             '(', 'VP', 'schläft', ')', ')'
             ]
        actual_tokens = list(tokenize("(S (NP Peter) (VP schläft))"))
        self.assertEqual(expected_tokens, actual_tokens)

    def test_valid_input(self):
        # Test parsing the forest with the example data.
        
        tree1 = Tree('S', [Tree('NP', [Tree('Peter', [])]),
                           Tree('VP', [Tree('schläft', [])])
                           ])

        tree2 = Tree('S', [Tree('NP', [Tree('Er', [])]),
                           Tree('VP', [Tree('schnarcht', [])])
                           ])
        
        input_str = "(S (NP Peter) (VP schläft)) (S (NP Er) (VP schnarcht))"
        tokens = tokenize(input_str)
        parser = Parser(tokens)
        
        # Get iterator
        parser_it = iter(parser)
        # Iterator over everything and store it in a list
        actual = list(parser_it)
        ref = [tree1, tree2]
        
        self.assertEqual(ref, actual)

    def test_extra_open_par(self):
        bad_str = "(root leaf1(" 
        tokens = tokenize(bad_str)
        try:
            parser = Parser(tokens)
            # parse until end
            trees = list(parser)
            self.fail("Parser should raise exception on invalid input")
        except ParserError:
            return #expected

    def test_missing_par(self):
        bad_str = "root leaf1)" 
        tokens = tokenize(bad_str)
        try:
            parser = Parser(tokens)
            # parse until end
            trees = list(parser)
            self.fail("Parser should raise exception on invalid input")
        except ParserError:
            return #expected

        

class TestCountNonterminals(unittest.TestCase):
    def test_non_term_count(self):
        # count_nonterminals should return a Dict containing the
        # label of the nonterminal as key
        # and the frequency it occures in the forest as
        # the value

        # name of our test file
        filename = 'test_trees.txt'

        # contents of test file
        # note: \n appears in baueme.txt
        # hence, we test it here
        
        test_str = "(S(NP Peter) (VP schläft))\n(S (NP Er) (VP schnarcht))"
        tree1 = Tree('S', [Tree('NP', [Tree('Peter', [])]),
                           Tree('VP', [Tree('schläft', [])])
                           ])

        tree2 = Tree('S', [Tree('NP', [Tree('Er', [])]),
                           Tree('VP', [Tree('schnarcht', [])])
                           ])
        
        expected = {'S' : 2,
                    'NP' : 2,
                    'VP' : 2,
                    }
        
        TestCountNonterminals.make_test_file(filename, test_str)
        actual = count_nonterminals(filename)
        TestCountNonterminals.delete_test_file(filename)

        self.assertEqual(expected, actual)
        
    
    @staticmethod
    def make_test_file(filename, data):
        # helper method for next test
        with open(filename, 'w') as f:
            f.write(data)       
    @staticmethod
    def delete_test_file(filename):
        os.remove(filename)
        
        
            
    
    
    
        

# Run the unittests if this is being run as __main__

if __name__ == "__main__":
    unittest.main()
            
            
        
