From ccd622a9cfc182c0ac2612d84743358e7f7a959d Mon Sep 17 00:00:00 2001
From: Frederik Arnold <frederik.arnold@hu-berlin.de>
Date: Tue, 23 Jan 2024 07:06:15 +0100
Subject: [PATCH] Update similarity training

---
 .../training/similarity/TrainSimilarity.py    | 23 +++----------------
 1 file changed, 3 insertions(+), 20 deletions(-)

diff --git a/indiquo/training/similarity/TrainSimilarity.py b/indiquo/training/similarity/TrainSimilarity.py
index cd11fc4..d7981bd 100644
--- a/indiquo/training/similarity/TrainSimilarity.py
+++ b/indiquo/training/similarity/TrainSimilarity.py
@@ -36,9 +36,6 @@ def main():
     with open(join(input_path, 'train_set.tsv'), 'r') as train_file:
         reader = csv.reader(train_file, delimiter='\t')
 
-        # skip first row (header)
-        # next(reader, None)
-
         for row in reader:
             ie = InputExample(texts=[row[0], row[1], row[2]])
             train_examples.append(ie)
@@ -49,35 +46,21 @@ def main():
     with open(join(input_path, 'val_set.tsv'), 'r') as train_file:
         reader = csv.reader(train_file, delimiter='\t')
 
-        # skip first row (header)
-        # next(reader, None)
-
         for row in reader:
             val_anchor.append(row[0])
             val_positive.append(row[1])
             val_negative.append(row[2])
 
-    # model = SentenceTransformer('deutsche-telekom/gbert-large-paraphrase-cosine')
     model = SentenceTransformer(model_name)
-
-    # Define your train examples. You need more than just two examples...
-    # train_examples = [InputExample(texts=['My first sentence', 'My second sentence'], label=0.8),
-    #                   InputExample(texts=['Another pair', 'Unrelated sentence'], label=0.3)]
-
     train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8)
 
-    # train_loss = losses.CosineSimilarityLoss(model)
     train_loss = losses.TripletLoss(model=model)
-    # train_loss = losses.BatchHardSoftMarginTripletLoss(
-    #     model=model,
-    #     distance_metric=BatchHardTripletLossDistanceFunction.eucledian_distance,
-    # )
 
-    # evaluator = evaluation.EmbeddingSimilarityEvaluator(val_anchor, sen, scores)
-    # evaluator = evaluation.TripletEvaluator(val_anchor, val_positive, val_negative)
+    evaluator = evaluation.TripletEvaluator(val_anchor, val_positive, val_negative)
 
     # Tune the model
-    model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=100, output_path=output_path)
+    model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=100,
+              evaluator=evaluator, evaluation_steps=10000, output_path=output_path)
 
 
 if __name__ == '__main__':
-- 
GitLab