'''
Created on May 8, 2012

@author: regneri
'''
from Graph import Graph, Edge
import sys


class NetworkEdge(Edge):
    """
        Kanten in Netzwerken haben zusaetzlich capacity, flow und residual capacity 
    """
    def __init__(self, src, tgt, cap, wt=1):
        Edge.__init__(self,src,tgt,wt)
        self.capacity = cap
        self.flow = 0
    
    def increase_flow(self, toAdd=1):
        self.flow += toAdd
        
    def decrease_flow(self, toReduce=1):
        self.flow -= toReduce
        
    def get_current_flow(self):
        return self.flow 

    def get_capacity(self):
        return self.capacity
    
    def residual(self):
        return self.capacity - self.flow
    
    def __str__(self):
        return self.source + ' --> ' + self.target + '  (' + str(self.weight) + '), Capacity: ' + str(self.capacity) + " Flow: " + str(self.flow)

    def __repr__(self):
        return "Edge("  + self.source + ',' + self.target + ',' + str(self.capacity) + "," + str(self.flow) + ')'
    
    def __hash__(self):
        return hash(self.source) + hash(self.target) * hash(self.capacity) + hash(self.weight) * hash(self.flow)
    
    def __eq__(self, other):
        return self.source  == other.source and self.target == other.target and self.capacity == other.capacity and self.weight == other.weight and self.flow == other.flow

