diff --git a/indiquo/cli/IndiQuoCLI.py b/indiquo/cli/IndiQuoCLI.py
index d51e47b0483ea257714f98cd2a2e70a3ceb0b524..b83ed982f3b7427f97cd9df6e8d94ee346909aeb 100644
--- a/indiquo/cli/IndiQuoCLI.py
+++ b/indiquo/cli/IndiQuoCLI.py
@@ -1,10 +1,13 @@
 import logging
 import sys
-from argparse import ArgumentParser
+from argparse import ArgumentParser, BooleanOptionalAction
 from datetime import datetime
 from os import listdir
 from os.path import join, isfile, splitext, basename, isdir
 from pathlib import Path
+
+from indiquo.training.scene import TrainSceneIdentification
+
 try:
     from flair.models import SequenceTagger
     from indiquo.core.CandidatePredictorRW import CandidatePredictorRW
@@ -32,6 +35,10 @@ def __train_candidate_st(train_folder_path, output_folder_path, model_name):
     TrainCandidateClassifierST.train(train_folder_path, output_folder_path, model_name)
 
 
+def __train_scene(train_folder_path, output_folder_path, model_name):
+    TrainSceneIdentification.train(train_folder_path, output_folder_path, model_name)
+
+
 def __process_file(pro_quo_lm, quid: Quid, indi_quo: IndiQuo, filename, drama, target_text, direct_quotes,
                    output_folder_path):
     print(f'Processing {filename} ...')
@@ -68,7 +75,7 @@ def __process_file(pro_quo_lm, quid: Quid, indi_quo: IndiQuo, filename, drama, t
 
 
 def __run_compare(source_file_path, target_path, candidate_model_path, scene_model_path, direct_quotes_path,
-                  output_folder_path, approach):
+                  output_folder_path, approach, add_context):
     drama_processor = Dramatist()
     drama = drama_processor.from_file(source_file_path)
     sentence_chunker = SentenceChunker(min_length=10, max_length=64, max_sentences=1)
