
from collections import defaultdict

rules = {
    'NP'  : [ ('S', 'VP', 1.0),      # S -> NP VP
              ('NP', 'PP', 0.2) ],   # NP -> NP PP
    'VP'  : [ ('VP', 'PP', 0.6) ],   # VP -> VP PP
    'V'   : [ ('VP', 'NP', 0.4) ],   # VP -> V NP
    'P'   : [ ('PP', 'NP', 1.0) ],   # PP -> P NP
    'DET' : [ ('NP', 'N', 0.8) ]     # NP -> DET N
}

lexicon = {
    'the'       : [ ('DET', 0.8) ],
    'a'         : [ ('DET', 0.2) ],
    'in'        : [ ('P', 0.5) ],
    'on'        : [ ('P', 0.5) ],
    'student'   : [ ('N', 0.55) ],
    'book'      : [ ('N', 0.25) ],
    'library'   : [ ('N', 0.2) ],
    'reads'     : [ ('V', 1.0) ]
}

class Tree:
    def __init__(self, label, children):
        self.label = label
        self.children = children
    
    def __str__(self):
        if self.children:
            return '(%s %s)' % (self.label, ' '.join(str(child) for child in self.children))
        else:
            return self.label

# t = Tree('NP', [Tree('Det', [Tree('the', [])]), Tree('N', [Tree('student', [])])])
# print(t) => (NP (Det the) (N student))

def pcyk(rules, lexicon, words):
    P = defaultdict(dict) # probabilities
    T = defaultdict(dict) # the tree
    for i, wi in enumerate(words, start=1):
        for lhs, prob in lexicon[wi]:
            P[i - 1, i][lhs] = prob
            T[i - 1, i][lhs] = Tree(lhs, [wi])
        for j in range(i - 2, -1, -1):
            for k in range(j + 1, i):
                # to be completed
                pass
                # P[j, k] - dictionary that maps nonterminal symbols to their
                #           probaility
                # T[j, k] - the corresponding derivation tree
    if 'S' in P[0, i]:
        return P[0, i]['S'], T[0, i]['S']
    else:
        return False

if __name__ == '__main__':
    (prob, tree) = pcyk(rules, lexicon, 'the student reads a book in the library'.split())
    print(prob)
    print(tree)