class Network(Graph):
    """
      Netzwerke sind Graphen, mit ein paar modifizierten Methoden
    """
    def __init__(self):
        Graph.__init__(self)
        
    def add_edge_between(self, n1, n2, capacity,  wt=1):
        edge = NetworkEdge(n1,n2, capacity, wt)
        Graph.add_edge(self,edge)
        
    
    def capacity(self, src, tgt):
        return self.get_edge_between(src,tgt).get_capacity()

    def flow(self, src, tgt):
        return self.get_edge_between(src,tgt).get_current_flow()

    def setCapacity(self,src,tgt,capacity):
        self.get_edge_between(src,tgt).capacity = capacity

    def addFlow(self,src,tgt,flow):
        self.get_edge_between(src,tgt).increase_flow(flow)

    def get_outgoing_edges(self,node, backward_edges=dict()):
        """
            Im temporaeren Netzwerk gibt es die umgekehrten Kanten, die ich nicht
            direkt zum Netzwerk hinzufuegen werde, sondern ich behalte sie separat
            in einem dictionary. Ohne den 2. Parameter verhaelt sich die Methode
            wie in der urspruenglichen Graph-Klasse.
        """
        incoming = Graph.get_incoming_edges(self,node)
        ret = [backward_edges[inc] for inc in incoming if inc in backward_edges.keys()]
        ret += Graph.get_outgoing_edges(self, node)
        return ret
    
    def get_incoming_edges(self, node,backward_edges=dict()):
        """
           ...analog zu outgoing_Edges
        """
        outgoing = Graph.get_outgoing_edges(self,node)
        ret = [backward_edges[out] for out in outgoing if out in backward_edges.keys()]
        ret +=  Graph.get_incoming_edges(self, node)
        return ret
    
    def get_free_capacity_path(self, n1, n2,backwards=dict()):
        """
          Berechnet einen Pfad zwischen n1 und n2 mit freien Kapazitaeten.
          n1 - Quelle
          n2 - Ziel
          backwards - die zusaetzlichen Kanten fuer max-flow
          
          Rueckgabe: Ein par aus der eigentlichen Kapazitaet und den Kanten des Pfades
        """
        cap,ret = self.get_free_capacity_path_rec(n1, n2, set(), sys.maxsize,backwards,set())
        return (cap,ret)
    
    
    def get_free_capacity_path_rec(self, n1,n2,alledges,cur_capacity,backwards,visited):
        """
           Rekursive Hilsmethode zur Berechnung eines Pfades zwischen n1 und n2
           n1 - Quelle
           n2 - Ziel
           alledges - aufgesammelte Kanten
           cur_capacity - aktuelle freie Kapazitaet
           backwards - umgekehrte Kanten, wenn fuer max_flow benoetigt
           visited - gesehene Knoten
        """
        if n1 == n2:         
            visited.add(n1)
            return (cur_capacity,alledges)
        else:
            visited.add(n2)
            if not self.indeg(n2):
                return 0 #no path to n2 on this way
            sorted_edges = reversed(sorted(self.get_incoming_edges(n2,backwards), key=lambda x:x.residual()))
            for edge in sorted_edges:
                if edge.residual() and edge.source not in visited:
                    cap = min([cur_capacity, edge.residual()])
                    ret,edgset = self.get_free_capacity_path_rec(n1, edge.source, alledges,cap,backwards,visited)
                    if ret:
                        edgset |= alledges
                        edgset.add(edge)
                        return (ret, edgset)
        return (0,set())    
    
    def reset_flow(self):
        """
            Setzt alle Kanten flows zurueck auf 0.
        """
        for edge in self._edges:
            edge.flow = 0
    
    def max_flow(self, source, sink):
        """
            Toplevel-Methode zur Max-Flow-Berechnung.
            source - Quellknoten
            sink - Senke
        """
        self.reset_flow()
        backwards_edges = dict()
        all_cap = 0
        all_edges = set()
        cap, edges = self.get_free_capacity_path(source, sink)
        all_cap += cap
        all_edges |= edges
        while(cap): #there is some capacity left from src to sink
            for edg in edges:
                #increase flow + update backward edges
                edg.increase_flow(cap)
                if edg not in backwards_edges.keys():
                    gde = NetworkEdge(edg.target, edg.source, edg.capacity, edg.weight)
                    gde.increase_flow(edg.capacity) #initially, the reverse edge is blocked
                    backwards_edges[edg] = gde
                    backwards_edges[gde] = edg
                backwards_edges[edg].decrease_flow(cap)
            cap, edges = self.get_free_capacity_path(source, sink, backwards_edges)
            all_cap += cap
            all_edges |= edges
        return(all_cap, self._edges)
    
    
    def cut_rec(self, start, visited,candidates):
        """
            Berechnet die Kanten, die einen minimalen Schnitt ergeben, aus einer Menge
            von Kandidaten.
            start - aktueller Knoten
            visited - besuchte Knoten
            candidates - Kanten mit ausgereizter Kapazitaet aus der max-flow Berechnung
        """
        ret = set()
        if not start in visited:
            visited.add(start)
            for edge in self.get_outgoing_edges(start):
                if edge in candidates:
                    ret.add(edge)
                else:
                    ret |= self.wccs_rec(edge.target,visited,candidates)
        return ret
    
    def wccs_rec(self, start, visited, forbidden, src = True):
        """
            Berechnet anhand der Kanten eines minimalen Schnittes die 
            Zusammenhangs-Komponenten, die entstehen, wenn man den Schnitt durchfuehrt
            start - aktuell betrachteter Knoten
            visited - besuchte Knoten
            forbidden - Kanten des Schnittes
            src - wenn "True", gehen wir von der Quelle zum Ziel, sonst umgekehrt
            
            Rueckgabe: Ein Paar aus den Knoten und Kanten der Zusammenhangskomponente
        """
        nodes,edges = set(),set()
        if not start in visited:
            nodes.add(start)
            visited.add(start)
            nextset = self.get_outgoing_edges(start) if src else self.get_incoming_edges(start)
            for edge in nextset:
                if edge not in forbidden:
                    edges.add(edge)
                    n,e = self.wccs_rec(edge.target,visited,forbidden)
                    nodes |= n
                    edges |= e
        return (nodes,edges)
    
    def min_cut(self,source,sink):
        """
          Toplevel-Methode zur Berechnung des minimalen Schnittes
          zwischen source und sink.
        """
        cap, edges = self.max_flow(source, sink)
        saturated = [edge for edge in edges if edge.residual() == 0]
        cut = self.cut_rec(source,set(),saturated)
     
        src_nodes, src_edges = self.wccs_rec(source,set(),cut)
        sink_nodes, sink_edges =self.wccs_rec(sink, set(), cut, False)
     
        return (self.makeSubgraph(src_nodes,src_edges), self.makeSubgraph(sink_nodes,sink_edges))
    
    
    def makeSubgraph(self,node_set,edge_set):
        """
          Berechnet eine Teilnetzwerk des aktuellen Netzwerk, bestehend 
          aus den gegebenen Knoten und Kanten.
        """
        ret = Network()
        ret._nodes = set(node_set)
        for edge in edge_set:
            ret.add_edge_between(edge.source, edge.target, edge.capacity, edge.weight)
        
        return ret
    
    
    
def parseNetwork(file):
    """
       Einfache Methode, die Netzwerk-Kanten aus einer Datei liest
       und das Netzwerk zurueck gibt.
       Format:
       quelle    ziel    kapazitaet    gewicht
    """
    n = Network()
    with open(file) as f:
        for line in f:
            src, tgt, cap, wt = line.strip().split("\t")
            n.add_edge_between(src, tgt, int(cap), int(wt))
    return n

def test():
    net = parseNetwork('zuege.txt')
    print("#4.1: Nets")
    for edge in net.edges():
        print(edge)
    print("#4.2: Max Flow")
    cap, edges = net.max_flow("Saarbruecken", "Frankfurt")
    print("max flow: ", str(cap))
    print("edges + capacity:", edges)
    
    print("#4.3 Min Cut (S,T)")
    srcnet, snknet = net.min_cut("Saarbruecken", "Frankfurt")
    print("S=", str(srcnet.nodes()), str(srcnet.edges()))
    print("T=", str(snknet.nodes()), str(snknet.edges()))
   
    
if __name__ == '__main__':
    test()
    
                
                    
        