@@ -76,10 +83,11 @@ def __run_compare(source_file_path, target_path, candidate_model_path, scene_mod
     if approach == 'iq':
         candidate_tokenizer = AutoTokenizer.from_pretrained(candidate_model_path)
         candidate_model = AutoModelForSequenceClassification.from_pretrained(candidate_model_path)
-        candidate_predictor = CandidatePredictor(drama, candidate_tokenizer, candidate_model, sentence_chunker)
+        candidate_predictor = CandidatePredictor(drama, candidate_tokenizer, candidate_model, sentence_chunker,
+                                                 add_context)
     elif approach == 'st':
         candidate_model = SentenceTransformer(candidate_model_path)
-        candidate_predictor = CandidatePredictorST(drama, candidate_model, sentence_chunker)
+        candidate_predictor = CandidatePredictorST(drama, candidate_model, sentence_chunker, add_context)
     elif approach == 'rw':
         candidate_model = SequenceTagger.load(candidate_model_path)
         candidate_predictor = CandidatePredictorRW(candidate_model, sentence_chunker)
@@ -147,6 +155,7 @@ def main(argv=None):
     parser_train_candidate.add_argument('--model', dest='model', default='deepset/gbert-large',
                                         help="")
 
+    # TODO: rename to candidate_st or similar
     parser_train_st = subparsers_train_model.add_parser('st', help='', description='')
 
     parser_train_st.add_argument('train_folder_path', nargs=1, metavar='train-folder-path',
@@ -156,6 +165,15 @@ def main(argv=None):
     parser_train_st.add_argument('--model', dest='model', default='deutsche-telekom/gbert-large-paraphrase-cosine',
                                         help="")
 
+    parser_train_scene = subparsers_train_model.add_parser('scene', help='', description='')
+
+    parser_train_scene.add_argument('train_folder_path', nargs=1, metavar='train-folder-path',
+                                 help='Path to the ')
+    parser_train_scene.add_argument('output_folder_path', nargs=1, metavar='output-folder-path',
+                                 help="Path to the input folder")
+    parser_train_scene.add_argument('--model', dest='model', default='deutsche-telekom/gbert-large-paraphrase-cosine',
+                                 help="")
+
     parser_compare = subparsers_command.add_parser('compare', help='', description='')
 
     parser_compare.add_argument("source_file_path", nargs=1, metavar="source-file-path",
@@ -174,6 +192,8 @@ def main(argv=None):
                                      ' ProQuoLM is used to find direct quotes.')
     parser_compare.add_argument('--approach', choices=['st', 'rw', 'iq'], dest='approach',
                                 default='iq', help='TBD')
+    parser_compare.add_argument('--add-context', dest='add_context', default=True,
+                                action=BooleanOptionalAction, help='')
 
     args = argument_parser.parse_args(argv)
 
@@ -181,7 +201,7 @@ def main(argv=None):
     logging.getLogger().setLevel(logging.getLevelName(log_level))
 
     if args.command == 'train':
-        if args.train_model == 'candidate' or args.train_model == 'st':
+        if args.train_model == 'candidate' or args.train_model == 'st' or args.train_model == 'scene':
             train_folder_path = args.train_folder_path[0]
             output_folder_path = args.output_folder_path[0]
             model = args.model
@@ -197,6 +217,8 @@ def main(argv=None):
                 __train_candidate(train_folder_path, output_folder_path, model)
             elif args.train_model == 'st':
                 __train_candidate_st(train_folder_path, output_folder_path, model)
+            elif args.train_model == 'scene':
+                __train_scene(train_folder_path, output_folder_path, model)
 
     elif args.command == 'compare':
         source_file_path = args.source_file_path[0]
@@ -206,6 +228,7 @@ def main(argv=None):
         output_folder_path = args.output_folder_path
         direct_quotes_path = args.direct_quotes_path
         approach = args.approach
+        add_context = args.add_context
 
         now = datetime.now()
         date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S')
@@ -213,7 +236,7 @@ def main(argv=None):
         Path(output_folder_path).mkdir(parents=True, exist_ok=True)
 
         __run_compare(source_file_path, target_path, candidate_model_folder_path, scene_model_folder_path,
-                      direct_quotes_path, output_folder_path, approach)
+                      direct_quotes_path, output_folder_path, approach, add_context)
 
 
 if __name__ == '__main__':
diff --git a/indiquo/core/CandidatePredictor.py b/indiquo/core/CandidatePredictor.py
index 8d3512068531145031c2d1f17d272dd11909fd1b..79c533ca5daf0979f9ceb6d2dbfc2c71b19bf7e0 100644
--- a/indiquo/core/CandidatePredictor.py
+++ b/indiquo/core/CandidatePredictor.py
@@ -14,13 +14,14 @@ class CandidatePredictor(BasePredictor):
     MAX_LENGTH = 128
     SIMILARITY_THRESHOLD = 0.0
 
-    def __init__(self, drama: Drama, tokenizer, model, chunker: BaseChunker):
+    def __init__(self, drama: Drama, tokenizer, model, chunker: BaseChunker, add_context):
         self.drama = drama
         self.tokenizer = tokenizer
         self.model = model
         self.chunker = chunker
         self.all_text_blocks = []
         self.source_text_blocks = []
+        self.add_context = add_context
 
         # for act_nr, act in enumerate(drama.acts):
         #     for scene_nr, scene in enumerate(act.scenes):
@@ -61,7 +62,11 @@ class CandidatePredictor(BasePredictor):
 
         candidates: List[Candidate] = []
         for chunk in filtered_chunks:
-            text = self.__add_context(chunk.text, target_text, chunk.start, chunk.end)
+            if self.add_context:
+                text = self.__add_context(chunk.text, target_text, chunk.start, chunk.end)
+            else:
+                text = chunk.text
+
             sim_result = self.__predict(text)
             if sim_result:
                 candidates.append(Candidate(chunk.start, chunk.end, chunk.text, sim_result))
@@ -69,7 +74,7 @@ class CandidatePredictor(BasePredictor):
         return candidates
 
     def __predict(self, target_text):
-        inputs = self.tokenizer(target_text, return_tensors="pt")
+        inputs = self.tokenizer(target_text, truncation=True, return_tensors="pt")
 
         with torch.no_grad():
             logits = self.model(**inputs).logits
@@ -83,6 +88,9 @@ class CandidatePredictor(BasePredictor):
     def __add_context(self, quote_text, text, quote_start, quote_end):
         rest_len = self.MAX_LENGTH - len(quote_text.split())
 
+        if rest_len < 0:
+            raise Exception('No rest len!')
+
         text_before = text[:quote_start]
         text_after = text[quote_end:]
 
@@ -104,6 +112,9 @@ class CandidatePredictor(BasePredictor):
         count_before = min(round(rest_len / 2), parts_before_count)
         count_after = min(rest_len - count_before, parts_after_count)
 
+        if count_before < 0 or count_after < 0:
+            raise Exception(f'Count before: {count_before}, after: {count_after}')
+
         text_before = ' '.join(parts_before[-count_before:])
         text_after = ' '.join(parts_after[:count_after])
 
diff --git a/indiquo/core/CandidatePredictorST.py b/indiquo/core/CandidatePredictorST.py
index d82b164851b8bd9df52c0c8ee7ac9f881197a62a..45a368c16e701c1a5190207f6ef6bae104b62297 100644
--- a/indiquo/core/CandidatePredictorST.py
+++ b/indiquo/core/CandidatePredictorST.py
@@ -6,18 +6,21 @@ from indiquo.core.BasePredictor import BasePredictor
 from indiquo.core.Candidate import Candidate
 from indiquo.core.chunker.BaseChunker import BaseChunker
 from kpcommons.Footnote import map_to_real_pos, get_footnote_ranges, remove_footnotes
+import re
 
 
 # noinspection PyMethodMayBeStatic
 class CandidatePredictorST(BasePredictor):
+    MAX_LENGTH = 128
     SIMILARITY_THRESHOLD = 0.0
 
-    def __init__(self, drama: Drama, model, chunker: BaseChunker):
+    def __init__(self, drama: Drama, model, chunker: BaseChunker, add_context):
         self.drama = drama
         self.model = model
         self.chunker = chunker
         self.all_text_blocks = []
         self.source_text_blocks = []
+        self.add_context = add_context
 
         for act_nr, act in enumerate(drama.acts):
             for scene_nr, scene in enumerate(act.scenes):
@@ -58,7 +61,13 @@ class CandidatePredictorST(BasePredictor):
 
         candidates: List[Candidate] = []
         for chunk in filtered_chunks:
-            score = self.__predict(chunk.text)
+
+            if self.add_context:
+                text = self.__add_context(chunk.text, target_text, chunk.start, chunk.end)
+            else:
+                text = chunk.text
+
+            score = self.__predict(text)
             if score:
                 candidates.append(Candidate(chunk.start, chunk.end, chunk.text, score))
 
@@ -82,3 +91,39 @@ class CandidatePredictorST(BasePredictor):
             return scene_scores[0][2]
 
         return None
+
+    def __add_context(self, quote_text, text, quote_start, quote_end):
+        rest_len = self.MAX_LENGTH - len(quote_text.split())
+
+        if rest_len < 0:
+            raise Exception('No rest len!')
+
+        text_before = text[:quote_start]
+        text_after = text[quote_end:]
+
+        text_before = text_before.replace('\n', ' ')
+        text_before = text_before.replace('\t', ' ')
+
+        text_after = text_after.replace('\n', ' ')
+        text_after = text_after.replace('\t', ' ')
+
+        text_before = re.sub(r'\[\[\[(?:.|\n)+?]]]', ' ', text_before)
+        text_after = re.sub(r'\[\[\[(?:.|\n)+?]]]', ' ', text_after)
+
+        parts_before = text_before.split()
+        parts_after = text_after.split()
+
+        parts_before_count = len(parts_before)
+        parts_after_count = len(parts_after)
+
+        count_before = min(round(rest_len / 2), parts_before_count)
+        count_after = min(rest_len - count_before, parts_after_count)
+
+        if count_before < 0 or count_after < 0:
+            raise Exception(f'Count before: {count_before}, after: {count_after}')
+
+        text_before = ' '.join(parts_before[-count_before:])
+        text_after = ' '.join(parts_after[:count_after])
+
+        ex_text = f'{text_before} {quote_text} {text_after}'
+        return ex_text
diff --git a/indiquo/core/chunker/SentenceChunker.py b/indiquo/core/chunker/SentenceChunker.py
index 04c687166de1e08958a381a03e908aca0422592c..a0821b2a6145f7c9635625493d5d08072cfe7c01 100644
--- a/indiquo/core/chunker/SentenceChunker.py
+++ b/indiquo/core/chunker/SentenceChunker.py
@@ -32,21 +32,6 @@ class SentenceChunker(BaseChunker):
                 last_chunk.end = ts.end
                 continue
 
-            # TODO: Mit oder ohne Limit? Ohne Limit werden Zitatblöcke besser zusammengehalten, aber es entstehen auch
-            #  teilweise lange Abschnitte.
-
-            # if len(ts.sent) < 25:
-            #     match = re.search(r'[0-9S)][.,]$', ts.sent.strip())
-            #
-            #     if match:
-            #         if prev_chunk:
-            #             prev_chunk.text += ts.sent
-            #             prev_chunk.end = ts.end
-            #         else:
-            #             prev_chunk = Chunk(ts.start, ts.end, ts.sent)
-            #
-            #         continue
-
             if prev_chunk:
                 prev_chunk.text += ts.sent
                 prev_chunk.end = ts.end
diff --git a/indiquo/training/scene/TrainSceneIdentification.py b/indiquo/training/scene/TrainSceneIdentification.py
new file mode 100644
index 0000000000000000000000000000000000000000..496287ff110f9e5e240d5c6dfb315ea260e3ebe6
--- /dev/null
+++ b/indiquo/training/scene/TrainSceneIdentification.py
@@ -0,0 +1,62 @@
+import math
+import random
+from os.path import join
+
+from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
+from torch.utils.data import DataLoader
+import csv
+
+
+def train(train_folder_path, output_folder_path, model_name):
+
+    train_examples = []
+
+    with open(join(train_folder_path, 'train_set.tsv'), 'r') as train_file:
+        reader = csv.reader(train_file, delimiter='\t')
+        # skip first row (header)
+        next(reader, None)
+
+        for row in reader:
+            ie = InputExample(texts=[row[0], row[1]])
+            train_examples.append(ie)
+
+    all_validation_elements = []
+
+    with open(join(train_folder_path, 'val_set.tsv'), 'r') as train_file:
+        reader = csv.reader(train_file, delimiter='\t')
+        # skip first row (header)
+        next(reader, None)
+
+        for row in reader:
+            all_validation_elements.append((row[0], row[1]))
+
+    val_sentences_1 = []
+    val_sentences_2 = []
+    val_labels = []
+
+    for pos, val_item in enumerate(all_validation_elements):
+        val_sentences_1.append(val_item[0])
+        val_sentences_2.append(val_item[1])
+        val_labels.append(1)
+
+        poses = random.sample(range(0, len(all_validation_elements)), 2)
+
+        if poses[0] != pos:
+            other = all_validation_elements[poses[0]]
+        else:
+            other = all_validation_elements[poses[1]]
+
+        val_sentences_1.append(val_item[0])
+        val_sentences_2.append(other[1])
+        val_labels.append(0)
+
+    model = SentenceTransformer(model_name)
+    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
+    train_loss = losses.MultipleNegativesRankingLoss(model=model)
+
+    evaluator = evaluation.BinaryClassificationEvaluator(val_sentences_1, val_sentences_2, val_labels)
+
+    num_epochs = 5
+    warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)  # 10% of train data for warm-up
+    model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=num_epochs, warmup_steps=warmup_steps,
+              evaluator=evaluator, output_path=output_folder_path)
diff --git a/indiquo/training/scene/__init__.py b/indiquo/training/scene/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/indiquo/varia/st/IndiQuo.py b/indiquo/varia/st/IndiQuo.py
index b6ca0446c667c6836f6ae3854b9738e7d9da5ae8..abf011887cfd6c0fd820c22cc208a4cba7774e1e 100644
--- a/indiquo/varia/st/IndiQuo.py
+++ b/indiquo/varia/st/IndiQuo.py
@@ -3,8 +3,6 @@ from argparse import ArgumentParser
 from os.path import join
 
 from sentence_transformers import SentenceTransformer, util
