Skip to content
Snippets Groups Projects
Commit 130b3597 authored by Frederik Arnold's avatar Frederik Arnold
Browse files

Add similarity training for drama summary

parent 35a2b31c
No related branches found
No related tags found
No related merge requests found
from argparse import ArgumentParser
from os.path import join
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from torch.utils.data import DataLoader
import csv
from datetime import datetime
def main():
argument_parser = ArgumentParser()
argument_parser.add_argument('input_path', nargs=1, metavar='input-path',
help="Path to the input folder")
argument_parser.add_argument('output_path', nargs=1, metavar='output-path',
help="Path to the input folder")
argument_parser.add_argument('model', nargs=1, metavar='model',
help="")
args = argument_parser.parse_args()
input_path = args.input_path[0]
output_path = args.output_path[0]
model_name = args.model[0]
model_name_repl = model_name.replace('/', '')
now = datetime.now()
date_time_string = now.strftime('%Y_%m_%d_%H_%M_%S')
date_time_string += f'_{model_name_repl}'
output_path = join(output_path, date_time_string)
train_examples = []
with open(join(input_path, 'train_set.tsv'), 'r') as train_file:
reader = csv.reader(train_file, delimiter='\t')
# skip first row (header)
# next(reader, None)
for row in reader:
ie = InputExample(texts=[row[0], row[1]])
train_examples.append(ie)
# val_anchor = []
# val_positive = []
# val_negative = []
# with open(join(input_path, 'val_set.tsv'), 'r') as train_file:
# reader = csv.reader(train_file, delimiter='\t')
#
# # skip first row (header)
# # next(reader, None)
#
# for row in reader:
# val_anchor.append(row[0])
# val_positive.append(row[1])
# val_negative.append(row[2])
# model = SentenceTransformer('deutsche-telekom/gbert-large-paraphrase-cosine')
model = SentenceTransformer(model_name)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8)
train_loss = losses.MultipleNegativesRankingLoss(model=model)
# evaluator = evaluation.TripletEvaluator(val_anchor, val_positive, val_negative)
# Tune the model
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=100,
evaluation_steps=10000, output_path=output_path)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment