diff --git a/indiquo/cli/IndiQuoCLI.py b/indiquo/cli/IndiQuoCLI.py
index b86c35e09f258ffc4c1659b032d702ecb2f5de2a..bd0404a4c15d43d93cf3bffdae9d5ca37d36c952 100644
--- a/indiquo/cli/IndiQuoCLI.py
+++ b/indiquo/cli/IndiQuoCLI.py
@@ -1,12 +1,11 @@
 import logging
 import sys
-from argparse import ArgumentParser, BooleanOptionalAction
+from argparse import ArgumentParser
 from datetime import datetime
 from os import listdir
 from os.path import join, isfile, splitext, basename, isdir
 from pathlib import Path
 from typing import List
-from xml.etree import ElementTree
 
 import transformers
 from proquo.core import Helper
@@ -17,40 +16,24 @@ from quid.core.Quid import Quid
 
 from dramatist.Dramatist import Dramatist
 from indiquo.core.IndiQuo import IndiQuo
-from indiquo.core.reference.ReferencePredictor import ReferencePredictor
+from indiquo.core.ScenePredictor import ScenePredictor
 from indiquo.core.chunker.SentenceChunker import SentenceChunker
-from indiquo.core.reference.ReferenceSolver import ReferenceSolver
-from indiquo.core.reference.RegExReferencePredictor import RegExReferencePredictor
-from indiquo.core.reportedspeech.ReportedSpeechByRefPredictor import ReportedSpeechByRefPredictor
-from indiquo.core.reportedspeech.ReportedSpeechPredictor import ReportedSpeechPredictor
-from indiquo.core.similarity.DramaSimilarityPredictor import DramaSimilarityPredictor
-from indiquo.core.similarity.SimilarityPredictor import SimilarityPredictor
-from indiquo.model.reference.ReferenceVectorizer import ReferenceVectorizer
-from indiquo.training.reference import TrainReference
-from flair.models import SequenceTagger
+from indiquo.core.CandidatePredictor import CandidatePredictor
 from sentence_transformers import SentenceTransformer
 
 import csv
-from textwrap import wrap
-import plotly.graph_objects as go
 
 
 def __train(lower_case, base_model_name, model_type, loss_type, num_masked_tokens, use_sep, num_train_steps,
             num_training_examples, num_validation_examples, template_batch_size, max_length, batches_per_update,
             template_train_file_path, val_file_path, output_folder_path):
-
-    TrainReference.train(lower_case, base_model_name, model_type, loss_type, num_masked_tokens, use_sep,
-                         num_train_steps,
-                         num_training_examples, num_validation_examples, template_batch_size, max_length,
-                         batches_per_update,
-                         template_train_file_path, val_file_path, output_folder_path)
+    pass
 
 
 def __process_file(pro_quo_lm, quid_matches, indi_quo: IndiQuo, filename, drama, target_text, output_folder_path):
     print(f'Processing {filename} ...')
 
     source_text = drama.get_text()
-
     short_matches: List[MatchRef] = pro_quo_lm.compare(source_text, target_text, quid_matches)
     all_matches = short_matches
 
@@ -58,127 +41,38 @@ def __process_file(pro_quo_lm, quid_matches, indi_quo: IndiQuo, filename, drama,
     all_matches.extend(long_matches)
     all_matches = Helper.remove_overlapping_matches(all_matches, target_text)
 
-    matches = indi_quo.compare(drama, target_text, all_matches)
+    matches = indi_quo.compare(target_text, all_matches)
 
     with open(join(output_folder_path, f'{filename}.tsv'), "w", encoding='utf-8') as output_file:
         writer = csv.writer(output_file, delimiter="\t", lineterminator="\n")
-        writer.writerow(['speech_start', 'speech_end', 'speech_text', 'ref_start', 'ref_end', 'ref_text',
-                         'direct_quote_start', 'direct_quote_end', 'direct_quote_text', 'act', 'scene', 'heatmap',
-                         'lines'])
-
-        for pos, m in enumerate(matches):
-            heatmap_name = 'NN'
-
-            if m.source_heatmap:
-                heatmap_name = f'heatmap_{pos + 1}'
-                __save_heatmap(m.target_text, m.source_heatmap.lines, m.act, m.scene, output_folder_path, heatmap_name)
-
-            speech_text = m.target_text.replace('\n', ' ')
-            ref_text = m.ref_text.replace('\n', ' ')
-            direct_quote_text = m.direct_quote_text.replace('\n', ' ')
-
-            lines = ''
-
-            for line_range in m.line_ranges:
-                if lines:
-                    lines += ','
-
-                if line_range[0] == line_range[1]:
-                    lines += f'{line_range[0]}'
-                else:
-                    lines += f'{line_range[0]}:{line_range[1]}'
-
-            writer.writerow([m.target_start, m.target_end, speech_text, m.ref_start, m.ref_end, ref_text,
-                             m.direct_quote_start, m.direct_quote_end, direct_quote_text,
-                             m.act, m.scene, heatmap_name, lines])
-
-
-def __save_heatmap(target_text, lines, act_nr, scene_nr, output_path, filename):
-    x_labels = []
-    y_values = []
-
-    for c in lines:
-        y_values.append([c.score])
-        x_labels.append(['<br>'.join(wrap(c.text, width=150))])
-
-    fig = go.Figure(data=go.Heatmap(
-                    z=y_values,
-                    text=x_labels,
-                    texttemplate="%{text}",
-                    colorscale='rdylbu',
-                    zmin=0,
-                    zmax=1,
-                    reversescale=True))
-
-    title_wrapped = '<br>'.join(wrap(target_text, width=150))
-    title_wrapped += f'<br>{act_nr}. Akt, {scene_nr}. Szene'
-
-    fig.update_layout(
-        title_text=f'"{title_wrapped}"',
-        xaxis=dict(zeroline=False, showgrid=False, visible=False),
-        yaxis=dict(zeroline=False, showgrid=False, visible=False),
-    )
+        writer.writerow(['start', 'end', 'text', 'scenes'])
 
-    # fig.show()
-    fig.write_html(join(output_path, f'{filename}.html'))
+        for m in matches:
+            scene_predictions = ''
 
+            for sp in m.scene_predictions:
+                if scene_predictions:
+                    scene_predictions += '#'
 
-def __run_compare(source_file_path, target_path, tokenizer_folder_path, model_folder_path, output_folder_path, max_length, lower_case, mask_count,
-                  use_sep, templates, left_to_right, average_mode, only_next_token):
+                scene_predictions += f'{sp.act}:{sp.scene}:{sp.score}'
 
-    # ref_vectorizer = ReferenceVectorizer.from_saved(tokenizer_folder_path, max_length, lower_case, mask_count,
-    #                                                use_sep)
+            speech_text = m.target_text.replace('\n', ' ')
+            writer.writerow([m.target_start, m.target_end, speech_text, scene_predictions])
 
-    # root = ElementTree.parse(source_file_path).getroot()
 
+def __run_compare(source_file_path, target_path, similarity_model_path, output_folder_path):
     drama_processor = Dramatist()
     drama = drama_processor.from_file(source_file_path)
 
-    # sub_models = []
-    # sub_folders = listdir(model_folder_path)
-    # single_model = False
-    #
-    # if 'model' in sub_folders:
-    #     single_model = True
-    #     full_path = join(model_folder_path, 'model')
-    #     model = transformers.TFBertForMaskedLM.from_pretrained(full_path)
-    #     sub_models.append(model)
-    # else:
-    #     for file_or_folder in listdir(model_folder_path):
-    #         if not file_or_folder.startswith('model_'):
-    #             continue
-    #
-    #         full_path = join(model_folder_path, file_or_folder)
-    #         model = transformers.TFBertForMaskedLM.from_pretrained(full_path)
-    #         sub_models.append(model)
-
-    single_sentence_chunker = SentenceChunker(min_length=10, max_length=64, max_sentences=1)
-
-    # reference_chunker = SentenceChunker(max_length=64, max_sentences=100)
-    # reference_predictor = ReferencePredictor(reference_chunker, sub_models, ref_vectorizer, templates, single_model,
-    #                                          left_to_right, mask_count, average_mode, only_next_token)
-    # reference_predictor.expand_templates()
-
-    # reference_solver = ReferenceSolver()
-    # reference_predictor = RegExReferencePredictor(reference_solver)
-
-    # reported_speech_model = SequenceTagger.load('de-historic-reported')
-    # reported_speech_chunker = SentenceChunker(min_length=10, max_length=64, max_sentences=1)
-    # reported_speech_predictor = ReportedSpeechPredictor(single_sentence_chunker, reported_speech_model)
-
-    similarity_model = SentenceTransformer('/Users/frede/Arbeit/HU/Indirect_citations/similarity/model/2024_02_02_10_18_09_deutsche-telekomgbert-large-paraphrase-cosine')
-    # similarity_predictor = SimilarityPredictor(similarity_model)
-
-    drama_similarity_predictor = DramaSimilarityPredictor(similarity_model, drama)
-
-    # reported_speech_by_ref_predictor = ReportedSpeechByRefPredictor(single_sentence_chunker)
+    sentence_chunker = SentenceChunker(min_length=10, max_length=64, max_sentences=1)
+    similarity_model = SentenceTransformer(similarity_model_path)
 
-    # indi_quo = IndiQuo(reference_predictor, reference_solver, reported_speech_predictor,
-    #                    reported_speech_by_ref_predictor,  similarity_predictor)
+    candidate_predictor = CandidatePredictor(drama, similarity_model, sentence_chunker)
+    scene_predictor = ScenePredictor(drama, similarity_model, 5)
 
-    indi_quo = IndiQuo(single_sentence_chunker, drama_similarity_predictor)
+    indi_quo = IndiQuo(candidate_predictor, scene_predictor)
 
-    link_vectorizer = LinkingVectorizer.from_saved(512, tokenizer_folder_path, lower_case)
+    link_vectorizer = LinkingVectorizer.from_saved(512, 'fredr0id/proquolm', True)
     link_model = transformers.TFBertForSequenceClassification.from_pretrained('fredr0id/proquolm', num_labels=2)
     pro_quo_lm = ProQuoLm(link_model, link_vectorizer)
     quid = Quid(min_match_length=2, keep_ambiguous_matches=True)
@@ -217,36 +111,36 @@ def main(argv=None):
 
     parser_train.add_argument('template_train_file_path', nargs=1, metavar='template-train-file-path',
                               help='Path to the txt file containing the training examples')
-    parser_train.add_argument('val_file_path', nargs=1, metavar='val-file-path',
-                              help='Path to the txt file containing the validation examples')
-    parser_train.add_argument('--output-folder-path', dest='output_folder_path',
-                              help='The output folder path. If this option is set the output will be saved to a file'
-                                   ' created in the specified folder')
+    # parser_train.add_argument('val_file_path', nargs=1, metavar='val-file-path',
+    #                           help='Path to the txt file containing the validation examples')
+    # parser_train.add_argument('--output-folder-path', dest='output_folder_path',
+    #                           help='The output folder path. If this option is set the output will be saved to a file'
+    #                                ' created in the specified folder')
     # parser_train.add_argument('--mlm-train-file-path', dest='mlm_train_file_path',
     #                              help='Path to the txt file containing the training examples', required=False)
-    parser_train.add_argument('--base-model-name', dest="base_model_name",
-                              default="bert-base-german-dbmdz-uncased", help="The model name")
-    parser_train.add_argument('--lower-case', dest="lower_case", default=True, action=BooleanOptionalAction,
-                              help="TBD")
-    parser_train.add_argument('--model-type', choices=['template', 'combined'], dest="model_type",
-                              default="combined", help="The model type")
-    parser_train.add_argument('--loss-type', choices=['mlm', 'weighted'], dest="loss_type",
-                              default="weighted", help="The loss type")
-    parser_train.add_argument('--num-examples', dest="num_examples", default=32, type=int, help="TBD")
-    parser_train.add_argument('--use-sep', dest="use_sep", default=False, action=BooleanOptionalAction,
-                              help="TBD")
-    parser_train.add_argument('--num-masked-tokens', dest="num_masked_tokens", default=10, type=int, help="TBD")
-    parser_train.add_argument('--batch-size', dest="batch_size", default=4, type=int, help="TBD")
-    parser_train.add_argument('--batches-per-update', dest="batches_per_update", default=1, type=int, help="TBD")
+    # parser_train.add_argument('--base-model-name', dest="base_model_name",
+    #                           default="bert-base-german-dbmdz-uncased", help="The model name")
+    # parser_train.add_argument('--lower-case', dest="lower_case", default=True, action=BooleanOptionalAction,
+    #                           help="TBD")
+    # parser_train.add_argument('--model-type', choices=['template', 'combined'], dest="model_type",
+    #                           default="combined", help="The model type")
+    # parser_train.add_argument('--loss-type', choices=['mlm', 'weighted'], dest="loss_type",
+    #                           default="weighted", help="The loss type")
+    # parser_train.add_argument('--num-examples', dest="num_examples", default=32, type=int, help="TBD")
+    # parser_train.add_argument('--use-sep', dest="use_sep", default=False, action=BooleanOptionalAction,
+    #                           help="TBD")
+    # parser_train.add_argument('--num-masked-tokens', dest="num_masked_tokens", default=10, type=int, help="TBD")
+    # parser_train.add_argument('--batch-size', dest="batch_size", default=4, type=int, help="TBD")
+    # parser_train.add_argument('--batches-per-update', dest="batches_per_update", default=1, type=int, help="TBD")
 
     parser_compare = subparsers_command.add_parser('compare', help='', description='')
 
+    parser_compare.add_argument("source_file_path", nargs=1, metavar="source-file-path",
+                                help="Path to the source xml file")
     parser_compare.add_argument('target_path', nargs=1, metavar='target-path',
                                 help='Path to the target text file or folder')
-    parser_compare.add_argument('tokenizer_folder_path', nargs=1, metavar='tokenizer-folder-path',
-                                help='Path to the relation tokenizer folder')
     parser_compare.add_argument('model_folder_path', nargs=1, metavar='model-folder-path',
-                                help='Path to the relation model folder')
+                                help='Path to the similarity model folder')
     parser_compare.add_argument('--output-folder-path', dest="output_folder_path",
                                 help="The output folder path. If this option is set the output will be saved to a file"
                                      " created in the specified folder")
@@ -258,73 +152,32 @@ def main(argv=None):
 
     if args.command == 'train':
         template_train_file_path = args.template_train_file_path[0]
-        val_file_path = args.val_file_path[0]
-        output_folder_path = args.output_folder_path
-        # mlm_train_file_path = args.mlm_train_file_path
-        model_type = args.model_type
-        loss_type = args.loss_type
-        num_training_examples = args.num_examples
-        use_sep = args.use_sep
-        template_batch_size = args.batch_size
-        batches_per_update = args.batches_per_update
-        base_model_name = args.base_model_name
-        lower_case = args.lower_case
-        num_masked_tokens = args.num_masked_tokens
-
-        # TODO: create arguments
-        # template config
-        template_max_length = 160
-        # epochs = 20
-        num_train_steps = 500
-        num_val_examples = 50
-
-        # aux mlm config
-        # mlm_aux = False
-        # mlm_chunk_size = 160
-        # mlm_batch_size = 3
-        # alpha = 0.9
-
-        now = datetime.now()
-        date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S')
-        output_folder_path = join(output_folder_path, date_time_string)
-        output_folder_path += f'_{model_type}_{loss_type}_{num_training_examples}_{num_train_steps}'
-        Path(output_folder_path).mkdir(parents=True, exist_ok=True)
-
-        __train(lower_case, base_model_name, model_type, loss_type, num_masked_tokens, use_sep, num_train_steps,
-                num_training_examples, num_val_examples, template_batch_size, template_max_length, batches_per_update,
-                template_train_file_path, val_file_path, output_folder_path)
+        # val_file_path = args.val_file_path[0]
+        # output_folder_path = args.output_folder_path
+        # # mlm_train_file_path = args.mlm_train_file_path
+        #
+        # now = datetime.now()
+        # date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S')
+        # output_folder_path = join(output_folder_path, date_time_string)
+        # output_folder_path += f'_{model_type}_{loss_type}_{num_training_examples}_{num_train_steps}'
+        # Path(output_folder_path).mkdir(parents=True, exist_ok=True)
+        #
+        # __train(lower_case, base_model_name, model_type, loss_type, num_masked_tokens, use_sep, num_train_steps,
+        #         num_training_examples, num_val_examples, template_batch_size, template_max_length, batches_per_update,
+        #         template_train_file_path, val_file_path, output_folder_path)
 
     elif args.command == 'compare':
-        source_file_path = '/Users/frede/Arbeit/HU/PDFExtraction/Dramen/Dantons_Tod/drama.xml'
-        # source_file_path = '/Users/frede/Arbeit/HU/PDFExtraction/Dramen/Iphigenie_auf_Tauris/drama.xml'
+        source_file_path = args.source_file_path[0]
         target_path = args.target_path[0]
-        tokenizer_folder_path = args.tokenizer_folder_path[0]
         model_folder_path = args.model_folder_path[0]
         output_folder_path = args.output_folder_path
 
-        max_length = 160
-        lower_case = True
-        mask_count = 10
-        use_sep = False
-        left_to_right = True
-        average_mode = 'all'
-        only_next_token = True
-
-        templates = [
-            "[MASK] bezeichnet einen Teil des Dramas.",
-            "[MASK] bezieht sich auf einen Teil des Stücks.",
-            "Referenz auf das Drama? [MASK]",
-            "Was bezeichnet einen Handlungsabschnitt? [MASK]",
-            "[MASK] bezeichnet einen Handlungsabschnitt."
-        ]
-
         now = datetime.now()
         date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S')
         output_folder_path = join(output_folder_path, date_time_string)
         Path(output_folder_path).mkdir(parents=True, exist_ok=True)
 
-        __run_compare(source_file_path, target_path, tokenizer_folder_path, model_folder_path, output_folder_path, max_length, lower_case, mask_count,
-                      use_sep, templates, left_to_right, average_mode, only_next_token)
+        __run_compare(source_file_path, target_path, model_folder_path, output_folder_path)
 
 
 if __name__ == '__main__':
diff --git a/indiquo/core/reference/Reference.py b/indiquo/core/Candidate.py
similarity index 83%
rename from indiquo/core/reference/Reference.py
rename to indiquo/core/Candidate.py
index d9122e26ab7430f9c5be2b62f55625d2d2482b3f..a7a547038d393a610b5e3c1f41ad44b5b67c685f 100644
--- a/indiquo/core/reference/Reference.py
+++ b/indiquo/core/Candidate.py
@@ -2,7 +2,7 @@ from dataclasses import dataclass
 
 
 @dataclass
-class Reference:
+class Candidate:
     start: int
     end: int
     text: str
