import unittest
from l8_cyk import *

class TestCYK(unittest.TestCase):
    lexicon = {
        'the' : [ (1.0, 'DET') ], 
        'some' :[ (1.0, 'DET') ],
        'salespeople' : [ (0.55, 'N') ],
        'dog' : [ (0.25, 'N') ],
        'biscuits' : [ (0.2, 'N') ],
        'sold' : [ (1.0, 'V') ]}
       
    binarized_grammar = [
        ('S', (1.0, ['NP', 'VP'])),
        ('NP', (0.7, ['DET', 'N'])),
        ('N-N', (1.0, ['N', 'N'])),
        ('NP', (0.25, ['DET', 'N-N'])),
        ('NP', (0.05, ['NP', 'NP'])),
        ('VP', (0.8, ['V', 'NP'])),
        ('NP-N', (1.0, ['NP', 'N'])),
        ('VP', (0.2, ['V', 'NP-N']))]


    def test_binarize(self):
        # NOTE: To simplify testing,
        # any new non-terminals must follow the naming scheme:
        # X-Y where X-Y -> X Y



        grammar = [
                ('S', (1.0, ['A'])),
                ('A', (0.5, ['B', 'C', 'D'])),
                ('A', (0.5, ['C', 'D', 'E']))
                ]
        valid_binarized_grammars = [
                [   
                    ('S', (1.0, ['A'])),
                    ('A', (0.5, ['B', 'C-D'])),
                    ('C-D', (1.0, ['C', 'D'])),
                    ('A', (0.5, ['C', 'D-E'])),
                    ('D-E', (1.0, ['D', 'E']))
                    ],
                [
                    ('S', (1.0, ['A'])),
                    ('A', (0.5, ['B', 'C-D'])),
                    ('C-D', (1.0, ['C', 'D'])),
                    ('A', (0.5, ['C', 'D-E'])),
                    ('D-E', (1.0, ['D', 'E']))
                    ]
                ]


        result = binarize(grammar)
        for valid_grammar in valid_binarized_grammars:
            if sorted(valid_grammar) == sorted(result):
                break
        else: # else from for block is only executed if no break occured
            self.fail("Grammar is not in correct form!")

    def test_recognize(self):

        valid_sentence = "some salespeople sold the dog biscuits".split()
        invalid_sentence = "some dog biscuits the dog".split()

        (valid, prob) = recognize_cyk('S', TestCYK.binarized_grammar, TestCYK.lexicon, valid_sentence)  
        self.assertTrue(valid)
        self.assertAlmostEqual(0.006545, prob, places=15)

        (valid, prob) = recognize_cyk('S', TestCYK.binarized_grammar, TestCYK.lexicon, invalid_sentence)
        self.assertFalse(valid)

    def test_parse(self):

        valid_sentence = "some salespeople sold the dog biscuits".split()

        expected_tree = Tree('S', [
            Tree('NP', [
                Tree('DET', [Tree('some', [])]),
                Tree('N', [Tree('salespeople', [])])
                ]),
            Tree('VP', [
                Tree('V', [Tree('sold', [])]), 
                Tree('NP', [
                    Tree('DET', [Tree('the', [])]),
                    Tree('N-N', [
                        Tree('N', [Tree('dog', [])]),
                        Tree('N', [Tree('biscuits', [])]
                            )]
                        )]
                    )]
                )])
        expected_prob = 0.00385

        (tree, prob) = parse_cyk('S', TestCYK.binarized_grammar, TestCYK.lexicon, valid_sentence)
        self.assertEqual(expected_tree, tree, "Parse trees are not equal")
        self.assertAlmostEqual(expected_prob, prob, places=15)

if __name__ == '__main__':
    unittest.main()





