Segmentation Models Python API

Getting started with segmentation models is easy.

Unet

segmentation_models.Unet(backbone_name='vgg16', input_shape=(None, None, 3), classes=1, activation='sigmoid', encoder_weights='imagenet', encoder_freeze=False, encoder_features='default', decoder_block_type='upsampling', decoder_filters=(256, 128, 64, 32, 16), decoder_use_batchnorm=True, **kwargs)

Unet is a fully convolution neural network for image semantic segmentation

Parameters:
  • backbone_name – name of classification model (without last dense layers) used as feature extractor to build segmentation model.
  • input_shape – shape of input data/image (H, W, C), in general case you do not need to set H and W shapes, just pass (None, None, C) to make your model be able to process images af any size, but H and W of input images should be divisible by factor 32.
  • classes – a number of classes for output (output shape - (h, w, classes)).
  • activation – name of one of keras.activations for last model layer (e.g. sigmoid, softmax, linear).
  • encoder_weights – one of None (random initialization), imagenet (pre-training on ImageNet).
  • encoder_freeze – if True set all layers of encoder (backbone model) as non-trainable.
  • encoder_features – a list of layer numbers or names starting from top of the model. Each of these layers will be concatenated with corresponding decoder block. If default is used layer names are taken from DEFAULT_SKIP_CONNECTIONS.
  • decoder_block_type

    one of blocks with following layers structure:

    • upsampling: Upsampling2D -> Conv2D -> Conv2D
    • transpose: Transpose2D -> Conv2D
  • decoder_filters – list of numbers of Conv2D layer filters in decoder blocks
  • decoder_use_batchnorm – if True, BatchNormalisation layer between Conv2D and Activation layers is used.
Returns:

Unet

Return type:

keras.models.Model

Linknet

segmentation_models.Linknet(backbone_name='vgg16', input_shape=(None, None, 3), classes=1, activation='sigmoid', encoder_weights='imagenet', encoder_freeze=False, encoder_features='default', decoder_filters=(None, None, None, None, 16), decoder_use_batchnorm=True, decoder_block_type='upsampling', **kwargs)

Linknet is a fully convolution neural network for fast image semantic segmentation

Note

This implementation by default has 4 skip connections (original - 3).

Parameters:
  • backbone_name – name of classification model (without last dense layers) used as feature extractor to build segmentation model.
  • input_shape – shape of input data/image (H, W, C), in general case you do not need to set H and W shapes, just pass (None, None, C) to make your model be able to process images af any size, but H and W of input images should be divisible by factor 32.
  • classes – a number of classes for output (output shape - (h, w, classes)).
  • activation – name of one of keras.activations for last model layer (e.g. sigmoid, softmax, linear).
  • encoder_weights – one of None (random initialization), imagenet (pre-training on ImageNet).
  • encoder_freeze – if True set all layers of encoder (backbone model) as non-trainable.
  • encoder_features – a list of layer numbers or names starting from top of the model. Each of these layers will be concatenated with corresponding decoder block. If default is used layer names are taken from DEFAULT_SKIP_CONNECTIONS.
  • decoder_filters – list of numbers of Conv2D layer filters in decoder blocks, for block with skip connection a number of filters is equal to number of filters in corresponding encoder block (estimates automatically and can be passed as None value).
  • decoder_use_batchnorm – if True, BatchNormalisation layer between Conv2D and Activation layers is used.
  • decoder_block_type – one of - upsampling: use Upsampling2D keras layer - transpose: use Transpose2D keras layer
Returns:

Linknet

Return type:

keras.models.Model

FPN

segmentation_models.FPN(backbone_name='vgg16', input_shape=(None, None, 3), input_tensor=None, classes=21, activation='softmax', encoder_weights='imagenet', encoder_freeze=False, encoder_features='default', pyramid_block_filters=256, pyramid_use_batchnorm=True, pyramid_dropout=None, final_interpolation='bilinear', **kwargs)

FPN is a fully convolution neural network for image semantic segmentation

Parameters:
  • backbone_name – name of classification model (without last dense layers) used as feature extractor to build segmentation model.
  • input_shape – shape of input data/image (H, W, C), in general case you do not need to set H and W shapes, just pass (None, None, C) to make your model be able to process images af any size, but H and W of input images should be divisible by factor 32.
  • input_tensor – optional Keras tensor (i.e. output of layers.Input()) to use as image input for the model (works only if encoder_weights is None).
  • classes – a number of classes for output (output shape - (h, w, classes)).
  • activation – name of one of keras.activations for last model layer (e.g. sigmoid, softmax, linear).
  • encoder_weights – one of None (random initialization), imagenet (pre-training on ImageNet).
  • encoder_freeze – if True set all layers of encoder (backbone model) as non-trainable.
  • encoder_features – a list of layer numbers or names starting from top of the model. Each of these layers will be used to build features pyramid. If default is used layer names are taken from DEFAULT_FEATURE_PYRAMID_LAYERS.
  • pyramid_block_filters – a number of filters in Feature Pyramid Block of FPN.
  • pyramid_use_batchnorm – if True, BatchNormalisation layer between Conv2D and Activation layers is used.
  • pyramid_dropout – spatial dropout rate for feature pyramid in range (0, 1).
  • final_interpolation – interpolation type for upsampling layers, on of nearest, bilinear.
