diff --git a/indiquo/cli/IndiQuoCLI.py b/indiquo/cli/IndiQuoCLI.py index 8daa9e03070b74d738aa421f89328b12fa235f00..ab4f6417e64a300d341a2bd6001edea8a421d1bd 100644 --- a/indiquo/cli/IndiQuoCLI.py +++ b/indiquo/cli/IndiQuoCLI.py @@ -35,19 +35,19 @@ from indiquo.training.candidate import TrainCandidateClassifier, TrainCandidateC logger = logging.getLogger(__name__) -def __train_candidate(train_folder_path, output_folder_path, model_name): +def __train_candidate(train_folder_path: str, output_folder_path: str, model_name: str): TrainCandidateClassifier.train(train_folder_path, output_folder_path, model_name) -def __train_candidate_st(train_folder_path, output_folder_path, model_name): +def __train_candidate_st(train_folder_path: str, output_folder_path: str, model_name: str): TrainCandidateClassifierST.train(train_folder_path, output_folder_path, model_name) -def __train_scene(train_folder_path, output_folder_path, model_name): +def __train_scene(train_folder_path: str, output_folder_path: str, model_name: str): TrainSceneIdentification.train(train_folder_path, output_folder_path, model_name) -def __process_file(indi_quo: IndiQuoBase, filename, target_text, output_folder_path): +def __process_file(indi_quo: IndiQuoBase, filename: str, target_text: str, output_folder_path: str): logger.info(f'Processing {filename} ...') matches = indi_quo.compare(target_text) @@ -69,8 +69,9 @@ def __process_file(indi_quo: IndiQuoBase, filename, target_text, output_folder_p writer.writerow([m.target_start, m.target_end, speech_text, m.score, scene_predictions]) -def __run_compare(compare_approach, model_type, source_file_path, target_path, candidate_model_path, scene_model_path, - output_folder_path, add_context, max_candidate_length, summaries_file_path): +def __run_compare(compare_approach: str, model_type: str, source_file_path: str, target_path: str, + candidate_model_path: str, scene_model_path: str, output_folder_path: str, add_context: bool, + max_candidate_length: int, summaries_file_path: str): drama_processor = Dramatist() drama = drama_processor.from_file(source_file_path) sentence_chunker = SentenceChunker(min_length=10, max_length=64, max_sentences=1) diff --git a/indiquo/core/BaseCandidatePredictor.py b/indiquo/core/BaseCandidatePredictor.py index 6d5467e1adc103303cbeb88a2dab7f726d042179..1cdac9c80166a8bdcc9174e5e2f9f09715b70948 100644 --- a/indiquo/core/BaseCandidatePredictor.py +++ b/indiquo/core/BaseCandidatePredictor.py @@ -6,5 +6,5 @@ from indiquo.core.Candidate import Candidate class BaseCandidatePredictor(ABC): @abstractmethod - def get_candidates(self, target_text) -> List[Candidate]: + def get_candidates(self, target_text: str) -> List[Candidate]: pass diff --git a/indiquo/core/BaseScenePredictor.py b/indiquo/core/BaseScenePredictor.py index 9380fc9459d382a86b28f297d488803f1201c729..281a1f8823fa30902c5676cbff04a7accc3caf2b 100644 --- a/indiquo/core/BaseScenePredictor.py +++ b/indiquo/core/BaseScenePredictor.py @@ -6,5 +6,5 @@ from indiquo.core.ScenePrediction import ScenePrediction class BaseScenePredictor(ABC): @abstractmethod - def predict_scene(self, text) -> List[List[ScenePrediction]]: + def predict_scene(self, text: str | List[str]) -> List[List[ScenePrediction]]: pass diff --git a/indiquo/core/CandidatePredictor.py b/indiquo/core/CandidatePredictor.py index 6c76927384d05ba03bce3c5b3dd1bb37e831dde4..431fdbc74a25eefabf680fe8570e31861025fb67 100644 --- a/indiquo/core/CandidatePredictor.py +++ b/indiquo/core/CandidatePredictor.py @@ -11,7 +11,7 @@ from kpcommons.Footnote import map_to_real_pos, get_footnote_ranges, remove_foot # noinspection PyMethodMayBeStatic class CandidatePredictor(BaseCandidatePredictor): - def __init__(self, tokenizer, model, chunker: BaseChunker, add_context, max_length): + def __init__(self, tokenizer, model, chunker: BaseChunker, add_context: bool, max_length: int): self.tokenizer = tokenizer self.model = model self.chunker = chunker @@ -21,7 +21,7 @@ class CandidatePredictor(BaseCandidatePredictor): self.max_length = max_length # overriding abstract method - def get_candidates(self, target_text) -> List[Candidate]: + def get_candidates(self, target_text: str) -> List[Candidate]: fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text) target_text_wo_fn: str = remove_footnotes(target_text) chunks = self.chunker.chunk(target_text_wo_fn) @@ -46,7 +46,7 @@ class CandidatePredictor(BaseCandidatePredictor): return candidates - def __predict(self, target_text): + def __predict(self, target_text: str) -> float: inputs = self.tokenizer(target_text, truncation=True, return_tensors="pt") with torch.no_grad(): @@ -55,7 +55,7 @@ class CandidatePredictor(BaseCandidatePredictor): score = float(p[0, 1]) return score - def __add_context(self, quote_text, text, quote_start, quote_end): + def __add_context(self, quote_text: str, text: str, quote_start: int, quote_end: int) -> str: rest_len = self.max_length - len(quote_text.split()) if rest_len < 0: diff --git a/indiquo/core/CandidatePredictorDummy.py b/indiquo/core/CandidatePredictorDummy.py index 3175ba3c5bb210b7f0d20839a3f47ab46a8f2106..fe18e310b62dbcc3fbd59af7ee6d93146a6e222e 100644 --- a/indiquo/core/CandidatePredictorDummy.py +++ b/indiquo/core/CandidatePredictorDummy.py @@ -13,7 +13,7 @@ class CandidatePredictorDummy(BaseCandidatePredictor): self.chunker = chunker # overriding abstract method - def get_candidates(self, target_text) -> List[Candidate]: + def get_candidates(self, target_text: str) -> List[Candidate]: fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text) target_text_wo_fn: str = remove_footnotes(target_text) chunks = self.chunker.chunk(target_text_wo_fn) diff --git a/indiquo/core/CandidatePredictorRW.py b/indiquo/core/CandidatePredictorRW.py index 38ccb9f4c6dcb2baada7240027e084fe5bce6d91..5c7ccc2763a56d311be5d91f0c8ce041e3b2639e 100644 --- a/indiquo/core/CandidatePredictorRW.py +++ b/indiquo/core/CandidatePredictorRW.py @@ -19,7 +19,7 @@ class CandidatePredictorRW(BaseCandidatePredictor): self.chunker = chunker # overriding abstract method - def get_candidates(self, target_text) -> List[Candidate]: + def get_candidates(self, target_text: str) -> List[Candidate]: fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text) target_text_wo_fn: str = remove_footnotes(target_text) chunks = self.chunker.chunk(target_text_wo_fn) diff --git a/indiquo/core/CandidatePredictorST.py b/indiquo/core/CandidatePredictorST.py index 2c5ae6f18a979003e8a611faa83c6e9d142061a8..2ba060594f048c34f7b1e880d9af31677963f24a 100644 --- a/indiquo/core/CandidatePredictorST.py +++ b/indiquo/core/CandidatePredictorST.py @@ -1,5 +1,5 @@ from typing import List -from sentence_transformers import util +from sentence_transformers import util, SentenceTransformer from dramatist.drama.Drama import Drama from indiquo.core.BaseCandidatePredictor import BaseCandidatePredictor @@ -12,7 +12,7 @@ import re # noinspection PyMethodMayBeStatic class CandidatePredictorST(BaseCandidatePredictor): - def __init__(self, drama: Drama, model, chunker: BaseChunker, add_context, max_length): + def __init__(self, drama: Drama, model: SentenceTransformer, chunker: BaseChunker, add_context: bool, max_length: int): self.drama = drama self.model = model self.chunker = chunker @@ -32,7 +32,7 @@ class CandidatePredictorST(BaseCandidatePredictor): self.source_embeddings = model.encode(self.source_text_blocks, convert_to_tensor=True) # overriding abstract method - def get_candidates(self, target_text) -> List[Candidate]: + def get_candidates(self, target_text: str) -> List[Candidate]: fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text) target_text_wo_fn: str = remove_footnotes(target_text) chunks = self.chunker.chunk(target_text_wo_fn) @@ -57,12 +57,12 @@ class CandidatePredictorST(BaseCandidatePredictor): return candidates - def __predict(self, target_text): + def __predict(self, target_text: str) -> float: target_embedding = self.model.encode([target_text], convert_to_tensor=True) hits = util.semantic_search(target_embedding, self.source_embeddings, top_k=1)[0] return hits[0]['score'] - def __add_context(self, quote_text, text, quote_start, quote_end): + def __add_context(self, quote_text: str, text: str, quote_start: int, quote_end: int) -> str: rest_len = self.max_length - len(quote_text.split()) if rest_len < 0: diff --git a/indiquo/core/CandidatePredictorSum.py b/indiquo/core/CandidatePredictorSum.py index fd1e52070430a0f56e25ccef2d33ba9572be861d..3c5f624d50200e5e9daf16cd51ae66facf124b50 100644 --- a/indiquo/core/CandidatePredictorSum.py +++ b/indiquo/core/CandidatePredictorSum.py @@ -1,21 +1,21 @@ -from typing import List +from typing import List, Tuple from indiquo.core.CandidateWithScenes import CandidateWithScenes from indiquo.core.ScenePrediction import ScenePrediction from indiquo.core.chunker.BaseChunker import BaseChunker from kpcommons.Footnote import map_to_real_pos, get_footnote_ranges, remove_footnotes -from sentence_transformers import util +from sentence_transformers import util, SentenceTransformer # noinspection PyMethodMayBeStatic class CandidatePredictorSum: - def __init__(self, summaries, model, chunker: BaseChunker): + def __init__(self, summaries: List[Tuple[int, int, str]], model: SentenceTransformer, chunker: BaseChunker): self.summaries = summaries self.model = model self.chunker = chunker self.summary_embeddings = model.encode([x[2] for x in self.summaries], convert_to_tensor=True) - def get_candidates(self, target_text) -> List[CandidateWithScenes]: + def get_candidates(self, target_text: str) -> List[CandidateWithScenes]: fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text) target_text_wo_fn: str = remove_footnotes(target_text) chunks = self.chunker.chunk(target_text_wo_fn) @@ -37,7 +37,7 @@ class CandidatePredictorSum: return candidates - def __predict(self, sentences): + def __predict(self, sentences: List[str]) -> Tuple[List[float], List[List[ScenePrediction]]]: target_embedding = self.model.encode(sentences, convert_to_tensor=True) hits = util.semantic_search(target_embedding, self.summary_embeddings, top_k=10) diff --git a/indiquo/core/IndiQuoBase.py b/indiquo/core/IndiQuoBase.py index 34fa49bda0c154649439ed2addf83e747e82c92c..6c833db2facb4eecd417bb887029790096bec0f9 100644 --- a/indiquo/core/IndiQuoBase.py +++ b/indiquo/core/IndiQuoBase.py @@ -6,5 +6,5 @@ from indiquo.match.Match import Match class IndiQuoBase(ABC): @abstractmethod - def compare(self, target_text) -> List[Match]: + def compare(self, target_text: str) -> List[Match]: pass diff --git a/indiquo/core/ScenePredictor.py b/indiquo/core/ScenePredictor.py index 299a0e9e8ace732c819e15489a132ec4590b150a..4558b885a8b0def664f8fb50a778eb2abff870b5 100644 --- a/indiquo/core/ScenePredictor.py +++ b/indiquo/core/ScenePredictor.py @@ -1,3 +1,5 @@ +from typing import List + from dramatist.drama.Drama import Drama from sentence_transformers import util @@ -7,7 +9,7 @@ from indiquo.core.ScenePrediction import ScenePrediction class ScenePredictor(BaseScenePredictor): - def __init__(self, drama: Drama, model, top_k): + def __init__(self, drama: Drama, model, top_k: int): self.model = model self.top_k = top_k self.all_text_blocks = [] @@ -25,7 +27,7 @@ class ScenePredictor(BaseScenePredictor): self.source_embeddings = model.encode(source_text_blocks, convert_to_tensor=True) # overriding abstract method - def predict_scene(self, text): + def predict_scene(self, text: str | List[str]) -> List[List[ScenePrediction]]: if isinstance(text, str): text = [text] diff --git a/indiquo/core/ScenePredictorDummy.py b/indiquo/core/ScenePredictorDummy.py index 1790d89d81b82b48efe3a831699c0e1587a56ac4..855cbd3c8705182ba05c4ef0acbc035e29863062 100644 --- a/indiquo/core/ScenePredictorDummy.py +++ b/indiquo/core/ScenePredictorDummy.py @@ -6,7 +6,7 @@ from indiquo.core.ScenePrediction import ScenePrediction class ScenePredictorDummy(BaseScenePredictor): - def predict_scene(self, text) -> List[List[ScenePrediction]]: + def predict_scene(self, text: str | List[str]) -> List[List[ScenePrediction]]: if isinstance(text, str): text = [text] diff --git a/indiquo/core/chunker/SentenceChunker.py b/indiquo/core/chunker/SentenceChunker.py index fd285f75395247201d1dedc3cbe6d61fd7687bd2..ab0864b9f651cf4e00a326fa306b279d4816e08c 100644 --- a/indiquo/core/chunker/SentenceChunker.py +++ b/indiquo/core/chunker/SentenceChunker.py @@ -9,7 +9,7 @@ import re class SentenceChunker(BaseChunker): - def __init__(self, min_length=0, max_length=10000, max_sentences=25): + def __init__(self, min_length: int = 0, max_length: int = 10000, max_sentences: int = 25): self.min_length = min_length self.max_length = max_length self.max_sentences = max_sentences @@ -97,7 +97,7 @@ class SentenceChunker(BaseChunker): return chunks_2 - def __split_too_long(self, chunk, length) -> List[Chunk]: + def __split_too_long(self, chunk: Chunk, length: int) -> List[Chunk]: org_text = chunk.text factor = (length // self.max_length) + 1 words = org_text.split() diff --git a/indiquo/training/candidate/TrainCandidateClassifier.py b/indiquo/training/candidate/TrainCandidateClassifier.py index 63177390b0c8e1b5661c55e3abf9c6783de7a72e..a6f2807a0d83754e6c16d36ff352ca02a2dc4b35 100644 --- a/indiquo/training/candidate/TrainCandidateClassifier.py +++ b/indiquo/training/candidate/TrainCandidateClassifier.py @@ -17,7 +17,7 @@ def __prepare_compute_metrics(eval_metric): return custom_compute_metrics -def train(train_folder_path, output_folder_path, model_name): +def train(train_folder_path: str, output_folder_path: str, model_name: str): tokenizer = AutoTokenizer.from_pretrained(model_name) special_tokens_dict = {'additional_special_tokens': ['<Q>', '</Q>']} diff --git a/indiquo/training/candidate/TrainCandidateClassifierST.py b/indiquo/training/candidate/TrainCandidateClassifierST.py index 10179d78cfd005efc527823a5654a8d4c96c5cd4..125f9f186b8d09b75948ffcde3c9d802e0381df1 100644 --- a/indiquo/training/candidate/TrainCandidateClassifierST.py +++ b/indiquo/training/candidate/TrainCandidateClassifierST.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader import csv -def train(train_folder_path, output_folder_path, model_name): +def train(train_folder_path: str, output_folder_path: str, model_name: str): train_examples = [] with open(join(train_folder_path, 'train_set.tsv'), 'r') as train_file: diff --git a/indiquo/training/scene/TrainSceneIdentification.py b/indiquo/training/scene/TrainSceneIdentification.py index 496287ff110f9e5e240d5c6dfb315ea260e3ebe6..604b2924b5ac215d3b1913571982cbf86913aec4 100644 --- a/indiquo/training/scene/TrainSceneIdentification.py +++ b/indiquo/training/scene/TrainSceneIdentification.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader import csv -def train(train_folder_path, output_folder_path, model_name): +def train(train_folder_path: str, output_folder_path: str, model_name: str): train_examples = []