Skip to content

Commit a41b34a

Browse files
committed
add temp fix for dice
1 parent daa9254 commit a41b34a

File tree

1 file changed

+46
-15
lines changed

1 file changed

+46
-15
lines changed

multimedeval/task_families.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,9 @@ def evaluate(self, predictions):
511511
return metrics
512512

513513
for label in labels_list:
514-
predicted_answers = []
515-
ground_truth = []
514+
# predicted_answers = []
515+
# ground_truth = []
516+
dsc_list = []
516517

517518
for prediction in predictions:
518519
answer = prediction["answer"].masks
@@ -527,27 +528,57 @@ def evaluate(self, predictions):
527528
else:
528529
pred = self.get_predicted_answer(answer)
529530

530-
predicted_answers.append(pred)
531-
ground_truth.append(gt)
531+
dice_similarity_coefficient = self.compute_dice_coefficient(
532+
gt, pred
533+
)
534+
# print(dice_similarity_coefficient)
535+
dsc_list.append(dice_similarity_coefficient)
536+
# predicted_answers.append(pred)
537+
# ground_truth.append(gt)
532538

533-
predicted_answers = np.array(predicted_answers)
534-
ground_truth = np.array(ground_truth)
539+
# predicted_answers = np.array(predicted_answers)
540+
# ground_truth = np.array(ground_truth)
535541
# print(predicted_answers.shape, ground_truth.shape)
536542

537-
predicted_answers = torch.tensor(predicted_answers, dtype=torch.long)
538-
ground_truth = torch.tensor(ground_truth, dtype=torch.long)
543+
# dice_similarity_coefficient = self.compute_dice_coefficient(
544+
# ground_truth, predicted_answers
545+
# )
539546

540-
dice = dice_scorer(predicted_answers, ground_truth).item()
541-
answers_log.append(
542-
(
543-
f"Label {label} have {len(predicted_answers)} data points, and the dice score is: {dice}."
544-
)
545-
)
547+
# predicted_answers = torch.tensor(predicted_answers, dtype=torch.long)
548+
# ground_truth = torch.tensor(ground_truth, dtype=torch.long)
546549

547-
metrics[f"{label}_dice"] = dice
550+
# dice = dice_scorer(predicted_answers, ground_truth).item()
551+
# answers_log.append(
552+
# (
553+
# f"Label {label} have {len(predicted_answers)} data points, and the dice score is: {dice}."
554+
# )
555+
# )
548556

557+
# metrics[f"{label}_generalized_dice_score"] = dice
558+
# print(sum(dsc_list), len(dsc_list))
559+
metrics[f"{label}_DSC"] = sum(dsc_list) / len(dsc_list)
560+
# del predicted_answers, ground_truth
549561
return EvaluationOutput(metrics=metrics, answer_log=answers_log)
550562

563+
def compute_dice_coefficient(self, mask_gt, mask_pred):
564+
"""Compute soerensen-dice coefficient.
565+
566+
compute the soerensen-dice coefficient between the ground truth mask `mask_gt`
567+
and the predicted mask `mask_pred`.
568+
569+
Args:
570+
mask_gt: 3-dim Numpy array of type bool. The ground truth mask.
571+
mask_pred: 3-dim Numpy array of type bool. The predicted mask.
572+
573+
Returns:
574+
the dice coeffcient as float. If both masks are empty, the result is NaN
575+
"""
576+
volume_sum = mask_gt.sum() + mask_pred.sum()
577+
if volume_sum == 0:
578+
return np.NaN
579+
volume_intersect = (mask_gt & mask_pred).sum()
580+
return 2 * volume_intersect / volume_sum
581+
551582

552583
class ReportComparison(Benchmark):
553584
"""A benchmark for report comparison tasks."""

0 commit comments

Comments
 (0)