diff --git a/indiquo/cli/IndiQuoCLI.py b/indiquo/cli/IndiQuoCLI.py index d51e47b0483ea257714f98cd2a2e70a3ceb0b524..b83ed982f3b7427f97cd9df6e8d94ee346909aeb 100644 --- a/indiquo/cli/IndiQuoCLI.py +++ b/indiquo/cli/IndiQuoCLI.py @@ -1,10 +1,13 @@ import logging import sys -from argparse import ArgumentParser +from argparse import ArgumentParser, BooleanOptionalAction from datetime import datetime from os import listdir from os.path import join, isfile, splitext, basename, isdir from pathlib import Path + +from indiquo.training.scene import TrainSceneIdentification + try: from flair.models import SequenceTagger from indiquo.core.CandidatePredictorRW import CandidatePredictorRW @@ -32,6 +35,10 @@ def __train_candidate_st(train_folder_path, output_folder_path, model_name): TrainCandidateClassifierST.train(train_folder_path, output_folder_path, model_name) +def __train_scene(train_folder_path, output_folder_path, model_name): + TrainSceneIdentification.train(train_folder_path, output_folder_path, model_name) + + def __process_file(pro_quo_lm, quid: Quid, indi_quo: IndiQuo, filename, drama, target_text, direct_quotes, output_folder_path): print(f'Processing {filename} ...') @@ -68,7 +75,7 @@ def __process_file(pro_quo_lm, quid: Quid, indi_quo: IndiQuo, filename, drama, t def __run_compare(source_file_path, target_path, candidate_model_path, scene_model_path, direct_quotes_path, - output_folder_path, approach): + output_folder_path, approach, add_context): drama_processor = Dramatist() drama = drama_processor.from_file(source_file_path) sentence_chunker = SentenceChunker(min_length=10, max_length=64, max_sentences=1) @@ -76,10 +83,11 @@ def __run_compare(source_file_path, target_path, candidate_model_path, scene_mod if approach == 'iq': candidate_tokenizer = AutoTokenizer.from_pretrained(candidate_model_path) candidate_model = AutoModelForSequenceClassification.from_pretrained(candidate_model_path) - candidate_predictor = CandidatePredictor(drama, candidate_tokenizer, candidate_model, sentence_chunker) + candidate_predictor = CandidatePredictor(drama, candidate_tokenizer, candidate_model, sentence_chunker, + add_context) elif approach == 'st': candidate_model = SentenceTransformer(candidate_model_path) - candidate_predictor = CandidatePredictorST(drama, candidate_model, sentence_chunker) + candidate_predictor = CandidatePredictorST(drama, candidate_model, sentence_chunker, add_context) elif approach == 'rw': candidate_model = SequenceTagger.load(candidate_model_path) candidate_predictor = CandidatePredictorRW(candidate_model, sentence_chunker) @@ -147,6 +155,7 @@ def main(argv=None): parser_train_candidate.add_argument('--model', dest='model', default='deepset/gbert-large', help="") + # TODO: rename to candidate_st or similar parser_train_st = subparsers_train_model.add_parser('st', help='', description='') parser_train_st.add_argument('train_folder_path', nargs=1, metavar='train-folder-path', @@ -156,6 +165,15 @@ def main(argv=None): parser_train_st.add_argument('--model', dest='model', default='deutsche-telekom/gbert-large-paraphrase-cosine', help="") + parser_train_scene = subparsers_train_model.add_parser('scene', help='', description='') + + parser_train_scene.add_argument('train_folder_path', nargs=1, metavar='train-folder-path', + help='Path to the ') + parser_train_scene.add_argument('output_folder_path', nargs=1, metavar='output-folder-path', + help="Path to the input folder") + parser_train_scene.add_argument('--model', dest='model', default='deutsche-telekom/gbert-large-paraphrase-cosine', + help="") + parser_compare = subparsers_command.add_parser('compare', help='', description='') parser_compare.add_argument("source_file_path", nargs=1, metavar="source-file-path", @@ -174,6 +192,8 @@ def main(argv=None): ' ProQuoLM is used to find direct quotes.') parser_compare.add_argument('--approach', choices=['st', 'rw', 'iq'], dest='approach', default='iq', help='TBD') + parser_compare.add_argument('--add-context', dest='add_context', default=True, + action=BooleanOptionalAction, help='') args = argument_parser.parse_args(argv) @@ -181,7 +201,7 @@ def main(argv=None): logging.getLogger().setLevel(logging.getLevelName(log_level)) if args.command == 'train': - if args.train_model == 'candidate' or args.train_model == 'st': + if args.train_model == 'candidate' or args.train_model == 'st' or args.train_model == 'scene': train_folder_path = args.train_folder_path[0] output_folder_path = args.output_folder_path[0] model = args.model @@ -197,6 +217,8 @@ def main(argv=None): __train_candidate(train_folder_path, output_folder_path, model) elif args.train_model == 'st': __train_candidate_st(train_folder_path, output_folder_path, model) + elif args.train_model == 'scene': + __train_scene(train_folder_path, output_folder_path, model) elif args.command == 'compare': source_file_path = args.source_file_path[0] @@ -206,6 +228,7 @@ def main(argv=None): output_folder_path = args.output_folder_path direct_quotes_path = args.direct_quotes_path approach = args.approach + add_context = args.add_context now = datetime.now() date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S') @@ -213,7 +236,7 @@ def main(argv=None): Path(output_folder_path).mkdir(parents=True, exist_ok=True) __run_compare(source_file_path, target_path, candidate_model_folder_path, scene_model_folder_path, - direct_quotes_path, output_folder_path, approach) + direct_quotes_path, output_folder_path, approach, add_context) if __name__ == '__main__': diff --git a/indiquo/core/CandidatePredictor.py b/indiquo/core/CandidatePredictor.py index 8d3512068531145031c2d1f17d272dd11909fd1b..79c533ca5daf0979f9ceb6d2dbfc2c71b19bf7e0 100644 --- a/indiquo/core/CandidatePredictor.py +++ b/indiquo/core/CandidatePredictor.py @@ -14,13 +14,14 @@ class CandidatePredictor(BasePredictor): MAX_LENGTH = 128 SIMILARITY_THRESHOLD = 0.0 - def __init__(self, drama: Drama, tokenizer, model, chunker: BaseChunker): + def __init__(self, drama: Drama, tokenizer, model, chunker: BaseChunker, add_context): self.drama = drama self.tokenizer = tokenizer self.model = model self.chunker = chunker self.all_text_blocks = [] self.source_text_blocks = [] + self.add_context = add_context # for act_nr, act in enumerate(drama.acts): # for scene_nr, scene in enumerate(act.scenes): @@ -61,7 +62,11 @@ class CandidatePredictor(BasePredictor): candidates: List[Candidate] = [] for chunk in filtered_chunks: - text = self.__add_context(chunk.text, target_text, chunk.start, chunk.end) + if self.add_context: + text = self.__add_context(chunk.text, target_text, chunk.start, chunk.end) + else: + text = chunk.text + sim_result = self.__predict(text) if sim_result: candidates.append(Candidate(chunk.start, chunk.end, chunk.text, sim_result)) @@ -69,7 +74,7 @@ class CandidatePredictor(BasePredictor): return candidates def __predict(self, target_text): - inputs = self.tokenizer(target_text, return_tensors="pt") + inputs = self.tokenizer(target_text, truncation=True, return_tensors="pt") with torch.no_grad(): logits = self.model(**inputs).logits @@ -83,6 +88,9 @@ class CandidatePredictor(BasePredictor): def __add_context(self, quote_text, text, quote_start, quote_end): rest_len = self.MAX_LENGTH - len(quote_text.split()) + if rest_len < 0: + raise Exception('No rest len!') + text_before = text[:quote_start] text_after = text[quote_end:] @@ -104,6 +112,9 @@ class CandidatePredictor(BasePredictor): count_before = min(round(rest_len / 2), parts_before_count) count_after = min(rest_len - count_before, parts_after_count) + if count_before < 0 or count_after < 0: + raise Exception(f'Count before: {count_before}, after: {count_after}') + text_before = ' '.join(parts_before[-count_before:]) text_after = ' '.join(parts_after[:count_after]) diff --git a/indiquo/core/CandidatePredictorST.py b/indiquo/core/CandidatePredictorST.py index d82b164851b8bd9df52c0c8ee7ac9f881197a62a..45a368c16e701c1a5190207f6ef6bae104b62297 100644 --- a/indiquo/core/CandidatePredictorST.py +++ b/indiquo/core/CandidatePredictorST.py @@ -6,18 +6,21 @@ from indiquo.core.BasePredictor import BasePredictor from indiquo.core.Candidate import Candidate from indiquo.core.chunker.BaseChunker import BaseChunker from kpcommons.Footnote import map_to_real_pos, get_footnote_ranges, remove_footnotes +import re # noinspection PyMethodMayBeStatic class CandidatePredictorST(BasePredictor): + MAX_LENGTH = 128 SIMILARITY_THRESHOLD = 0.0 - def __init__(self, drama: Drama, model, chunker: BaseChunker): + def __init__(self, drama: Drama, model, chunker: BaseChunker, add_context): self.drama = drama self.model = model self.chunker = chunker self.all_text_blocks = [] self.source_text_blocks = [] + self.add_context = add_context for act_nr, act in enumerate(drama.acts): for scene_nr, scene in enumerate(act.scenes): @@ -58,7 +61,13 @@ class CandidatePredictorST(BasePredictor): candidates: List[Candidate] = [] for chunk in filtered_chunks: - score = self.__predict(chunk.text) + + if self.add_context: + text = self.__add_context(chunk.text, target_text, chunk.start, chunk.end) + else: + text = chunk.text + + score = self.__predict(text) if score: candidates.append(Candidate(chunk.start, chunk.end, chunk.text, score)) @@ -82,3 +91,39 @@ class CandidatePredictorST(BasePredictor): return scene_scores[0][2] return None + + def __add_context(self, quote_text, text, quote_start, quote_end): + rest_len = self.MAX_LENGTH - len(quote_text.split()) + + if rest_len < 0: + raise Exception('No rest len!') + + text_before = text[:quote_start] + text_after = text[quote_end:] + + text_before = text_before.replace('\n', ' ') + text_before = text_before.replace('\t', ' ') + + text_after = text_after.replace('\n', ' ') + text_after = text_after.replace('\t', ' ') + + text_before = re.sub(r'\[\[\[(?:.|\n)+?]]]', ' ', text_before) + text_after = re.sub(r'\[\[\[(?:.|\n)+?]]]', ' ', text_after) + + parts_before = text_before.split() + parts_after = text_after.split() + + parts_before_count = len(parts_before) + parts_after_count = len(parts_after) + + count_before = min(round(rest_len / 2), parts_before_count) + count_after = min(rest_len - count_before, parts_after_count) + + if count_before < 0 or count_after < 0: + raise Exception(f'Count before: {count_before}, after: {count_after}') + + text_before = ' '.join(parts_before[-count_before:]) + text_after = ' '.join(parts_after[:count_after]) + + ex_text = f'{text_before} {quote_text} {text_after}' + return ex_text diff --git a/indiquo/core/chunker/SentenceChunker.py b/indiquo/core/chunker/SentenceChunker.py index 04c687166de1e08958a381a03e908aca0422592c..a0821b2a6145f7c9635625493d5d08072cfe7c01 100644 --- a/indiquo/core/chunker/SentenceChunker.py +++ b/indiquo/core/chunker/SentenceChunker.py @@ -32,21 +32,6 @@ class SentenceChunker(BaseChunker): last_chunk.end = ts.end continue - # TODO: Mit oder ohne Limit? Ohne Limit werden Zitatblöcke besser zusammengehalten, aber es entstehen auch - # teilweise lange Abschnitte. - - # if len(ts.sent) < 25: - # match = re.search(r'[0-9S)][.,]$', ts.sent.strip()) - # - # if match: - # if prev_chunk: - # prev_chunk.text += ts.sent - # prev_chunk.end = ts.end - # else: - # prev_chunk = Chunk(ts.start, ts.end, ts.sent) - # - # continue - if prev_chunk: prev_chunk.text += ts.sent prev_chunk.end = ts.end diff --git a/indiquo/training/scene/TrainSceneIdentification.py b/indiquo/training/scene/TrainSceneIdentification.py new file mode 100644 index 0000000000000000000000000000000000000000..496287ff110f9e5e240d5c6dfb315ea260e3ebe6 --- /dev/null +++ b/indiquo/training/scene/TrainSceneIdentification.py @@ -0,0 +1,62 @@ +import math +import random +from os.path import join + +from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation +from torch.utils.data import DataLoader +import csv + + +def train(train_folder_path, output_folder_path, model_name): + + train_examples = [] + + with open(join(train_folder_path, 'train_set.tsv'), 'r') as train_file: + reader = csv.reader(train_file, delimiter='\t') + # skip first row (header) + next(reader, None) + + for row in reader: + ie = InputExample(texts=[row[0], row[1]]) + train_examples.append(ie) + + all_validation_elements = [] + + with open(join(train_folder_path, 'val_set.tsv'), 'r') as train_file: + reader = csv.reader(train_file, delimiter='\t') + # skip first row (header) + next(reader, None) + + for row in reader: + all_validation_elements.append((row[0], row[1])) + + val_sentences_1 = [] + val_sentences_2 = [] + val_labels = [] + + for pos, val_item in enumerate(all_validation_elements): + val_sentences_1.append(val_item[0]) + val_sentences_2.append(val_item[1]) + val_labels.append(1) + + poses = random.sample(range(0, len(all_validation_elements)), 2) + + if poses[0] != pos: + other = all_validation_elements[poses[0]] + else: + other = all_validation_elements[poses[1]] + + val_sentences_1.append(val_item[0]) + val_sentences_2.append(other[1]) + val_labels.append(0) + + model = SentenceTransformer(model_name) + train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16) + train_loss = losses.MultipleNegativesRankingLoss(model=model) + + evaluator = evaluation.BinaryClassificationEvaluator(val_sentences_1, val_sentences_2, val_labels) + + num_epochs = 5 + warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up + model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=num_epochs, warmup_steps=warmup_steps, + evaluator=evaluator, output_path=output_folder_path) diff --git a/indiquo/training/scene/__init__.py b/indiquo/training/scene/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indiquo/varia/st/IndiQuo.py b/indiquo/varia/st/IndiQuo.py index b6ca0446c667c6836f6ae3854b9738e7d9da5ae8..abf011887cfd6c0fd820c22cc208a4cba7774e1e 100644 --- a/indiquo/varia/st/IndiQuo.py +++ b/indiquo/varia/st/IndiQuo.py @@ -3,8 +3,6 @@ from argparse import ArgumentParser from os.path import join from sentence_transformers import SentenceTransformer, util -import plotly.graph_objects as go -from textwrap import wrap import csv import random @@ -14,7 +12,7 @@ from indiquo.varia.st.TestItem import TestItem def main(argv=None): TOP_LIMIT = 10 - MIN_SCORE = 0.7 + # MIN_SCORE = 0.7 argument_parser = ArgumentParser() @@ -60,8 +58,9 @@ def main(argv=None): source_embeddings = model.encode(source_text_blocks, convert_to_tensor=True) top_count_dict = {} - tn_count = 0 - fn_count = 0 + + # tn_count = 0 + # fn_count = 0 for i in range(1, TOP_LIMIT + 1): top_count_dict[i] = 0 @@ -81,20 +80,20 @@ def main(argv=None): start_line, end_line = drama.acts[act_nr].scenes[scene_nr].get_line_range() scene_scores.append((start_line, end_line, score, text)) - if test_item.line_ranges[0][0] > -1: - if scene_scores[0][2] < MIN_SCORE: - fn_count += 1 - - # print(f'Top Score: {scene_scores[0][2]}') - - for i in range(1, TOP_LIMIT + 1): - if scene_scores[i-1][0] <= test_item.line_ranges[0][0] <= scene_scores[i-1][1]: - # print(f'In Top {i}, {scene_scores[i-1][2]}') - top_count_dict[i] += 1 - break - else: - if scene_scores[0][2] < MIN_SCORE: - tn_count += 1 + # if test_item.line_ranges[0][0] > -1: + # if scene_scores[0][2] < MIN_SCORE: + # fn_count += 1 + # + # # print(f'Top Score: {scene_scores[0][2]}') + # + # for i in range(1, TOP_LIMIT + 1): + # if scene_scores[i-1][0] <= test_item.line_ranges[0][0] <= scene_scores[i-1][1]: + # # print(f'In Top {i}, {scene_scores[i-1][2]}') + # top_count_dict[i] += 1 + # break + # else: + # if scene_scores[0][2] < MIN_SCORE: + # tn_count += 1 test_items_cnt = sum(1 for ti in test_items if ti.line_ranges[0][0] > -1) total_cnt = 0 @@ -145,40 +144,40 @@ def main(argv=None): writer.writerow([perc]) -def visualize(candidates, act_nr, scene_nr): - x_labels = [] - y_values = [] - - for c in candidates: - # print(f'\n{c.source_span.text}\n{c.target_span.text}\n{c.score}') - y_values.append([c.score]) - x_labels.append(['<br>'.join(wrap(c.source_span.text, width=150))]) - - # fig = px.imshow(y_values, text=x_labels, zmin=0, zmax=1) - - fig = go.Figure(data=go.Heatmap( - z=y_values, - text=x_labels, - texttemplate="%{text}", - colorscale='rdylbu', - zmin=0, - zmax=1, - reversescale=True)) - - # fig = ff.create_annotated_heatmap(y_values, annotation_text=x_labels, colorscale='rdylbu', font_colors=['black'], - # reversescale=True, showscale=True, zmin=0, zmax=1, zauto=False) - - title_wrapped = '<br>'.join(wrap(candidates[0].target_span.text, width=150)) - title_wrapped += f'<br>{act_nr}. Akt, {scene_nr}. Szene' - - fig.update_layout( - title_text=f'"{title_wrapped}"', - # margin=dict(l=10, r=10, t=10, b=10, pad=10), - xaxis=dict(zeroline=False, showgrid=False, visible=False), - yaxis=dict(zeroline=False, showgrid=False, visible=False), - ) - - fig.show() +# def visualize(candidates, act_nr, scene_nr): +# x_labels = [] +# y_values = [] +# +# for c in candidates: +# # print(f'\n{c.source_span.text}\n{c.target_span.text}\n{c.score}') +# y_values.append([c.score]) +# x_labels.append(['<br>'.join(wrap(c.source_span.text, width=150))]) +# +# # fig = px.imshow(y_values, text=x_labels, zmin=0, zmax=1) +# +# fig = go.Figure(data=go.Heatmap( +# z=y_values, +# text=x_labels, +# texttemplate="%{text}", +# colorscale='rdylbu', +# zmin=0, +# zmax=1, +# reversescale=True)) +# +# # fig = ff.create_annotated_heatmap(y_values, annotation_text=x_labels, colorscale='rdylbu', font_colors=['black'], +# # reversescale=True, showscale=True, zmin=0, zmax=1, zauto=False) +# +# title_wrapped = '<br>'.join(wrap(candidates[0].target_span.text, width=150)) +# title_wrapped += f'<br>{act_nr}. Akt, {scene_nr}. Szene' +# +# fig.update_layout( +# title_text=f'"{title_wrapped}"', +# # margin=dict(l=10, r=10, t=10, b=10, pad=10), +# xaxis=dict(zeroline=False, showgrid=False, visible=False), +# yaxis=dict(zeroline=False, showgrid=False, visible=False), +# ) +# +# fig.show() def get_candidates(source_sentences_embeddings, target_sentence_embedding, top_k):