"""Model base classes."""
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from typing import Generic
from typing import TypeVar
import pandas as pd
from keras import ops
from prfmodel.stimuli.base import Stimulus
from prfmodel.typing import Tensor
from prfmodel.utils import _get_norm_fun
S = TypeVar("S", bound=Stimulus)
[docs]
class BatchDimensionError(Exception):
"""
Exception raised when arguments have different sizes in the batch (first) dimension.
Parameters
----------
arg_names: Sequence[str]
Names of arguments that have different sizes in batch dimension.
arg_shapes: Sequence[tuple of int]
Shapes of arguments that have different sizes in batch dimension.
"""
def __init__(self, arg_names: Sequence[str], arg_shapes: Sequence[tuple[int, ...]]):
names = ", ".join(arg_names)
shapes = ", ".join([str(s[0]) for s in arg_shapes])
super().__init__(f"Arguments {names} have different sizes in batch (first) dimension: {shapes}")
[docs]
class ShapeError(Exception):
"""
Exception raised when an argument has less than two dimensions.
Parameters
----------
arg_name: str
Argument name.
arg_shape: tuple of int
Argument shape.
"""
def __init__(self, arg_name: str, arg_shape: tuple[int, ...]):
super().__init__(
f"Argument {arg_name} must have at least two dimensions but has shape {arg_shape}",
)
[docs]
class BaseModel(ABC):
"""
Abstract base class for models.
Cannot be instantiated on its own.
Can only be used as a parent class to create custom model classes.
Subclasses must override the abstract `parameter_names` property.
Attributes
----------
parameter_names
Examples
--------
Create a custom model class that inherits from the base class:
>>> class CustomModel(BaseModel):
>>> @property
>>> def parameter_names(self):
>>> return ["a", "b"]
>>> model = CustomModel()
>>> print(model.parameter_names)
["a", "b"]
"""
@property
@abstractmethod
def parameter_names(self) -> list[str]:
"""A list with names of parameters that are used by the model."""
[docs]
class BaseResponse(BaseModel, Generic[S]):
"""
Generic abstract base class for response models.
Cannot be instantiated on its own.
Can only be used as a parent class to create custom population receptive field models.
Subclasses must override the abstract `__call__` method and must be defined
with a specific stimulus type.
"""
[docs]
@abstractmethod
def __call__(self, stimulus: S, parameters: pd.DataFrame, dtype: str | None = None) -> Tensor:
"""
Predict the model response for a stimulus.
Parameters
----------
stimulus : Stimulus
Stimulus object.
parameters : pandas.DataFrame
Dataframe with columns containing different model parameters and rows containing parameter values
for different voxels.
dtype : str, optional
The dtype of the prediction result. If `None` (the default), uses the dtype from
:func:`prfmodel.utils.get_dtype`.
Returns
-------
Tensor
Model predictions of shape `(num_voxels, ...)` and dtype `dtype`. The number of voxels is the
number of rows in `parameters`. The number and size of other axes depends on the stimulus.
"""
[docs]
class BaseEncoder(BaseModel, Generic[S]):
"""
Generic abstract base class for encoding model responses.
Cannot be instantiated on its own.
Can only be used as a parent class to create custom encoding models.
Subclasses must override the abstract `parameter_names` property and `__call__` method and must be defined
with a specific stimulus type.
"""
[docs]
@abstractmethod
def __call__(
self,
stimulus: S,
response: Tensor,
parameters: pd.DataFrame,
dtype: str | None = None,
):
"""Encode a model response with a stimulus.
Parameters
----------
stimulus : Stimulus
Stimulus object.
response : Tensor
Model response.
parameters : pandas.DataFrame
Dataframe with columns containing different model parameters and rows containing parameter values
for different voxels.
dtype : str, optional
The dtype of the encoded response. If `None` (the default), uses the dtype from
:func:`prfmodel.utils.get_dtype`.
Returns
-------
Tensor
The stimulus encoded model response with shape `(num_voxels, ...)` dtype `dtype`. The number of voxels is
the number of rows in `parameters`. The number and size of other axes depends on the stimulus and the
response.
"""
[docs]
class BaseImpulse(BaseModel):
"""
Abstract base class for impulse response models.
Cannot be instantiated on its own.
Can only be used as a parent class to create custom impulse response models.
Subclasses must override the abstract `__call__` method.
Parameters
----------
duration : float, default=32.0
The duration of the impulse response (in seconds).
offset : float, default=0.0001
The offset of the impulse response (in seconds). By default a very small offset is added to prevent infinite
response values at t = 0.
resolution : float, default=1.0
The time resultion of the impulse response (in seconds), that is the number of points per second at which the
impulse response function is evaluated.
norm : str, optional, default="sum"
The normalization of the response. Can be `"sum"` (default), `"mean"`, `"max"`, `"norm"`, or `None`.
If `None`, no normalization is performed.
default_parameters : dict of float, optional
Dictionary with scalar default parameter values. Keys must be valid parameter names.
"""
def __init__(
self,
duration: float = 32.0,
offset: float = 0.0001,
resolution: float = 1.0,
norm: str | None = "sum",
default_parameters: dict[str, float] | None = None,
):
super().__init__()
self.duration = duration
self.offset = offset
self.resolution = resolution
# Check if norm arg is valid
if norm is not None:
_get_norm_fun(norm)
self.norm = norm
if default_parameters is not None:
if any(key not in self.parameter_names for key in default_parameters):
msg = "Invalid default parameter name, please provide valid parameter default parameter names"
raise ValueError(msg)
if any(not isinstance(val, float) for val in default_parameters.values()):
msg = "Default parameters must be single float values"
raise ValueError(msg)
self.default_parameters = default_parameters
self._frames: Tensor | None = None
@property
def num_frames(self) -> int:
"""The total number of time frames at which the impulse response function is evaluated."""
return int(self.duration / self.resolution)
@property
def frames(self) -> Tensor:
"""
The time frames at which the impulse response function is evaluated.
Time frames are linearly interpolated between `offset` and `duration` and have shape (1, `num_frames`).
"""
if self._frames is None:
self._frames = ops.expand_dims(ops.linspace(self.offset, self.duration, self.num_frames), 0)
return self._frames
def _join_default_parameters(self, parameters: pd.DataFrame) -> pd.DataFrame:
if self.default_parameters is not None:
parameters = parameters.copy()
for key, val in self.default_parameters.items():
parameters[key] = val
return parameters
[docs]
@abstractmethod
def __call__(self, parameters: pd.DataFrame, dtype: str | None = None) -> Tensor:
"""
Compute the impulse response.
Parameters
----------
parameters : pandas.DataFrame
Dataframe with columns containing different model parameters and rows containing parameter values
for different voxels.
dtype : str, optional
The dtype of the prediction result. If `None` (the default), uses the dtype from
:func:`prfmodel.utils.get_dtype`.
Returns
-------
Tensor
Model predictions of shape `(num_voxels, num_frames)` and dtype `dtype`. The number of voxels is the
number of rows in `parameters`.
"""
[docs]
class BaseTemporal(BaseModel):
"""
Abstract base class for temporal models.
Cannot be instantiated on its own.
Can only be used as a parent class to create custom temporal models.
Subclasses must override the abstract `__call__` method.
"""
[docs]
@abstractmethod
def __call__(self, inputs: Tensor, parameters: pd.DataFrame, dtype: str | None = None) -> Tensor:
"""
Make predictions with the temporal model.
Parameters
----------
inputs : Tensor
Input tensor with temporal response and shape (num_batches, num_frames).
parameters : pandas.DataFrame
Dataframe with columns containing different model parameters and rows containing parameter values
for different batches.
dtype : str, optional
The dtype of the prediction result. If `None` (the default), uses the dtype from
:func:`prfmodel.utils.get_dtype`.
Returns
-------
Tensor
Model predictions of shape `(num_voxels, num_frames)` and dtype `dtype`. The number of voxels is the
number of rows in `parameters`.
"""
[docs]
class BaseComposite(BaseModel, Generic[S]):
"""
Generic abstract base class for creating composite models.
Cannot be instantiated on its own. Can only be used as a parent class to create custom composite models.
Subclasses must override the abstract `__call__` method and must be defined
with a specific stimulus type.
This class is intended for combining multiple submodels into a composite model with a custom `__call__`
method that defines how the submodels interact to make a composite prediction.
Parameters
----------
**models
Submodels to be combined into the composite model. All submodel classes must inherit from `BaseModel`.
Raises
------
TypeError
If submodel classes do not inherit from `BaseModel`.
"""
def __init__(self, **models: BaseModel | None):
super().__init__()
for model in models.values():
if model is not None and not issubclass(model.__class__, BaseModel):
msg = "Model instance must inherit from BaseModel"
raise TypeError(msg)
self.models = models
@property
def parameter_names(self) -> list[str]:
"""A list with names of unique parameters that are used by the submodels."""
param_names = []
for model in self.models.values():
if model is not None:
param_names.extend(model.parameter_names)
# Make sure no duplicates are returned (preserve insertion order)
return list(dict.fromkeys(param_names))
[docs]
@abstractmethod
def __call__(
self,
stimulus: S,
parameters: pd.DataFrame,
dtype: str | None = None,
) -> Tensor:
"""
Predict a composite model response to a stimulus.
Parameters
----------
stimulus : Stimulus.
Stimulus object.
parameters : pandas.DataFrame
Dataframe with columns containing different (sub-) model parameters and rows containing parameter values
for different voxels.
dtype : str, optional
The dtype of the prediction result. If `None` (the default), uses the dtype from
:func:`prfmodel.utils.get_dtype`.
Returns
-------
Tensor
Composite model predictions of shape `(num_voxels, num_frames)` and dtype `dtype`. The number of voxels is
the number of rows in `parameters`.
"""