From ccd622a9cfc182c0ac2612d84743358e7f7a959d Mon Sep 17 00:00:00 2001 From: Frederik Arnold <frederik.arnold@hu-berlin.de> Date: Tue, 23 Jan 2024 07:06:15 +0100 Subject: [PATCH] Update similarity training --- .../training/similarity/TrainSimilarity.py | 23 +++---------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/indiquo/training/similarity/TrainSimilarity.py b/indiquo/training/similarity/TrainSimilarity.py index cd11fc4..d7981bd 100644 --- a/indiquo/training/similarity/TrainSimilarity.py +++ b/indiquo/training/similarity/TrainSimilarity.py @@ -36,9 +36,6 @@ def main(): with open(join(input_path, 'train_set.tsv'), 'r') as train_file: reader = csv.reader(train_file, delimiter='\t') - # skip first row (header) - # next(reader, None) - for row in reader: ie = InputExample(texts=[row[0], row[1], row[2]]) train_examples.append(ie) @@ -49,35 +46,21 @@ def main(): with open(join(input_path, 'val_set.tsv'), 'r') as train_file: reader = csv.reader(train_file, delimiter='\t') - # skip first row (header) - # next(reader, None) - for row in reader: val_anchor.append(row[0]) val_positive.append(row[1]) val_negative.append(row[2]) - # model = SentenceTransformer('deutsche-telekom/gbert-large-paraphrase-cosine') model = SentenceTransformer(model_name) - - # Define your train examples. You need more than just two examples... - # train_examples = [InputExample(texts=['My first sentence', 'My second sentence'], label=0.8), - # InputExample(texts=['Another pair', 'Unrelated sentence'], label=0.3)] - train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8) - # train_loss = losses.CosineSimilarityLoss(model) train_loss = losses.TripletLoss(model=model) - # train_loss = losses.BatchHardSoftMarginTripletLoss( - # model=model, - # distance_metric=BatchHardTripletLossDistanceFunction.eucledian_distance, - # ) - # evaluator = evaluation.EmbeddingSimilarityEvaluator(val_anchor, sen, scores) - # evaluator = evaluation.TripletEvaluator(val_anchor, val_positive, val_negative) + evaluator = evaluation.TripletEvaluator(val_anchor, val_positive, val_negative) # Tune the model - model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=100, output_path=output_path) + model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=100, + evaluator=evaluator, evaluation_steps=10000, output_path=output_path) if __name__ == '__main__': -- GitLab