From 88b1182166707ed4ea7e86789d06c11765eea015 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 14:21:55 +0800 Subject: [PATCH 01/10] Weights in alpha for FocalLoss Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 54 +++++++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 28d1c0cdc9..ede08b970d 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -70,7 +70,7 @@ def __init__( include_background: bool = True, to_onehot_y: bool = False, gamma: float = 2.0, - alpha: float | None = None, + alpha: float | Sequence[float] | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: LossReduction | str = LossReduction.MEAN, use_softmax: bool = False, @@ -78,11 +78,13 @@ def __init__( """ Args: include_background: if False, channel index 0 (background category) is excluded from the loss calculation. - If False, `alpha` is invalid when using softmax. + If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights). to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False. gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2. alpha: value of the alpha in the definition of the alpha-balanced Focal loss. - The value should be in [0, 1]. Defaults to None. + The value should be in [0, 1]. + If a sequence is provided, it must match the number of classes (after excluding background if set). + Defaults to None. weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length of the sequence should be the same as the number of classes. If not ``include_background``, @@ -156,13 +158,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss: Optional[torch.Tensor] = None input = input.float() target = target.float() + + alpha_arg = self.alpha if self.use_softmax: if not self.include_background and self.alpha is not None: - self.alpha = None - warnings.warn("`include_background=False`, `alpha` ignored when using softmax.") - loss = softmax_focal_loss(input, target, self.gamma, self.alpha) + if isinstance(self.alpha, (float, int)): + alpha_arg = None + warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") + loss = softmax_focal_loss(input, target, self.gamma, self.alpha_arg) else: - loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) + loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha_arg) num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -203,7 +208,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def softmax_focal_loss( - input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) @@ -215,8 +220,18 @@ def softmax_focal_loss( loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target if alpha is not None: - # (1-alpha) for the background class and alpha for the other classes - alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) + alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) + + if alpha_t.ndim == 0: # scalar + # (1-alpha) for the background class and alpha for the other classes + alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) + else: # sequence + if alpha_t.shape[0] != target.shape[1]: + raise ValueError( + f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." + ) + alpha_fac = alpha_t + broadcast_dims = [-1] + [1] * len(target.shape[2:]) alpha_fac = alpha_fac.view(broadcast_dims) loss = alpha_fac * loss @@ -225,7 +240,7 @@ def softmax_focal_loss( def sigmoid_focal_loss( - input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) @@ -248,8 +263,21 @@ def sigmoid_focal_loss( loss = (invprobs * gamma).exp() * loss if alpha is not None: - # alpha if t==1; (1-alpha) if t==0 - alpha_factor = target * alpha + (1 - target) * (1 - alpha) + alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) + if alpha_t.ndim == 0: # scalar + # alpha if t==1; (1-alpha) if t==0 + alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t) + else: # sequence / per-channel alpha + if alpha_t.shape[0] != target.shape[1]: + raise ValueError( + f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." + ) + # Reshape alpha for broadcasting: (1, C, 1, 1...) + broadcast_dims = [-1] + [1] * len(target.shape[2:]) + alpha_t = alpha_t.view(broadcast_dims) + # Apply alpha_c if t==1, (1-alpha_c) if t==0 for channel c + alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t) + loss = alpha_factor * loss return loss From 820ce947bd3d4b7192afc951590faafc7dbea837 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 14:29:05 +0800 Subject: [PATCH 02/10] fix Local variable lpha_arg is assigned to but never used Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index ede08b970d..b000a7f8fb 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -165,9 +165,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if isinstance(self.alpha, (float, int)): alpha_arg = None warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") - loss = softmax_focal_loss(input, target, self.gamma, self.alpha_arg) + loss = softmax_focal_loss(input, target, self.gamma, alpha_arg) else: - loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha_arg) + loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg) num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: From 50cc7e97bd2103a97c048c785dbffbdcebeb2353 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 14:47:14 +0800 Subject: [PATCH 03/10] fix mypy error Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index b000a7f8fb..fcd31a80f6 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -223,8 +223,9 @@ def softmax_focal_loss( alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) if alpha_t.ndim == 0: # scalar + alpha_val = alpha_t.item() # (1-alpha) for the background class and alpha for the other classes - alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) + alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss) else: # sequence if alpha_t.shape[0] != target.shape[1]: raise ValueError( From 1b2483441dccbc0ee117dfdec15f4663bd4a8e73 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 16:03:28 +0800 Subject: [PATCH 04/10] fix undefined type error Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index fcd31a80f6..e463ee9e8d 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -112,12 +112,17 @@ def __init__( self.include_background = include_background self.to_onehot_y = to_onehot_y self.gamma = gamma - self.alpha = alpha self.weight = weight self.use_softmax = use_softmax weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.alpha: float | torch.Tensor | None + + if isinstance(alpha, (list, tuple)): + self.alpha = torch.tensor(alpha) + else: + self.alpha = alpha def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -159,7 +164,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input = input.float() target = target.float() - alpha_arg = self.alpha + alpha_arg: float | torch.Tensor | None = self.alpha + if isinstance(alpha_arg, torch.Tensor): + alpha_arg = alpha_arg.to(input.device) + if self.use_softmax: if not self.include_background and self.alpha is not None: if isinstance(self.alpha, (float, int)): @@ -208,7 +216,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def softmax_focal_loss( - input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) @@ -241,7 +249,7 @@ def softmax_focal_loss( def sigmoid_focal_loss( - input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) From 258b79dc8785292dce42e98d8a1eac80f96ac283 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 16:19:34 +0800 Subject: [PATCH 05/10] fix alpha type bugs Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index e463ee9e8d..e3fc246c0f 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -114,15 +114,15 @@ def __init__( self.gamma = gamma self.weight = weight self.use_softmax = use_softmax - weight = torch.as_tensor(weight) if weight is not None else None - self.register_buffer("class_weight", weight) - self.class_weight: None | torch.Tensor self.alpha: float | torch.Tensor | None if isinstance(alpha, (list, tuple)): self.alpha = torch.tensor(alpha) else: self.alpha = alpha + weight = torch.as_tensor(weight) if weight is not None else None + self.register_buffer("class_weight", weight) + self.class_weight: None | torch.Tensor def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -165,8 +165,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = target.float() alpha_arg: float | torch.Tensor | None = self.alpha - if isinstance(alpha_arg, torch.Tensor): - alpha_arg = alpha_arg.to(input.device) if self.use_softmax: if not self.include_background and self.alpha is not None: @@ -228,13 +226,16 @@ def softmax_focal_loss( loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target if alpha is not None: - alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) + if isinstance(alpha, torch.Tensor): + alpha_t = alpha.to(device=input.device, dtype=input.dtype) + else: + alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype) if alpha_t.ndim == 0: # scalar alpha_val = alpha_t.item() # (1-alpha) for the background class and alpha for the other classes alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss) - else: # sequence + else: # tensor (sequence) if alpha_t.shape[0] != target.shape[1]: raise ValueError( f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." @@ -272,11 +273,15 @@ def sigmoid_focal_loss( loss = (invprobs * gamma).exp() * loss if alpha is not None: - alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) + if isinstance(alpha, torch.Tensor): + alpha_t = alpha.to(device=input.device, dtype=input.dtype) + else: + alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype) + if alpha_t.ndim == 0: # scalar # alpha if t==1; (1-alpha) if t==0 alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t) - else: # sequence / per-channel alpha + else: # tensor (sequence) if alpha_t.shape[0] != target.shape[1]: raise ValueError( f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." From a574d7ba2af3a9f6c539ac8c237485208fc3066a Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 16:36:07 +0800 Subject: [PATCH 06/10] fix alpha type bugs Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index e3fc246c0f..a5d3748814 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -115,11 +115,12 @@ def __init__( self.weight = weight self.use_softmax = use_softmax self.alpha: float | torch.Tensor | None - - if isinstance(alpha, (list, tuple)): - self.alpha = torch.tensor(alpha) + if alpha is None: + self.alpha = None + elif isinstance(alpha, (float, int)): + self.alpha = float(alpha) else: - self.alpha = alpha + self.alpha = torch.as_tensor(alpha) weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor @@ -165,6 +166,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = target.float() alpha_arg: float | torch.Tensor | None = self.alpha + if isinstance(alpha_arg, torch.Tensor): + alpha_arg = alpha_arg.to(input.device) if self.use_softmax: if not self.include_background and self.alpha is not None: From 1f37d0dfbb9e9485182d519668e675202e159c5f Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 6 Jan 2026 13:54:00 +0800 Subject: [PATCH 07/10] add test case for alpha as a sequence Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 9 ++------ tests/losses/test_focal_loss.py | 41 +++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index a5d3748814..526885e3f1 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -165,18 +165,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input = input.float() target = target.float() - alpha_arg: float | torch.Tensor | None = self.alpha - if isinstance(alpha_arg, torch.Tensor): - alpha_arg = alpha_arg.to(input.device) - if self.use_softmax: if not self.include_background and self.alpha is not None: if isinstance(self.alpha, (float, int)): - alpha_arg = None warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") - loss = softmax_focal_loss(input, target, self.gamma, alpha_arg) + loss = softmax_focal_loss(input, target, self.gamma, self.alpha) else: - loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg) + loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: diff --git a/tests/losses/test_focal_loss.py b/tests/losses/test_focal_loss.py index e7f447d90e..15f78d4f37 100644 --- a/tests/losses/test_focal_loss.py +++ b/tests/losses/test_focal_loss.py @@ -374,6 +374,47 @@ def test_script(self): test_input = torch.ones(2, 2, 8, 8) test_script_save(loss, test_input, test_input) + def test_alpha_sequence_broadcasting(self): + """ + Test FocalLoss with alpha as a sequence for proper broadcasting. + """ + num_classes = 3 + alpha_seq = [0.1, 0.5, 2.0] + batch_size = 2 + spatial_dims = (4, 4) + + devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] + + for device in devices: + logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device) + target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device) + + # Case 1: Softmax + Alpha Sequence + loss_func_softmax = FocalLoss( + to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=True, reduction="mean" + ) + loss_soft = loss_func_softmax(logits, target) + + self.assertTrue(torch.is_tensor(loss_soft)) + self.assertEqual(loss_soft.ndim, 0) + self.assertTrue(loss_soft > 0, f"Softmax loss on {device} should be positive") + + # Case 2: Sigmoid + Alpha Sequence + loss_func_sigmoid = FocalLoss( + to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=False, reduction="mean" + ) + loss_sig = loss_func_sigmoid(logits, target) + + self.assertTrue(torch.is_tensor(loss_sig)) + self.assertEqual(loss_sig.ndim, 0) + self.assertTrue(loss_sig > 0, f"Sigmoid loss on {device} should be positive") + + # Case 3: Error Handling (Mismatched alpha length) + if device == devices[0]: + wrong_alpha = [0.1, 0.5] + with self.assertRaisesRegex(ValueError, "length of alpha"): + FocalLoss(to_onehot_y=True, alpha=wrong_alpha, use_softmax=True)(logits, target) + if __name__ == "__main__": unittest.main() From 7ef77d1717d64a13f292bc7aeb333e493b549d95 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 7 Jan 2026 13:46:32 +0800 Subject: [PATCH 08/10] another test to test without the background Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 4 +- tests/losses/test_focal_loss.py | 68 ++++++++++++++++----------------- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 526885e3f1..f9d1c94f54 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -168,7 +168,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.use_softmax: if not self.include_background and self.alpha is not None: if isinstance(self.alpha, (float, int)): - warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") + warnings.warn( + "`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2 + ) loss = softmax_focal_loss(input, target, self.gamma, self.alpha) else: loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) diff --git a/tests/losses/test_focal_loss.py b/tests/losses/test_focal_loss.py index 15f78d4f37..151db2dc71 100644 --- a/tests/losses/test_focal_loss.py +++ b/tests/losses/test_focal_loss.py @@ -21,7 +21,7 @@ from monai.losses import FocalLoss from monai.networks import one_hot -from tests.test_utils import test_script_save +from tests.test_utils import TEST_DEVICES, test_script_save TEST_CASES = [] for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: @@ -77,6 +77,13 @@ TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276]) TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8, "use_softmax": True}, input_data, 0.08138]) +TEST_ALPHA_BROADCASTING = [] +for case in TEST_DEVICES: + device = case[0] + for include_background in [True, False]: + for use_softmax in [True, False]: + TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax]) + class TestFocalLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -374,46 +381,39 @@ def test_script(self): test_input = torch.ones(2, 2, 8, 8) test_script_save(loss, test_input, test_input) - def test_alpha_sequence_broadcasting(self): + @parameterized.expand(TEST_ALPHA_BROADCASTING) + def test_alpha_sequence_broadcasting(self, device, include_background, use_softmax): """ Test FocalLoss with alpha as a sequence for proper broadcasting. """ num_classes = 3 - alpha_seq = [0.1, 0.5, 2.0] batch_size = 2 spatial_dims = (4, 4) - devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] - - for device in devices: - logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device) - target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device) - - # Case 1: Softmax + Alpha Sequence - loss_func_softmax = FocalLoss( - to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=True, reduction="mean" - ) - loss_soft = loss_func_softmax(logits, target) - - self.assertTrue(torch.is_tensor(loss_soft)) - self.assertEqual(loss_soft.ndim, 0) - self.assertTrue(loss_soft > 0, f"Softmax loss on {device} should be positive") - - # Case 2: Sigmoid + Alpha Sequence - loss_func_sigmoid = FocalLoss( - to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=False, reduction="mean" - ) - loss_sig = loss_func_sigmoid(logits, target) - - self.assertTrue(torch.is_tensor(loss_sig)) - self.assertEqual(loss_sig.ndim, 0) - self.assertTrue(loss_sig > 0, f"Sigmoid loss on {device} should be positive") - - # Case 3: Error Handling (Mismatched alpha length) - if device == devices[0]: - wrong_alpha = [0.1, 0.5] - with self.assertRaisesRegex(ValueError, "length of alpha"): - FocalLoss(to_onehot_y=True, alpha=wrong_alpha, use_softmax=True)(logits, target) + logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device) + target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device) + + if include_background: + alpha_seq = [0.1, 0.5, 2.0] + else: + alpha_seq = [0.5, 2.0] + + loss_func = FocalLoss( + to_onehot_y=True, + gamma=2.0, + alpha=alpha_seq, + include_background=include_background, + use_softmax=use_softmax, + reduction="mean", + ) + + result = loss_func(logits, target) + + self.assertTrue(torch.is_tensor(result)) + self.assertEqual(result.ndim, 0) + self.assertTrue( + result > 0, f"Loss should be positive. params: dev={device}, bg={include_background}, softmax={use_softmax}" + ) if __name__ == "__main__": From b547410ac2bc01aa21a8311868cd33fd946afadc Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 7 Jan 2026 18:30:47 +0800 Subject: [PATCH 09/10] fix bug: scalar alpha still passed despite ignored warning Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 7 ++++--- tests/losses/test_focal_loss.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index f9d1c94f54..5d0a1c4d27 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -164,16 +164,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss: Optional[torch.Tensor] = None input = input.float() target = target.float() - + alpha_arg = self.alpha if self.use_softmax: if not self.include_background and self.alpha is not None: if isinstance(self.alpha, (float, int)): + alpha_arg = None warnings.warn( "`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2 ) - loss = softmax_focal_loss(input, target, self.gamma, self.alpha) + loss = softmax_focal_loss(input, target, self.gamma, alpha_arg) else: - loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) + loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg) num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: diff --git a/tests/losses/test_focal_loss.py b/tests/losses/test_focal_loss.py index 151db2dc71..0c2716054a 100644 --- a/tests/losses/test_focal_loss.py +++ b/tests/losses/test_focal_loss.py @@ -79,10 +79,10 @@ TEST_ALPHA_BROADCASTING = [] for case in TEST_DEVICES: - device = case[0] + dev = case[0] for include_background in [True, False]: for use_softmax in [True, False]: - TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax]) + TEST_ALPHA_BROADCASTING.append([dev, include_background, use_softmax]) class TestFocalLoss(unittest.TestCase): From 65e1c363205429e16321eb79e0d466e01dfa4175 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Thu, 8 Jan 2026 10:55:55 +0800 Subject: [PATCH 10/10] fix: When alpha is a sequence, each alpha[c] should be interpreted as the weight for positive samples of class c. Negative samples should have a default weight of 1.0 Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 9 ++++++--- tests/losses/test_focal_loss.py | 7 ++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index bf8eb1fb31..caa237fca8 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -82,7 +82,8 @@ def __init__( gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2. alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in [0, 1]. - If a sequence is provided, it must match the number of classes (after excluding background if set). + If a sequence is provided, its length must match the number of classes + (excluding the background class if `include_background=False`). Defaults to None. weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length @@ -289,8 +290,10 @@ def sigmoid_focal_loss( # Reshape alpha for broadcasting: (1, C, 1, 1...) broadcast_dims = [-1] + [1] * len(target.shape[2:]) alpha_t = alpha_t.view(broadcast_dims) - # Apply alpha_c if t==1, (1-alpha_c) if t==0 for channel c - alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t) + # Apply per-class weight only to positive samples + # For positive samples (target==1): multiply by alpha[c] + # For negative samples (target==0): keep weight as 1.0 + alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t)) loss = alpha_factor * loss diff --git a/tests/losses/test_focal_loss.py b/tests/losses/test_focal_loss.py index 0c2716054a..35017ec898 100644 --- a/tests/losses/test_focal_loss.py +++ b/tests/losses/test_focal_loss.py @@ -24,7 +24,8 @@ from tests.test_utils import TEST_DEVICES, test_script_save TEST_CASES = [] -for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: +for case in TEST_DEVICES: + device = case[0] input_data = { "input": torch.tensor( [[[[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device @@ -79,10 +80,10 @@ TEST_ALPHA_BROADCASTING = [] for case in TEST_DEVICES: - dev = case[0] + device = case[0] for include_background in [True, False]: for use_softmax in [True, False]: - TEST_ALPHA_BROADCASTING.append([dev, include_background, use_softmax]) + TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax]) class TestFocalLoss(unittest.TestCase):