-import plotly.graph_objects as go
-from textwrap import wrap
 import csv
 import random
 
@@ -14,7 +12,7 @@ from indiquo.varia.st.TestItem import TestItem
 
 def main(argv=None):
     TOP_LIMIT = 10
-    MIN_SCORE = 0.7
+    # MIN_SCORE = 0.7
 
     argument_parser = ArgumentParser()
 
@@ -60,8 +58,9 @@ def main(argv=None):
         source_embeddings = model.encode(source_text_blocks, convert_to_tensor=True)
 
         top_count_dict = {}
-        tn_count = 0
-        fn_count = 0
+
+        # tn_count = 0
+        # fn_count = 0
 
         for i in range(1, TOP_LIMIT + 1):
             top_count_dict[i] = 0
@@ -81,20 +80,20 @@ def main(argv=None):
                 start_line, end_line = drama.acts[act_nr].scenes[scene_nr].get_line_range()
                 scene_scores.append((start_line, end_line, score, text))
 
-            if test_item.line_ranges[0][0] > -1:
-                if scene_scores[0][2] < MIN_SCORE:
-                    fn_count += 1
-
-                # print(f'Top Score: {scene_scores[0][2]}')
-
-                for i in range(1, TOP_LIMIT + 1):
-                    if scene_scores[i-1][0] <= test_item.line_ranges[0][0] <= scene_scores[i-1][1]:
-                        # print(f'In Top {i}, {scene_scores[i-1][2]}')
-                        top_count_dict[i] += 1
-                        break
-            else:
-                if scene_scores[0][2] < MIN_SCORE:
-                    tn_count += 1
+            # if test_item.line_ranges[0][0] > -1:
+            #     if scene_scores[0][2] < MIN_SCORE:
+            #         fn_count += 1
+            #
+            #     # print(f'Top Score: {scene_scores[0][2]}')
+            #
+            #     for i in range(1, TOP_LIMIT + 1):
+            #         if scene_scores[i-1][0] <= test_item.line_ranges[0][0] <= scene_scores[i-1][1]:
+            #             # print(f'In Top {i}, {scene_scores[i-1][2]}')
+            #             top_count_dict[i] += 1
+            #             break
+            # else:
+            #     if scene_scores[0][2] < MIN_SCORE:
+            #         tn_count += 1
 
         test_items_cnt = sum(1 for ti in test_items if ti.line_ranges[0][0] > -1)
         total_cnt = 0
