Skip to content
12 changes: 8 additions & 4 deletions monai/losses/image_dissimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,16 @@ def __init__(
raise ValueError(f"kernel_size must be odd, got {self.kernel_size}")

_kernel = look_up_option(kernel_type, kernel_dict)
self.kernel = _kernel(self.kernel_size)
self.kernel.require_grads = False
self.kernel_vol = self.get_kernel_vol()
self.kernel: torch.Tensor
self.kernel_vol: torch.Tensor
self.register_buffer("kernel", _kernel(self.kernel_size), persistent=False)
Comment thread
ericspod marked this conversation as resolved.
self.register_buffer("kernel_vol", self.get_kernel_vol(), persistent=False)
Comment thread
ericspod marked this conversation as resolved.

self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)

def get_kernel_vol(self):
def get_kernel_vol(self) -> torch.Tensor:
assert self.kernel is not None
vol = self.kernel
for _ in range(self.ndim - 1):
vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0))
Expand All @@ -138,6 +140,8 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")

t2, p2, tp = target * target, pred * pred, target * pred
assert self.kernel is not None
assert self.kernel_vol is not None
kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred)
kernels = [kernel] * self.ndim
# sum over kernel
Expand Down
Loading