from sim.Match import Match
from sim.MatchSegment import MatchSegment
from sim.Text import Text
import re
from sim.Token import Token
from rapidfuzz import fuzz, process
from datasketch import MinHash, MinHashLSH

class SimTexter:

    def __init__(self, min_match_length):
        self.min_match_length = min_match_length
        self.cache = {}

    def compare(self, input_texts):

        (texts, tokens) = self.read_input(input_texts)

        mts_tags = {}
        forward_references = {}
        existing_tags = []
        lsh = MinHashLSH(threshold=0.80, num_perm=128)

        for i in range(0, len(texts)):
            (mts_tags, forward_references, existing_tags, lsh) = self.make_forward_references(i, texts[i], tokens, mts_tags, existing_tags, forward_references, lsh)

        similarities = self.get_similarities(tokens, texts, 0, 1, forward_references)

        return similarities

        # self.print_similarities(similarities, input_texts)

    def read_input(self, input_texts):

        texts = []
        tokens = []

        for input_text in input_texts:
            nr_of_characters = len(input_text)
            nr_of_words = len(input_text.split())
            file_name = 'dummy'
            tk_start_pos = len(tokens)

            tokens.extend(self.tokenize_text(input_text))
            tk_end_pos = len(tokens)
            text = Text('Text', nr_of_characters, nr_of_words, file_name, tk_start_pos, tk_end_pos)
            texts.append(text)

        return texts, tokens

    def tokenize_text(self, input_text):
        cleaned_text = self.clean_text(input_text)

        tokens = []

        for match in re.finditer("[^\\s]+", cleaned_text):
            token = self.clean_word(match.group())

            if len(token) > 0:
                text_begin_pos = match.start()
                text_end_pos = match.end()

                tokens.append(Token(token, text_begin_pos, text_end_pos))

        return tokens

    def clean_text(self, input_text):
        # TODO: optional machen

        input_text = re.sub("[.?!,;:/()'+\\-\\[\\]‚‘…]", " ", input_text)
        input_text = re.sub("[0-9]", " ", input_text)

        return input_text.lower()

    def clean_word(self, input_word):
        # TODO: Umlaute ersetzen, optional machen
        return input_word

    def make_forward_references(self, text_index, text, tokens, mts_tags, existing_tags, forward_references, lsh):
        text_begin_pos = text.tk_start_pos
        text_end_pos = text.tk_end_pos

        for i in range(text_begin_pos, text_end_pos - self.min_match_length):
            tag = ''

            for token in tokens[i: i + self.min_match_length]:
                tag = tag + token.text

            # TODO: geht das fuzzy??

            # if tag in mts_tags:
            #    forward_references[mts_tags[tag]] = i

            # mts_tags[tag] = i

            # if text_index == 0:
            #     existing_tags.append(tag)
            # else:
            #     best_existing_tag = process.extractOne(tag, existing_tags, scorer=fuzz.ratio, score_cutoff=80)
            #
            #     if best_existing_tag:
            #         forward_references[mts_tags[best_existing_tag[0]]] = i

            if text_index == 0:
                my_set = set(tag)
                min_hash = MinHash(num_perm=128)

                for d in my_set:
                    min_hash.update(d.encode('utf8'))

                lsh.insert(tag, min_hash, False)
            else:
                my_set = set(tag)
                min_hash = MinHash(num_perm=128)

                for d in my_set:
                    min_hash.update(d.encode('utf8'))

                result = lsh.query(min_hash)

                if result and len(result) > 0:
                    closest_match = self.get_closest_match(result, tag)
                    if closest_match:
                        forward_references[mts_tags[closest_match]] = i

            mts_tags[tag] = i

        return mts_tags, forward_references, existing_tags, lsh

    def get_similarities(self, tokens, texts, source_text_index, target_text_index, forward_references):
        source_token_start_pos = texts[source_text_index].tk_start_pos
        source_token_end_pos = texts[source_text_index].tk_end_pos

        similarities = []

        while source_token_start_pos + self.min_match_length <= source_token_end_pos:
            best_match = self.get_best_match(tokens, texts, source_text_index, target_text_index, source_token_start_pos,
                                        forward_references)

            if best_match and best_match.length > 0:
                source_character_start_pos = tokens[best_match.source_token_start_pos].start_pos
                source_character_end_pos = tokens[best_match.source_token_start_pos + best_match.length - 1].end_pos
                target_character_start_pos = tokens[best_match.target_token_start_pos].start_pos
                target_character_end_pos = tokens[best_match.target_token_start_pos + best_match.length - 1].end_pos

                similarities.append((MatchSegment(best_match.source_text_index, best_match.source_token_start_pos,
                                                  best_match.length, source_character_start_pos, source_character_end_pos),
                                     MatchSegment(best_match.target_text_index, best_match.target_token_start_pos,
                                                  best_match.length, target_character_start_pos, target_character_end_pos)))

                source_token_start_pos = source_token_start_pos + best_match.length
            else:
                source_token_start_pos = source_token_start_pos + 1

        return similarities

    def get_best_match(self, tokens, texts, source_text_index, target_text_index, source_token_start_pos, forward_references):
        best_match_length = 0
        token_pos = source_token_start_pos

        source_token_pos = 0
        target_token_pos = 0

        best_match_token_pos = 0

        best_match = None

        while 0 < token_pos < len(tokens):

            if token_pos < texts[target_text_index].tk_start_pos:
                if token_pos in forward_references:
                    token_pos = forward_references[token_pos]
                else:
                    token_pos = -1
                continue

            min_match_length = self.min_match_length

            if best_match_length > 0:
                min_match_length = best_match_length + 1

            source_token_pos = source_token_start_pos + min_match_length - 1
            target_token_pos = token_pos + min_match_length - 1

            if source_token_pos < texts[source_text_index].tk_end_pos and texts[
                target_text_index].tk_end_pos > target_token_pos >= source_token_pos + min_match_length:

                cnt = min_match_length

                while cnt > 0 and self.fuzzy_match(tokens[source_token_pos].text, tokens[target_token_pos].text) > 80:
                    source_token_pos = source_token_pos - 1
                    target_token_pos = target_token_pos - 1
                    cnt = cnt - 1

                if cnt > 0:
                    if token_pos in forward_references:
                        token_pos = forward_references[token_pos]
                    else:
                        token_pos = -1
                    continue
            else:
                if token_pos in forward_references:
                    token_pos = forward_references[token_pos]
                else:
                    token_pos = -1
                continue

            new_match_length = min_match_length
            source_token_pos = source_token_start_pos + min_match_length
            target_token_pos = token_pos + min_match_length

            while source_token_pos < texts[source_text_index].tk_end_pos and texts[
                target_text_index].tk_end_pos > target_token_pos > source_token_pos + \
                    new_match_length and self.fuzzy_match(tokens[source_token_pos].text,
                                                          tokens[target_token_pos].text) > 80:

                source_token_pos = source_token_pos + 1
                target_token_pos = target_token_pos + 1
                new_match_length = new_match_length + 1

            if new_match_length >= self.min_match_length and new_match_length > best_match_length:
                best_match_length = new_match_length
                best_match_token_pos = token_pos
                best_match = Match(source_text_index, source_token_start_pos, target_text_index, best_match_token_pos,
                                   best_match_length)

            if token_pos in forward_references:
                token_pos = forward_references[token_pos]
            else:
                token_pos = -1

        return best_match

    def fuzzy_match(self, input1, input2):

        # if input1 + input2 in self.cache:
        #    return self.cache[input1 + input2]

        # if abs(len(input1) - len(input2)) >= 3:
        #    self.cache[input1 + input2] = 0
        #    return 0

        ratio = fuzz.ratio(input1, input2)
        # self.cache[input1 + input2] = ratio
        return ratio

    def get_closest_match(self, candidates, word):
        if word in candidates:
            return word

        best_existing_tag = process.extractOne(word, candidates, scorer=fuzz.ratio, score_cutoff=80)

        if best_existing_tag:
            return best_existing_tag[0]

        return None

    def print_similarities(self, similarities, input_texts):
        for similarity_tuple in similarities:
            similarity_literature = similarity_tuple[0]
            similarity_scientific = similarity_tuple[1]

            print('{0}, {1}'.format(similarity_literature, similarity_scientific))
            print(input_texts[0][similarity_literature.character_start_pos:similarity_literature.character_end_pos])