diff --git a/indiquo/core/CandidatePredictor.py b/indiquo/core/CandidatePredictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..12862d506392368dd468b1c1396a747b11e6ab1d
--- /dev/null
+++ b/indiquo/core/CandidatePredictor.py
@@ -0,0 +1,120 @@
+from typing import Tuple, List
+from sentence_transformers import util
+import re
+
+from dramatist.Drama import Drama
+from indiquo.core import Util
+from indiquo.core.Candidate import Candidate
+
+from indiquo.core.chunker.BaseChunker import BaseChunker
+
+
+# noinspection PyMethodMayBeStatic
+class CandidatePredictor:
+    SIMILARITY_THRESHOLD = 0.78
+
+    def __init__(self, drama: Drama, model, chunker: BaseChunker):
+        self.drama = drama
+        self.model = model
+        self.chunker = chunker
+        self.all_text_blocks = []
+        self.source_text_blocks = []
+
+        for act_nr, act in enumerate(drama.acts):
+            for scene_nr, scene in enumerate(act.scenes):
+                text_blocks = scene.get_text_in_blocks(128)
+
+                for tbt in text_blocks:
+                    self.all_text_blocks.append((act_nr, scene_nr, tbt.text))
+                    self.source_text_blocks.append(tbt.text)
+
+        self.source_embeddings = model.encode(self.source_text_blocks, convert_to_tensor=True)
+
+    def get_candidates(self, target_text, direct_quotes) -> List[Candidate]:
+        fn_ranges, fn_ranges_with_offset = self.__get_footnote_ranges(target_text)
+        target_text_wo_fn: str = self.__remove_footnotes(target_text)
+        chunks = self.chunker.chunk(target_text_wo_fn)
+
+        for chunk in chunks:
+            start = chunk.start
+            end = chunk.end
+            real_start, real_end = self.__map_to_real_pos(start, end, fn_ranges_with_offset)
+            chunk.start = real_start
+            chunk.end = real_end
+
+        filtered_chunks = []
+        for chunk in chunks:
+            found_match = False
+            for dq in direct_quotes:
+                overlap_length = Util.calculate_overlap(chunk.start, chunk.end, dq.target_span.start,
+                                                        dq.target_span.end)
+
+                if overlap_length > 0:
+                    found_match = True
+                    break
+
+            if not found_match:
+                filtered_chunks.append(chunk)
+
+        candidates: List[Candidate] = []
+        for chunk in filtered_chunks:
+            sim_result = self.__predict(chunk.text)
+            if sim_result:
+                candidates.append(Candidate(chunk.start, chunk.end, chunk.text))
+
+        return candidates
+
+    def __predict(self, target_text):
+        target_embedding = self.model.encode([target_text], convert_to_tensor=True)
+        hits = util.semantic_search(target_embedding, self.source_embeddings, top_k=1)[0]
+
+        scene_scores = []
+        for hit in hits:
+            idx = hit['corpus_id']
+            score = hit['score']
+            act_nr = self.all_text_blocks[idx][0]
+            scene_nr = self.all_text_blocks[idx][1]
+            text = self.all_text_blocks[idx][2]
+            start_line, end_line = self.drama.acts[act_nr].scenes[scene_nr].get_line_range()
+            scene_scores.append((start_line, end_line, score, text))
+
+        if scene_scores[0][2] >= self.SIMILARITY_THRESHOLD:
+            return scene_scores[0]
+
+    def __get_footnote_ranges(self, input_text: str) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]:
+        """
+        Takes a text and returns a list of tuples of start and end character positions of footnote ranges.
+        :param input_text: The input text
+        :return: A list of tuples of start and end character positions of footnote ranges
+        """
+        result: List[Tuple[int, int]] = []
+        result_with_offset: List[Tuple[int, int]] = []
+
+        offset = 0
+        for re_match in re.finditer(r'\[\[\[((?:.|\n)+?)]]]', input_text):
+            start = re_match.start()
+            end = re_match.end()
+            result.append((start, end))
+            result_with_offset.append((start - offset, end - offset))
+            offset += end - start
+
+        return result, result_with_offset
+
+    def __remove_footnotes(self, input_text: str):
+        result_text = re.sub(r'\[\[\[((?:.|\n)+?)]]]', '', input_text)
+        return result_text
+
+    def __map_to_real_pos(self, start, end, fn_ranges):
+        start_offset = 0
+        end_offset = 0
+
+        for fn_range in fn_ranges:
+            if fn_range[0] < start:
+                start_offset += fn_range[1] - fn_range[0]
+                end_offset += fn_range[1] - fn_range[0]
+            elif fn_range[0] < end:
+                end_offset += fn_range[1] - fn_range[0]
+            else:
+                break
+
+        return start + start_offset, end + end_offset
diff --git a/indiquo/core/IndiQuo.py b/indiquo/core/IndiQuo.py
index 762eb84965b8dd1991c008d03f958a3cdf6b194e..08b5fd168a433cd20b12390271bf3b8a7c550e7c 100644
--- a/indiquo/core/IndiQuo.py
+++ b/indiquo/core/IndiQuo.py
@@ -1,364 +1,23 @@
-from typing import Tuple, List, Optional
-import re
-
-from dramatist.Drama import Drama
-from indiquo.core import Util
-from indiquo.core.chunker.ChunkRef import ChunkRef
-from indiquo.core.reference.ReferencePrediction import ReferencePrediction
+from typing import List
+from indiquo.core.CandidatePredictor import CandidatePredictor
+from indiquo.core.Candidate import Candidate
+from indiquo.core.ScenePredictor import ScenePredictor
 from indiquo.match.Match import Match
 
 
 # noinspection PyMethodMayBeStatic
 class IndiQuo:
 
-    MAX_REFERENCE_DIST = 10
-    MAX_QUOTE_DIST = 5
-
-    def __init__(self, sentence_chunker, similarity_predictor):
-        # self.reference_predictor = reference_predictor
-        # self.reference_solver = reference_solver
-        # self.reported_speech_predictor = reported_speech_predictor
-        # self.reported_speech_by_ref_predictor = reported_speech_by_ref_predictor
-        self.sentence_chunker = sentence_chunker
-        self.similarity_predictor = similarity_predictor
-
-    def compare(self, drama: Drama, target_text: str, pro_quo_matches) -> List[Match]:
-        """
-        Compare TBD
-        :param drama: TBD
-        :param target_text:
-        :param pro_quo_matches:
-        :return:
-        """
-        fn_ranges, fn_ranges_with_offset = self.__get_footnote_ranges(target_text)
-        target_text_wo_fn: str = self.__remove_footnotes(target_text)
-
-        # reference_candidates = self.reference_predictor.predict(target_text_wo_fn)
-
-        # reported_speech_candidates = self.reported_speech_predictor.predict(target_text_wo_fn)
-        reported_speech_candidates = self.sentence_chunker.chunk(target_text_wo_fn)
-
-        # reported_speech_by_ref_candidates = self.reported_speech_by_ref_predictor.predict(target_text_wo_fn,
-        #                                                                                   reference_candidates)
-        # for rc in reference_candidates:
-        #     start = rc.references[0].start
-        #     end = rc.references[0].end
-        #     real_start, real_end = self.__map_to_real_pos(start, end, fn_ranges_with_offset)
-        #     rc.references[0].start = real_start
-        #     rc.references[0].end = real_end
-        #
-        # reference_candidates.sort(key=lambda x: x.references[0].start)
-
-        for rsc in reported_speech_candidates:
-            start = rsc.start
-            end = rsc.end
-            real_start, real_end = self.__map_to_real_pos(start, end, fn_ranges_with_offset)
-            rsc.start = real_start
-            rsc.end = real_end
-
-        filtered_candidates = []
-        for rsc in reported_speech_candidates:
-            found_match = False
-            for match in pro_quo_matches:
-                overlap_length = Util.calculate_overlap(rsc.start, rsc.end, match.target_span.start,
-                                                        match.target_span.end)
-
-                if overlap_length > 0:
-                    found_match = True
-                    break
-
-            if not found_match:
-                filtered_candidates.append(rsc)
-
-        result_filtered = []
-        for rsc in filtered_candidates:
-            sim_result = self.similarity_predictor.predict(rsc.text)
-            if sim_result:
-                result_filtered.append(Match(rsc.start, rsc.end, rsc.text, -1, -1, '', -1, -1, '', -1, -1, None, []))
-
-        result = result_filtered
-        # current = None
-        # for next_r in result_filtered:
-        #     if not current:
-        #         current = next_r
-        #     else:
-        #         if current.target_end == next_r.target_start:
-        #             current = Match(current.target_start, next_r.target_end, current.target_text + ' ' + next_r.target_text,
-        #                             -1, -1, '', -1, -1, '', -1, -1, None, [])
-        #         else:
-        #             result.append(current)
-        #             current = next_r
-        #
-        # if current:
-        #     result.append(current)
-
-        # for rsc in reported_speech_by_ref_candidates:
-        #     start = rsc.start
-        #     end = rsc.end
-        #     real_start, real_end = self.__map_to_real_pos(start, end, fn_ranges_with_offset)
-        #     rsc.start = real_start
-        #     rsc.end = real_end
-        #
-        #     for ref in rsc.references:
-        #         ref_real_start, ref_real_end = self.__map_to_real_pos(ref.start, ref.end, fn_ranges_with_offset)
-        #         ref.start = ref_real_start
-        #         ref.end = ref_real_end
-        #
-        # reported_speech_by_ref_candidates = self.__filter_candidates(reported_speech_by_ref_candidates, pro_quo_matches)
-        #
-        # result = []
-        # for rsc in reported_speech_candidates:
-        #     candidate_text = rsc.text
-        #     overlapping_matches = []
-        #
-        #     for match in pro_quo_matches:
-        #         overlap_length = Util.calculate_overlap(rsc.start, rsc.end, match.target_span.start,
-        #                                                 match.target_span.end)
-        #
-        #         if overlap_length > 0:
-        #             match_length = len(match.target_span.text.split(' '))
-        #             if match_length >= 5:
-        #                 overlapping_matches.append(match)
-        #
-        #     direct_quote_start = -1
-        #     direct_quote_end = -1
-        #     direct_quote_text = ''
-        #
-        #     direct_quote_act = -1
-        #     direct_quote_scene = -1
-        #
-        #     if len(overlapping_matches) > 0:
-        #         if len(overlapping_matches) == 1:
-        #             match = overlapping_matches[0]
-        #             direct_quote_act, direct_quote_scene = drama.get_scene_act_for_position(match.source_span.start)
-        #
-        #             overlap_start = max(rsc.start, match.target_span.start)
-        #             overlap_end = min(rsc.end, match.target_span.end)
-        #             candidate_text = rsc.text[0:overlap_start-rsc.start] + rsc.text[overlap_end-rsc.start:]
-        #
-        #             if not candidate_text:
-        #                 continue
-        #
-        #             direct_quote_start = match.target_span.start
-        #             direct_quote_end = match.target_span.end
-        #             direct_quote_text = match.target_span.text
-        #         else:
-        #             print(f'Too many overlapping matches: {len(overlapping_matches)}\n{rsc}')
-        #
-        #     best_reference = self.__get_best_references(target_text, reference_candidates, rsc, fn_ranges,
-        #                                                 pro_quo_matches)
-        #
-        #     if candidate_text and best_reference:
-        #         # TODO: use verse, but how to match up lines and verses?
-        #         act_nr, scene_nr, _ = self.reference_solver.solve(best_reference.references[0].text)
-        #
-        #         ref_start = best_reference.references[0].start
-        #         ref_end = best_reference.references[0].end
-        #         ref_text = best_reference.references[0].text
-        #
-        #         if act_nr != -1 and scene_nr != -1:
-        #             if direct_quote_act != -1 and direct_quote_scene != -1:
-        #                 if act_nr != direct_quote_act or scene_nr != direct_quote_scene:
-        #                     print(f'Ref and Quote do not match!\n{rsc}')
-        #
-        #             heatmap = self.similarity_predictor.predict(drama, candidate_text, act_nr, scene_nr)
-        #             result.append(Match(rsc.start, rsc.end, candidate_text, ref_start, ref_end, ref_text,
-        #                                 direct_quote_start, direct_quote_end, direct_quote_text,
-        #                                 act_nr, scene_nr, heatmap))
-        #         else:
-        #             result.append(Match(rsc.start, rsc.end, candidate_text, ref_start, ref_end, ref_text,
-        #                                 direct_quote_start, direct_quote_end, direct_quote_text,
-        #                                 -1, -1, None))
-        #
-        # to_delete = set()
-        # for rsbrc in reported_speech_by_ref_candidates:
-        #     # TODO: Ignore candidates with multiple references. Or use only candidates with one reference which is not after a quote
-        #
-        #     ref_start = rsbrc.references[0].start
-        #     ref_end = rsbrc.references[0].end
-        #     ref_text = rsbrc.references[0].text
-        #
-        #     act_nr, scene_nr, _ = self.reference_solver.solve(ref_text)
-        #     heatmap = self.similarity_predictor.predict(drama, rsbrc.text, act_nr, scene_nr)
-        #
-        #     ne = Match(rsbrc.start, rsbrc.end, rsbrc.text, ref_start, ref_end, ref_text, -1, -1, '', act_nr, scene_nr,
-        #                heatmap)
-        #     for pos, r in enumerate(result):
-        #         if pos in to_delete:
-        #             continue
-        #
-        #         overlap_length = Util.calculate_overlap(ne.target_start, ne.target_end, r.target_start, r.target_end)
-        #
-        #         if overlap_length > 0:
-        #             to_delete.add(pos)
-        #             new_start = min(ne.target_start, r.target_start)
-        #             new_end = max(ne.target_end, r.target_end)
-        #             new_text = target_text[new_start:new_end]
-        #             ne = Match(new_start, new_end, new_text, -1, -1, '', -1, -1, '', r.act, r.scene, r.source_heatmap)
-        #
-        #     result.append(ne)
-        #
-        # result_updated = []
-        # for pos, r in enumerate(result):
-        #     if pos not in to_delete:
-        #         result_updated.append(r)
-        #
-        # for r in result_updated:
-        #     if not r.source_heatmap:
-        #         continue
-        #
-        #     ranges = []
-        #     for line in r.source_heatmap.lines:
-        #         if line.score > 0.5:
-        #             ranges.append((line.start, line.end))
-        #
-        #     r.line_ranges = ranges
-
-        return result
-
-    def __get_best_references(self, input_text, references, reported_speech_candidate, footnote_ranges: List,
-                              proquo_matches) -> Optional[ReferencePrediction]:
-        # TODO: rewrite and improve, currently not really ideal
-        # TODO: how to define 'best reference'?
-
-        speech_start = reported_speech_candidate.start
-        speech_end = reported_speech_candidate.end
-
-        for reference in references:
-            ref_start = reference.references[0].start
-            ref_end = reference.references[0].end
-
-            if ref_end < speech_start:
-                continue
-
-            if ref_start >= speech_start and ref_end <= speech_end:
-                found = False
-                for m in proquo_matches:
-                    m_start = m.target_span.start
-                    m_end = m.target_span.end
-
-                    overlap_q_ref = Util.calculate_overlap(m_start, m_end, ref_start, ref_end)
-                    if overlap_q_ref > 0:
-                        found = True
-                        break
-
-                    d_quote_dist = ref_start - m_end
-
-                    if 0 <= d_quote_dist <= self.MAX_QUOTE_DIST:
-                        found = True
-                        break
-
-                if found:
-                    continue
+    def __init__(self, candidate_predictor: CandidatePredictor, scene_predictor: ScenePredictor):
+        self.candidate_predictor = candidate_predictor
+        self.scene_predictor = scene_predictor
 
-                return reference
+    def compare(self, target_text: str, direct_quotes) -> List[Match]:
+        result: List[Match] = []
+        candidates: List[Candidate] = self.candidate_predictor.get_candidates(target_text, direct_quotes)
 
-            words_in_fn_count = self.__count_words_in_footnotes(input_text, footnote_ranges, speech_end, ref_start)
-            dist = len(input_text[speech_end:ref_start].split()) - words_in_fn_count
-
-            if dist > self.MAX_REFERENCE_DIST:
-                continue
-
-            found = False
-            for m in proquo_matches:
-                m_end = m.target_span.end
-                d_quote_dist = ref_start - m_end
-
-                if 0 <= d_quote_dist <= self.MAX_QUOTE_DIST:
-                    found = True
-                    break
-
-            if found:
-                continue
-
-            return reference
-
-        return None
-
-    def __count_words_in_footnotes(self, input_text, footnote_ranges, start, end):
-        count = 0
-
-        for fr in footnote_ranges:
-            if start <= fr[0] < end:
-                text = input_text[fr[0]:fr[1]]
-                count += len(text.split())
-
-        return count
-
-    def __get_footnote_ranges(self, input_text: str) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]:
-        """
-        Takes a text and returns a list of tuples of start and end character positions of footnote ranges.
-        :param input_text: The input text
-        :return: A list of tuples of start and end character positions of footnote ranges
-        """
-        result: List[Tuple[int, int]] = []
-        result_with_offset: List[Tuple[int, int]] = []
-
-        offset = 0
-        for re_match in re.finditer(r'\[\[\[((?:.|\n)+?)]]]', input_text):
-            start = re_match.start()
-            end = re_match.end()
-            result.append((start, end))
-            result_with_offset.append((start - offset, end - offset))
-            offset += end - start
-
-        return result, result_with_offset
-
-    def __remove_footnotes(self, input_text: str):
-        result_text = re.sub(r'\[\[\[((?:.|\n)+?)]]]', '', input_text)
-        return result_text
-
-    def __map_to_real_pos(self, start, end, fn_ranges):
-        start_offset = 0
-        end_offset = 0
-
-        for fn_range in fn_ranges:
-            if fn_range[0] < start:
-                start_offset += fn_range[1] - fn_range[0]
-                end_offset += fn_range[1] - fn_range[0]
-            elif fn_range[0] < end:
-                end_offset += fn_range[1] - fn_range[0]
-            else:
-                break
-
-        return start + start_offset, end + end_offset
-
-    def __filter_candidates(self, candidates, proquo_matches):
-        # check for a reference but not right after a quote
-        # TODO: how to handle multiple references?
-        result: List[ChunkRef] = []
         for candidate in candidates:
-            matches = []
-
-            for m in proquo_matches:
-                overlap_start = max(candidate.start, m.target_span.start)
-                overlap_end = min(candidate.end, m.target_span.end)
-                overlap_length = overlap_end - overlap_start
-
-                if overlap_length > 0:
-                    matches.append(m)
-
-            if len(matches) == 0:
-                result.append(candidate)
-            else:
-                for ref in candidate.references:
-                    found = False
-                    for m in matches:
-                        m_start = m.target_span.start
-                        m_end = m.target_span.end
-
-                        overlap_q_ref = Util.calculate_overlap(m_start, m_end, ref.start, ref.end)
-                        if overlap_q_ref > 0:
-                            found = True
-                            break
-
-                        dist = ref.start - m_end
-                        if 0 <= dist <= self.MAX_QUOTE_DIST:
-                            found = True
-                            break
-
-                    if not found:
-                        result.append(candidate)
-                        break
+            scene_predictions = self.scene_predictor.predict_scene(candidate.text)
+            result.append(Match(candidate.start, candidate.end, candidate.text, scene_predictions))
 
         return result
diff --git a/indiquo/core/similarity/Line.py b/indiquo/core/ScenePrediction.py
similarity index 54%
rename from indiquo/core/similarity/Line.py
rename to indiquo/core/ScenePrediction.py
index 3754b841f885e8a1608519b686cd18848d20624e..1b8bd5c06d83da87e5178ac89a9585ca89a5babb 100644
--- a/indiquo/core/similarity/Line.py
+++ b/indiquo/core/ScenePrediction.py
@@ -2,8 +2,7 @@ from dataclasses import dataclass
 
 
 @dataclass
-class Line:
-    start: int
-    end: int
-    text: str
+class ScenePrediction:
+    act: int
+    scene: int
     score: float
diff --git a/indiquo/core/similarity/DramaSimilarityPredictor.py b/indiquo/core/ScenePredictor.py
similarity index 60%
rename from indiquo/core/similarity/DramaSimilarityPredictor.py
rename to indiquo/core/ScenePredictor.py
index 380e4f7748148ce93da276d78f31e1b5dbaa795b..bab8b73d60246389d21127b304b60fb2325b9d5f 100644
--- a/indiquo/core/similarity/DramaSimilarityPredictor.py
+++ b/indiquo/core/ScenePredictor.py
@@ -1,13 +1,15 @@
+from dramatist.Drama import Drama
 from sentence_transformers import util
 
+from indiquo.core.ScenePrediction import ScenePrediction
 
-class DramaSimilarityPredictor:
 
-    SIMILARITY_THRESHOLD = 0.78
+class ScenePredictor:
 
-    def __init__(self, model, drama):
-        self.model = model
+    def __init__(self, drama: Drama, model, top_k):
         self.drama = drama
+        self.model = model
+        self.top_k = top_k
         self.all_text_blocks = []
         self.source_text_blocks = []
 
@@ -21,19 +23,16 @@ class DramaSimilarityPredictor:
 
         self.source_embeddings = model.encode(self.source_text_blocks, convert_to_tensor=True)
 
-    def predict(self, target_text):
-        target_embedding = self.model.encode([target_text], convert_to_tensor=True)
-        hits = util.semantic_search(target_embedding, self.source_embeddings, top_k=1)[0]
+    def predict_scene(self, text):
+        target_embedding = self.model.encode([text], convert_to_tensor=True)
+        hits = util.semantic_search(target_embedding, self.source_embeddings, top_k=self.top_k)
 
-        scene_scores = []
+        predictions = []
         for hit in hits:
             idx = hit['corpus_id']
             score = hit['score']
             act_nr = self.all_text_blocks[idx][0]
             scene_nr = self.all_text_blocks[idx][1]
-            text = self.all_text_blocks[idx][2]
-            start_line, end_line = self.drama.acts[act_nr].scenes[scene_nr].get_line_range()
-            scene_scores.append((start_line, end_line, score, text))
+            predictions.append(ScenePrediction(act_nr, scene_nr, score))
 
