Skip to content
Snippets Groups Projects
Commit 17bdc37c authored by Frederik Arnold's avatar Frederik Arnold
Browse files

Add more type hints

parent ea17109e
No related branches found
No related tags found
No related merge requests found
Showing
with 36 additions and 33 deletions
......@@ -35,19 +35,19 @@ from indiquo.training.candidate import TrainCandidateClassifier, TrainCandidateC
logger = logging.getLogger(__name__)
def __train_candidate(train_folder_path, output_folder_path, model_name):
def __train_candidate(train_folder_path: str, output_folder_path: str, model_name: str):
TrainCandidateClassifier.train(train_folder_path, output_folder_path, model_name)
def __train_candidate_st(train_folder_path, output_folder_path, model_name):
def __train_candidate_st(train_folder_path: str, output_folder_path: str, model_name: str):
TrainCandidateClassifierST.train(train_folder_path, output_folder_path, model_name)
def __train_scene(train_folder_path, output_folder_path, model_name):
def __train_scene(train_folder_path: str, output_folder_path: str, model_name: str):
TrainSceneIdentification.train(train_folder_path, output_folder_path, model_name)
def __process_file(indi_quo: IndiQuoBase, filename, target_text, output_folder_path):
def __process_file(indi_quo: IndiQuoBase, filename: str, target_text: str, output_folder_path: str):
logger.info(f'Processing {filename} ...')
matches = indi_quo.compare(target_text)
......@@ -69,8 +69,9 @@ def __process_file(indi_quo: IndiQuoBase, filename, target_text, output_folder_p
writer.writerow([m.target_start, m.target_end, speech_text, m.score, scene_predictions])
def __run_compare(compare_approach, model_type, source_file_path, target_path, candidate_model_path, scene_model_path,
output_folder_path, add_context, max_candidate_length, summaries_file_path):
def __run_compare(compare_approach: str, model_type: str, source_file_path: str, target_path: str,
candidate_model_path: str, scene_model_path: str, output_folder_path: str, add_context: bool,
max_candidate_length: int, summaries_file_path: str):
drama_processor = Dramatist()
drama = drama_processor.from_file(source_file_path)
sentence_chunker = SentenceChunker(min_length=10, max_length=64, max_sentences=1)
......
......@@ -6,5 +6,5 @@ from indiquo.core.Candidate import Candidate
class BaseCandidatePredictor(ABC):
@abstractmethod
def get_candidates(self, target_text) -> List[Candidate]:
def get_candidates(self, target_text: str) -> List[Candidate]:
pass
......@@ -6,5 +6,5 @@ from indiquo.core.ScenePrediction import ScenePrediction
class BaseScenePredictor(ABC):
@abstractmethod
def predict_scene(self, text) -> List[List[ScenePrediction]]:
def predict_scene(self, text: str | List[str]) -> List[List[ScenePrediction]]:
pass
......@@ -11,7 +11,7 @@ from kpcommons.Footnote import map_to_real_pos, get_footnote_ranges, remove_foot
# noinspection PyMethodMayBeStatic
class CandidatePredictor(BaseCandidatePredictor):
def __init__(self, tokenizer, model, chunker: BaseChunker, add_context, max_length):
def __init__(self, tokenizer, model, chunker: BaseChunker, add_context: bool, max_length: int):
self.tokenizer = tokenizer
self.model = model
self.chunker = chunker
......@@ -21,7 +21,7 @@ class CandidatePredictor(BaseCandidatePredictor):
self.max_length = max_length
# overriding abstract method
def get_candidates(self, target_text) -> List[Candidate]:
def get_candidates(self, target_text: str) -> List[Candidate]:
fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text)
target_text_wo_fn: str = remove_footnotes(target_text)
chunks = self.chunker.chunk(target_text_wo_fn)
......@@ -46,7 +46,7 @@ class CandidatePredictor(BaseCandidatePredictor):
return candidates
def __predict(self, target_text):
def __predict(self, target_text: str) -> float:
inputs = self.tokenizer(target_text, truncation=True, return_tensors="pt")
with torch.no_grad():
......@@ -55,7 +55,7 @@ class CandidatePredictor(BaseCandidatePredictor):
score = float(p[0, 1])
return score
def __add_context(self, quote_text, text, quote_start, quote_end):
def __add_context(self, quote_text: str, text: str, quote_start: int, quote_end: int) -> str:
rest_len = self.max_length - len(quote_text.split())
if rest_len < 0:
......
......@@ -13,7 +13,7 @@ class CandidatePredictorDummy(BaseCandidatePredictor):
self.chunker = chunker
# overriding abstract method
def get_candidates(self, target_text) -> List[Candidate]:
def get_candidates(self, target_text: str) -> List[Candidate]:
fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text)
target_text_wo_fn: str = remove_footnotes(target_text)
chunks = self.chunker.chunk(target_text_wo_fn)
......
......@@ -19,7 +19,7 @@ class CandidatePredictorRW(BaseCandidatePredictor):
self.chunker = chunker
# overriding abstract method
def get_candidates(self, target_text) -> List[Candidate]:
def get_candidates(self, target_text: str) -> List[Candidate]:
fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text)
target_text_wo_fn: str = remove_footnotes(target_text)
chunks = self.chunker.chunk(target_text_wo_fn)
......
from typing import List
from sentence_transformers import util
from sentence_transformers import util, SentenceTransformer
from dramatist.drama.Drama import Drama
from indiquo.core.BaseCandidatePredictor import BaseCandidatePredictor
......@@ -12,7 +12,7 @@ import re
# noinspection PyMethodMayBeStatic
class CandidatePredictorST(BaseCandidatePredictor):
def __init__(self, drama: Drama, model, chunker: BaseChunker, add_context, max_length):
def __init__(self, drama: Drama, model: SentenceTransformer, chunker: BaseChunker, add_context: bool, max_length: int):
self.drama = drama
self.model = model
self.chunker = chunker
......@@ -32,7 +32,7 @@ class CandidatePredictorST(BaseCandidatePredictor):
self.source_embeddings = model.encode(self.source_text_blocks, convert_to_tensor=True)
# overriding abstract method
def get_candidates(self, target_text) -> List[Candidate]:
def get_candidates(self, target_text: str) -> List[Candidate]:
fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text)
target_text_wo_fn: str = remove_footnotes(target_text)
chunks = self.chunker.chunk(target_text_wo_fn)
......@@ -57,12 +57,12 @@ class CandidatePredictorST(BaseCandidatePredictor):
return candidates
def __predict(self, target_text):
def __predict(self, target_text: str) -> float:
target_embedding = self.model.encode([target_text], convert_to_tensor=True)
hits = util.semantic_search(target_embedding, self.source_embeddings, top_k=1)[0]
return hits[0]['score']
def __add_context(self, quote_text, text, quote_start, quote_end):
def __add_context(self, quote_text: str, text: str, quote_start: int, quote_end: int) -> str:
rest_len = self.max_length - len(quote_text.split())
if rest_len < 0:
......
from typing import List
from typing import List, Tuple
from indiquo.core.CandidateWithScenes import CandidateWithScenes
from indiquo.core.ScenePrediction import ScenePrediction
from indiquo.core.chunker.BaseChunker import BaseChunker
from kpcommons.Footnote import map_to_real_pos, get_footnote_ranges, remove_footnotes
from sentence_transformers import util
from sentence_transformers import util, SentenceTransformer
# noinspection PyMethodMayBeStatic
class CandidatePredictorSum:
def __init__(self, summaries, model, chunker: BaseChunker):
def __init__(self, summaries: List[Tuple[int, int, str]], model: SentenceTransformer, chunker: BaseChunker):
self.summaries = summaries
self.model = model
self.chunker = chunker
self.summary_embeddings = model.encode([x[2] for x in self.summaries], convert_to_tensor=True)
def get_candidates(self, target_text) -> List[CandidateWithScenes]:
def get_candidates(self, target_text: str) -> List[CandidateWithScenes]:
fn_ranges, fn_ranges_with_offset = get_footnote_ranges(target_text)
target_text_wo_fn: str = remove_footnotes(target_text)
chunks = self.chunker.chunk(target_text_wo_fn)
......@@ -37,7 +37,7 @@ class CandidatePredictorSum:
return candidates
def __predict(self, sentences):
def __predict(self, sentences: List[str]) -> Tuple[List[float], List[List[ScenePrediction]]]:
target_embedding = self.model.encode(sentences, convert_to_tensor=True)
hits = util.semantic_search(target_embedding, self.summary_embeddings, top_k=10)
......
......@@ -6,5 +6,5 @@ from indiquo.match.Match import Match
class IndiQuoBase(ABC):
@abstractmethod
def compare(self, target_text) -> List[Match]:
def compare(self, target_text: str) -> List[Match]:
pass
from typing import List
from dramatist.drama.Drama import Drama
from sentence_transformers import util
......@@ -7,7 +9,7 @@ from indiquo.core.ScenePrediction import ScenePrediction
class ScenePredictor(BaseScenePredictor):
def __init__(self, drama: Drama, model, top_k):
def __init__(self, drama: Drama, model, top_k: int):
self.model = model
self.top_k = top_k
self.all_text_blocks = []
......@@ -25,7 +27,7 @@ class ScenePredictor(BaseScenePredictor):
self.source_embeddings = model.encode(source_text_blocks, convert_to_tensor=True)
# overriding abstract method
def predict_scene(self, text):
def predict_scene(self, text: str | List[str]) -> List[List[ScenePrediction]]:
if isinstance(text, str):
text = [text]
......
......@@ -6,7 +6,7 @@ from indiquo.core.ScenePrediction import ScenePrediction
class ScenePredictorDummy(BaseScenePredictor):
def predict_scene(self, text) -> List[List[ScenePrediction]]:
def predict_scene(self, text: str | List[str]) -> List[List[ScenePrediction]]:
if isinstance(text, str):
text = [text]
......
......@@ -9,7 +9,7 @@ import re
class SentenceChunker(BaseChunker):
def __init__(self, min_length=0, max_length=10000, max_sentences=25):
def __init__(self, min_length: int = 0, max_length: int = 10000, max_sentences: int = 25):
self.min_length = min_length
self.max_length = max_length
self.max_sentences = max_sentences
......@@ -97,7 +97,7 @@ class SentenceChunker(BaseChunker):
return chunks_2
def __split_too_long(self, chunk, length) -> List[Chunk]:
def __split_too_long(self, chunk: Chunk, length: int) -> List[Chunk]:
org_text = chunk.text
factor = (length // self.max_length) + 1
words = org_text.split()
......
......@@ -17,7 +17,7 @@ def __prepare_compute_metrics(eval_metric):
return custom_compute_metrics
def train(train_folder_path, output_folder_path, model_name):
def train(train_folder_path: str, output_folder_path: str, model_name: str):
tokenizer = AutoTokenizer.from_pretrained(model_name)
special_tokens_dict = {'additional_special_tokens': ['<Q>', '</Q>']}
......
......@@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
import csv
def train(train_folder_path, output_folder_path, model_name):
def train(train_folder_path: str, output_folder_path: str, model_name: str):
train_examples = []
with open(join(train_folder_path, 'train_set.tsv'), 'r') as train_file:
......
......@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
import csv
def train(train_folder_path, output_folder_path, model_name):
def train(train_folder_path: str, output_folder_path: str, model_name: str):
train_examples = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment