import unittest
from l4_nfa import *

class NFATests(unittest.TestCase):
    def test_states(self):
        m = NFA(-1, [(-1, '', 0), (0,'h',1), (1,'a',2), (2,'',0), (2,'!',3)], [3])
        ref_states = set([-1, 0, 1, 2, 3])
        self.assertEqual(ref_states, m.states())

    def test_transitions(self):
        ref_trans = [(-1, '', 0), (0,'h',1), (1,'a',2), (2,'',0), (2,'!',3)]
        ref_start = -1
        ref_finish = [3]

       #check all the transitions are present
        m = NFA(ref_start, ref_trans, ref_finish)
        for src, char, tgt in ref_trans:
            self.assertTrue(m.has_transition(src, char, tgt))
        
        # make sure method works for non-existent transitions
        self.assertFalse(m.has_transition(-1, 'h', 0))
        


    def test_recognize(self):
        m = NFA(-1, [(-1, '', 0), (0,'h',1), (1,'a',2), (2,'',0), (2,'!',3)], [3])
        self.assertTrue(m.recognize('haha!'))
        self.assertFalse(m.recognize('hah!'))        

    def test_remove_epsilon(self):  
        m = NFA(-1, [(-1, '', 0), (0,'h',1), (1,'a',2), (2,'',0), (2,'!',3)], [3])
        m.removeEpsilon()

        # new NFA should accept same words
        self.assertTrue(m.recognize('haha!'))
        self.assertFalse(m.recognize('hah!'))

        # new NFA should no '' transition!
        nodes = m.states()
        for src in nodes:
            for tgt in nodes:
                self.assertFalse(m.has_transition(src, '', tgt))

    def test_generate(self):
        # This test fails if max_iterations iterations are made and the expected 
        # words are not generated
        # This implies, you CANNOT use DFS
        max_iterations = 100
        m = NFA(-1, [(-1, '', 0), (0,'h',1), (1,'a',2), (2,'',0), (2,'!',3)], [3])

        expected_words = set(['ha!', 'haha!', 'hahaha!', 'hahahaha!'])
        forbidden_words = set(['', 'a!', 'ha', '!', 'haa!', 'hha!'])
        words_gen = m.generate()
        for i in range(max_iterations):
            word = next(words_gen)
            #print(word) # uncomment this to see what words your NFA is generating
            if word in expected_words:
                expected_words.remove(word)
            if word in forbidden_words:
                self.fail("Generated invalid word: " + word)
        self.assertEqual(0, len(expected_words), "Words " + str(expected_words) + " were not generated!")



if __name__ == '__main__':
    unittest.main()