@@ -145,40 +144,40 @@ def main(argv=None):
             writer.writerow([perc])
 
 
-def visualize(candidates, act_nr, scene_nr):
-    x_labels = []
-    y_values = []
-
-    for c in candidates:
-        # print(f'\n{c.source_span.text}\n{c.target_span.text}\n{c.score}')
-        y_values.append([c.score])
-        x_labels.append(['<br>'.join(wrap(c.source_span.text, width=150))])
-
-    # fig = px.imshow(y_values, text=x_labels, zmin=0, zmax=1)
-
-    fig = go.Figure(data=go.Heatmap(
-                    z=y_values,
-                    text=x_labels,
-                    texttemplate="%{text}",
-                    colorscale='rdylbu',
-                    zmin=0,
-                    zmax=1,
-                    reversescale=True))
-
-    # fig = ff.create_annotated_heatmap(y_values, annotation_text=x_labels, colorscale='rdylbu', font_colors=['black'],
-    #                                   reversescale=True, showscale=True, zmin=0, zmax=1, zauto=False)
-
-    title_wrapped = '<br>'.join(wrap(candidates[0].target_span.text, width=150))
-    title_wrapped += f'<br>{act_nr}. Akt, {scene_nr}. Szene'
-
-    fig.update_layout(
-        title_text=f'"{title_wrapped}"',
-        # margin=dict(l=10, r=10, t=10, b=10, pad=10),
-        xaxis=dict(zeroline=False, showgrid=False, visible=False),
-        yaxis=dict(zeroline=False, showgrid=False, visible=False),
-    )
-
-    fig.show()
+# def visualize(candidates, act_nr, scene_nr):
+#     x_labels = []
+#     y_values = []
+#
+#     for c in candidates:
+#         # print(f'\n{c.source_span.text}\n{c.target_span.text}\n{c.score}')
+#         y_values.append([c.score])
+#         x_labels.append(['<br>'.join(wrap(c.source_span.text, width=150))])
+#
+#     # fig = px.imshow(y_values, text=x_labels, zmin=0, zmax=1)
+#
+#     fig = go.Figure(data=go.Heatmap(
+#                     z=y_values,
+#                     text=x_labels,
+#                     texttemplate="%{text}",
+#                     colorscale='rdylbu',
+#                     zmin=0,
+#                     zmax=1,
+#                     reversescale=True))
+#
+#     # fig = ff.create_annotated_heatmap(y_values, annotation_text=x_labels, colorscale='rdylbu', font_colors=['black'],
+#     #                                   reversescale=True, showscale=True, zmin=0, zmax=1, zauto=False)
+#
+#     title_wrapped = '<br>'.join(wrap(candidates[0].target_span.text, width=150))
+#     title_wrapped += f'<br>{act_nr}. Akt, {scene_nr}. Szene'
+#
+#     fig.update_layout(
+#         title_text=f'"{title_wrapped}"',
+#         # margin=dict(l=10, r=10, t=10, b=10, pad=10),
+#         xaxis=dict(zeroline=False, showgrid=False, visible=False),
+#         yaxis=dict(zeroline=False, showgrid=False, visible=False),
+#     )
+#
+#     fig.show()
 
 
 def get_candidates(source_sentences_embeddings, target_sentence_embedding, top_k):