# Tests for Assignment 2
from unittest import TestCase
import unittest

# To test ex. 1-3
from l2_graphs import *

""" For exercise 4: simply comment out the above import,
and uncomment the below import. The methods signitures (return type, argument types)
should be the same.
Ignore my comment in the tutorial about returning Edge, or Node classes.
"""

# To test ex. 4. All the tests should pass!
#from l2_graphs_oo import *

def star_graph(n):
    ''' returns a graph with nodes labelled 0 through n-1 as strings
        with node 0 having 1 edge to each of the other nodes
`   '''
    graph = Graph()
    nodes = [str(i) for i in range(n)]
    for node in nodes:
        graph.add_node(node)
    for i in range(1,n):
        graph.add_edge(nodes[0], nodes[i])
        
    return graph

class GraphTests(TestCase):
    def test_empty_graph(self):
        # create empty graph
        graph = Graph()
        
        # NOTE: graph.nodes() and graph.edges()
        # may return any iterable
        # I use assertCountEqual to ensure
        # the order doesn't matter
        
        self.assertCountEqual([], graph.nodes())
        self.assertCountEqual([], graph.edges())

    
    def test_node_methods(self):
        graph = Graph()
        nodes = {'node1', 'node2'}
        for node in nodes:
            graph.add_node(node)

        self.assertCountEqual(nodes, graph.nodes())

    def test_add_duplicate_node(self):
        graph = Graph()
        label = 'not_unique'
        graph.add_node(label)
        graph.add_node(label)
        self.assertCountEqual([label], list(graph.nodes()))
        
    def test_integer_labels(self):
        graph = Graph()
        n = 10
        for i in range(n):
            graph.add_node(i)
        self.assertCountEqual(range(n), graph.nodes())
            

    def test_edge_methods(self):
        graph = Graph()
        """ 1 --> 2 --> 3 """
        
        edges = {('node1', 'node2'), ('node2', 'node3')}
        nodes = ['node1', 'node2', 'node3']
        
        for edge in edges:
            graph.add_edge(edge[0], edge[1])
            
        # ensure the nodes were added
        self.assertCountEqual(nodes, graph.nodes())
        
        # check the edges are present
        self.assertCountEqual(edges, graph.edges())
        
        
        # loop through all pairs of nodes, except if they are equal
        for src in nodes:
            for dst in nodes:
                if src != dst:
                    present = (src, dst) in edges
                    self.assertEqual(present, graph.has_edge(src, dst))

    def test_degree_methods(self):
        n = 10
        graph = star_graph(n)
        self.assertEqual(0, graph.indeg('0'))
        self.assertEqual(n-1, graph.outdeg('0'))
        for i in range(1, n):
            self.assertEqual(1, graph.indeg(str(i)))
            self.assertEqual(0, graph.outdeg(str(i)))

    # Helper Method
    def check_layers(self, graph, root, layers):
        current_depth_list = []
        current_depth = 0
        for node in graph.bfs('0'):
            current_depth_list.append(node)
            if len(current_depth_list) == len(layers[current_depth]):
                self.assertCountEqual(layers[current_depth], current_depth_list)
                
                #current_depth_list.clear() # only works in python 3.3
                del current_depth_list[:]
                current_depth+=1

    def test_bfs(self):
        # This graph contains cycles...
        # If the test does not terminate, your code does not
        # handel cycles correctly
        
        # Note: The order of the nodes returned is only partially determined
        # As long as the nodes in depth i appear before depth i+1,
        # then it is correct
        
        graph = Graph()
        d0 = ['0']
        d1 = ['1a', '1b', '1c']
        d2 = ['2']

        for x in d1:
            graph.add_edge('0', x)
            graph.add_edge(x, '2')
            graph.add_edge(x, '0')

        partial_order = [d0, d1, d2]
        self.check_layers(graph, '0', partial_order)
        
    # Helper method
    def check_paths(self, graph, root, paths):
        # check that order is defined along path, but not between paths
        used_paths = []
        current_path = []
        current_node_in_path = 0
        dfs = graph.dfs(root)
        # ensure root is correct
        self.assertEqual(root, next(dfs))
        for node in dfs:
            # check if we start another path
            if current_node_in_path >= len(current_path):               
                used_paths.append(current_path)
                for path in paths:
                    if path[0] == node:
                        current_path = path
                        current_node_in_path = 0
                        break
                else:
                    # else clause of for-loop is executed if no break is called
                    # this is the case, if the nodes doesn't match the start
                    # of any path
                    self.fail(node + " didn't match any path")
               
            # make sure this is the next node on the path
            self.assertEqual(current_path[current_node_in_path], node)
            current_node_in_path+=1
        
        
    def test_dfs(self):
        graph = Graph()
        p1 = ['1a', '2a']
        p2 = ['1b', '2b', '3b']
        p3 = ['1c', '2c']
        
        paths = [p1, p2, p3]
        for path in paths:
            for i in range(len(path) - 1):
                graph.add_edge(path[i], path[i+1])

        # add a loop
        graph.add_edge('2b', '0')

        # connect to center node
        for path in paths:
            graph.add_edge('0', path[0])

        self.check_paths(graph, '0', paths)

                
class UndirectedGraphTests(TestCase):
    def test_edge_direction(self):
        n = 7
        graph = star_graph(n)
        undirected = graph.get_undirected_graph()

        edges = graph.edges()
        for (x, y) in edges:
            self.assertTrue(undirected.has_edge(y,x))

    def test_deg_methods(self):
        # Ensure in-degree and out-degree are equal for all nodes
        n = 15
        graph = star_graph(n)
        undirected = graph.get_undirected_graph()

        nodes = graph.nodes()
        for node in nodes:
            self.assertEqual(undirected.indeg(node), undirected.outdeg(node))

    def test_add_edge(self):
        # create super simple graph
        graph = Graph()
        graph.add_edge('0', '1')

        # get undirected copy
        undirected = graph.get_undirected_graph()
        
        undirected.add_edge('A', 'B')
        
        # check that edges exist
        self.assertTrue(undirected.has_edge('A', 'B'))
        self.assertTrue(undirected.has_edge('B', 'A'))
        self.assertTrue(undirected.has_edge('0', '1'))
        self.assertTrue(undirected.has_edge('1', '0'))

        # check that original graph doesn't have extra edge!
        self.assertCountEqual(['0', '1'], graph.nodes())
        self.assertCountEqual([('0', '1')], graph.edges())

class GraphCycleTest(TestCase):
    def test_tiny_cycle(self):
        graph = Graph()
        graph.add_edge(0,1)
        graph.add_edge(1,0)
        self.assertTrue(graph.has_cycle)
        self.assertFalse(graph.get_undirected_graph().has_cycle())
        
    def test_directed_cycle(self):
        graph = Graph()
        n = 10
        for i in range(1,n):
            graph.add_edge(i-1, i)
        self.assertFalse(graph.has_cycle())

        # add edge in 'wrong' direction for cycle
        graph.add_edge(1, 3)
        self.assertFalse(graph.has_cycle())

        # complete cycle
        graph.add_edge(n-1, 0)
        self.assertTrue(graph.has_cycle())

    def test_undirected_cycle(self):
        graph = Graph()
        graph.add_edge(0,1)
        graph.add_edge(1,2)
        graph.add_edge(0,2)
        self.assertFalse(graph.has_cycle())
        self.assertTrue(graph.get_undirected_graph().has_cycle())
        
        
if __name__ == '__main__':
    unittest.main()
