from sim.Match import Match
from sim.MatchSegment import MatchSegment
from sim.Text import Text
import re
from sim.Token import Token
from rapidfuzz import fuzz, process
from datasketch import MinHash, MinHashLSH


# noinspection PyMethodMayBeStatic
class SimTexter:

    def __init__(self, min_match_length, max_gap):
        self.min_match_length = min_match_length
        self.max_gap = max_gap

    def compare(self, input_texts):

        (texts, tokens) = self.__read_input(input_texts)

        mts_tags = {}
        forward_references = {}
        lsh = MinHashLSH(threshold=0.80, num_perm=128)

        for i in range(0, len(texts)):
            (mts_tags, forward_references, lsh) = self.__make_forward_references(i, texts[i], tokens, mts_tags,
                                                                                 forward_references, lsh)
        similarities = self.__get_similarities(tokens, texts, 0, 1, forward_references)

        # self.__print_similarities(similarities, input_texts)

        cleaned_similarities = []

        pos = 0
        current_match_segment = None

        while pos < len(similarities):
            if pos + 1 >= len(similarities) and not current_match_segment:

                if current_match_segment:
                    cleaned_similarities.append(current_match_segment)

                cleaned_similarities.append(similarities[pos])
                break

            if current_match_segment:
                extend_existing = True
                current_source_sim = current_match_segment[0]
                next_source_sim = similarities[pos][0]
                current_target_sim = current_match_segment[1]
                next_target_sim = similarities[pos][1]
            else:
                extend_existing = False
                current_source_sim = similarities[pos][0]
                next_source_sim = similarities[pos + 1][0]
                current_target_sim = similarities[pos][1]
                next_target_sim = similarities[pos + 1][1]

            if ((1 <= next_target_sim.token_start_pos - (
                    current_target_sim.token_start_pos + current_target_sim.token_length) <= 2) and (
                        1 <= next_target_sim.token_start_pos - (
                        current_target_sim.token_start_pos + current_target_sim.token_length) <= 2)) or (
                    next_target_sim.token_start_pos - (current_target_sim.token_start_pos + current_target_sim.token_length) == 0
                    and '[...]' in input_texts[1][tokens[next_target_sim.token_start_pos-1].end_pos:tokens[next_target_sim.token_start_pos].start_pos]):

                current_match_segment = (MatchSegment(current_source_sim.text_index, current_source_sim.token_start_pos,
                                                      current_source_sim.token_length + next_source_sim.token_length,
                                                      current_source_sim.character_start_pos,
                                                      next_source_sim.character_end_pos),
                                         MatchSegment(current_target_sim.text_index, current_target_sim.token_start_pos,
                                                      current_target_sim.token_length + next_target_sim.token_length,
                                                      current_target_sim.character_start_pos,
                                                      next_target_sim.character_end_pos))
                if extend_existing:
                    pos = pos + 1
                else:
                    pos = pos + 2

                if pos >= len(similarities):
                    cleaned_similarities.append(current_match_segment)
            else:
                if current_match_segment:
                    cleaned_similarities.append(current_match_segment)
                    current_match_segment = None
                else:
                    cleaned_similarities.append(similarities[pos])
                    pos = pos + 1

        return cleaned_similarities

        # self.print_similarities(similarities, input_texts)

    def __read_input(self, input_texts):

        texts = []
        tokens = []

        for input_text in input_texts:
            nr_of_characters = len(input_text)
            nr_of_words = len(input_text.split())
            file_name = 'dummy'
            tk_start_pos = len(tokens)

            tokens.extend(self.tokenize_text(input_text))
            tk_end_pos = len(tokens)
            text = Text('Text', nr_of_characters, nr_of_words, file_name, tk_start_pos, tk_end_pos)
            texts.append(text)

        return texts, tokens

    def tokenize_text(self, input_text):
        cleaned_text = self.clean_text(input_text)

        tokens = []

        for match in re.finditer("[^\\s]+", cleaned_text):
            token = self.__clean_word(match.group())

            if len(token) > 0:
                text_begin_pos = match.start()
                text_end_pos = match.end()

                tokens.append(Token(token, text_begin_pos, text_end_pos))

        return tokens

    def clean_text(self, input_text):
        # TODO: optional machen

        input_text = re.sub("[^a-zA-Z0-9äüöÄÜÖß ]", " ", input_text)
        # input_text = re.sub("[.?!,‚‘'’»«<>;:/()+\\-–\\[\\]…\"_\r\n]", " ", input_text)
        input_text = re.sub("[0-9]", " ", input_text)

        return input_text.lower()

    def __clean_word(self, input_word):
        # TODO: optional machen
        input_word = input_word.replace('ß', 'ss')
        input_word = input_word.replace('ä', 'ae')
        input_word = input_word.replace('ö', 'oe')
        input_word = input_word.replace('ü', 'ue')
        input_word = input_word.replace('ey', 'ei')

        return input_word

    def __make_forward_references(self, text_index, text, tokens, mts_tags, forward_references, lsh):
        text_begin_pos = text.tk_start_pos
        text_end_pos = text.tk_end_pos

        for i in range(text_begin_pos, text_end_pos - self.min_match_length):
            tag = ''

            for token in tokens[i: i + self.min_match_length]:
                tag = tag + token.text

            if text_index == 0:
                my_set = set(tag)
                min_hash = MinHash(num_perm=128)

                for d in my_set:
                    min_hash.update(d.encode('utf8'))

                lsh.insert(tag, min_hash, False)
            else:
                my_set = set(tag)
                min_hash = MinHash(num_perm=128)

                for d in my_set:
                    min_hash.update(d.encode('utf8'))

                result = lsh.query(min_hash)

                if result and len(result) > 0:
                    closest_match = self.__get_closest_match(result, tag)
                    if closest_match:
                        forward_references[mts_tags[closest_match]] = i

            mts_tags[tag] = i

        return mts_tags, forward_references, lsh

    def __get_similarities(self, tokens, texts, source_text_index, target_text_index, forward_references):
        source_token_start_pos = texts[source_text_index].tk_start_pos
        source_token_end_pos = texts[source_text_index].tk_end_pos

        similarities = []

        while source_token_start_pos + self.min_match_length <= source_token_end_pos:
            best_match = self.__get_best_match(tokens, texts, source_text_index, target_text_index,
                                               source_token_start_pos, forward_references)

            if best_match and best_match.source_length > 0:
                source_character_start_pos = tokens[best_match.source_token_start_pos].start_pos
                source_character_end_pos = tokens[
                    best_match.source_token_start_pos + best_match.source_length - 1].end_pos
                target_character_start_pos = tokens[best_match.target_token_start_pos].start_pos
                target_character_end_pos = tokens[
                    best_match.target_token_start_pos + best_match.target_length - 1].end_pos

                similarities.append((MatchSegment(best_match.source_text_index, best_match.source_token_start_pos,
                                                  best_match.source_length, source_character_start_pos,
                                                  source_character_end_pos),
                                     MatchSegment(best_match.target_text_index, best_match.target_token_start_pos,
                                                  best_match.target_length, target_character_start_pos,
                                                  target_character_end_pos)))

                source_token_start_pos = source_token_start_pos + best_match.source_length
            else:
                source_token_start_pos = source_token_start_pos + 1

        return similarities

    def __get_best_match(self, tokens, texts, source_text_index, target_text_index, source_token_start_pos,
                         forward_references):
        best_match_length = 0
        token_pos = source_token_start_pos
        best_match = None
        offset_source = 0
        offset_target = 0
        has_skipped = False

        while 0 <= token_pos < len(tokens):

            if token_pos < texts[target_text_index].tk_start_pos:
                if token_pos in forward_references:
                    token_pos = forward_references[token_pos]
                else:
                    token_pos = -1
                continue

            min_match_length = self.min_match_length

            if best_match_length > 0:
                min_match_length = best_match_length + 1

            source_token_pos = source_token_start_pos + min_match_length - 1
            target_token_pos = token_pos + min_match_length - 1

            if source_token_pos < texts[source_text_index].tk_end_pos and texts[
                target_text_index].tk_end_pos > target_token_pos >= source_token_pos + min_match_length:

                cnt = min_match_length

                while cnt > 0:
                    if self.__fuzzy_match(tokens[source_token_pos].text, tokens[target_token_pos].text) > 80:
                        source_token_pos = source_token_pos - 1
                        target_token_pos = target_token_pos - 1
                        cnt = cnt - 1
                    else:
                        found = False
                        for i in range(1, self.max_gap + 1):
                            if self.__fuzzy_match(tokens[source_token_pos - i].text,
                                                  tokens[target_token_pos].text) > 80:
                                source_token_pos = source_token_pos - 1 - i
                                target_token_pos = target_token_pos - 1
                                cnt = cnt - 1
                                found = True
                                break
                        if not found:
                            break

                if cnt > 0:
                    if token_pos in forward_references:
                        token_pos = forward_references[token_pos]
                    else:
                        token_pos = -1
                    continue
            else:
                if token_pos in forward_references:
                    token_pos = forward_references[token_pos]
                else:
                    token_pos = -1
                continue

            new_match_length = min_match_length
            source_token_pos = source_token_start_pos + min_match_length
            target_token_pos = token_pos + min_match_length

            while source_token_pos < texts[source_text_index].tk_end_pos and texts[
                target_text_index].tk_end_pos > target_token_pos > source_token_pos + new_match_length:

                if self.__fuzzy_match(tokens[source_token_pos].text, tokens[target_token_pos].text) > 80:
                    source_token_pos = source_token_pos + 1
                    target_token_pos = target_token_pos + 1
                    new_match_length = new_match_length + 1
                elif self.__fuzzy_match(tokens[source_token_pos].text + tokens[source_token_pos + 1].text,
                                        tokens[target_token_pos].text) > 80:
                    source_token_pos = source_token_pos + 1 + 1
                    target_token_pos = target_token_pos + 1
                    new_match_length = new_match_length + 1 + 1
                    offset_target = offset_target + 1
                elif self.__fuzzy_match(tokens[source_token_pos].text,
                                        tokens[target_token_pos].text + tokens[target_token_pos + 1].text) > 80:
                    source_token_pos = source_token_pos + 1
                    target_token_pos = target_token_pos + 1 + 1
                    new_match_length = new_match_length + 1 + 1
                    offset_source = offset_source + 1
                elif not has_skipped:
                    found = False
                    for i in range(1, self.max_gap + 1):
                        if self.__fuzzy_match(tokens[source_token_pos + i].text, tokens[target_token_pos].text) > 80:
                            source_token_pos = source_token_pos + 1 + i
                            target_token_pos = target_token_pos + 1
                            new_match_length = new_match_length + 1 + i
                            offset_target = offset_target + i
                            found = True
                            has_skipped = True
                            break

                    if not found:
                        if self.__fuzzy_match(tokens[source_token_pos].text, tokens[target_token_pos + 1].text) > 80:
                            source_token_pos = source_token_pos + 1
                            target_token_pos = target_token_pos + 1 + 1
                            new_match_length = new_match_length + 1 + 1
                            offset_source = offset_source + 1
                            found = True
                            has_skipped = True

                    if not found:
                        break
                else:
                    break

            if new_match_length >= self.min_match_length and new_match_length > best_match_length:
                best_match_length = new_match_length
                best_match_token_pos = token_pos
                best_match = Match(source_text_index, source_token_start_pos, target_text_index, best_match_token_pos,
                                   best_match_length - offset_source, best_match_length - offset_target)

            if token_pos in forward_references:
                token_pos = forward_references[token_pos]
            else:
                token_pos = -1

        return best_match

    def __fuzzy_match(self, input1, input2):

        if len(input1) < 2 or len(input2) < 2:
            return 0

        ratio = fuzz.ratio(input1, input2)
        return ratio

    def __get_closest_match(self, candidates, word):

        if not candidates or len(candidates) == 0:
            return None

        if word in candidates:
            return word

        best_existing_tag = process.extractOne(word, candidates, scorer=fuzz.ratio, score_cutoff=80)

        if best_existing_tag:
            return best_existing_tag[0]

        return None

    def __print_similarities(self, similarities, input_texts):
        literature_content = input_texts[0]
        scientific_content = input_texts[1]

        result = ''

        for similarity_tuple in similarities:
            similarity_literature = similarity_tuple[0]
            similarity_scientific = similarity_tuple[1]

            content = literature_content[
                      similarity_literature.character_start_pos:similarity_literature.character_end_pos]
            result += '\n' + str(similarity_literature.character_start_pos) + '\t' + str(
                similarity_literature.character_end_pos) + '\t' + content

            content = scientific_content[
                      similarity_scientific.character_start_pos:similarity_scientific.character_end_pos]
            result += '\n' + str(similarity_scientific.character_start_pos) + '\t' + str(
                similarity_scientific.character_end_pos) + '\t' + content

        print(result)