Skip to content

API Reference

SetConv2D

Bases: Layer

Implementation of the SetConv2D layer. For more details see Chinello & Boracchi (2025).

Parameters:

Name Type Description Default
filters int

Number of output filters in the convolution.

required
kernel_size int | tuple

Size of the convolution kernel.

required
activation string | None

Activation function to use.

None
mhsa_dropout float

Dropout rate for the MHSA layer.

0.0
padding string

Padding mode for convolution (same or valid).

'same'
strides int | tuple

Stride size for convolution.

1
**kwargs

Additional keyword arguments for the Layer base class.

{}
Source code in src/cstmodels/layers.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
@keras.utils.register_keras_serializable()
class SetConv2D(layers.Layer):
    """
    Implementation of the SetConv2D layer. For more details see Chinello & Boracchi (2025).

    Args:
        filters (int): Number of output filters in the convolution.
        kernel_size (int | tuple): Size of the convolution kernel.
        activation (string | None): Activation function to use.
        mhsa_dropout (float): Dropout rate for the MHSA layer.
        padding (string): Padding mode for convolution (`same` or `valid`).
        strides (int | tuple): Stride size for convolution.
        **kwargs: Additional keyword arguments for the Layer base class.
    """
    def __init__(
            self,
            filters,
            kernel_size,
            activation=None,
            mhsa_dropout=.0,
            padding='same',
            strides=1,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.mhsa_dropout = mhsa_dropout
        self.padding = padding
        self.strides = strides

        if not (activation is None or isinstance(activation, str)):
            raise ValueError("Activation must be a string or None")
        self.activation = activation if activation else 'linear'

        self.conv = layers.Conv2D(
            self.filters,
            self.kernel_size,
            activation=None,
            padding=self.padding,
            strides=self.strides,
            name='conv'
        )
        self.gap = layers.GlobalAveragePooling2D()
        self.mha = layers.MultiHeadAttention(
            num_heads=max(1, self.filters // CST15_MHSA_HEAD_DIM),
            key_dim=min(self.filters, CST15_MHSA_HEAD_DIM),
            dropout=self.mhsa_dropout,
            name ='mhsa'
        )
        self.activ = layers.Activation(activation=self.activation)

    def build(self, input_shape):
        """
        This method simply marks the layer as built.

        Args:
            input_shape (shapelike): Shape of the input to the layer.
        """
        self.built = True

    def call(self, X, set_size):
        """
        Main logic for the SetConv2D layer.

        Args:
            X (tensor): Input tensor of shape `(batch * set_size, H, W, C)`.
            set_size (scalar): Size of the set dimension.

        Returns:
            X (tensor): Output tensor after applying SetConv2D operations.
        """
        # 1. Convolution
        X = self.conv(X)

        # 2. Compute channel descriptors via GAP
        Y = self.gap(X)

        # 3. Compute bias adjustments via MHSA
        Y = ops.reshape(Y, [-1, set_size, self.filters])
        Y = self.mha(Y, Y)
        Y = ops.reshape(Y, [-1, self.filters])

        # 4. Add dynamic bias adjustments to the output of 1.
        X = X + ops.expand_dims(ops.expand_dims(Y, axis=1), axis=1)

        # 5. Activation
        X = self.activ(X)

        return X

    def get_config(self):
        """
        Returns the configuration of the layer for serialization.

        Returns:
            config (dict): Configuration of the layer for serialization.
        """
        config = super().get_config()
        config.update({
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'activation': self.activation,
            'mhsa_dropout': self.mhsa_dropout,
            'padding': self.padding,
            'strides': self.strides,
        })

        return config

build(input_shape)

This method simply marks the layer as built.

Parameters:

Name Type Description Default
input_shape shapelike

Shape of the input to the layer.

required
Source code in src/cstmodels/layers.py
141
142
143
144
145
146
147
148
def build(self, input_shape):
    """
    This method simply marks the layer as built.

    Args:
        input_shape (shapelike): Shape of the input to the layer.
    """
    self.built = True

call(X, set_size)

Main logic for the SetConv2D layer.

Parameters:

Name Type Description Default
X tensor

Input tensor of shape (batch * set_size, H, W, C).

required
set_size scalar

Size of the set dimension.

required

Returns:

Name Type Description
X tensor

Output tensor after applying SetConv2D operations.

Source code in src/cstmodels/layers.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def call(self, X, set_size):
    """
    Main logic for the SetConv2D layer.

    Args:
        X (tensor): Input tensor of shape `(batch * set_size, H, W, C)`.
        set_size (scalar): Size of the set dimension.

    Returns:
        X (tensor): Output tensor after applying SetConv2D operations.
    """
    # 1. Convolution
    X = self.conv(X)

    # 2. Compute channel descriptors via GAP
    Y = self.gap(X)

    # 3. Compute bias adjustments via MHSA
    Y = ops.reshape(Y, [-1, set_size, self.filters])
    Y = self.mha(Y, Y)
    Y = ops.reshape(Y, [-1, self.filters])

    # 4. Add dynamic bias adjustments to the output of 1.
    X = X + ops.expand_dims(ops.expand_dims(Y, axis=1), axis=1)

    # 5. Activation
    X = self.activ(X)

    return X

get_config()

Returns the configuration of the layer for serialization.

Returns:

Name Type Description
config dict

Configuration of the layer for serialization.

Source code in src/cstmodels/layers.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def get_config(self):
    """
    Returns the configuration of the layer for serialization.

    Returns:
        config (dict): Configuration of the layer for serialization.
    """
    config = super().get_config()
    config.update({
        'filters': self.filters,
        'kernel_size': self.kernel_size,
        'activation': self.activation,
        'mhsa_dropout': self.mhsa_dropout,
        'padding': self.padding,
        'strides': self.strides,
    })

    return config

SmartReshape2D

Bases: Layer

Reshapes 4D or 5D tensors to handle an explicit set dimension.

This layer is useful when working with data that may or may not have a set dimension (e.g., (batch * set_size, H, W, C) vs. (batch, set_size, H, W, C)). It automatically infers the correct shape and reshapes the input tensor accordingly.

Parameters:

Name Type Description Default
**kwargs

Keyword arguments for the Layer base class.

{}
Source code in src/cstmodels/layers.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@keras.utils.register_keras_serializable()
class SmartReshape2D(layers.Layer):
    """
    Reshapes 4D or 5D tensors to handle an explicit set dimension.

    This layer is useful when working with data that may or may not have a set dimension
    (e.g., `(batch * set_size, H, W, C)` vs. `(batch, set_size, H, W, C)`).
    It automatically infers the correct shape and reshapes the input tensor accordingly.

    Args:
        **kwargs: Keyword arguments for the Layer base class.
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        pass

    def build(self, input_shape):
        """
        This method simply marks the layer as built.

        Args:
            input_shape (shapelike): Shape of the input to the layer.
        """
        self.built = True

    def call(self, x, set_size=None):
        """
        Main logic for the SmartReshape2D layer.

        If the input tensor has 5 dimensions, it is assumed to be in the format
        `(batch, set_size, H, W, C)` and is reshaped to
        `(batch * set_size, H, W, C)`.

        If the input tensor has 4 dimensions, it is assumed to be in the format
        `(batch * set_size, H, W, C)` and is reshaped to
        `(batch, set_size, H, W, C)`, where `set_size` is provided as an argument.

        Args:
            x (tensor): Input tensor of shape `(batch * set_size, H, W, C)`
                        or `(batch, set_size, H, W, C)`.
            set_size (scalar): Size of the set dimension (required if input is 4D, optional if 5D)
              or `None`.

        Returns:
            x (tensor): The reshaped tensor.
            set_size (scalar): The set size.
        """
        tensor_shape = ops.shape(x)
        height = tensor_shape[-3]
        width = tensor_shape[-2]
        channels = tensor_shape[-1]
        n_dims = ops.ndim(x)

        if n_dims == 5:
            # Input is already in (batch, set_size, height, width, channels) format
            # -> Flatten the set dimension
            target_shape = (-1, height, width, channels)
            set_size = ops.shape(x)[1] # Extract set_size from the input shape
        else:
            # Input is in (batch * set_size, height, width, channels) format
            # -> Reshape to include set dimension
            target_shape = (-1, set_size, height, width, channels)

        x = ops.reshape(x, target_shape)

        return x, set_size

    def get_config(self):
        """
        Returns the configuration of the layer for serialization.

        Returns:
            config (dict): Configuration of the layer for serialization.
        """
        config = super().get_config()
        return config

build(input_shape)

This method simply marks the layer as built.

Parameters:

Name Type Description Default
input_shape shapelike

Shape of the input to the layer.

required
Source code in src/cstmodels/layers.py
27
28
29
30
31
32
33
34
def build(self, input_shape):
    """
    This method simply marks the layer as built.

    Args:
        input_shape (shapelike): Shape of the input to the layer.
    """
    self.built = True

call(x, set_size=None)

Main logic for the SmartReshape2D layer.

If the input tensor has 5 dimensions, it is assumed to be in the format (batch, set_size, H, W, C) and is reshaped to (batch * set_size, H, W, C).

If the input tensor has 4 dimensions, it is assumed to be in the format (batch * set_size, H, W, C) and is reshaped to (batch, set_size, H, W, C), where set_size is provided as an argument.

Parameters:

Name Type Description Default
x tensor

Input tensor of shape (batch * set_size, H, W, C) or (batch, set_size, H, W, C).

required
set_size scalar

Size of the set dimension (required if input is 4D, optional if 5D) or None.

None

Returns:

Name Type Description
x tensor

The reshaped tensor.

set_size scalar

The set size.

Source code in src/cstmodels/layers.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def call(self, x, set_size=None):
    """
    Main logic for the SmartReshape2D layer.

    If the input tensor has 5 dimensions, it is assumed to be in the format
    `(batch, set_size, H, W, C)` and is reshaped to
    `(batch * set_size, H, W, C)`.

    If the input tensor has 4 dimensions, it is assumed to be in the format
    `(batch * set_size, H, W, C)` and is reshaped to
    `(batch, set_size, H, W, C)`, where `set_size` is provided as an argument.

    Args:
        x (tensor): Input tensor of shape `(batch * set_size, H, W, C)`
                    or `(batch, set_size, H, W, C)`.
        set_size (scalar): Size of the set dimension (required if input is 4D, optional if 5D)
          or `None`.

    Returns:
        x (tensor): The reshaped tensor.
        set_size (scalar): The set size.
    """
    tensor_shape = ops.shape(x)
    height = tensor_shape[-3]
    width = tensor_shape[-2]
    channels = tensor_shape[-1]
    n_dims = ops.ndim(x)

    if n_dims == 5:
        # Input is already in (batch, set_size, height, width, channels) format
        # -> Flatten the set dimension
        target_shape = (-1, height, width, channels)
        set_size = ops.shape(x)[1] # Extract set_size from the input shape
    else:
        # Input is in (batch * set_size, height, width, channels) format
        # -> Reshape to include set dimension
        target_shape = (-1, set_size, height, width, channels)

    x = ops.reshape(x, target_shape)

    return x, set_size

get_config()

Returns the configuration of the layer for serialization.

Returns:

Name Type Description
config dict

Configuration of the layer for serialization.

Source code in src/cstmodels/layers.py
78
79
80
81
82
83
84
85
86
def get_config(self):
    """
    Returns the configuration of the layer for serialization.

    Returns:
        config (dict): Configuration of the layer for serialization.
    """
    config = super().get_config()
    return config

CST15(pretrained=True)

Loads or builds the CST15 model. In both cases, the model is compiled with Adam optimizer and Categorical Crossentropy loss.

Parameters:

Name Type Description Default
pretrained bool

If True, loads CST15 pretrained on ImageNet. If False, builds a new CST15 model from scratch.

True

Returns:

Name Type Description
model KerasModel

The CST15 model instance.

Source code in src/cstmodels/models.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def CST15(pretrained=True):
    """
    Loads or builds the CST15 model. In both cases, the model is compiled
    with Adam optimizer and Categorical Crossentropy loss.

    Args:
        pretrained (bool): If `True`, loads CST15 pretrained on ImageNet.
            If `False`, builds a new CST15 model from scratch.

    Returns:
        model (KerasModel): The CST15 model instance.
    """
    if not isinstance(pretrained, bool):
        raise ValueError("Pretrained must be a boolean value")

    if pretrained:
        # Load the pretrained model
        path = keras.utils.get_file('CST15.keras', origin=CST15_URL)
        model = keras.saving.load_model(path)
        return model

    # Build a new model
    model = _build_CST15()
    model.compile(
        optimizer='adam',
        loss='CategoricalCrossentropy',
        metrics=['accuracy']
    )

    return model