-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add explicit spatial_ndim tracking to MetaTensor #8765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
74e7ca4
c52a149
ea915cb
e50ae41
787eef4
9f359b0
dc16dda
ab2be3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,22 +13,59 @@ | |||||||||
|
|
||||||||||
| import functools | ||||||||||
| import warnings | ||||||||||
| from collections.abc import Sequence | ||||||||||
| from collections.abc import Mapping, Sequence | ||||||||||
| from copy import deepcopy | ||||||||||
| from typing import Any | ||||||||||
|
|
||||||||||
| import numpy as np | ||||||||||
| import torch | ||||||||||
|
|
||||||||||
| import monai | ||||||||||
| from monai.config.type_definitions import NdarrayTensor | ||||||||||
| from monai.data.meta_obj import MetaObj, get_track_meta | ||||||||||
| from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata | ||||||||||
| from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor | ||||||||||
| from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta | ||||||||||
| from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata | ||||||||||
| from monai.utils import look_up_option | ||||||||||
| from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys | ||||||||||
| from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor | ||||||||||
|
|
||||||||||
| __all__ = ["MetaTensor"] | ||||||||||
| __all__ = ["MetaTensor", "get_spatial_ndim"] | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int: | ||||||||||
| """Clamp spatial dims to a valid range for the current tensor shape.""" | ||||||||||
| limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1) | ||||||||||
| return max(1, min(int(spatial_ndim), limit)) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _has_explicit_no_channel(meta: Mapping | None) -> bool: | ||||||||||
| return ( | ||||||||||
| isinstance(meta, Mapping) | ||||||||||
| and MetaKeys.ORIGINAL_CHANNEL_DIM in meta | ||||||||||
| and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM]) | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def get_spatial_ndim(img: NdarrayOrTensor) -> int: | ||||||||||
| """Return the number of spatial dimensions assuming channel-first layout. | ||||||||||
|
|
||||||||||
| Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to | ||||||||||
| ``img.ndim - 1``. Always assumes channel-first (``no_channel=False``) | ||||||||||
| because callers run after ``EnsureChannelFirst`` has already added one. | ||||||||||
| """ | ||||||||||
| if isinstance(img, MetaTensor): | ||||||||||
| return _normalize_spatial_ndim(img.spatial_ndim, img.ndim) | ||||||||||
| return img.ndim - 1 | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _is_batch_only_index(index: Any) -> bool: | ||||||||||
| """True when indexing pattern selects only the batch axis (e.g., ``x[0]`` or ``x[0, ...]``).""" | ||||||||||
| if isinstance(index, (int, np.integer)): | ||||||||||
| return True | ||||||||||
| if not isinstance(index, Sequence) or not index: | ||||||||||
| return False | ||||||||||
| if not isinstance(index[0], (int, np.integer)): | ||||||||||
| return False | ||||||||||
| return all(i in (slice(None, None, None), Ellipsis, None) for i in index[1:]) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @functools.lru_cache(None) | ||||||||||
|
|
@@ -111,6 +148,7 @@ def __new__( | |||||||||
| meta: dict | None = None, | ||||||||||
| applied_operations: list | None = None, | ||||||||||
| *args, | ||||||||||
| spatial_ndim: int | None = None, | ||||||||||
| **kwargs, | ||||||||||
| ) -> MetaTensor: | ||||||||||
| _kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {} | ||||||||||
|
|
@@ -123,6 +161,7 @@ def __init__( | |||||||||
| meta: dict | None = None, | ||||||||||
| applied_operations: list | None = None, | ||||||||||
| *_args, | ||||||||||
| spatial_ndim: int | None = None, | ||||||||||
| **_kwargs, | ||||||||||
| ) -> None: | ||||||||||
| """ | ||||||||||
|
|
@@ -134,6 +173,8 @@ def __init__( | |||||||||
| the list is typically maintained by `monai.transforms.TraceableTransform`. | ||||||||||
| See also: :py:class:`monai.transforms.TraceableTransform` | ||||||||||
| _args: additional args (currently not in use in this constructor). | ||||||||||
| spatial_ndim: optional number of spatial dimensions. If ``None``, derived | ||||||||||
| from the affine matrix clamped by the tensor shape. | ||||||||||
| _kwargs: additional kwargs (currently not in use in this constructor). | ||||||||||
|
|
||||||||||
| Note: | ||||||||||
|
|
@@ -158,6 +199,14 @@ def __init__( | |||||||||
| self.affine = self.meta[MetaKeys.AFFINE] | ||||||||||
| else: | ||||||||||
| self.affine = self.get_default_affine() | ||||||||||
| # Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape. | ||||||||||
| # This cached value is kept in sync via the affine setter for hot-path performance. | ||||||||||
| no_channel = _has_explicit_no_channel(self.meta) | ||||||||||
| if spatial_ndim is not None: | ||||||||||
| self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel) | ||||||||||
| elif self.affine.ndim == 2: | ||||||||||
| self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel) | ||||||||||
|
|
||||||||||
| # applied_operations | ||||||||||
| if applied_operations is not None: | ||||||||||
| self.applied_operations = applied_operations | ||||||||||
|
|
@@ -237,6 +286,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): | |||||||||
| if func == torch.Tensor.__getitem__: | ||||||||||
| if idx > 0 or len(args) < 2 or len(args[0]) < 1: | ||||||||||
| return ret | ||||||||||
| full_idx = args[1] | ||||||||||
| batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1] | ||||||||||
| # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the | ||||||||||
| # first element will be `slice(None, None, None)` and `Ellipsis`, | ||||||||||
|
|
@@ -258,6 +308,8 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): | |||||||||
| ret_meta.is_batch = False | ||||||||||
| if hasattr(ret_meta, "__dict__"): | ||||||||||
| ret.__dict__ = ret_meta.__dict__.copy() | ||||||||||
| if _is_batch_only_index(full_idx): | ||||||||||
| ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim) | ||||||||||
| # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. | ||||||||||
| # But we only want to split the batch if the `unbind` is along the 0th dimension. | ||||||||||
| elif func == torch.Tensor.unbind: | ||||||||||
|
|
@@ -467,15 +519,40 @@ def affine(self) -> torch.Tensor: | |||||||||
|
|
||||||||||
| @affine.setter | ||||||||||
| def affine(self, d: NdarrayTensor) -> None: | ||||||||||
| """Set the affine.""" | ||||||||||
| self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) | ||||||||||
| """Set the affine. | ||||||||||
|
|
||||||||||
| When setting a non-batched affine matrix, automatically synchronizes the cached | ||||||||||
| spatial_ndim attribute to maintain consistency between the affine matrix (source of truth) | ||||||||||
| and the cached spatial dimension count. | ||||||||||
| """ | ||||||||||
| a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) | ||||||||||
| self.meta[MetaKeys.AFFINE] = a | ||||||||||
| if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth) | ||||||||||
| no_channel = _has_explicit_no_channel(self.meta) | ||||||||||
| self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel) | ||||||||||
|
|
||||||||||
| @property | ||||||||||
| def spatial_ndim(self) -> int: | ||||||||||
| """Get the number of spatial dimensions. | ||||||||||
|
|
||||||||||
| This value is cached for hot-path performance and is kept in sync with the affine matrix | ||||||||||
| via the affine setter. The affine matrix is the source of truth for spatial dimensions. | ||||||||||
| """ | ||||||||||
| return getattr(self, "_spatial_ndim", _DEFAULT_SPATIAL_NDIM) | ||||||||||
|
|
||||||||||
| @spatial_ndim.setter | ||||||||||
| def spatial_ndim(self, val: int) -> None: | ||||||||||
| """Set the number of spatial dimensions.""" | ||||||||||
| if val < 1: | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| raise ValueError(f"spatial_ndim must be >= 1, got {val}") | ||||||||||
| self._spatial_ndim = val | ||||||||||
|
Comment on lines
+544
to
+548
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Harden Line 518 only checks Proposed fix+from numbers import Integral
...
`@spatial_ndim.setter`
def spatial_ndim(self, val: int) -> None:
"""Set the number of spatial dimensions."""
+ if not isinstance(val, Integral):
+ raise TypeError(f"spatial_ndim must be an integer, got {type(val).__name__}")
if val < 1:
raise ValueError(f"spatial_ndim must be >= 1, got {val}")
- self._spatial_ndim = val
+ self._spatial_ndim = _normalize_spatial_ndim(int(val), self.ndim)🧰 Tools🪛 Ruff (0.15.2)[warning] 519-519: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
|
||||||||||
| @property | ||||||||||
| def pixdim(self): | ||||||||||
| """Get the spacing""" | ||||||||||
| if self.is_batch: | ||||||||||
| return [affine_to_spacing(a) for a in self.affine] | ||||||||||
| return affine_to_spacing(self.affine) | ||||||||||
| return [affine_to_spacing(a, r=self.spatial_ndim) for a in self.affine] | ||||||||||
| return affine_to_spacing(self.affine, r=self.spatial_ndim) | ||||||||||
|
|
||||||||||
| def peek_pending_shape(self): | ||||||||||
| """ | ||||||||||
|
|
@@ -490,7 +567,7 @@ def peek_pending_shape(self): | |||||||||
|
|
||||||||||
| def peek_pending_affine(self): | ||||||||||
| res = self.affine | ||||||||||
| r = len(res) - 1 | ||||||||||
| r = res.shape[-1] - 1 if res.ndim >= 2 else self.spatial_ndim | ||||||||||
| if r not in (2, 3): | ||||||||||
| warnings.warn(f"Only 2d and 3d affine are supported, got {r}d input.") | ||||||||||
| for p in self.pending_operations: | ||||||||||
|
|
@@ -503,8 +580,10 @@ def peek_pending_affine(self): | |||||||||
| return res | ||||||||||
|
|
||||||||||
| def peek_pending_rank(self): | ||||||||||
| a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine | ||||||||||
| return 1 if a is None else int(max(1, len(a) - 1)) | ||||||||||
| if self.pending_operations: | ||||||||||
| a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) | ||||||||||
| return 1 if a is None else int(max(1, len(a) - 1)) | ||||||||||
| return self.spatial_ndim | ||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||
|
|
||||||||||
| def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override] | ||||||||||
| """ | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.