-        if scene_scores[0][2] >= self.SIMILARITY_THRESHOLD:
-            return scene_scores[0]
+        return predictions
diff --git a/indiquo/core/chunker/ChunkRef.py b/indiquo/core/chunker/ChunkRef.py
deleted file mode 100644
index 89237163fbab51879f68c22cea5fbb3a3adbaf78..0000000000000000000000000000000000000000
--- a/indiquo/core/chunker/ChunkRef.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from dataclasses import dataclass
-from typing import List
-
-from indiquo.core.chunker.Chunk import Chunk
-from indiquo.core.reference.Reference import Reference
-from indiquo.core.reference.ReferencePrediction import ReferencePrediction
-
-
-@dataclass
-class ChunkRef(Chunk):
-
-    references: List[Reference]
-
-    def __init__(self, start, end, text, references):
-        super().__init__(start, end, text)
-        self.references = references
diff --git a/indiquo/core/reference/BaseReferencePredictor.py b/indiquo/core/reference/BaseReferencePredictor.py
deleted file mode 100644
index bb718e7b3d59bec741d5dc87160b9fba2d429ba4..0000000000000000000000000000000000000000
--- a/indiquo/core/reference/BaseReferencePredictor.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import List
-from indiquo.core.reference.ReferencePrediction import ReferencePrediction
-
-
-class BaseReferencePredictor(ABC):
-
-    @abstractmethod
-    def predict(self, text: str) -> List[ReferencePrediction]:
-        pass
diff --git a/indiquo/core/reference/ReferencePrediction.py b/indiquo/core/reference/ReferencePrediction.py
deleted file mode 100644
index 2f5fbc752ddabaae2a45b7878ca10cf81f7ceb5c..0000000000000000000000000000000000000000
--- a/indiquo/core/reference/ReferencePrediction.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from dataclasses import dataclass
-from typing import List
-
-from indiquo.core.reference.Reference import Reference
-
-
-@dataclass
-class ReferencePrediction:
-    predicted_text: str
-    references: List[Reference]
-
-    @property
-    def reference(self) -> Reference:
-        return self.references[0]
diff --git a/indiquo/core/reference/ReferencePredictor.py b/indiquo/core/reference/ReferencePredictor.py
deleted file mode 100644
index 6689e379e0d3694488f695d435f369aee12c523f..0000000000000000000000000000000000000000
--- a/indiquo/core/reference/ReferencePredictor.py
+++ /dev/null
@@ -1,339 +0,0 @@
-import re
-import warnings
-from typing import List
-
-import numpy as np
-from difflib import SequenceMatcher
-import tensorflow as tf
-
-from indiquo.core.chunker.BaseChunker import BaseChunker
-from indiquo.core.reference.BaseReferencePredictor import BaseReferencePredictor
-from indiquo.core.reference.ReferencePrediction import ReferencePrediction
-from indiquo.core.reference.Reference import Reference
-
-
-class ReferencePredictor(BaseReferencePredictor):
-
-    def __init__(self, chunker: BaseChunker, ref_models, ref_vectorizer, templates, single_model, left_to_right,
-                 mask_count, average_mode,
-                 only_next_token):
-        self.chunker = chunker
-        self.ref_models = ref_models
-        self.ref_vectorizer = ref_vectorizer
-        self.templates = templates
-        self.single_model = single_model
-        self.left_to_right = left_to_right
-        self.mask_count = mask_count
-        self.average_mode = average_mode
-        self.only_next_token = only_next_token
-
-        if only_next_token and not left_to_right:
-            raise Exception('\'Only next token\' only works left to right!')
-
-    # overriding abstract method
-    def predict(self, text) -> List[ReferencePrediction]:
-        sentences = self.chunker.chunk(text)
-        result = []
-        for s in sentences:
-            ref_prediction = self.predict_sentence(s.text)
-            start = ref_prediction.references[0].start
-            end = ref_prediction.references[0].end
-
-            if start != -1 and end != -1:
-                global_start = s.start + start
-                global_end = s.start + end
-                ref_prediction.references[0].start = global_start
-                ref_prediction.references[0].end = global_end
-                result.append(ref_prediction)
-
-        return result
-
-    def predict_sentence(self, context) -> ReferencePrediction:
-        none_id = self.ref_vectorizer.none_id
-        sep_id = self.ref_vectorizer.sep_id
-        pad_id = self.ref_vectorizer.pad_id
-        mask_token_id = self.ref_vectorizer.tokenizer.mask_token_id
-
-        input_context_ids = self.ref_vectorizer.tokenizer.encode(context)[1:-1]
-
-        input_strs = []
-        for template_pos, template in enumerate(self.templates):
-            if not self.single_model and template_pos >= len(self.ref_models):
-                break
-
-            input_strs.append(context + ' ' + template)
-
-        inputs_list = self.ref_vectorizer.vectorize_inference(input_strs)
-
-        # if there are no masks, the text was too long and we cannot do anything but warn
-        for input_ids in inputs_list['input_ids']:
-            mask_token_indexes = np.argwhere(input_ids == mask_token_id)
-            if len(mask_token_indexes) < self.mask_count:
-                warnings.warn('Too few masks found! Text is probably too long!')
-                return ReferencePrediction('<none>', [Reference(-1, -1, '<none>')])
-
-        remaining_ids = input_context_ids.copy()
-        remaining_ids += [none_id]
-
-        if self.left_to_right:
-            predicted_ids = []
-        else:
-            predicted_ids = {}
-
-        count = 0
-        after_sep = False
-        while count < self.mask_count:
-            count += 1
-
-            best_token_index, best_token_id, best_mask_token_pos = (
-                self.__get_top_tokens(inputs_list, input_context_ids, remaining_ids, predicted_ids, pad_id, none_id,
-                                      sep_id, mask_token_id))
-
-            if self.left_to_right:
-                predicted_ids.append(best_token_id)
-            else:
-                mask_token_indexes = np.argwhere(inputs_list['input_ids'][0] == mask_token_id)
-                mask_token_index = mask_token_indexes[best_mask_token_pos][0]
-                predicted_ids[mask_token_index] = best_token_id
-
-            # early stopping
-            # TODO: add early stopping for predictions with <sep>
-            if self.left_to_right and sep_id == -1 and (best_token_id == none_id or best_token_id == pad_id):
-                break
-
-            for inputs in inputs_list['input_ids']:
-                mask_token_indexes = np.argwhere(inputs == mask_token_id)
-                mask_token_index = mask_token_indexes[best_mask_token_pos][0]
-                inputs[mask_token_index] = best_token_id
-
-            if best_token_id != pad_id:
-                assert remaining_ids[best_token_index] != pad_id
-
-            if sep_id == -1:
-                if count == 1:
-                    if none_id in remaining_ids:
-                        del remaining_ids[remaining_ids.index(none_id)]
-
-                    if pad_id not in remaining_ids:
-                        remaining_ids += [pad_id]
-
-                if best_token_id != pad_id and best_token_id != none_id:
-                    del remaining_ids[best_token_index]
-            else:
-                # TODO: this needs proper testing
-                if best_token_id != pad_id:
-                    del remaining_ids[best_token_index]
-
-                if after_sep:
-                    remaining_ids += [pad_id]
-
-                if none_id in remaining_ids:
-                    del remaining_ids[remaining_ids.index(none_id)]
-
-                if count == 1:
-                    if sep_id not in remaining_ids:
-                        remaining_ids += [sep_id]
-
-                elif count >= 2 and best_token_id == sep_id:
-                    remaining_ids += [none_id]
-                    after_sep = True
-
-        if not self.left_to_right:
-            keys = sorted(predicted_ids.keys())
-            temp = []
-            for key in keys:
-                temp.append(predicted_ids[key])
-            predicted_ids = temp
-
-        predicted_text = self.ref_vectorizer.tokenizer.decode(predicted_ids)
-        positions = self.__get_positions_in_context(context, predicted_ids, none_id, pad_id, sep_id)
-
-        references = []
-
-        for pos in positions:
-            start = pos[0]
-            end = pos[1]
-            text = pos[2]
-            ref = Reference(start, end, text)
-            references.append(ref)
-
-        return ReferencePrediction(predicted_text, references)
-
-    def expand_templates(self):
-        expanded_templates = []
-
-        for template in self.templates:
-            template = template.replace('[MASK]', ' '.join(['[MASK]'] * self.mask_count))
-            expanded_templates.append(template)
-
-        self.templates = expanded_templates
-
-    def __get_top_tokens(self, inputs_list, context_ids, remaining_ids, predicted_ids, pad_id, none_id, sep_id,
-                         mask_token_id):
-        token_logits_lists = []
-
-        if self.single_model:
-            model = self.ref_models[0]
-            model_output = model(inputs_list)
-            token_logits_lists = model_output.logits
-        else:
-            assert len(self.ref_models) == len(inputs_list['input_ids'])
-            for pos, model in enumerate(self.ref_models):
-                input_ids = inputs_list['input_ids'][pos:pos+1]
-                attention_mask = inputs_list['attention_mask'][pos:pos+1]
-                token_type_ids = inputs_list['token_type_ids'][pos:pos+1]
-
-                token_logits = model(input_ids, attention_mask, token_type_ids).logits
-                token_logits_lists.append(token_logits[0])
-            token_logits_lists = tf.convert_to_tensor(token_logits_lists)
-
-        probs_per_input = [[] for _ in range(len(inputs_list['input_ids']))]
-
-        for pos, (token_logits, input_ids) in enumerate(zip(token_logits_lists, inputs_list['input_ids'])):
-            mask_token_indexes = np.argwhere(input_ids == mask_token_id)
-
-            if self.left_to_right:
-                mask_token_index = mask_token_indexes[0][0]
-                mask_token_logits = token_logits[mask_token_index, :]
-
-                if self.average_mode == 'all':
-                    probs = tf.nn.softmax(mask_token_logits).numpy()
-                    probs = np.take(probs, remaining_ids)
-                    probs_per_input[pos].append(probs)
-                else:
-                    probs = tf.nn.softmax(np.take(mask_token_logits, remaining_ids)).numpy()
-                    probs_per_input[pos].append(probs)
-
-            else:
-                for item in mask_token_indexes:
-                    mask_token_index = item[0]
-                    mask_token_logits = token_logits[mask_token_index, :]
-                    probs = tf.nn.softmax(mask_token_logits).numpy()
-                    probs = np.take(probs, remaining_ids)
-                    probs_per_input[pos].append(probs)
-
-        probs_per_mask = [list(i) for i in zip(*probs_per_input)]
-
-        if self.left_to_right:
-            best_mask_token_pos = 0
-            probs_list = probs_per_mask[0]
-        else:
-            probs_max_per_input = []
-
-            for my_input in probs_per_input:
-                max_values = []
-                for mask in my_input:
-                    max_values.append(np.max(mask))
-                probs_max_per_input.append(max_values)
-
-            probs_max_per_mask = [list(i) for i in zip(*probs_max_per_input)]
-            trans_ave = [sum(i) / len(i) for i in probs_max_per_mask]
-
-            best_mask_token_pos = trans_ave.index(max(trans_ave))
-            probs_list = probs_per_mask[best_mask_token_pos]
-
-        averaged_probs = np.mean(probs_list, axis=0)
-
-        if (self.only_next_token and len(predicted_ids) > 0 and predicted_ids[-1] != none_id and
-                predicted_ids[-1] != pad_id and predicted_ids[-1] != sep_id):
-
-            # only use most recent part which only consists of actual words
-            part_predicted_ids = []
-            for cid in reversed(predicted_ids):
-                if cid == none_id or cid == pad_id or cid == sep_id:
-                    break
-
-                part_predicted_ids.insert(0, cid)
-
-            matches = SequenceMatcher(None, part_predicted_ids, context_ids, autojunk=False).get_matching_blocks()
-            valid_ids = []
-
-            if pad_id in remaining_ids:
-                valid_ids.append(pad_id)
-
-            if sep_id in remaining_ids:
-                valid_ids.append(sep_id)
-
-            if none_id in remaining_ids:
-                valid_ids.append(none_id)
-
-            for m in matches:
-                # we are only interested in matches which have the same length as the predicted ids
-                if m.size < len(part_predicted_ids):
-                    continue
-
-                if m.b + m.size < len(context_ids):
-                    next_id = context_ids[m.b + m.size]
-                    valid_ids.append(next_id)
-
-            assert len(valid_ids) > 0
-
-            best_token_index = -1
-            best_token_id = -1
-            highest_prob = 0
-
-            for v_id in valid_ids:
-                if v_id not in remaining_ids:
-                    continue
-                index = remaining_ids.index(v_id)
-
-                assert remaining_ids[index] == v_id
-                prob = averaged_probs[index]
-
-                if prob > highest_prob:
-                    best_token_index = index
-                    best_token_id = v_id
-                    highest_prob = prob
-
-        else:
-            top_k = tf.math.top_k(averaged_probs, k=1)
-            best_token_index = top_k.indices[0]
-            best_token_id = remaining_ids[best_token_index]
-
-        return best_token_index, best_token_id, best_mask_token_pos
-
-    def __get_positions_in_context(self, context, predicted_ids, none_id, pad_id, sep_id):
-
-        if none_id in predicted_ids:
-            return [(-1, -1, '<none>')]
-
-        sub_sequences = []
-        sub_sequence = []
-
-        for pid in predicted_ids:
-            if pid == pad_id:
-                sub_sequences.append(sub_sequence)
-                sub_sequence = []
-                break
-            elif pid == sep_id:
-                sub_sequences.append(sub_sequence)
-                sub_sequence = []
-            else:
-                sub_sequence.append(pid)
-
-        if len(sub_sequence) > 0:
-            sub_sequences.append(sub_sequence)
-
-        result = []
-
-        for ss in sub_sequences:
-            if len(ss) == 0:
-                continue
-
-            context_ids = self.ref_vectorizer.tokenizer(context)
-            match = SequenceMatcher(None, context_ids['input_ids'], ss, autojunk=False).find_longest_match()
-
-            start = context_ids.token_to_chars(match.a).start
-            end = context_ids.token_to_chars(match.a + match.size - 1).end
-            text = context[start:end]
-
-            if len(text) == 1:
-                continue
-
-            if re.search(r'\w', text):
-                result.append((start, end, text))
-
-        if len(result) == 0:
-            return [(-1, -1, '<none>')]
-
-        return result
diff --git a/indiquo/core/reference/ReferenceSolver.py b/indiquo/core/reference/ReferenceSolver.py
deleted file mode 100644
index 4bc7d9e26f02562adfb7b740c42fe3de0aa158c3..0000000000000000000000000000000000000000
--- a/indiquo/core/reference/ReferenceSolver.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import re
-
-
-class ReferenceSolver:
-    roman_map = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000, 'IV': 4, 'IX': 9, 'XL': 40, 'XC': 90,
-                 'CD': 400, 'CM': 900}
-    nl_number_map = {'erste,': 1, 'erster': 1, 'ersten': 1,
-                     'zweite': 2, 'zweiter': 2, 'zweiten': 2,
-                     'dritte': 3, 'dritter': 3, 'dritten': 3,
-                     'vierte': 4, 'vierter': 4, 'vierten': 4,
-                     'fünfte': 5, 'fünfter': 5, 'fünften': 5,
-                     'sechste': 6, 'sechster': 6, 'sechsten': 6,
-                     'siebte': 7, 'siebter': 7, 'siebten': 7}
-
-    nl_number_pattern = '(?:erste[rn]*|zweite[rn]*|dritte[rn]*|vierte[rn]*|fünfte[rn]*|sechste[rn]*|siebte[rn]*)'
-    nl_act_pattern = '(?:Akt|Aufzug)'
-    nl_scene_pattern = '(?:Szene|Auftritt)'
-
-    # base patterns
-    pattern_number = '[0-9]+'
-    pattern_roman = '(?=[MDCLXVI])M*(C[MD]|D?C{0,3})(X[CL]|L?X{0,3})(I[XV]|V?I{0,3})'
-    pattern_act_scene = f'(?P<act>{pattern_roman})(?:[/, ]|, ?)(?P<scene>{pattern_number})'
-    f_pattern = '(?: ?[f]{1,2}\\.?)'
-
-    patterns = [
-        # 123, (ohne Kontext nicht klar was gemeint ist)
-        f'{pattern_number}',
-
-        # nur Verse
-        # V. 494 # V. 1810 ff., 1832 f.
-        f'Vs?\\. ?(?P<verse>{pattern_number}){f_pattern}?',
-        # V. 1892-1915 # Vs. 1893-1936
-        f'Vs?\\. ?(?P<verse1>{pattern_number}) ?[\\-\u2013\u2014] ?(?P<verse2>{pattern_number})',
-        # TODO: V. 1810 ff., 1832 f.
-
-        # III 1 # V 3 # III,2 # V/3
-        f'{pattern_act_scene}',
-
-        # V/3, 1968f. # I,3:293f
-        f'{pattern_act_scene}[,:](?: ?V\\.?)? ?(?P<verse>{pattern_number}){f_pattern}?',
-
-        # V/3, 1851-1853 # III 1, V. 1094-1117
-        f'{pattern_act_scene}[,:](?: ?V\\.?)? ?(?P<verse1>{pattern_number}) ?[\\-\u2013\u2014] ?(?P<verse2>{pattern_number})',
-
-        # 1. Akt # 3. Aufzug
-        f'(?P<act>{pattern_number})\\.? ?{nl_act_pattern}',
-
-        # 1. Szene # 3. Auftritt
-        f'(?P<scene>{pattern_number})\\.? ?{nl_scene_pattern}',
-
-        # Erster Akt # dritten Aufzug
-        f'(?P<act>{nl_number_pattern}) {nl_act_pattern}',
-
-        # Erste Szene # dritten Auftritt
-        f'(?P<scene>{nl_number_pattern}) {nl_scene_pattern}'
-    ]
-
-    def solve(self, text):
-        for pattern in self.patterns:
-            match = re.match(f'^{pattern}$', text, flags=re.IGNORECASE)
-
-            if match:
-                act = -1
-                scene = -1
-                verse = -1
-
-                if 'act' in match.groupdict() and match.group('act'):
-                    matched = match.group('act')
-
-                    if matched.isdigit():
-                        act = int(matched)
-                    elif self.__is_roman(matched):
-                        act = self.__roman_to_integer(matched)
-                    else:
-                        act = self.__nl_to_integer(matched)
-
-                if 'scene' in match.groupdict() and match.group('scene'):
-                    matched = match.group('scene')
-
-                    if matched.isdigit():
-                        scene = int(matched)
-                    else:
-                        scene = self.__nl_to_integer(matched)
-
-                if 'verse' in match.groupdict() and match.group('verse'):
-                    verse = int(match.group('verse'))
-
-                if 'verse1' in match.groupdict() and match.group('verse1'):
-                    verse = int(match.group('verse1'))
-
-                if 'verse2' in match.groupdict() and match.group('verse2'):
-                    # verse = int(match.group('verse1'))
-                    # TODO: implement
-                    pass
-
-                return act, scene, verse
-
-        print(f'Could not solve reference: {text}')
-        return -1, -1, -1
-
-    def __is_roman(self, roman_input) -> bool:
-        for s in roman_input:
-            if s not in self.roman_map:
-                return False
-
-        return True
-
-    def __roman_to_integer(self, roman_input) -> int:
-
-        i = 0
-        num = 0
-        while i < len(roman_input):
-            if i + 1 < len(roman_input) and roman_input[i:i + 2] in self.roman_map:
-                num += self.roman_map[roman_input[i:i + 2]]
-                i += 2
-            else:
-                num += self.roman_map[roman_input[i]]
-                i += 1
-        return num
-
-    def __nl_to_integer(self, nl_input) -> int:
-        nl_input_lower = nl_input.lower()
-
-        if nl_input_lower in self.nl_number_map:
-            return self.nl_number_map[nl_input_lower]
-
-        return -1
diff --git a/indiquo/core/reference/RegExReferencePredictor.py b/indiquo/core/reference/RegExReferencePredictor.py
deleted file mode 100644
index ce4536525b02fd7130d6c34cf8fbf761da9b0bb6..0000000000000000000000000000000000000000
--- a/indiquo/core/reference/RegExReferencePredictor.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from typing import List
-
-from indiquo.core.reference.BaseReferencePredictor import BaseReferencePredictor
-from indiquo.core.reference.Reference import Reference
-from indiquo.core.reference.ReferencePrediction import ReferencePrediction
-
-
-class RegExReferencePredictor(BaseReferencePredictor):
-
-    MAX_PARENTHESES_LENGTH = 25
-
-    def __init__(self, reference_solver):
-        self._reference_solver = reference_solver
-
-    def predict(self, text: str) -> List[ReferencePrediction]:
-        return self.__get_text_in_parentheses(text)
-
-    def __get_text_in_parentheses(self, input_text: str):
-
-        bracket_positions = {}
-        stack = []
-
-        for i, c in enumerate(input_text):
-            if c == '(':
-                stack.append(i)
-            elif c == ')':
-                if len(stack) > 0:
-                    # There will be errors, so we just ignore them
-                    bracket_positions[stack.pop()] = i
-
-        possible_references = []
-
-        for start, end in bracket_positions.items():
-            if end - start > self.MAX_PARENTHESES_LENGTH:
-                continue
-
-            ref_text = input_text[start + 1:end]
-            act, scene, verse = self.__solve_reference(ref_text)
-
-            if act != -1 or scene != -1 or verse != -1:
-                possible_references.append(ReferencePrediction(ref_text, [Reference(start + 1, end, ref_text)]))
-
-        return possible_references
-
-    def __solve_reference(self, text):
-        return self._reference_solver.solve(text)
diff --git a/indiquo/core/reference/__init__.py b/indiquo/core/reference/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/indiquo/core/reportedspeech/ReportedSpeechByRefPredictor.py b/indiquo/core/reportedspeech/ReportedSpeechByRefPredictor.py
deleted file mode 100644
index a1ff33260b522ef241205a5326009d021424171d..0000000000000000000000000000000000000000
--- a/indiquo/core/reportedspeech/ReportedSpeechByRefPredictor.py
+++ /dev/null
@@ -1,34 +0,0 @@
-from typing import List
-
-from indiquo.core.chunker.BaseChunker import BaseChunker
-from indiquo.core.chunker.Chunk import Chunk
-from indiquo.core.chunker.ChunkRef import ChunkRef
-from indiquo.core.reference.Reference import Reference
-from indiquo.core.reference.ReferencePrediction import ReferencePrediction
-
-
-class ReportedSpeechByRefPredictor:
-
-    def __init__(self, chunker: BaseChunker):
-        self.chunker = chunker
-
-    def predict(self, text, reference_predictions: List[ReferencePrediction]) -> List[Chunk]:
-        sentences = self.chunker.chunk(text)
-        candidates: List[ChunkRef] = []
-
-        for s in sentences:
-            references = []
-            for rp in reference_predictions:
-                ref = rp.reference
-
-                overlap_start = max(s.start, ref.start)
-                overlap_end = min(s.end, ref.end)
-                overlap_length = overlap_end - overlap_start
-
-                if overlap_length > 0:
-                    references.append(Reference(rp.reference.start, rp.reference.end, rp.reference.text))
-
-            if len(references) > 0:
-                candidates.append(ChunkRef(s.start, s.end, s.text, references))
-
-        return candidates
diff --git a/indiquo/core/reportedspeech/ReportedSpeechPredictor.py b/indiquo/core/reportedspeech/ReportedSpeechPredictor.py
deleted file mode 100644
index d3e28dcaf0127100583c6b6c2d1bdda65e3d6f7b..0000000000000000000000000000000000000000
--- a/indiquo/core/reportedspeech/ReportedSpeechPredictor.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from typing import List
-from flair.data import Sentence
-from flair.nn import Model
-from indiquo.core.chunker.BaseChunker import BaseChunker
-from indiquo.core.chunker.Chunk import Chunk
-
-
-class ReportedSpeechPredictor:
-
-    def __init__(self, chunker: BaseChunker, model: Model):
-        self.chunker = chunker
-        self.model = model
-
-    def predict(self, text) -> List[Chunk]:
-        sentences = self.chunker.chunk(text)
-        flair_sentences = []
-
-        for s in sentences:
-            flair_sentence = Sentence(s.text, use_tokenizer=True, start_position=s.start)
-            self.model.predict(flair_sentence)
-            flair_sentences.append(flair_sentence)
-
-        candidates: List[Chunk] = []
-        current_candidate = None
-
-        for fs in flair_sentences:
-            found = False
-            for t in fs.tokens:
-                if t.labels[0].value == 'reported':
-                    found = True
-                    break
-
-            if found:
-                if not current_candidate:
-                    chunk_text = text[fs.start_pos:fs.end_pos]
-                    current_candidate = Chunk(fs.start_pos, fs.end_pos, chunk_text)
-                else:
-                    chunk_text = text[current_candidate.start:fs.end_pos]
-                    current_candidate = Chunk(current_candidate.start, fs.end_pos, chunk_text)
-            else:
-                if current_candidate:
-                    candidates.append(current_candidate)
-                    current_candidate = None
-
-        return candidates
diff --git a/indiquo/core/reportedspeech/__init__.py b/indiquo/core/reportedspeech/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/indiquo/core/similarity/Heatmap.py b/indiquo/core/similarity/Heatmap.py
deleted file mode 100644
index f77965fb8c9cc2d97c01cb396cc0008ddc9b9934..0000000000000000000000000000000000000000
--- a/indiquo/core/similarity/Heatmap.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from dataclasses import dataclass
-from typing import List
-from indiquo.core.similarity.Line import Line
-
-
-@dataclass
-class Heatmap:
-    lines: List[Line]
diff --git a/indiquo/core/similarity/SimilarityPair.py b/indiquo/core/similarity/SimilarityPair.py
deleted file mode 100644
index f93996b36f91def81df0479f536307e973316530..0000000000000000000000000000000000000000
--- a/indiquo/core/similarity/SimilarityPair.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from dataclasses import dataclass
-
-
-@dataclass
-class SimilarityPair:
-    start: int
-    end: int
-    text: str
diff --git a/indiquo/core/similarity/SimilarityPredictor.py b/indiquo/core/similarity/SimilarityPredictor.py
deleted file mode 100644
index 9cd3cc5229ba58c521f89e6e1dd524e0de6d2e55..0000000000000000000000000000000000000000
--- a/indiquo/core/similarity/SimilarityPredictor.py
+++ /dev/null
@@ -1,55 +0,0 @@
-from typing import List
-from sentence_transformers import SentenceTransformer, util
-from indiquo.core.similarity.Heatmap import Heatmap
-from indiquo.core.similarity.Line import Line
-from indiquo.core.similarity.SimilarityPair import SimilarityPair
-
-
-class SimilarityPredictor:
-
-    def __init__(self, model):
-        self.model = model
-
-    def predict(self, drama, target_text, act_nr, scene_nr):
-        text_blocks = drama.get_text_for_scene_by_character(act_nr, scene_nr, 128)
-        heatmap = self.__get_heatmap(text_blocks, target_text)
-        return heatmap
-
-    def __get_heatmap(self, source_sentences: List[str], target_text: str) -> Heatmap:
-        all_sentences = []
-        all_candidates = []
-
-        # end = -1
-        for line_start, line_end, sent in source_sentences:
-            # start = end + 1
-            content = sent.strip()
-            # end = start + len(content)
-
-            all_candidates.append(SimilarityPair(line_start, line_end, content))
-            all_sentences.append(content)
-
-        source_sent_count = len(all_sentences)
-
-        all_candidates.append(SimilarityPair(0, len(target_text), target_text))
-        all_sentences.append(target_text)
-
-        paraphrases = util.paraphrase_mining(self.model, all_sentences)
-        lines = []
-
-        for paraphrase in paraphrases:
-            score, s_id, t_id = paraphrase
-
-            if (s_id < source_sent_count and t_id < source_sent_count) or \
-                    (s_id >= source_sent_count and t_id >= source_sent_count):
-                continue
-
-            if score < 0:
-                continue
-
-            source_span = all_candidates[s_id]
-            lines.append(Line(source_span.start, source_span.end, source_span.text, score))
-
-        lines.sort(key=lambda x: x.start, reverse=True)
-
-        heatmap = Heatmap(lines)
-        return heatmap
diff --git a/indiquo/core/similarity/__init__.py b/indiquo/core/similarity/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/indiquo/match/Match.py b/indiquo/match/Match.py
index 159dfc9ecdcce0daed472d71ef4a82ecce4d2681..23ddd19bec51d0c6c43207d813c253d9d0cd71a4 100644
--- a/indiquo/match/Match.py
+++ b/indiquo/match/Match.py
@@ -1,7 +1,6 @@
 from dataclasses import dataclass
