from math import sqrt
import re
import os
import string

# Aufgabe 1
class Vector(dict):

    def __getitem__(self, key):
        return dict.get(self, key, 0.0)
    def __add__(self, other):
        return Vector((k, self[k] + other[k]) for k in set(self.keys() + other.keys()))
    def __mul__(self, scalar):
        return Vector(((k, v * float(scalar)) for (k, v) in self.items()))    
    def dotproduct(self, other):
        return sum(self[k] * other[k] for k in self.keys() if k in other)
    def norm(self):
        return sqrt(sum(v * v for v in self.values()))   

    def __sub__(self, other):
        return Vector() # implement me!

    def __div__(self, scalar):
        return Vector() # implement me!
 
    
    @staticmethod
    def cosine(first, second):
        if first.norm() * second.norm() == 0:
            return 0
        return first.dotproduct(second) / (first.norm() * second.norm())
    
    @staticmethod   
    def euclid(first, second):
        return sqrt(sum(pow(first[k] - second[k], 2) for k in set(list(first.keys()) + list(second.keys()))))        

    @staticmethod
    def centroid(vecs):
        return Vector() # implement me!


def parse_file(filename):
    mapping = {}
    for char in string.punctuation:
        mapping[char] = ""
    trans = str.maketrans(mapping)
    with open(filename, 'r') as f:
        for line in f:
            for word in line.split():
                word = str.lower(word.translate(trans))
                yield word    


# Aufgabe 2
class VectorModel:

    def __init__(self):
        pass #implement me!
    
    def add_words(self, w_iterator, window=5):
        pass # implement me!
    
    def convert_to_pmi(self):
        pass # implement me!
            

    def get_similar_words(self, word, n=10):
        return [] #implement me!

def testVectorModel(model, testdir="nyt199407", testword="man", no_of_items=10, distance=Vector.cosine):
    print("Wort: {}, Aehnlichkeitsmass: {}".format(testword,distance)) 
    print([x for x in model.get_similar_words(testword, no_of_items, distance, distance == Vector.cosine)])

if __name__ == '__main__': 
    
    # mid-sized window
    print("Training new model, window size 5...")
    model = VectorModel()
    testdir = "nyt199407"
    for f in os.listdir(testdir):
        model.add_words(parse_file(testdir + "/" + f))
    testVectorModel(model, "man")
    testVectorModel(model, "man", distance = Vector.euclid)   
    
    print("Converting scores to PMI...")
    model.convert_to_pmi()
    testVectorModel(model, "man")
    
    # bigger window 
    print("Training new model, window size 10...")
    model = VectorModel()
    testdir = "nyt199407"
    for f in os.listdir(testdir):
        model.add_words(parse_file(testdir + "/" + f), window=10)
    testVectorModel(model, "man")


    # small window
    print("Training new model, window size 3...")
    model = VectorModel()
    testdir = "nyt199407"
    for f in os.listdir(testdir):
        model.add_words(parse_file(testdir + "/" + f), window=3)
    testVectorModel(model, "man")