Returns:

FPN

Return type:

keras.models.Model

PSPNet

segmentation_models.PSPNet(backbone_name='vgg16', input_shape=(384, 384, 3), classes=21, activation='softmax', encoder_weights='imagenet', encoder_freeze=False, downsample_factor=8, psp_conv_filters=512, psp_pooling_type='avg', psp_use_batchnorm=True, psp_dropout=None, final_interpolation='bilinear', **kwargs)

PSPNet is a fully convolution neural network for image semantic segmentation

Parameters:
  • backbone_name – name of classification model used as feature extractor to build segmentation model.
  • input_shape – shape of input data/image (H, W, C). H and W should be divisible by 6 * downsample_factor and NOT None!
  • classes – a number of classes for output (output shape - (h, w, classes)).
  • activation – name of one of keras.activations for last model layer (e.g. sigmoid, softmax, linear).
  • encoder_weights – one of None (random initialization), imagenet (pre-training on ImageNet).
  • encoder_freeze – if True set all layers of encoder (backbone model) as non-trainable.
  • downsample_factor – one of 4, 8 and 16. Downsampling rate or in other words backbone depth to construct PSP module on it.
  • psp_conv_filters – number of filters in Conv2D layer in each PSP block.
  • psp_pooling_type – one of ‘avg’, ‘max’. PSP block pooling type (maximum or average).
  • psp_use_batchnorm – if True, BatchNormalisation layer between Conv2D and Activation layers is used.
  • psp_dropout – dropout rate between 0 and 1.
  • final_interpolationduc or bilinear - interpolation type for final upsampling layer.
Returns:

PSPNet

Return type:

keras.models.Model

metrics

segmentation_models.metrics.iou_score(gt, pr, class_weights=1.0, smooth=1e-12, per_image=True)

The Jaccard index, also known as Intersection over Union and the Jaccard similarity coefficient (originally coined coefficient de communauté by Paul Jaccard), is a statistic used for comparing the similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets, and is defined as the size of the intersection divided by the size of the union of the sample sets:

\[J(A, B) = \frac{A \cap B}{A \cup B}\]
Parameters:
  • gt – ground truth 4D keras tensor (B, H, W, C)
  • pr – prediction 4D keras tensor (B, H, W, C)
  • class_weights
    1. or list of class weights, len(weights) = C
  • smooth – value to avoid division by zero
  • per_image – if True, metric is calculated as mean over images in batch (B), else over whole batch
Returns:

IoU/Jaccard score in range [0, 1]

segmentation_models.metrics.f_score(gt, pr, class_weights=1, beta=1, smooth=1e-12, per_image=True)

The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall, where an F-score reaches its best value at 1 and worst score at 0. The relative contribution of precision and recall to the F1-score are equal. The formula for the F score is:

\[F_\beta(precision, recall) = (1 + \beta^2) \frac{precision \cdot recall} {\beta^2 \cdot precision + recall}\]

The formula in terms of Type I and Type II errors:

\[F_\beta(A, B) = \frac{(1 + \beta^2) TP} {(1 + \beta^2) TP + \beta^2 FN + FP}\]
where:
TP - true positive; FP - false positive; FN - false negative;
Parameters:
  • gt – ground truth 4D keras tensor (B, H, W, C)
  • pr – prediction 4D keras tensor (B, H, W, C)
  • class_weights
    1. or list of class weights, len(weights) = C
  • beta – f-score coefficient
  • smooth – value to avoid division by zero
  • per_image – if True, metric is calculated as mean over images in batch (B), else over whole batch
Returns:

F-score in range [0, 1]

losses

segmentation_models.losses.jaccard_loss(gt, pr, class_weights=1.0, smooth=1e-12, per_image=True)

Jaccard loss function for imbalanced datasets:

\[L(A, B) = 1 - \frac{A \cap B}{A \cup B}\]
Parameters:
  • gt – ground truth 4D keras tensor (B, H, W, C)
  • pr – prediction 4D keras tensor (B, H, W, C)
  • class_weights
    1. or list of class weights, len(weights) = C
  • smooth – value to avoid division by zero
  • per_image – if True, metric is calculated as mean over images in batch (B), else over whole batch
Returns:

Jaccard loss in range [0, 1]

segmentation_models.losses.dice_loss(gt, pr, class_weights=1.0, smooth=1e-12, per_image=True)

Dice loss function for imbalanced datasets:

\[L(precision, recall) = 1 - (1 + \beta^2) \frac{precision \cdot recall} {\beta^2 \cdot precision + recall}\]
Parameters:
  • gt – ground truth 4D keras tensor (B, H, W, C)
  • pr – prediction 4D keras tensor (B, H, W, C)
  • class_weights
    1. or list of class weights, len(weights) = C
  • smooth – value to avoid division by zero
  • per_image – if True, metric is calculated as mean over images in batch (B), else over whole batch
Returns:

Dice loss in range [0, 1]

utils

segmentation_models.backbones.get_preprocessing(name)
segmentation_models.utils.set_trainable(model)

Set all layers of model trainable and recompile it

Note

Model is recompiled using same optimizer, loss and metrics:

model.compile(model.optimizer, model.loss, model.metrics)
Parameters:model (keras.models.Model) – instance of keras model