From 9935b0e355fd5b750a4a75c8119676051c28e24f Mon Sep 17 00:00:00 2001 From: Kcstring Date: Wed, 29 Apr 2026 13:44:26 +0800 Subject: [PATCH] Fix KL divergence with zero true labels --- machine_learning/loss_functions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/machine_learning/loss_functions.py b/machine_learning/loss_functions.py index 0bd9aa8b5401..0e54cf2e422a 100644 --- a/machine_learning/loss_functions.py +++ b/machine_learning/loss_functions.py @@ -655,11 +655,16 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float Traceback (most recent call last): ... ValueError: Input arrays must have the same length. + >>> true_labels = np.array([0.0, 0.5, 0.5]) + >>> predicted_probs = np.array([0.2, 0.3, 0.5]) + >>> float(kullback_leibler_divergence(true_labels, predicted_probs)) + 0.25541281188299536 """ if len(y_true) != len(y_pred): raise ValueError("Input arrays must have the same length.") - kl_loss = y_true * np.log(y_true / y_pred) + non_zero = y_true != 0 + kl_loss = y_true[non_zero] * np.log(y_true[non_zero] / y_pred[non_zero]) return np.sum(kl_loss)