-from typing import Optional, List, Tuple
-
-from indiquo.core.similarity.Heatmap import Heatmap
+from typing import List
+from indiquo.core.ScenePrediction import ScenePrediction
 
 
 @dataclass
@@ -9,13 +8,4 @@ class Match:
     target_start: int
     target_end: int
     target_text: str
-    ref_start: int
-    ref_end: int
-    ref_text: str
-    direct_quote_start: int
-    direct_quote_end: int
-    direct_quote_text: str
-    act: int
-    scene: int
-    source_heatmap: Optional[Heatmap]
-    line_ranges: List[Tuple[int, int]]
+    scene_predictions: List[ScenePrediction]
diff --git a/indiquo/model/__init__.py b/indiquo/model/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/indiquo/model/reference/ReferenceModelTrainer.py b/indiquo/model/reference/ReferenceModelTrainer.py
deleted file mode 100644
index 13483fe60b477241ffa796169a663841728ea5c3..0000000000000000000000000000000000000000
--- a/indiquo/model/reference/ReferenceModelTrainer.py
+++ /dev/null
@@ -1,185 +0,0 @@
-import numpy as np
-from keras.utils import Progbar
-import transformers
-import tensorflow as tf
-from tensorflow.keras.metrics import SparseCategoricalAccuracy
-import keras
-
-
-class ReferenceModelTrainer:
-
-    def __init__(self, vectorizer, num_epochs, num_steps, loss_type, batches_per_update, mlm_aux=False, alpha=0.5):
-        self.vectorizer = vectorizer
-        self.num_epochs = num_epochs
-        self.num_steps = num_steps
-        self.loss_type = loss_type
-        self.batches_per_update = batches_per_update
-        self.mlm_aux = mlm_aux
-        self.alpha = alpha
-
-    def get_model(self, model_name):
-        model = transformers.TFBertForMaskedLM.from_pretrained(model_name, from_pt=True)
-        model.resize_token_embeddings(self.vectorizer.vocab_size)
-        return model
-
-    def train_model(self, model_name, dataset, template_batch_size):
-        # TODO: improve speed, see https://keras.io/guides/writing_a_training_loop_from_scratch/
-
-        none_id = self.vectorizer.none_id
-        sep_id = self.vectorizer.sep_id
-        pad_id = self.vectorizer.pad_id
-        mask_token_id = self.vectorizer.tokenizer.mask_token_id
-
-        model = self.get_model(model_name)
-        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, ignore_class=-100)
-
-        optimizer, _ = transformers.create_optimizer(
-            init_lr=1e-5,
-            num_warmup_steps=(self.num_steps * self.num_epochs) * 0.05,
-            num_train_steps=self.num_steps * self.num_epochs,
-            weight_decay_rate=0.01,
-        )
-
-        metrics_names = []
-        train_acc_metric = SparseCategoricalAccuracy()
-        val_acc_metric = SparseCategoricalAccuracy()
-
-        tf_train_dataset = dataset['train'].to_tf_dataset(
-            # columns=["input_ids", "token_type_ids", "attention_mask", "labels", "context_ids"],
-            # label_cols=["labels"],
-            batch_size=template_batch_size,
-            shuffle=False
-        )
-
-        tf_eval_dataset = dataset['eval'].to_tf_dataset(
-            batch_size=template_batch_size,
-            shuffle=False
-        )
-
-        for epoch in range(self.num_epochs):
-            print(f'\nEpoch {epoch + 1}/{self.num_epochs}')
-            progbar = Progbar(self.num_steps, stateful_metrics=metrics_names)
-
-            total_loss = 0
-            for step, elem in enumerate(tf_train_dataset):
-                with tf.GradientTape() as tape:
-                    template_input_ids = elem['input_ids']
-                    template_token_type_ids = elem['token_type_ids']
-                    template_attention_mask = elem['attention_mask']
-                    template_y_batch = elem['labels']
-                    context_ids_batch = elem['context_ids']
-
-                    template_output = model(input_ids=template_input_ids,
-                                            attention_mask=template_attention_mask,
-                                            token_type_ids=template_token_type_ids,
-                                            training=True)
-
-                    if self.loss_type == 'mlm':
-                        # use default mlm loss, i.e. cross entropy loss
-                        template_loss_value = loss_fn(template_y_batch, template_output.logits)
-                    elif self.loss_type == 'weighted':
-                        loss_list = []
-
-                        for tids, logits, template_y, context_ids in zip(template_input_ids, template_output.logits,
-                                                                         template_y_batch, context_ids_batch):
-
-                            mask_token_indexes = tf.where(tids == mask_token_id)
-                            loss_per_mask = []
-
-                            context_ids = context_ids.numpy().tolist()
-                            context_ids = [i for i in context_ids if i != 0]
-                            context_ids += [none_id, pad_id]
-
-                            if sep_id != -1:
-                                context_ids += [sep_id]
-
-                            for item in mask_token_indexes:
-                                mask_token_index = item[0]
-                                mask_token_logits = logits[mask_token_index, :]
-                                mask_token_probs = tf.nn.softmax(mask_token_logits)
-                                true_label_id = template_y[mask_token_index]
-
-                                true_one_hot = np.zeros(len(mask_token_probs), dtype=np.float32)
-                                true_one_hot[true_label_id] = 1.0
-
-                                bce_loss_list = keras.backend.binary_crossentropy(true_one_hot, mask_token_probs)
-
-                                msk = np.zeros(mask_token_logits.shape, dtype=bool)
-                                msk[context_ids] = True
-                                msk[true_label_id] = False
-
-                                true_label_loss = bce_loss_list[true_label_id]
-                                context_losses = bce_loss_list[msk]
-                                context_loss = tf.reduce_sum(context_losses)
-
-                                comb_loss = true_label_loss + context_loss
-                                loss_per_mask.append(comb_loss)
-
-                            loss_mean = tf.reduce_mean(loss_per_mask)
-                            loss_list.append(loss_mean)
-
-                        template_loss_value = sum(loss_list) / len(loss_list)
-                    else:
-                        raise Exception(f'Invalid loss type: {self.loss_type}')
-
-                    loss = template_loss_value
-
-                sample_weights = np.ones(template_y_batch.shape)
-
-                for i in range(len(sample_weights)):
-                    sample_weights[i][template_y_batch[i] == -100] = 0
-
-                train_acc_metric.update_state(template_y_batch, template_output.logits, sample_weights)
-                grads = tape.gradient(loss, model.trainable_weights)
-
-                if self.batches_per_update > 1:
-                    if step % self.batches_per_update == 0:
-                        accum_grads = [tf.zeros_like(w) for w in grads]
-
-                    accum_grads = [ag + g for ag, g in zip(accum_grads, grads)]
-
-                    if (step + 1) % self.batches_per_update == 0:
-                        optimizer.apply_gradients(
-                            [(ag / self.batches_per_update, w) for ag, w in zip(accum_grads, model.trainable_weights)])
-
-                        train_acc = train_acc_metric.result()
-                        total_loss += loss
-                        values = [('loss', total_loss / (step + 1)), ('acc', train_acc)]
-                        progbar.update(step / self.batches_per_update, values=values)
-                else:
-                    optimizer.apply_gradients(zip(grads, model.trainable_weights))
-
-                    train_acc = train_acc_metric.result()
-                    total_loss += loss
-                    values = [('loss', total_loss / (step + 1)), ('acc', train_acc)]
-                    progbar.update(step, values=values)
-
-            train_acc_metric.reset_states()
-
-            val_loss = 0
-            val_steps = 0
-            for step, elem in enumerate(tf_eval_dataset):
-                template_input_ids = elem['input_ids']
-                template_token_type_ids = elem['token_type_ids']
-                template_attention_mask = elem['attention_mask']
-                template_y_batch = elem['labels']
-
-                model_output = model(input_ids=template_input_ids,
-                                     attention_mask=template_attention_mask,
-                                     token_type_ids=template_token_type_ids,
-                                     training=False)
-
-                sample_weights = np.ones(template_y_batch.shape)
-                for i in range(len(sample_weights)):
-                    sample_weights[i][template_y_batch[i] == -100] = 0
-                val_acc_metric.update_state(template_y_batch, model_output.logits, sample_weights)
-
-                val_loss += loss_fn(template_y_batch, model_output.logits)
-                val_steps += 1
-
-            val_acc = val_acc_metric.result()
-            val_acc_metric.reset_states()
-            values = [('val_loss', val_loss / val_steps), ('val_acc', val_acc)]
-            progbar.update(self.num_steps, values=values)
-
-        return model
diff --git a/indiquo/model/reference/ReferenceVectorizer.py b/indiquo/model/reference/ReferenceVectorizer.py
deleted file mode 100644
index ec8c2cb74dbf1cd67734ffeca44cc882e7eb0dbf..0000000000000000000000000000000000000000
--- a/indiquo/model/reference/ReferenceVectorizer.py
+++ /dev/null
@@ -1,223 +0,0 @@
-import transformers
-import warnings
-import torch
-
-
-class ReferenceVectorizer:
-    tokenizer: transformers.BertTokenizer
-    vocab_size: int
-    max_length: int
-    num_masked_tokens: int
-    none_id: int
-    pad_id: int
-    sep_id: int
-
-    def __init__(self, tokenizer, max_length, num_masked_tokens, use_sep):
-        self.tokenizer = tokenizer
-        self.vocab_size = len(self.tokenizer)
-        self.max_length = max_length
-        self.num_masked_tokens = num_masked_tokens
-
-        self.pad_id = self.__get_special_token_id('<pad>')
-        self.none_id = self.__get_special_token_id('<none>')
-        self.sep_id = -1
-
-        if use_sep:
-            self.sep_id = tokenizer.added_tokens_encoder['<sep>']
-
-    @classmethod
-    def from_raw(cls, model_name, max_length, lower_case, num_masked_tokens, use_sep):
-
-        tokenizer = transformers.BertTokenizerFast.from_pretrained(
-            model_name, do_lower_case=lower_case
-        )
-
-        special_tokens_list = ['<pad>', '<none>']
-
-        if use_sep:
-            special_tokens_list.append('<sep>')
-
-        special_tokens_dict = {'additional_special_tokens': special_tokens_list}
-        tokenizer.add_special_tokens(special_tokens_dict)
-
-        return cls(tokenizer, max_length, num_masked_tokens, use_sep)
-
-    @classmethod
-    def from_saved(cls, tokenizer_path, max_length, lower_case, num_masked_tokens, use_sep):
-        tokenizer = transformers.BertTokenizerFast.from_pretrained(
-            tokenizer_path, do_lower_case=lower_case
-        )
-
-        return cls(tokenizer, max_length, num_masked_tokens, use_sep)
-
-    def vectorize_templates(self, contexts, answers, templates):
-        result_input_id_lists = []
-        result_token_type_id_lists = []
-        result_attention_mask_lists = []
-        result_label_lists = []
-        result_context_id_lists = []
-
-        template_cache = {}
-
-        for context, ans_obj, template in zip(contexts, answers, templates):
-
-            if template in template_cache:
-                template_inputs_ids = template_cache[template][0]
-                template_len = template_cache[template][1]
-                mask_index = template_cache[template][2]
-            else:
-                template_inputs_ids = self.tokenizer.encode(template)[1:-1]
-                template_len = len(template_inputs_ids)
-                mask_index = template_inputs_ids.index(self.tokenizer.mask_token_id)
-                template_cache[template] = (template_inputs_ids, template_len, mask_index)
-
-            context_inputs_ids = self.tokenizer.encode(context)[1:-1]
-            answer_list = []
-
-            for ans in ans_obj:
-                answer_list.append(ans['text'])
-
-            answer_combined = answer_list[0]
-
-            if len(answer_list) > 1:
-                if self.sep_id == -1:
-                    warnings.warn('There are multiple answers but <sep> is not set!')
-
-                for i in range(1, len(answer_list)):
-                    answer_combined += f' <sep> {answer_list[i]}'
-            else:
-                if self.sep_id != -1:
-                    answer_combined += ' <sep> <none>'
-
-            answer_labels = self.tokenizer.encode(answer_combined)[1:-1]
-
-            if len(answer_labels) > self.num_masked_tokens:
-                warnings.warn(f'\'{answer_combined}\' is too long: {len(answer_labels)}. Skipping!')
-                continue
-
-            answer_input_ids = [self.tokenizer.mask_token_id] * self.num_masked_tokens
-            answer_padding = [self.pad_id] * (self.num_masked_tokens - len(answer_labels))
-            answer_labels = answer_labels + answer_padding
-
-            assert len(answer_input_ids) == len(answer_labels)
-
-            answer_token_type_ids = [0] * len(answer_input_ids)
-            answer_attention_mask = [1] * len(answer_input_ids)
-
-            template_token_type_ids = [0] * template_len
-            template_attention_mask = [1] * template_len
-            template_labels = [-100] * template_len
-
-            all_input_ids = template_inputs_ids.copy()
-            all_input_ids[mask_index:mask_index + 1] = answer_input_ids
-
-            all_token_type_ids = template_token_type_ids
-            all_token_type_ids[mask_index:mask_index + 1] = answer_token_type_ids
-
-            all_attention_mask = template_attention_mask
-            all_attention_mask[mask_index:mask_index + 1] = answer_attention_mask
-
-            all_labels = template_labels
-            all_labels[mask_index:mask_index + 1] = answer_labels
-
-            remaining_space = self.max_length - 2 - len(all_input_ids)
-
-            assert remaining_space > 0
-
-            result_context_ids = context_inputs_ids.copy()
-
-            if len(context_inputs_ids) <= remaining_space:
-                num_pad_tokens = remaining_space - len(context_inputs_ids)
-                context_inputs_ids_padded = ([0] * num_pad_tokens) + context_inputs_ids
-                context_token_type_ids = [0] * len(context_inputs_ids_padded)
-                context_attention_mask = ([0] * num_pad_tokens) + ([1] * len(context_inputs_ids))
-                context_labels = [-100] * len(context_inputs_ids_padded)
-
-                all_input_ids = context_inputs_ids_padded + all_input_ids
-                all_token_type_ids = context_token_type_ids + all_token_type_ids
-                all_attention_mask = context_attention_mask + all_attention_mask
-                all_labels = context_labels + all_labels
-
-                context_diff = len(all_input_ids) - len(result_context_ids)
-                result_context_ids = [0] * context_diff + result_context_ids
-
-            elif len(context_inputs_ids) > remaining_space:
-                num_remove = len(context_inputs_ids) - remaining_space
-                warnings.warn(f'Need to truncate: {num_remove}')
-                context_input_ids_reduced = context_inputs_ids[num_remove:]
-                context_token_type_ids = [0] * len(context_input_ids_reduced)
-                context_attention_mask = ([1] * len(context_input_ids_reduced))
-                context_labels = [-100] * len(context_input_ids_reduced)
-
-                result_context_ids = context_input_ids_reduced
-
-                all_input_ids = context_input_ids_reduced + all_input_ids
-                all_token_type_ids = context_token_type_ids + all_token_type_ids
-                all_attention_mask = context_attention_mask + all_attention_mask
-                all_labels = context_labels + all_labels
-
-                context_diff = len(all_input_ids) - len(result_context_ids)
-                result_context_ids = [0] * context_diff + result_context_ids
-
-            all_input_ids = [self.tokenizer.cls_token_id] + all_input_ids + [self.tokenizer.sep_token_id]
-            all_token_type_ids = [0] + all_token_type_ids + [0]
-            all_attention_mask = [1] + all_attention_mask + [1]
-            all_labels = [-100] + all_labels + [-100]
-            result_context_ids = [0] + result_context_ids + [0]
-
-            assert len(all_input_ids) == len(all_token_type_ids)
-            assert len(all_input_ids) == len(all_attention_mask)
-            assert len(all_input_ids) == len(all_labels)
-            assert len(all_input_ids) == self.max_length
-
-            result_input_id_lists.append(all_input_ids)
-            result_token_type_id_lists.append(all_token_type_ids)
-            result_attention_mask_lists.append(all_attention_mask)
-            result_label_lists.append(all_labels)
-            result_context_id_lists.append(result_context_ids)
-
-        return {
-            'input_ids': result_input_id_lists,
-            'token_type_ids': result_token_type_id_lists,
-            'attention_mask': result_attention_mask_lists,
-            'labels': result_label_lists,
-            'context_ids': result_context_id_lists
-        }
-
-    def vectorize_inference(self, input_strs):
-        inputs_list = self.tokenizer(input_strs, max_length=self.max_length, padding='max_length', truncation=True,
-                                     return_tensors="np")
-
-        return inputs_list
-
-    def __get_special_token_id(self, token):
-        index = self.tokenizer.additional_special_tokens.index(token)
-        return self.tokenizer.additional_special_tokens_ids[index]
-
-    # def vectorize_mlm(self, tokenizer, text, chunk_size):
-    #
-    #     inputs = tokenizer(text)
-    #
-    #     inputs = {'input_ids': inputs['input_ids'],
-    #               'token_type_ids': inputs['token_type_ids'],
-    #               'attention_mask': inputs['attention_mask']}
-    #
-    #     result = {
-    #         k: [t[i:i + chunk_size] for i in range(0, len(t), chunk_size)]
-    #         for k, t in inputs.items()
-    #     }
-    #
-    #     result["labels"] = result["input_ids"].copy()
-    #     samples = []
-    #
-    #     for i in range(0, len(result['input_ids']) - 1):
-    #         samples.append({'input_ids': result['input_ids'][i],
-    #                         'token_type_ids': result['token_type_ids'][i],
-    #                         'attention_mask': result['attention_mask'][i],
-    #                         'labels': result['labels'][i]})
-    #
-    #     data_collator = transformers.DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm_probability=0.15)
-    #     masked_samples = data_collator(samples)
-    #     masked_samples['attention_mask'] = tf.ones(len(masked_samples['input_ids']), chunk_size, dtype=tf.int32)
-    #     masked_samples['token_type_ids'] = tf.zeros(len(masked_samples['input_ids']), chunk_size, dtype=tf.int32)
-    #     return masked_samples
diff --git a/indiquo/model/reference/__init__.py b/indiquo/model/reference/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/indiquo/training/reference/TrainReference.py b/indiquo/training/reference/TrainReference.py
deleted file mode 100644
index 345af0b5c82055a3ee8d83c797c6e034f1d5fed9..0000000000000000000000000000000000000000
--- a/indiquo/training/reference/TrainReference.py
+++ /dev/null
@@ -1,184 +0,0 @@
-import json
-import math
-from os.path import join
-from pathlib import Path
-from datasets import DatasetDict, Dataset
-
-from indiquo.model.reference.ReferenceModelTrainer import ReferenceModelTrainer
-from indiquo.model.reference.ReferenceVectorizer import ReferenceVectorizer
-import transformers
-import torch
-
-
-def preprocess_function(samples, vectorizer):
-    train_template_samples = vectorizer.vectorize_templates(samples['context'], samples['answer'], samples['template'])
-    return train_template_samples
-
-
-def train(lower_case, base_model_name, model_type, loss_type, num_masked_tokens, use_sep, num_train_steps,
-          num_training_examples, num_validation_examples, template_batch_size, max_length, batches_per_update,
-          template_train_file_path, val_file_path, output_folder_path):
-
-    # create and write config
-    config = {
-        'base model name': base_model_name,
-        'model type': model_type,
-        'loss type': loss_type,
-        'num masked tokens': num_masked_tokens,
-        'use sep': use_sep,
-        'num train steps': num_train_steps,
-        'num training examples': num_training_examples,
-        'template batch size': template_batch_size,
-        'template max length': max_length,
-        'batches per update': batches_per_update,
-    }
-
-    with open(join(output_folder_path, 'config.json'), 'w', encoding='utf-8') as config_file:
-        content = json.dumps(config)
-        config_file.write(content)
-
-    ref_vectorizer = ReferenceVectorizer.from_raw(base_model_name, max_length, lower_case, num_masked_tokens, use_sep)
-
-    tokenizer_dir = join(output_folder_path, 'tokenizer')
-    Path(tokenizer_dir).mkdir(parents=True, exist_ok=True)
-    ref_vectorizer.tokenizer.save_pretrained(tokenizer_dir)
-
-    with open(val_file_path, 'r', encoding='utf-8') as val_file:
-        val_json = json.load(val_file)
-
-    with open(template_train_file_path, 'r', encoding='utf-8') as train_file:
-        train_json = json.load(train_file)
-
-    train_contexts = []
-    train_answers = []
-
-    for ex in train_json['examples']:
-        train_contexts.append(ex['context'])
-        train_answers.append(ex['answer'])
-
-    val_contexts = []
-    val_answers = []
-
-    for ex in val_json['examples']:
-        val_contexts.append(ex['context'])
-        val_answers.append(ex['answer'])
-
-    # if mlm_aux and mlm_train_file_path:
-    #     with open(mlm_train_file_path, 'r') as txt_file:
-    #         train_text = txt_file.read()
-    #         train_text = train_text[0:400000]
-
-    if model_type == 'combined':
-        datasets = DatasetDict()
-
-        all_train_contexts = []
-        all_train_answers = []
-        all_train_templates = []
-
-        for template in train_json['templates']:
-            all_train_contexts.extend(train_contexts)
-            all_train_answers.extend(train_answers)
-            all_train_templates.extend([template] * len(train_contexts))
-
-        train_dataset_dict = {
-            "context": all_train_contexts,
-            "answer": all_train_answers,
-            "template": all_train_templates
-        }
-
-        train_dataset = Dataset.from_dict(train_dataset_dict)
-        tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True, batch_size=1000,
-                                                    remove_columns=['context', 'answer', 'template'],
-                                                    fn_kwargs={'vectorizer': ref_vectorizer})
-
-        tokenized_train_dataset.shuffle(seed=42)
-
-        if tokenized_train_dataset.num_rows < num_training_examples:
-            raise Exception('Not enough training examples')
-
-        tokenized_train_dataset.select(range(num_training_examples))
-        datasets['train'] = tokenized_train_dataset
-
-        all_val_contexts = []
-        all_val_answers = []
-        all_val_templates = []
-
-        for template in val_json['templates']:
-            all_val_contexts.extend(val_contexts)
-            all_val_answers.extend(val_answers)
-            all_val_templates.extend([template] * len(val_contexts))
-
-        val_dataset_dict = {
-            "context": all_val_contexts,
-            "answer": all_val_answers,
-            "template": all_val_templates
-        }
-
-        val_dataset = Dataset.from_dict(val_dataset_dict)
-        tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True, batch_size=1000,
-                                                remove_columns=['context', 'answer', 'template'],
-                                                fn_kwargs={'vectorizer': ref_vectorizer})
-
-        tokenized_val_dataset.shuffle(seed=42)
-        datasets['eval'] = tokenized_val_dataset
-
-        num_examples = datasets.num_rows['train']
-        real_batch_size = template_batch_size * batches_per_update
-        steps_per_epoch = num_examples / real_batch_size
-
-        if num_train_steps > 0:
-            num_epochs = math.ceil(num_train_steps / steps_per_epoch)
-            assert num_epochs > 0
-        else:
-            # TODO: set number of epochs or number of train steps
-            raise Exception('Number of train steps not set')
-
-        # train_mlm_samples = None
-        # if mlm_aux:
-        #     train_mlm_samples = None
-        #
-        #     # sanity check
-        #     mlm_step_count = math.floor(len(train_mlm_samples['input_ids']) / mlm_batch_size)
-        #     assert steps_per_epoch < mlm_step_count
-
-        ref_model_trainer = ReferenceModelTrainer(ref_vectorizer, num_epochs, steps_per_epoch, loss_type,
-                                                  batches_per_update)
-        model = ref_model_trainer.train_model(base_model_name, datasets, template_batch_size)
-
-        # TODO: load and save best model?
-
-        best_model_path = join(output_folder_path, 'model')
-        model.save_pretrained(best_model_path)
-
-    elif model_type == 'template':
-        for template_pos, template in enumerate(train_json['templates']):
-            print(f'\nStarting template {template_pos + 1}')
-            template_output_folder_path = join(output_folder_path, f'model_{template_pos + 1}')
-            Path(template_output_folder_path).mkdir(parents=True, exist_ok=True)
-
-            train_template_samples = ref_vectorizer.vectorize_templates(train_contexts, train_answers, template,
-                                                                        num_training_examples)
-            val_template_samples = ref_vectorizer.vectorize_templates(val_contexts, val_answers, template,
-                                                                      num_validation_examples)
-
-            steps_per_epoch = math.floor(len(train_template_samples) / template_batch_size)
-
-            num_epochs = 0
-            if num_train_steps > 0:
-                num_epochs = num_train_steps // steps_per_epoch
-
-            train_mlm_samples = None
-            # if mlm_aux:
-            #     train_mlm_samples = None
-            #
-            #     # sanity check
-            #     mlm_step_count = math.floor(len(train_mlm_samples['input_ids']) / mlm_batch_size)
-            #     assert steps_per_epoch < mlm_step_count
-
-            ref_model_trainer = ReferenceModelTrainer(ref_vectorizer, num_epochs, steps_per_epoch, loss_type,
-                                                      batches_per_update)
-            # TODO: update
-            model = ref_model_trainer.train_model(base_model_name, train_template_samples, val_template_samples,
-                                                  template_batch_size, max_length)
-
-            model.save_pretrained(template_output_folder_path)
diff --git a/indiquo/training/reference/__init__.py b/indiquo/training/reference/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/indiquo/training/summarization/Train.py b/indiquo/training/summarization/Train.py
deleted file mode 100644
index c5fae26874b8cc5743a3eb33d36dc92b3280eb10..0000000000000000000000000000000000000000
--- a/indiquo/training/summarization/Train.py
+++ /dev/null
@@ -1,118 +0,0 @@
-from argparse import ArgumentParser
-from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer,
-                          Seq2SeqTrainingArguments)
-from datasets import load_dataset
-from datetime import datetime
-from os.path import join
-from pathlib import Path
-import numpy as np
-import torch
-
-
-def preprocess_function(sample, tokenizer, max_source_length, max_target_length, padding="max_length"):
-    inputs = ["summarize: " + item.replace('||', '\n') for item in sample["dialogue"]]
-
-    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
-    labels = tokenizer(text_target=sample["summary"], max_length=max_target_length, padding=padding, truncation=True)
-
-    if padding == "max_length":
-        labels["input_ids"] = [
-           [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
-        ]
-
-    model_inputs["labels"] = labels["input_ids"]
-    return model_inputs
-
-
-def main():
-    argument_parser = ArgumentParser(prog='indiquo', description='TBD')
-
-    argument_parser.add_argument("--data", dest="data_file_path", help="", required=True)
-    argument_parser.add_argument("-m", dest="model", help="", required=True)
-    argument_parser.add_argument("--dc", dest="deepspeed_config", help="", required=False, default='')
-    argument_parser.add_argument("-o", dest="output_path", help="", required=True)
-
-    args = argument_parser.parse_args()
-    data_file_path = args.data_file_path
-    model_id = args.model
-    deepspeed_config = args.deepspeed_config
-    output_path = args.output_path
-
-    # print('GPUs:')
-    # print(tf.config.list_physical_devices('GPU'))
-    # print(torch.cuda.is_available())
-    # print(torch.cuda.device_count())
-    # print(torch.cuda.current_device())
-
-    now = datetime.now()
-    date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S')
-    output_folder_path = join(output_path, date_time_string)
-    Path(output_folder_path).mkdir(parents=True, exist_ok=True)
-
-    max_source_length = 1024
-    max_target_length = 128
-
-    tokenizer = AutoTokenizer.from_pretrained(model_id)
-    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
-
-    data_files = {"train": "train_examples.tsv", "eval": "val_examples.tsv"}
-    dataset = load_dataset(data_file_path, data_files=data_files)
-
-    tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=['file'],
-                                    fn_kwargs={'tokenizer': tokenizer, 'max_source_length': max_source_length,
-                                               'max_target_length': max_target_length})
-
-    # all_source_lengths = [len(x) for x in tokenized_dataset['train']['input_ids']]
-    # all_target_lengths = [len(x) for x in tokenized_dataset['train']['labels']]
-
-    # max_source_length = max(all_source_lengths)
-    # min_source_length = min(all_source_lengths)
-
-    # max_target_length = max(all_target_lengths)
-    # min_target_length = min(all_target_lengths)
-
-    # all_source_lengths.sort(reverse=True)
-    # source_percentile = np.percentile(all_source_lengths, 90)
-
-    # all_target_lengths.sort(reverse=True)
-    # target_percentile = np.percentile(all_target_lengths, 90)
-
-    # print(f'99 percentile: {source_percentile} characters')
-
-    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
-
-    # Define training args
-    training_args = Seq2SeqTrainingArguments(
-        output_dir=output_folder_path,
-        per_device_train_batch_size=1,
-        per_device_eval_batch_size=1,
-        gradient_accumulation_steps=4,
-        predict_with_generate=True,
-        fp16=False,
-        learning_rate=5e-5,
-        num_train_epochs=3,
-        # logging & evaluation strategies
-        logging_dir=f'{output_folder_path}/logs',
-        logging_strategy="steps",
-        logging_steps=500,
-        evaluation_strategy="epoch",
-        save_strategy="epoch",
-        save_total_limit=2,
-        load_best_model_at_end=True,
-        push_to_hub=False,
-        # deepspeed=deepspeed_config
-    )
-
-    trainer = Seq2SeqTrainer(
-        model=model,
-        args=training_args,
-        data_collator=data_collator,
-        train_dataset=tokenized_dataset["train"],
-        eval_dataset=tokenized_dataset["eval"]
-    )
-
-    trainer.train()
-
-
-if __name__ == '__main__':
-    main()
diff --git a/indiquo/training/summarization/__init__.py b/indiquo/training/summarization/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/indiquo/varia/DramaSummary.py b/indiquo/varia/DramaSummary.py
deleted file mode 100644
index 9470f6fcbc6b62b407c4e919cad86e3429e64300..0000000000000000000000000000000000000000
--- a/indiquo/varia/DramaSummary.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from argparse import ArgumentParser
-from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
-import csv
-
-
-def main():
-    argument_parser = ArgumentParser(prog='indiquo', description='TBD')
-
-    argument_parser.add_argument("--test", dest="test_file_path", help="", required=True)
-    argument_parser.add_argument("-m", dest="model", help="", required=True)
-
-    args = argument_parser.parse_args()
-    test_file_path = args.test_file_path
-    model_id = args.model
-
-    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
-    # TODO: adjust
-    tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-large')
-
-    with open(test_file_path, 'r') as test_file:
-        reader = csv.reader(test_file, delimiter='\t')
-        # skip first row (header)
-        next(reader, None)
-
-        for row in reader:
-            inputs = tokenizer(f'Summarize: {row[0]}', max_length=1024, padding='max_length', truncation=True,
-                               return_tensors='pt')
-            outputs = model.generate(**inputs, max_new_tokens=128)
-            print(f'Summary: {tokenizer.batch_decode(outputs, skip_special_tokens=True)}')
-
-
-if __name__ == '__main__':
-    main()
diff --git a/indiquo/varia/FewShotBertTraining.py b/indiquo/varia/FewShotBertTraining.py
deleted file mode 100644
index e0c306b58ea5b7e8056c8579bd983ef5e8244180..0000000000000000000000000000000000000000
--- a/indiquo/varia/FewShotBertTraining.py
+++ /dev/null
@@ -1,208 +0,0 @@
-from argparse import ArgumentParser
-
-import tensorflow as tf
-import transformers
-from tensorflow.keras.utils import Sequence
-from tensorflow.keras.callbacks import ModelCheckpoint
-import numpy as np
-from tensorflow.keras.optimizers.legacy import Adam
-from datetime import datetime
-from os.path import join
-from pathlib import Path
-import json
-
-
-class DataGenerator(Sequence):
-
-    def __init__(self, examples, batch_size):
-        self.examples = examples
-        self.batch_size = batch_size
-
-    def __len__(self):
-        return (np.ceil(len(self.examples) / self.batch_size)).astype(np.int32)
-
-    def __getitem__(self, idx):
-        input_ids = []
-        attention_masks = []
-        token_type_ids = []
-        labels = []
-
-        for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
-            if i >= len(self.examples):
-                break
-
-            input_ids.append(self.examples[i]['input_ids'])
-            attention_masks.append(self.examples[i]['attention_mask'])
-            token_type_ids.append(self.examples[i]['token_type_ids'])
-            labels.append(self.examples[i]['labels'])
-
-        input_ids_np = np.array(input_ids, dtype="int32")
-        attention_masks_np = np.array(attention_masks, dtype="int32")
-        token_type_ids_np = np.array(token_type_ids, dtype="int32")
-        labels_np = np.array(labels, dtype="int32")
-
-        return [input_ids_np, attention_masks_np, token_type_ids_np], labels_np
-
-
-def generate_masked_samples(tokenizer, contexts, answers, template, chunk_size, num_masked_tokens):
-    result = []
-    template_inputs_ids = tokenizer.encode(template)[1:-1]
-    mask_index = template_inputs_ids.index(tokenizer.mask_token_id)
-
-    for context, answer in zip(contexts, answers):
-        context_inputs_ids = tokenizer.encode(context)[1:-1]
-        answer_labels = tokenizer.encode(answer)[1:-1]
-
-        if len(answer_labels) > num_masked_tokens:
-            raise Exception(f'\'{answer}\' is too long')
-
-        answer_input_ids = [tokenizer.mask_token_id] * num_masked_tokens
-        answer_padding = [0] * (num_masked_tokens - len(answer_labels))
-        answer_labels = answer_labels + answer_padding
-
-        assert len(answer_input_ids) == len(answer_labels)
-
-        answer_token_type_ids = [0] * len(answer_input_ids)
-        answer_attention_mask = [1] * len(answer_input_ids)
-
-        template_token_type_ids = [0] * len(template_inputs_ids)
-        template_attention_mask = [1] * len(template_inputs_ids)
-        template_labels = [-100] * len(template_inputs_ids)
-
-        all_input_ids = template_inputs_ids
-        all_input_ids[mask_index:mask_index+1] = answer_input_ids
-
-        all_token_type_ids = template_token_type_ids
-        all_token_type_ids[mask_index:mask_index+1] = answer_token_type_ids
-
-        all_attention_mask = template_attention_mask
-        all_attention_mask[mask_index:mask_index+1] = answer_attention_mask
-
-        all_labels = template_labels
-        all_labels[mask_index:mask_index+1] = answer_labels
-
-        remaining_space = chunk_size - 2 - len(all_input_ids)
-
-        if len(context_inputs_ids) < remaining_space:
-            num_pad_tokens = remaining_space - len(context_inputs_ids)
-            context_inputs_ids_padded = ([0] * num_pad_tokens) + context_inputs_ids
-            context_token_type_ids = [0] * len(context_inputs_ids_padded)
-            context_attention_mask = ([0] * num_pad_tokens) + ([1] * len(context_inputs_ids))
-            context_labels = [-100] * len(context_inputs_ids_padded)
-
-            all_input_ids = context_inputs_ids_padded + all_input_ids
-            all_token_type_ids = context_token_type_ids + all_token_type_ids
-            all_attention_mask = context_attention_mask + all_attention_mask
-            all_labels = context_labels + all_labels
-
-        elif len(context_inputs_ids) > remaining_space:
-            print('Need to truncate')
-            num_remove = len(context_inputs_ids) - remaining_space
-            context_input_ids_reduced = context_inputs_ids[num_remove:]
-            context_token_type_ids = [0] * len(context_input_ids_reduced)
-            context_attention_mask = ([1] * len(context_input_ids_reduced))
-            context_labels = [-100] * len(context_input_ids_reduced)
-
-            all_input_ids = context_input_ids_reduced + all_input_ids
-            all_token_type_ids = context_token_type_ids + all_token_type_ids
-            all_attention_mask = context_attention_mask + all_attention_mask
-            all_labels = context_labels + all_labels
-
-        all_input_ids = [tokenizer.cls_token_id] + all_input_ids + [tokenizer.sep_token_id]
-        all_token_type_ids = [0] + all_token_type_ids + [0]
-        all_attention_mask = [1] + all_attention_mask + [1]
-        all_labels = [-100] + all_labels + [-100]
-
-        assert len(all_input_ids) == len(all_token_type_ids)
-        assert len(all_input_ids) == len(all_attention_mask)
-        assert len(all_input_ids) == len(all_labels)
-        assert len(all_input_ids) == chunk_size
-
-        result.append({'input_ids': all_input_ids,
-                       'token_type_ids': all_token_type_ids,
-                       'attention_mask': all_attention_mask,
-                       'labels': all_labels})
-
-    return result
-
-
-def main():
-    argument_parser = ArgumentParser()
-
-    argument_parser.add_argument('train_file_path', nargs=1, metavar='train-file-path',
-                                 help='Path to the txt file containing the training examples')
-    argument_parser.add_argument('val_file_path', nargs=1, metavar='val-file-path',
-                                 help='Path to the txt file containing the validation examples')
-    argument_parser.add_argument('--output-folder-path', dest='output_folder_path',
-                                 help='The output folder path. If this option is set the output will be saved to a file'
-                                      ' created in the specified folder')
-
-    args = argument_parser.parse_args()
-
-    train_file_path = args.train_file_path[0]
-    val_file_path = args.val_file_path[0]
-    output_folder_path = args.output_folder_path
-
-    now = datetime.now()
-    date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S')
-    output_folder_path = join(output_folder_path, date_time_string)
-    Path(output_folder_path).mkdir(parents=True, exist_ok=True)
-
-    chunk_size = 100
-    batch_size = 5
-
-    tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-german-dbmdz-uncased", do_lower_case=True)
-    model = transformers.TFBertForMaskedLM.from_pretrained("bert-base-german-dbmdz-uncased", from_pt=True)
-
-    # with open(val_file_path, 'r') as txt_file:
-    #     val_text = txt_file.read()
-
-    with open(train_file_path, 'r', encoding='utf-8') as train_file:
-        train_json = json.load(train_file)
-
-    template = train_json['templates'][0]
-
-    contexts = []
-    answers = []
-
-    for ex in train_json['examples']:
-        contexts.append(ex['context'])
-        answers.append(ex['answer'])
-
-    train_masked_samples = generate_masked_samples(tokenizer, contexts, answers, template, chunk_size, 10)
-    # val_masked_samples = generate_masked_samples(tokenizer, val_text, chunk_size, data_collator)
-
-    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, ignore_class=-100)
-    # perplexity = keras_nlp.metrics.Perplexity(name="perplexity", mask_token_id=tokenizer.mask_token_id,
-    #                                           from_logits=True)
-
-    # optimizer, _ = transformers.create_optimizer(
-    #     init_lr=2e-5,
-    #     num_warmup_steps=1_000,
-    #     num_train_steps=len(train_masked_samples),
-    #     weight_decay_rate=0.01,
-    # )
-
-    # model.compile(loss=loss, optimizer=Adam(0.00002), metrics=['acc'])
-    # model.compile(loss=loss, optimizer=optimizer)
-    model.compile(loss=loss, optimizer=Adam(0.00002))
-
-    checkpoint_model_path = join(output_folder_path, f'bert.h5')
-    model_checkpoint = ModelCheckpoint(checkpoint_model_path, save_best_only=False, save_weights_only=True)
-
-    training_generator = DataGenerator(train_masked_samples, batch_size)
-    # validation_generator = TemplateDataGenerator(val_masked_samples, batch_size)
-
-    model.fit(x=training_generator,
-              steps_per_epoch=int(len(train_masked_samples) // batch_size),
-              # validation_data=validation_generator,
-              # validation_steps=int(len(val_masked_samples['input_ids']) // batch_size),
-              epochs=20, callbacks=[model_checkpoint])
-
-    model.load_weights(checkpoint_model_path)
-    best_model_path = join(output_folder_path, 'bert')
-    model.save_pretrained(best_model_path)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/indiquo/varia/InferenceSummary.py b/indiquo/varia/InferenceSummary.py
deleted file mode 100644
index 635e2eec47b82a76729e2ed73d13a4a7769cc3e9..0000000000000000000000000000000000000000
--- a/indiquo/varia/InferenceSummary.py
+++ /dev/null
@@ -1,258 +0,0 @@
-from transformers import PegasusTokenizer, TFPegasusForConditionalGeneration, AutoTokenizer, AutoModel, TFBartForConditionalGeneration, BartTokenizer
-from bertviz import model_view, head_view
-import torch
-import matplotlib.pyplot as plt
-import math
-import plotly.figure_factory as ff
-import spacy
-import numpy as np
-import attention_graph_util
-
-
-def main():
-    # input_text = "Answer the following yes / no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
-    # input_text = "Summarize: Mitarbeiter, die nicht in der Nähe der Büros wohnen, müssten demzufolge umziehen – dazu würde das Unternehmen Unterstützung anbieten. Für eine Entscheidung habe Grindr, das Unternehmen hinter der auf LGBTQ ausgerichteten Dating-App, den Angestellten eine knapp 14-tägige Frist eingeräumt. Betroffen seien die Mitarbeiter der Abteilungen Design, Technik und Marketing. Inzwischen hätten laut WSJ rund 80 der 178 Beschäftigten gekündigt."
-    input_text = "Ten thousand people are missing after unprecedented flooding in Libya, the Red Cross said on Tuesday, as the extent of the damage to Derna, the port city where two dams burst over the weekend, became more clear. Tamer Ramadan, the Libya envoy for the International Federation of Red Cross and Red Crescent Societies, gave the figure at a UN briefing in Geneva, describing the death toll as “huge”. The health minister in the administration that controls the east of Libya said more than 3,000 people had been confirmed dead. “The number of missing people is in the thousands, and the number of dead is expected to reach 10,000,” Othman Abdel Jalil told Al-Massar TV channel."
-    decoder_input_text = "Unprecedented flooding in Libya has led to the Red Cross reporting approximately 10,000 missing individuals, with over 3,000 confirmed deaths, and expectations that the death toll may reach 10,000, according to the health minister of the eastern administration."
-
-    # input_text = """Fasse folgende Szene aus einem Drama zusammen:
-    # Szene: Flur in Nathans Hause.
-    # Nathan von der Reise kommend.
-    # Daja ihm entgegen.
-    # DAJA.
-    # Er ist es! Nathan! – Gott sei ewig Dank,
-    # Daß Ihr doch endlich einmal wiederkommt.
-    # NATHAN.
-    # Ja, Daja; Gott sei Dank! Doch warum endlich?
-    # Hab' ich denn eher wiederkommen wollen?
-    # Und wiederkommen können? Babylon
-    # Ist von Jerusalem, wie ich den Weg,
-    # Seit ab bald rechts, bald links, zu nehmen bin
-    # Genötigt worden, gut zwei hundert Meilen;
-    # Und Schulden einkassieren, ist gewiß
-    # Auch kein Geschäft, das merklich födert, das
-    # So von der Hand sich schlagen läßt.
-    # DAJA.
-    # O Nathan,
-    # Wie elend, elend hättet Ihr indes
-    # Hier werden können! Euer Haus ...
-    # NATHAN.
-    # Das brannte.
-    # So hab' ich schon vernommen. – Gebe Gott,
-    # Daß ich nur alles schon vernommen habe!
-    # DAJA.
-    # Und wäre leicht von Grund aus abgebrannt.
-    # NATHAN.
-    # Dann, Daja, hätten wir ein neues uns
-    # Gebaut; und ein bequemeres.
-    # DAJA.
-    # Schon wahr! –
-    # Doch Recha wär' bei einem Haare mit
-    # Verbrannt.
-    # NATHAN.
-    # Verbrannt? Wer? meine Recha? sie? –
-    # Das hab' ich nicht gehört. – Nun dann! So hätte
-    # Ich keines Hauses mehr bedurft. – Verbrannt
-    # Bei einem Haare! – Ha! sie ist es wohl!
-    # Ist wirklich wohl verbrannt! – Sag' nur heraus!
-    # Heraus nur! – Töte mich: und martre mich
-    # Nicht länger. – Ja, sie ist verbrannt."""
-
-    # tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
-    # model = AutoModel.from_pretrained("facebook/bart-large-cnn")
-
-    # model = TFPegasusForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
-    # tokenizer = PegasusTokenizer.from_pretrained("facebook/bart-large-cnn")
-
-    model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
-    tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
-
-    encoder_input_ids = tokenizer(input_text, return_tensors="np", add_special_tokens=True).input_ids
-    with tokenizer.as_target_tokenizer():
-        decoder_input_ids = tokenizer(decoder_input_text, return_tensors="np", add_special_tokens=True).input_ids
-
-    inputs = tokenizer(input_text, return_tensors="np")
-    # outputs = model.generate(**inputs, min_new_tokens=100, max_new_tokens=500)
-    outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, output_attentions=True)
-    # print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
-
-    encoder_attentions_pt = [torch.from_numpy(layer_attn.numpy()) for layer_attn in outputs.encoder_attentions]
-    decoder_attentions_pt = [torch.from_numpy(layer_attn.numpy()) for layer_attn in outputs.decoder_attentions]
-    cross_attentions_pt = [torch.from_numpy(layer_attn.numpy()) for layer_attn in outputs.cross_attentions]
-
-    encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
-    decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])
-
-    _attentions = [att.numpy() for att in outputs.encoder_attentions]
-    attentions_mat = np.asarray(_attentions)[:, 0]
-    print(attentions_mat.shape)
-
-    res_att_mat = attentions_mat.sum(axis=1) / attentions_mat.shape[1]
-    res_att_mat = res_att_mat + np.eye(res_att_mat.shape[1])[None, ...]
-    res_att_mat = res_att_mat / res_att_mat.sum(axis=-1)[..., None]
-
-    joint_attentions = attention_graph_util.compute_joint_attention(res_att_mat, add_residual=False)
-    joint_att_adjmat, joint_labels_to_index = attention_graph_util.get_adjmat(mat=joint_attentions, input_tokens=encoder_text)
-
-    G = attention_graph_util.draw_attention_graph(joint_att_adjmat, joint_labels_to_index, n_layers=joint_attentions.shape[0], length=joint_attentions.shape[-1])
-
-    # res_adj_mat, res_labels_to_index = attention_graph_util.get_adjmat(mat=res_att_mat, input_tokens=encoder_text)
-    # res_G = attention_graph_util.draw_attention_graph(res_adj_mat, res_labels_to_index, n_layers=res_att_mat.shape[0], length=res_att_mat.shape[-1])
-    #
-    # output_nodes = []
-    # input_nodes = []
-    # for key in res_labels_to_index:
-    #     if 'L24' in key:
-    #         output_nodes.append(key)
-    #     if res_labels_to_index[key] < attentions_mat.shape[-1]:
-    #         input_nodes.append(key)
-    #
-    # flow_values = attention_graph_util.compute_flows(res_G, res_labels_to_index, input_nodes, length=attentions_mat.shape[-1])
-
-    i = 0
-
-    # attention_graph_util.plot_attention_heatmap(attentions_mat.sum(axis=1) / attentions_mat.shape[1], src[ex_id], t_positions=targets[ex_id],
-    #                        se
-    #
-    # plot_attention_heatmap(attentions_mat.sum(axis=1) / attentions_mat.shape[1], src[ex_id], t_positions=targets[ex_id],
-    #                        sentence=sentence)
-
-    # formatted_attention = format_attention(cross_attentions_pt, [0,1,2,3,4])
-    # formatted_attention = format_attention(cross_attentions_pt, [0], [0])
-
-    # formatted_cross_attention = format_attention(cross_attentions_pt)
-    # formatted_encoder_attention = format_attention(encoder_attentions_pt)
-    # formatted_encoder_attention = format_attention(encoder_attentions_pt, layers=[0, 1, 2, 3, 4, 5, 6])
-    # formatted_encoder_attention = format_attention(encoder_attentions_pt)
-
-    # cross_per_word = torch.sum(formatted_cross_attention, (0, 1, 2))
-    # encoder_per_word = torch.sum(formatted_encoder_attention, (0, 1, 2))
-    # per_word = torch.mean(formatted_attention, (0, 1, 2))
-
-    # merged_words = []
-    #
-    # # for pos, (word, weight) in enumerate(zip(encoder_text, cross_per_word)):
-    # #     merged_words.append((word, weight.item()))
-    #
-    # current_tokens = []
-    # current_weights = []
-    #
-    # for pos, (word, weight) in enumerate(zip(encoder_text, cross_per_word)):
-    #     if word.startswith('▁') or word.startswith('Ġ'):
-    #         if len(current_tokens) > 0:
-    #             joined_tokens = ''.join(current_tokens)
-    #             joined_weights = max(current_weights)
-    #
-    #             if joined_weights > 250:
-    #                 joined_weights = 0
-    #
-    #             merged_words.append((joined_tokens, joined_weights))
-    #             current_tokens.clear()
-    #             current_weights.clear()
-    #
-    #         current_tokens.append(word.replace('▁', '').replace('Ġ', ''))
-    #         current_weights.append(weight.item())
-    #     else:
-    #         current_tokens.append(word)
-    #         current_weights.append(weight.item())
-    #
-    # if len(current_tokens) > 0:
-    #     joined_tokens = ''.join(current_tokens)
-    #     joined_weights = max(current_weights)
-    #     if joined_weights > 50:
-    #         joined_weights = 0
-    #     merged_words.append((joined_tokens, joined_weights))
-    #
-    # x_labels = [[]]
-    # y_values = [[]]
-    #
-    # for pos, word in enumerate(merged_words):
-    #     if pos > 0 and pos % 15 == 0:
-    #         x_labels.append([])
-    #         y_values.append([])
-    #
-    #     x_labels[-1].append(word[0])
-    #     # y_values[-1].append(math.log(weight.item()))
-    #     y_values[-1].append(word[1])
-    #
-    # diff = 15 - len(x_labels[-1])
-    #
-    # if diff > 0:
-    #     x_labels[-1].extend([''] * diff)
-    #     y_values[-1].extend([0] * diff)
-    #
-    # save_plotly(x_labels, y_values)
-    #
-    # res = sorted(merged_words, key=lambda x: x[1], reverse=True)
-    #
-    # for i in res:
-    #     print(f'{i[0]}: {i[1]}')
-
-    # cm = 1 / 2.54
-    # plt.rcParams["figure.figsize"] = (20 * cm, 8 * cm)
-    # plt.rcParams['font.size'] = 9
-    #
-    # plt.scatter(x_labels, y_values, color='0.0', marker='x')
-    #
-    # plt.xticks(rotation=90)
-    # # plt.ylabel('F\u2081-score')
-    # # plt.ylim([0, 1.1])
-    # # plt.legend()
-    # plt.tight_layout()
-    # plt.show()
-
-    # plt.savefig(join(graph_output_path, f'{graph_prefix}_link_eval.pdf'))
-
-    # html_model_view = head_view(
-    #     encoder_attention=encoder_attentions_pt,
-    #     decoder_attention=decoder_attentions_pt,
-    #     cross_attention=cross_attentions_pt,
-    #     encoder_tokens=encoder_text,
-    #     decoder_tokens=decoder_text,
-    #     html_action='return',
-    #     include_layers=[5]
-    # )
-    #
-    # with open("/Users/frede/Arbeit/HU/Indirect_citations/bertviz/model_view.html", 'w') as file:
-    #     file.write(html_model_view.data)
-
-
-def save_plotly(x_labels, y_values):
-    # colorscale = [[0.0, 'rgb(255,255,255)'], [100.0, 'rgb(255, 255, 153)'],
-    #               [500.0, 'rgb(153, 255, 204)'], [2000.0, 'rgb(179, 217, 255)'],
-    #               [3000.0, 'rgb(240, 179, 255)'], [5000.0, 'rgb(255, 77, 148)']]
-
-    # fig = ff.create_annotated_heatmap(color[::-1], annotation_text=symbol[::-1], text=hover[::-1],
-    #                                   colorscale=colorscale, font_colors=['black'], hoverinfo='text')
-
-    fig = ff.create_annotated_heatmap(y_values[::-1], annotation_text=x_labels[::-1], colorscale='rdylbu',
-                                      reversescale=True, showscale=True)
-
-    fig.update_layout(
-        title_text='Test',
-        # margin=dict(l=10, r=10, t=10, b=10, pad=10),
-        xaxis=dict(zeroline=False, showgrid=False),
-        yaxis=dict(zeroline=False, showgrid=False, scaleanchor="x"),
-    )
-    fig.show()
-
-
-def format_attention(attention, layers=None, heads=None):
-    if layers:
-        attention = [attention[layer_index] for layer_index in layers]
-    squeezed = []
-    for layer_attention in attention:
-        # 1 x num_heads x seq_len x seq_len
-        if len(layer_attention.shape) != 4:
-            raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
-                             "output_attentions=True when initializing your model.")
-        layer_attention = layer_attention.squeeze(0)
-        if heads:
-            layer_attention = layer_attention[heads]
-        squeezed.append(layer_attention)
-    # num_layers x num_heads x seq_len x seq_len
-    return torch.stack(squeezed)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/indiquo/varia/MaskedBertTraining.py b/indiquo/varia/MaskedBertTraining.py
deleted file mode 100644
index 7d8b12552549822a2bc8873ff99565b0138cf89f..0000000000000000000000000000000000000000
--- a/indiquo/varia/MaskedBertTraining.py
+++ /dev/null
@@ -1,143 +0,0 @@
-from argparse import ArgumentParser
-
-import tensorflow as tf
-import transformers
-from tensorflow.keras.utils import Sequence
-from tensorflow.keras.callbacks import ModelCheckpoint
-import numpy as np
-from tensorflow.keras.optimizers.legacy import Adam
-import keras_nlp
-from datetime import datetime
-from os.path import join
-from pathlib import Path
-
-
-class DataGenerator(Sequence):
-
-    def __init__(self, examples, batch_size):
-        self.examples = examples
-        self.batch_size = batch_size
-
-    def __len__(self):
-        return (np.ceil(len(self.examples['input_ids']) / self.batch_size)).astype(np.int32)
-
-    def __getitem__(self, idx):
-        input_ids = []
-        attention_masks = []
-        token_type_ids = []
-        labels = []
-
-        for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
-            if i >= len(self.examples['input_ids']):
-                break
-
-            input_ids.append(self.examples['input_ids'][i].tolist())
-            attention_masks.append(self.examples['attention_mask'][i].tolist())
-            token_type_ids.append(self.examples['token_type_ids'][i].tolist())
-            labels.append(self.examples['labels'][i].tolist())
-
-        input_ids_np = np.array(input_ids, dtype="int32")
-        attention_masks_np = np.array(attention_masks, dtype="int32")
-        token_type_ids_np = np.array(token_type_ids, dtype="int32")
-        labels_np = np.array(labels, dtype="int32")
-
-        return [input_ids_np, attention_masks_np, token_type_ids_np], labels_np
-
-
-def generate_masked_samples(tokenizer, text, chunk_size, data_collator):
-    inputs = tokenizer(text)
-
-    inputs = {'input_ids': inputs['input_ids'],
-              'token_type_ids': inputs['token_type_ids'],
-              'attention_mask': inputs['attention_mask']}
-
-    result = {
-        k: [t[i:i + chunk_size] for i in range(0, len(t), chunk_size)]
-        for k, t in inputs.items()
-    }
-
-    result["labels"] = result["input_ids"].copy()
-    samples = []
-
-    for i in range(0, len(result['input_ids']) - 1):
-        samples.append({'input_ids': result['input_ids'][i],
-                        'token_type_ids': result['token_type_ids'][i],
-                        'attention_mask': result['attention_mask'][i],
-                        'labels': result['labels'][i]})
-
-    masked_samples = data_collator(samples)
-    return masked_samples
-
-
-def main():
-    argument_parser = ArgumentParser()
-
-    argument_parser.add_argument('train_file_path', nargs=1, metavar='train-file-path',
-                                 help='Path to the txt file containing the training examples')
-    argument_parser.add_argument('val_file_path', nargs=1, metavar='val-file-path',
-                                 help='Path to the txt file containing the validation examples')
-    argument_parser.add_argument('--output-folder-path', dest='output_folder_path',
-                                 help='The output folder path. If this option is set the output will be saved to a file'
-                                      ' created in the specified folder')
-
-    args = argument_parser.parse_args()
-
-    train_file_path = args.train_file_path[0]
-    val_file_path = args.val_file_path[0]
-    output_folder_path = args.output_folder_path
-
-    now = datetime.now()
-    date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S')
-    output_folder_path = join(output_folder_path, date_time_string)
-    Path(output_folder_path).mkdir(parents=True, exist_ok=True)
-
-    chunk_size = 512
-    batch_size = 8
-
-    tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-german-dbmdz-uncased", do_lower_case=True)
-    model = transformers.TFBertForMaskedLM.from_pretrained("bert-base-german-dbmdz-uncased", from_pt=True)
-    data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
-
-    with open(train_file_path, 'r') as txt_file:
-        train_text = txt_file.read()
-
-    with open(val_file_path, 'r') as txt_file:
-        val_text = txt_file.read()
-
-    train_masked_samples = generate_masked_samples(tokenizer, train_text, chunk_size, data_collator)
-    val_masked_samples = generate_masked_samples(tokenizer, val_text, chunk_size, data_collator)
-
-    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, ignore_class=-100)
-    # perplexity = keras_nlp.metrics.Perplexity(name="perplexity", mask_token_id=tokenizer.mask_token_id,
-    #                                           from_logits=True)
-
-    optimizer, _ = transformers.create_optimizer(
-        init_lr=2e-5,
-        num_warmup_steps=1_000,
-        num_train_steps=len(train_masked_samples['input_ids']),
-        weight_decay_rate=0.01,
-    )
-
-    # model.compile(loss=loss, optimizer=Adam(0.00002), metrics=['acc'])
-    model.compile(loss=loss, optimizer=optimizer)
-
-    checkpoint_model_path = join(output_folder_path, f'bert.h5')
-    model_checkpoint = ModelCheckpoint(checkpoint_model_path, monitor='val_loss', mode='min', save_best_only=True,
-                                       save_weights_only=True)
-
-    training_generator = DataGenerator(train_masked_samples, batch_size)
-    validation_generator = DataGenerator(val_masked_samples, batch_size)
-
-    model.fit(x=training_generator,
-              steps_per_epoch=int(len(train_masked_samples['input_ids']) // batch_size),
-              validation_data=validation_generator,
-              validation_steps=int(len(val_masked_samples['input_ids']) // batch_size),
-              epochs=3, callbacks=[model_checkpoint])
-
-    model.load_weights(checkpoint_model_path)
-    best_model_path = join(output_folder_path, 'bert')
-    model.save_pretrained(best_model_path)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/indiquo/varia/attention_graph_util.py b/indiquo/varia/attention_graph_util.py
deleted file mode 100644
index 942f63285d3b62b67e043ef351e3015e376c4624..0000000000000000000000000000000000000000
--- a/indiquo/varia/attention_graph_util.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import networkx as nx
-import numpy as np
-import tensorflow as tf
-import matplotlib.pyplot as plt
-import os
-import tensorflow as tf
-
-
-def get_adjmat(mat, input_tokens):
-    n_layers, length, _ = mat.shape
-    adj_mat = np.zeros(((n_layers + 1) * length, (n_layers + 1) * length))
-    labels_to_index = {}
-    for k in np.arange(length):
-        labels_to_index[str(k) + "_" + input_tokens[k]] = k
-
-    for i in np.arange(1, n_layers + 1):
-        for k_f in np.arange(length):
-            index_from = (i) * length + k_f
-            label = "L" + str(i) + "_" + str(k_f)
-            labels_to_index[label] = index_from
-            for k_t in np.arange(length):
-                index_to = (i - 1) * length + k_t
-                adj_mat[index_from][index_to] = mat[i - 1][k_f][k_t]
-
-    return adj_mat, labels_to_index
-
-
-def draw_attention_graph(adjmat, labels_to_index, n_layers, length):
-    A = adjmat
-    G = nx.from_numpy_matrix(A, create_using=nx.DiGraph())
-    for i in np.arange(A.shape[0]):
-        for j in np.arange(A.shape[1]):
-            nx.set_edge_attributes(G, {(i, j): A[i, j]}, 'capacity')
-
-    pos = {}
-    label_pos = {}
-    for i in np.arange(n_layers + 1):
-        for k_f in np.arange(length):
-            pos[i * length + k_f] = ((i + 0.4) * 2, length - k_f)
-            label_pos[i * length + k_f] = (i * 2, length - k_f)
-
-    index_to_labels = {}
-    for key in labels_to_index:
-        index_to_labels[labels_to_index[key]] = key.split("_")[-1]
-        if labels_to_index[key] >= length:
-            index_to_labels[labels_to_index[key]] = ''
-
-    # plt.figure(1,figsize=(20,12))
-
-    nx.draw_networkx_nodes(G, pos, node_color='green', node_size=50)
-    nx.draw_networkx_labels(G, pos=label_pos, labels=index_to_labels, font_size=18)
-
-    all_weights = []
-    # 4 a. Iterate through the graph nodes to gather all the weights
-    for (node1, node2, data) in G.edges(data=True):
-        all_weights.append(data['weight'])  # we'll use this when determining edge thickness
-
-    # 4 b. Get unique weights
-    unique_weights = list(set(all_weights))
-
-    # 4 c. Plot the edges - one by one!
-    for weight in unique_weights:
-        # 4 d. Form a filtered list with just the weight you want to draw
-        weighted_edges = [(node1, node2) for (node1, node2, edge_attr) in G.edges(data=True) if
-                          edge_attr['weight'] == weight]
-        # 4 e. I think multiplying by [num_nodes/sum(all_weights)] makes the graphs edges look cleaner
-
-        w = weight  # (weight - min(all_weights))/(max(all_weights) - min(all_weights))
-        width = w
-        nx.draw_networkx_edges(G, pos, edgelist=weighted_edges, width=width, edge_color='darkblue')
-
-    return G
-
-
-def compute_flows(G, labels_to_index, input_nodes, length):
-    number_of_nodes = len(labels_to_index)
-    flow_values = np.zeros((number_of_nodes, number_of_nodes))
-    for key in labels_to_index:
-        if key not in input_nodes:
-            current_layer = int(labels_to_index[key] / length)
-            pre_layer = current_layer - 1
-            u = labels_to_index[key]
-            for inp_node_key in input_nodes:
-                v = labels_to_index[inp_node_key]
-                flow_value = nx.maximum_flow_value(G, u, v, flow_func=nx.algorithms.flow.edmonds_karp)
-                flow_values[u][pre_layer * length + v] = flow_value
-            flow_values[u] /= flow_values[u].sum()
-
-    return flow_values
-
-
-def compute_node_flow(G, labels_to_index, input_nodes, output_nodes, length):
-    number_of_nodes = len(labels_to_index)
-    flow_values = np.zeros((number_of_nodes, number_of_nodes))
-    for key in output_nodes:
-        if key not in input_nodes:
-            current_layer = int(labels_to_index[key] / length)
-            pre_layer = current_layer - 1
-            u = labels_to_index[key]
-            for inp_node_key in input_nodes:
-                v = labels_to_index[inp_node_key]
-                flow_value = nx.maximum_flow_value(G, u, v, flow_func=nx.algorithms.flow.edmonds_karp)
-                flow_values[u][pre_layer * length + v] = flow_value
-            flow_values[u] /= flow_values[u].sum()
-
-    return flow_values
-
-
-def compute_joint_attention(att_mat, add_residual=True):
-    if add_residual:
-        residual_att = np.eye(att_mat.shape[1])[None, ...]
-        aug_att_mat = att_mat + residual_att
-        aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[..., None]
-    else:
-        aug_att_mat = att_mat
-
-    joint_attentions = np.zeros(aug_att_mat.shape)
-
-    layers = joint_attentions.shape[0]
-    joint_attentions[0] = aug_att_mat[0]
-    for i in np.arange(1, layers):
-        joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i - 1])
-
-    return joint_attentions
\ No newline at end of file
diff --git a/indiquo/varia/backup/CLMInference.py b/indiquo/varia/backup/CLMInference.py
deleted file mode 100644
index de1cb5ce4c3aa10dd5dda763945d1339dc607543..0000000000000000000000000000000000000000
--- a/indiquo/varia/backup/CLMInference.py
+++ /dev/null
@@ -1,167 +0,0 @@
-from argparse import ArgumentParser
-from os import listdir
-
-import transformers
-import numpy as np
-import tensorflow as tf
-import json
-from os.path import join
-import torch
-
-
-def main():
-    argument_parser = ArgumentParser()
-
-    argument_parser.add_argument('test_file_path', nargs=1, metavar='test-file-path',
-                                 help='Path to the txt file containing the test examples')
-    argument_parser.add_argument('model_path', nargs=1, metavar='model-path',
-                                 help='Path to the txt file containing the test examples')
-
-    args = argument_parser.parse_args()
-    test_file_path = args.test_file_path[0]
-    model_folder_path = args.model_path[0]
-
-    single_model = True
-
-    tokenizer_path = join(model_folder_path, 'tokenizer')
-    tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)
-
-    pad_id = tokenizer.pad_token_id
-    none_id = -10
-    # none_id = tokenizer.additional_special_tokens_ids[tokenizer.additional_special_tokens.index('<none>')]
-    # sep_id = tokenizer.additional_special_tokens_ids[tokenizer.additional_special_tokens.index('<sep>')]
-
-    sub_models = []
-
-    if not single_model:
-        for file_or_folder in listdir(model_folder_path):
-            if not file_or_folder.startswith('template'):
-                continue
-
-            full_path = join(model_folder_path, file_or_folder, 'bert')
-            model = transformers.TFGPT2LMHeadModel.from_pretrained(full_path)
-            sub_models.append(model)
-    else:
-        full_path = join(model_folder_path, 'combined', 'bert')
-        model = transformers.TFGPT2LMHeadModel.from_pretrained(full_path)
-        # model = transformers.TFGPT2LMHeadModel.from_pretrained("dbmdz/german-gpt2")
-        sub_models.append(model)
-
-    with open(test_file_path, 'r', encoding='utf-8') as test_file:
-        train_json = json.load(test_file)
-
-    contexts = []
-    answers = []
-    templates = []
-
-    for template in train_json['templates']:
-        templates.append(template)
-
-    for ex in train_json['examples']:
-        contexts.append(ex['context'])
-
-        answer_list = ex['answer']
-        answer_combined = answer_list[0]
-
-        for i in range(1, len(answer_list)):
-            answer_combined += f' <sep> {answer_list[i]}'
-
-        answers.append(answer_combined)
-
-    for context, answer in zip(contexts, answers):
-        inputs_context = tokenizer(context)
-        inputs_context['input_ids'] += [none_id]
-
-        inputs_list = []
-        for template_pos, template in enumerate(templates):
-            if not single_model and template_pos >= len(sub_models):
-                break
-
-            inputs_list.append(tokenizer.bos_token + ' ' + context + ' ' + template)
-
-        inputs_list = ['Die Hauptstadt von Deutschland ist']
-        predicted_ids = []
-        finished = False
-
-        while not finished:
-            probs_list = []
-
-            if single_model:
-                model = sub_models[0]
-
-                for input_str in inputs_list:
-                    inputs = tokenizer.encode(input_str, return_tensors="pt")
-                    inputs_2 = tokenizer(input_str, return_tensors="np")
-
-                    outputs_test = model.generate(inputs, penalty_alpha=0, top_k=1, max_new_tokens=2,
-                                                  num_beams=1, temperature=0, do_sample=False)
-                    test = tokenizer.batch_decode(outputs_test, skip_special_tokens=False)
-
-                    last_layer_logits = model(**inputs_2).logits[:, -1, :]
-
-                    top_5_tokens = np.argsort(-last_layer_logits[0])[:5].tolist()
-
-                    okl = 1
-
-                    # robs = tf.nn.softmax(last_layer_logits[0]).numpy()
-                    # probs_list.append(probs)
-
-                    # top_logits = transformers.top_k_top_p_filtering(last_layer_logits, top_k=5, top_p=1.0)
-                    # probabilities = tf.nn.softmax(top_logits, dim=-1)
-
-                    # generated_next_token = torch.multinomial(probabilities, num_samples=1)
-                    # generated = torch.cat([inputs, generated_next_token], dim=-1)
-
-            else:
-                for model, inputs in zip(sub_models, inputs_list):
-                    token_logits = model(**inputs).logits
-                    mask_token_indexes = np.argwhere(inputs["input_ids"] == tokenizer.mask_token_id)
-
-                    mask_token_index = mask_token_indexes[0][1]
-                    mask_token_logits = token_logits[0, mask_token_index, :]
-                    probs = tf.nn.softmax(mask_token_logits).numpy()
-                    probs_list.append(probs)
-
-            averaged_probs = np.mean(probs_list, axis=0)
-
-            highest_prob = 0
-            best_token_id = 0
-            best_token_index = 0
-
-            for i, valid_id in enumerate(inputs_context['input_ids']):
-                prob = averaged_probs[valid_id]
-
-                if prob > highest_prob:
-                    highest_prob = prob
-                    best_token_id = valid_id
-                    best_token_index = i
-
-            predicted_ids.append(best_token_id)
-
-            if best_token_id == tokenizer.eos_token_id or len(predicted_ids) > 15:
-                finished = True
-                break
-
-            new_token = tokenizer.decode(best_token_id)
-
-            for i in range(len(inputs_list)):
-                inputs_list[i] += f' {new_token}'
-
-            if none_id in inputs_context['input_ids']:
-                del inputs_context['input_ids'][inputs_context['input_ids'].index(none_id)]
-
-            if pad_id not in inputs_context['input_ids']:
-                inputs_context['input_ids'] += [pad_id]
-
-            if sep_id not in inputs_context['input_ids']:
-                inputs_context['input_ids'] += [sep_id]
-
-            if best_token_id != pad_id and best_token_id != none_id and best_token_id != sep_id:
-                del inputs_context['input_ids'][best_token_index]
-
-        result = tokenizer.decode(predicted_ids)
-        print(f'\n\nExpected: {answer}\nPredicted: {result}\nContext: {context}')
-
-
-if __name__ == '__main__':
-    main()
diff --git a/indiquo/varia/backup/CLMTraining.py b/indiquo/varia/backup/CLMTraining.py
deleted file mode 100644
index dfaea0ae08f0e7b59a5166c83fbc624b3d4eda99..0000000000000000000000000000000000000000
--- a/indiquo/varia/backup/CLMTraining.py
+++ /dev/null
@@ -1,369 +0,0 @@
-import math
-from argparse import ArgumentParser
-
-import tensorflow as tf
-import transformers
-from tensorflow.keras.utils import Sequence, Progbar
-from tensorflow.keras.callbacks import ModelCheckpoint
-from tensorflow.keras.metrics import SparseCategoricalAccuracy
-import numpy as np
-from tensorflow.keras.optimizers.legacy import Adam
-from datetime import datetime
-from os.path import join
-from pathlib import Path
-import json
-import torch
-import random
-
-
-class TemplateDataGenerator(Sequence):
-
-    def __init__(self, examples, batch_size):
-        self.examples = examples
-        self.batch_size = batch_size
-        self.len = self.__len__()
-
-    def __len__(self):
-        return (np.floor(len(self.examples) / self.batch_size)).astype(np.int32)
-
-    def __getitem__(self, idx):
-        input_ids = []
-        attention_masks = []
-        labels = []
-
-        for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
-            if i >= len(self.examples):
-                break
-
-            input_ids.append(self.examples[i]['input_ids'])
-            attention_masks.append(self.examples[i]['attention_mask'])
-            labels.append(self.examples[i]['labels'])
-
-        input_ids_np = np.array(input_ids, dtype="int32")
-        attention_masks_np = np.array(attention_masks, dtype="int32")
-        labels_np = np.array(labels, dtype="int32")
-
-        return [input_ids_np, attention_masks_np], labels_np
-
-    def getitem(self, index):
-        return self.__getitem__(index)
-
-
-def generate_masked_template_samples(tokenizer, contexts, answers, template, chunk_size, num_examples):
-
-    # test = tokenizer.encode('test <|endoftext|>')
-
-    pad_id = tokenizer.pad_token_id
-    # eos_id = tokenizer.eos_token_id
-
-    # template_inputs_ids = tokenizer.encode(template)
-    result = []
-
-    for context, answer_list in zip(contexts, answers):
-        # context_input_ids = tokenizer.encode(context)
-        answer_combined = answer_list[0]
-
-        for i in range(1, len(answer_list)):
-            answer_combined += f' <sep> {answer_list[i]}'
-
-        # answer_input_ids = tokenizer.encode(answer_combined)
-        # stop_id = tokenizer.additional_special_tokens_ids[tokenizer.additional_special_tokens.index('<|stop|>')]
-
-        complete_string = f'{context} {template} {answer_combined} <|endoftext|>'
-        input_ids = tokenizer.encode(complete_string)
-        remaining_space = chunk_size - len(input_ids)
-
-        attention_mask = [1] * len(input_ids)
-        labels = input_ids.copy()
-
-        if remaining_space > 0:
-            num_pad_tokens = remaining_space
-            padding_input_ids = [pad_id] * num_pad_tokens
-
-            # input_ids += padding_input_ids
-            # attention_mask += [0] * num_pad_tokens
-            # labels += [-100] * num_pad_tokens
-
-            input_ids = padding_input_ids + input_ids
-            attention_mask = [0] * num_pad_tokens + attention_mask
-            labels = [-100] * num_pad_tokens + labels
-
-        elif remaining_space < 0:
-            num_remove = -remaining_space
-            print(f'Need to truncate: {num_remove}')
-
-            input_ids = input_ids[num_remove:]
-            attention_mask = attention_mask[num_remove:]
-            labels = labels[num_remove:]
-
-        assert len(input_ids) == len(labels)
-        assert len(input_ids) == len(attention_mask)
-        assert len(input_ids) == chunk_size
-
-        result.append({'input_ids': input_ids,
-                       'attention_mask': attention_mask,
-                       'labels': labels})
-
-        if len(result) == num_examples:
-            break
-
-    return result
-
-
-def train_model(model, template_train_dataset, template_val_dataset, epochs, step_count):
-    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, ignore_class=-100)
-
-    optimizer, _ = transformers.create_optimizer(
-        init_lr=2e-5,
-        num_warmup_steps=(step_count * epochs) * 0.05,
-        num_train_steps=step_count * epochs,
-        weight_decay_rate=0.01,
-    )
-
-    model.compile(optimizer=optimizer, loss=model.compute_loss)
-
-    metrics_names = []
-    train_acc_metric = SparseCategoricalAccuracy()
-    val_acc_metric = SparseCategoricalAccuracy()
-
-    for epoch in range(epochs):
-        print(f'\nEpoch {epoch + 1}/{epochs}')
-        progbar = Progbar(step_count, stateful_metrics=metrics_names)
-
-        for step, (template_x_batch, template_y_batch) in enumerate(template_train_dataset):
-            with tf.GradientTape() as tape:
-                template_input_ids = template_x_batch[0]
-                template_output = model(input_ids=template_input_ids,
-                                        attention_mask=template_x_batch[1],
-                                        training=True)
-
-                # loss = model.compute_loss(template_y_batch, template_output.logits)
-
-                loss = loss_fn(template_y_batch, template_output.logits)
-
-                # loss_per_mask = tf.keras.losses.sparse_categorical_crossentropy(template_y_batch,
-                #                                                                 template_output.logits,
-                #                                                                from_logits=True, ignore_class=-100)
-                # loss_list = []
-                # for tids, lmp, logits in zip(template_input_ids, loss_per_mask, template_output.logits):
-                #     # new_loss_mean = tf.reduce_mean(loss_per_mask)
-                #     template_loss_value = tf.math.divide_no_nan(tf.reduce_sum(weighted_losses), len(mask_token_indexes))
-                #     loss_list.append(template_loss_value)
-                #
-                # template_loss_value = sum(loss_list) / len(loss_list)
-                # test_loss = loss_fn(template_y_batch, template_output.logits)
-
-                # loss = template_loss_value
-
-            sample_weights = np.ones(template_y_batch.shape)
-
-            for i in range(len(sample_weights)):
-                sample_weights[i][template_y_batch[i] == -100] = 0
-
-            train_acc_metric.update_state(template_y_batch, template_output.logits, sample_weights)
-            train_acc = train_acc_metric.result()
-            values = [('loss', loss / (step + 1)), ('acc', train_acc)]
-            progbar.update(step, values=values)
-
-            grads = tape.gradient(loss, model.trainable_weights)
-            optimizer.apply_gradients(zip(grads, model.trainable_weights))
-
-        train_acc_metric.reset_states()
-
-        for step, (template_x_batch, template_y_batch) in enumerate(template_val_dataset):
-            model_output = model(input_ids=template_x_batch[0],
-                                 attention_mask=template_x_batch[1],
-                                 training=False)
-
-            sample_weights = np.ones(template_y_batch.shape)
-            for i in range(len(sample_weights)):
-                sample_weights[i][template_y_batch[i] == -100] = 0
-            val_acc_metric.update_state(template_y_batch, model_output.logits, sample_weights)
-
-        val_acc = val_acc_metric.result()
-        val_acc_metric.reset_states()
-        values = [('val_acc', val_acc)]
-        progbar.update(step_count, values=values)
-
-    return model
-
-
-def prepare_datasets(train_template_masked_samples, val_template_masked_samples, template_batch_size,
-                     template_chunk_size):
-    template_training_generator = TemplateDataGenerator(train_template_masked_samples, template_batch_size)
-    template_val_generator = TemplateDataGenerator(val_template_masked_samples, template_batch_size)
-
-    def template_training_generator_helper():
-        for i in range(template_training_generator.len):
-            item = template_training_generator.getitem(i)
-            yield item
-
-    def template_val_generator_helper():
-        for i in range(template_val_generator.len):
-            item = template_val_generator.getitem(i)
-            yield item
-
-    template_train_dataset = tf.data.Dataset.from_generator(template_training_generator_helper, output_signature=(
-        tf.TensorSpec(shape=(2, template_batch_size, template_chunk_size), dtype=tf.int32),
-        tf.TensorSpec(shape=(template_batch_size, template_chunk_size), dtype=tf.int32)
-    ))
-
-    template_val_dataset = tf.data.Dataset.from_generator(template_val_generator_helper, output_signature=(
-        tf.TensorSpec(shape=(2, template_batch_size, template_chunk_size), dtype=tf.int32),
-        tf.TensorSpec(shape=(template_batch_size, template_chunk_size), dtype=tf.int32)
-    ))
-
-    return template_train_dataset, template_val_dataset
-
-
-def main():
-    argument_parser = ArgumentParser()
-
-    argument_parser.add_argument('template_train_file_path', nargs=1, metavar='template-train-file-path',
-                                 help='Path to the txt file containing the training examples')
-    argument_parser.add_argument('val_file_path', nargs=1, metavar='val-file-path',
-                                 help='Path to the txt file containing the validation examples')
-    argument_parser.add_argument('--output-folder-path', dest='output_folder_path',
-                                 help='The output folder path. If this option is set the output will be saved to a file'
-                                      ' created in the specified folder')
-
-    args = argument_parser.parse_args()
-
-    template_train_file_path = args.template_train_file_path[0]
-    val_file_path = args.val_file_path[0]
-    output_folder_path = args.output_folder_path
-
-    now = datetime.now()
-    date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S')
-    output_folder_path = join(output_folder_path, date_time_string)
-    Path(output_folder_path).mkdir(parents=True, exist_ok=True)
-
-    template_chunk_size = 160
-    template_batch_size = 4
-    epochs = 20
-    train_steps = 1500
-    num_training_examples = 64
-    num_val_examples = 32
-
-    # model_types = ['combined', 'template']
-    model_types = ['combined']
-    # model_types = ['template']
-
-    PAD_TOKEN = "<|pad|>"
-    # EOS_TOKEN = '<|endoftext|>'
-
-    tokenizer = transformers.AutoTokenizer.from_pretrained(
-        "dbmdz/german-gpt2",
-        # bos_token=EOS_TOKEN,
-        # eos_token=EOS_TOKEN,
-        pad_token=PAD_TOKEN,
-        max_length=template_chunk_size)
-
-    special_tokens_dict = {'additional_special_tokens': ['<none>', '<sep>']}
-    tokenizer.add_special_tokens(special_tokens_dict)
-
-    tokenizer_dir = join(output_folder_path, 'tokenizer')
-    Path(tokenizer_dir).mkdir(parents=True, exist_ok=True)
-    tokenizer.save_pretrained(tokenizer_dir)
-
-    with open(val_file_path, 'r', encoding='utf-8') as val_file:
-        val_json = json.load(val_file)
-
-    with open(template_train_file_path, 'r', encoding='utf-8') as train_file:
-        train_json = json.load(train_file)
-
-    train_contexts = []
-    train_answers = []
-
-    for ex in train_json['examples']:
-        train_contexts.append(ex['context'])
-        train_answers.append(ex['answer'])
-
-    val_contexts = []
-    val_answers = []
-
-    for ex in val_json['examples']:
-        val_contexts.append(ex['context'])
-        val_answers.append(ex['answer'])
-
-    if 'combined' in model_types:
-        all_train_template_masked_samples = []
-        all_val_template_masked_samples = []
-
-        for template in train_json['templates']:
-            train_template_masked_samples = generate_masked_template_samples(tokenizer, train_contexts, train_answers,
-                                                                             template, template_chunk_size,
-                                                                             num_training_examples)
-            val_masked_samples = generate_masked_template_samples(tokenizer, val_contexts, val_answers, template,
-                                                                  template_chunk_size, num_val_examples)
-
-            all_train_template_masked_samples.extend(train_template_masked_samples)
-            all_val_template_masked_samples.extend(val_masked_samples)
-
-        random.shuffle(all_train_template_masked_samples)
-        random.shuffle(all_val_template_masked_samples)
-
-        # ratio = math.floor(len(all_train_template_masked_samples) / num_training_examples)
-
-        steps_per_epoch = math.floor(len(all_train_template_masked_samples) / template_batch_size)
-
-        if train_steps > 0:
-            epochs = train_steps // steps_per_epoch
-
-        template_train_dataset, template_val_dataset = \
-            prepare_datasets(all_train_template_masked_samples, all_val_template_masked_samples,
-                             template_batch_size, template_chunk_size)
-
-        model = transformers.TFGPT2LMHeadModel.from_pretrained(
-            "dbmdz/german-gpt2",
-            pad_token_id=tokenizer.pad_token_id,
-            # eos_token_id=tokenizer.eos_token_id,
-        )
-        model.resize_token_embeddings(len(tokenizer))
-
-        # unk_tok_emb = model.transformer.wte.weight[tokenizer.unk_token_id, :]
-        # for i in range(num_added_special_tokens):
-        #     model.transformer.wte.weight[-(i + 1), :] = unk_tok_emb
-
-        # best_model_path = join(output_folder_path, 'combined', 'bert')
-        # model.save_pretrained(best_model_path)
-
-        model = train_model(model, template_train_dataset, template_val_dataset, epochs,
-                            steps_per_epoch)
-
-        best_model_path = join(output_folder_path, 'combined', 'bert')
-        model.save_pretrained(best_model_path)
-
-    if 'template' in model_types:
-        for template_pos, template in enumerate(train_json['templates']):
-            print(f'\nStarting template {template_pos + 1}')
-            template_output_folder_path = join(output_folder_path, f'template_{template_pos + 1}')
-            Path(template_output_folder_path).mkdir(parents=True, exist_ok=True)
-
-            train_template_masked_samples = generate_masked_template_samples(tokenizer, train_contexts, train_answers,
-                                                                             template, template_chunk_size,
-                                                                             num_training_examples)
-            val_masked_samples = generate_masked_template_samples(tokenizer, val_contexts, val_answers, template,
-                                                                  template_chunk_size, num_val_examples)
-
-            steps_per_epoch = math.floor(len(train_template_masked_samples) / template_batch_size)
-
-            if train_steps > 0:
-                epochs = train_steps // steps_per_epoch
-
-            template_train_dataset, template_val_dataset = \
-                prepare_datasets(train_template_masked_samples, val_masked_samples, template_batch_size,
-                                 template_chunk_size)
-
-            model = transformers.TFBertForMaskedLM.from_pretrained("bert-base-german-dbmdz-uncased", from_pt=True)
-            # model = transformers.TFBertForMaskedLM.from_pretrained("deepset/gbert-large", from_pt=True)
-            model.resize_token_embeddings(len(tokenizer))
-
-            model = train_model(model, template_train_dataset, template_val_dataset, epochs, steps_per_epoch)
-
-            best_model_path = join(template_output_folder_path, 'bert')
-            model.save_pretrained(best_model_path)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/indiquo/varia/backup/MLMInference.py b/indiquo/varia/backup/MLMInference.py
deleted file mode 100644
index 9b16a844de7215234d590d73553b491dbf730311..0000000000000000000000000000000000000000
--- a/indiquo/varia/backup/MLMInference.py
+++ /dev/null
@@ -1,109 +0,0 @@
-from argparse import ArgumentParser
-from os import listdir
-
-import transformers
-import json
-from os.path import join
-import csv
-from pathlib import Path
-
-from indiquo.core.reference.ReferencePredictor import ReferencePredictor
-from indiquo.model.reference.ReferenceVectorizer import ReferenceVectorizer
-
-
-def main():
-    argument_parser = ArgumentParser()
-
-    argument_parser.add_argument('test_file_path', nargs=1, metavar='test-file-path',
-                                 help='Path to the txt file containing the test examples')
-    argument_parser.add_argument('model_path', nargs=1, metavar='model-path',
-                                 help='Path to the txt file containing the test examples')
-    argument_parser.add_argument('output_path', nargs=1, metavar='output-path',
-                                 help='')
-    argument_parser.add_argument('--configs', nargs='+', dest='configs',
-                                 help='TBD')
-
-    args = argument_parser.parse_args()
-    test_file_path = args.test_file_path[0]
-    model_folder_path = args.model_path[0]
-    output_folder_path = args.output_path[0]
-    configs = args.configs
-
-    mask_count = 10
-    lower_case = True
-
-    left_to_right = True
-    only_next_token = True
-    use_sep = False
-
-    average_mode = 'all'
-    # average_mode = 'context'
-
-    max_length = 160
-
-    for config in configs:
-        config_path: str = join(model_folder_path, config)
-        tokenizer_path = join(config_path, 'tokenizer')
-        ref_vectorizer = ReferenceVectorizer.from_saved(tokenizer_path, max_length, lower_case, mask_count, use_sep)
-
-        sub_models = []
-        sub_folders = listdir(config_path)
-        single_model = False
-
-        if 'model' in sub_folders:
-            single_model = True
-            full_path = join(config_path, 'model')
-            model = transformers.TFBertForMaskedLM.from_pretrained(full_path)
-            sub_models.append(model)
-        else:
-            for file_or_folder in listdir(config_path):
-                if not file_or_folder.startswith('model_'):
-                    continue
-
-                full_path = join(config_path, file_or_folder)
-                model = transformers.TFBertForMaskedLM.from_pretrained(full_path)
-                sub_models.append(model)
-
-        with open(test_file_path, 'r', encoding='utf-8') as test_file:
-            train_json = json.load(test_file)
-
-        templates = train_json['templates']
-
-        mlm_predictor = ReferencePredictor(sub_models, ref_vectorizer, templates, single_model, left_to_right, mask_count,
-                                           average_mode, only_next_token)
-        mlm_predictor.expand_templates()
-
-        contexts = []
-        answers = []
-        for ex in train_json['examples']:
-            contexts.append(ex['context'])
-
-            answer_list = []
-
-            for ans in ex['answer']:
-                answer_list.append(ans['text'])
-
-            answer_combined = answer_list[0]
-
-            for i in range(1, len(answer_list)):
-                answer_combined += f' <sep> {answer_list[i]}'
-
-            answers.append(answer_combined)
-
-        for ex_count, (context, answer) in enumerate(zip(contexts, answers)):
-            ref_prediction = mlm_predictor.predict(context)
-            print(f'\n\nExpected: {answer}\nPredicted: {ref_prediction.predicted_text}\nContext: {context}')
-
-            output_config_path = join(output_folder_path, config)
-            Path(output_config_path).mkdir(parents=True, exist_ok=True)
-
-            with open(join(output_config_path, f'file_{ex_count + 1}.tsv'), "w", encoding='utf-8') as out_file:
-                writer = csv.writer(out_file, delimiter="\t", lineterminator="\n")
-                writer.writerow(['start', 'end', 'text'])
-
-                for ref in ref_prediction.references:
-                    writer.writerow([ref.start, ref.end, ref.text])
-
-
-if __name__ == '__main__':
-    main()
diff --git a/indiquo/varia/backup/__init__.py b/indiquo/varia/backup/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/indiquo/varia/persons/DramaTest.py b/indiquo/varia/persons/DramaTest.py
deleted file mode 100644
index 2d41d08187e0f42953db53b3b392e0daf46a19c5..0000000000000000000000000000000000000000
--- a/indiquo/varia/persons/DramaTest.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import sys
-from argparse import ArgumentParser
-from os.path import isfile, join, isdir, splitext, basename
-from os import listdir
-from Person import Person
-from Scene import Scene
-from SpeechCount import SpeechCount
-import json
-import re
-
-
-def __json_decoder_person(json_input):
-    if 'name' in json_input:
-        return Person(json_input['id'], json_input['name'], json_input['sex'], [], set(json_input['scenes']))
-    elif 'token_count' in json_input:
-        return SpeechCount(json_input['start'], json_input['end'], json_input['token_count'])
-    else:
-        return Scene(json_input['id'], json_input['start'], json_input['end'])
-
-
-def get_persons_for_drama(input_path):
-    with open(input_path, 'r', encoding='utf-8') as persons_file:
-        persons = json.load(persons_file, object_hook=__json_decoder_person)
-
-    return persons
-
-
-def process_file(drama_content, persons, target_path):
-
-    if isfile(target_path) and target_path.endswith(".txt"):
-        with open(target_path, 'r', encoding='utf-8') as target_file:
-            target_file_content = target_file.read()
-
-        # filename = splitext(basename(target_path))[0]
-        target_file_content = target_file_content.lower()
-
-        paragraphs = target_file_content.split('\n\n')
-        text_blocks = []
-
-        current_text_block = ''
-        for paragraph in paragraphs:
-            paragraph = paragraph.strip()
-            if len(paragraph) < 5:
-                continue
-
-            if len(paragraph.split()) < 20:
-                current_text_block += f'\n\n{paragraph}'
-            else:
-                if current_text_block:
-                    text_blocks.append(current_text_block)
-                    current_text_block = ''
-
-                text_blocks.append(paragraph)
-
-        # tokens = target_file_content.split()
-
-        for text_block in text_blocks:
-            # text_block = ' '.join(tokens[i:i + 400])
-            best_person1 = None
-            best_person2 = None
-            best_inter_count = 10000
-
-            for person1 in persons:
-                if re.search(rf'{person1.name}', text_block, re.IGNORECASE):
-                    for person2 in persons:
-                        if person1 == person2:
-                            continue
-
-                        if re.search(rf'{person2.name}', text_block, re.IGNORECASE):
-                            inter_count = len(person1.scenes.intersection(person2.scenes))
-
-                            if 0 < inter_count < best_inter_count:
-                                best_person1 = person1
-                                best_person2 = person2
-                                best_inter_count = inter_count
-
-            if best_person1 and best_person2:
-                inter = best_person1.scenes.intersection(best_person2.scenes)
-                print(f'\n\n----------------------------------------\n\n{text_block}')
-                print(f'\n\nPerson 1: {best_person1.name}, Person 2: {best_person2.name}, Szenen: {len(inter)}')
-                for scene in inter:
-                    print(f'\n\nSzene {scene.id}:')
-                    print(drama_content[scene.start:scene.end])
-
-
-def main():
-    argument_parser = ArgumentParser()
-
-    argument_parser.add_argument("-d", "--drama_path", dest="drama_path",
-                                 help="Path to the input xml file or folder", required=True)
-    argument_parser.add_argument("-m", "--mapping_path", dest="mapping_path",
-                                 help="Path to the input xml file or folder", required=True)
-    argument_parser.add_argument("-t", "--target_path", dest="target_path",
-                                 help="Path to the input xml file or folder", required=True)
-
-    # argument_parser.add_argument("-o", "--output_path", dest="output_path",
-    #                              help="Path to the folder for storing the txt files with raw text", required=True)
-
-    args = argument_parser.parse_args()
-
-    drama_path = args.drama_path
-    mapping_path = args.mapping_path
-    target_path = args.target_path
-
-    if isfile(drama_path) and drama_path.endswith(".txt"):
-        with open(drama_path, 'r', encoding='utf-8') as drama_file:
-            drama_content = drama_file.read()
-
-    persons = get_persons_for_drama(mapping_path)
-
-    if isfile(target_path) and target_path.endswith(".txt"):
-        process_file(drama_content, persons, target_path)
-    elif isdir(target_path):
-        for fileOrFolder in listdir(target_path):
-            full_path = join(target_path, fileOrFolder)
-
-            if isfile(full_path) and full_path.endswith(".xml"):
-                process_file(drama_content, persons, full_path)
-
-
-if __name__ == '__main__':
-    sys.exit(main())
diff --git a/indiquo/varia/persons/Person.py b/indiquo/varia/persons/Person.py
deleted file mode 100644
index 5f00e8333c971996e624dff8cf51aacdb4499511..0000000000000000000000000000000000000000
--- a/indiquo/varia/persons/Person.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from dataclasses import dataclass, field
-from typing import List, Set
-
-from Scene import Scene
-from Speech import Speech
-
-
-@dataclass
-class Person:
-    id: str
-    name: str
-    sex: str
-    speeches: List[Speech] = field(default_factory=list)
-    scenes: Set[Scene] = field(default_factory=set)
-
-    def add_speech(self, speech: Speech):
-        self.speeches.append(speech)
-
-    def add_scene(self, scene: Scene):
-        self.scenes.add(scene)
diff --git a/indiquo/varia/persons/Scene.py b/indiquo/varia/persons/Scene.py
deleted file mode 100644
index d139761a73d018e4b0805f2c6c41bdb1bdcc8ec3..0000000000000000000000000000000000000000
--- a/indiquo/varia/persons/Scene.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from dataclasses import dataclass
-
-
-@dataclass(eq=True, frozen=False, unsafe_hash=True)
-class Scene:
-    id: int
-    start: int
-    end: int
diff --git a/indiquo/varia/persons/Speech.py b/indiquo/varia/persons/Speech.py
deleted file mode 100644
index 92329b7dd333c2347a2e8c489ce7d9f3447682e4..0000000000000000000000000000000000000000
--- a/indiquo/varia/persons/Speech.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from dataclasses import dataclass
-
-
-@dataclass
-class Speech:
-    start: int
-    end: int
-    token_count: int
diff --git a/indiquo/varia/persons/SpeechCount.py b/indiquo/varia/persons/SpeechCount.py
deleted file mode 100644
index 24b52b38efdfaf3ee7807920c0f8a2681d7efa59..0000000000000000000000000000000000000000
--- a/indiquo/varia/persons/SpeechCount.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from dataclasses import dataclass
-from Speech import Speech
-
-
-@dataclass
-class SpeechCount(Speech):
-    quoted_token_count: int = 0
diff --git a/indiquo/varia/persons/__init__.py b/indiquo/varia/persons/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000