diff --git a/indiquo/training/similarity/TrainSimilarityContrastive.py b/indiquo/training/similarity/TrainSimilarityContrastive.py index bbf143d8a5fb74996f3ee79bafe69f3c7002ddf2..d624d64ee4fdd5b14aa893d60d74d9b5c56753a2 100644 --- a/indiquo/training/similarity/TrainSimilarityContrastive.py +++ b/indiquo/training/similarity/TrainSimilarityContrastive.py @@ -42,7 +42,7 @@ def main(): model = SentenceTransformer(model_name) train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16) - train_loss = losses.ContrastiveLoss(model=model, margin=1.0) + train_loss = losses.OnlineContrastiveLoss(model=model) # evaluator = evaluation.TripletEvaluator(val_anchor, val_positive, val_negative)