# Tests for exercise 7
# Nicholas Merritt


import unittest
from collections import Counter
from l7_cnf import *
from l7_cyk import *
rules = [
  ('S', ['X', 'E', 'F', 'E']),
  ('X', ['A']),
  ('A', ['B']),
  ('B', ['C']),
  ('C', ['D']),
  ('D', ['x']),
  ('S', ['X','A']),
  ('E', [])
]

class CNFTests(unittest.TestCase):
    def test_is_cnf(self):
        start_symbol = 'S'

        too_many_rhs = [(start_symbol, ['A', 'A', 'A'])]

 #       start_on_rhs = [(start_symbol, ['A', start_symbol])]

        unit_rule = [(start_symbol, ['A', 'B']), ('A', ['B'])]

        cnf = [(start_symbol, ['A', 'B'])]

        self.assertFalse(is_cnf(too_many_rhs), 'Rule is too long for CNF')
#        self.assertFalse(is_cnf(start_on_rhs), 'Start symbol may not appear on rhs')
        self.assertFalse(is_cnf(unit_rule), 'Bonus Failed')
        self.assertTrue(is_cnf(cnf))

    #Helper Method
    def get_new_symbols(self, rules, existing_symbols):
        ''' helper method for renaming rules into standard form'''

        new_symbols = set()

        for left, right in rules:
            for nonterm in [left] + right:
                if nonterm not in existing_symbols:
                    new_symbols.add(nonterm)

        return new_symbols

    # Helper Method
    def rename_symbols(self, rules, mapping):
        ''' Maps all non-terminal symbols in the rules, to those specified by
        mapping in-place'''

        for i in range(len(rules)):
            lhs = mapping.get(rules[i][0], rules[i][0])
            rhs = [mapping.get(nonterm, nonterm) for nonterm in rules[i][1]]
            rules[i] = (lhs, rhs)

            
    def test_cnf(self):
        non_cnf = [('S', ['A', 'B', 'A'])]
        
        valid_cnf_forms = [
                [('S', ['A', 'C']), ('C', ['B', 'A'])],
                [('S', ['C', 'A']),('C', ['A', 'B'])]
                ]
        # rename 'extra' non-terminal to 'C' so we can compare
        result = cnf(non_cnf)
        new_symbols = self.get_new_symbols(rules = result, existing_symbols =
                ['S', 'A', 'B'])
        self.assertTrue(len(new_symbols) == 1)
        new_symbol = new_symbols.pop()
        mapping = {new_symbol : 'C'}
        self.rename_symbols(result, mapping)

        for valid in valid_cnf_forms:
            if sorted(valid) == sorted(result):
                break
        else:
            self.assertFalse("Not in valid cnf form")


    def tokenize(self, sentence):
        return " ".split(sentence)

    def test_simple_recognizer(self):
        start = 'S'
        rules = [('S', ['A', 'B'])]
        lexicon = [('A', 'a'), ('B', 'b')]

        self.assertTrue(recognize_cyk(start, rules, lexicon, ['a', 'b']))
        self.assertFalse(recognize_cyk(start, rules, lexicon,['b', 'a']))

    def test_non_cnf(self):
        try:
            recognize_cyk('S', [('S', ['A', 'A', 'A'])], [('A', 'a')], ['a', 'b'])
        except NotWellFormedException:
            return
        self.fail('Expected an exception for non-cnf grammar!')

    def test_recognizer(self):
        start = 'S'

        # rules in cnf
        rules = [ 
            (  'S', ['NP', 'VP']),
            (  'S', ['NP', 'VP-PP']),    
            (  'VP-PP', ['VP', 'PP']),
            ( 'NP', ['DET', 'N']),
            ( 'NP', ['POSS', 'N']),
            ( 'NP', ['NP', 'PP']),
            ( 'PP', ['P', 'NP']),
            ( 'VP', ['V', 'NP']),
            ('VP', ['VP', 'PP'])
        ]

        lexicon = [
            ('DET', 'the'),
            ('DET', 'a'),
            ('DET', 'an'),
            (  'POSS', 'his'),
            (  'V', 'shot'),
            (  'V', 'chases'),
            (  'N', 'boy'),
            (  'N', 'mouse'),
            (  'N', 'elephant'),
            (  'N', 'pajamas'),
            (  'N', 'cat'),
            (  'N', 'roof'),
            (  'N', 'house'),
            (  'N', 'library'),
            (  'N', 'university'),
            (  'P', 'in'),
            (  'P', 'of'),
            (  'P', 'on'),
            (  'P', 'near'),
        ]

        valid_sentence = "the boy shot an elephant in his pajamas".split()
        silly_valid_sentence = "the elephant shot a boy in his pajamas".split()
        invalid_sentence = "the boy in chases".split()

        self.assertTrue(recognize_cyk(start, rules, lexicon, valid_sentence))
        self.assertTrue(recognize_cyk(start, rules, lexicon, silly_valid_sentence))
        self.assertFalse(recognize_cyk(start, rules, lexicon, invalid_sentence))



if __name__ == '__main__':
    unittest.main()

