Segmentation models is python library with Neural Networks for Image Segmentation based on Keras (Tensorflow) framework.

The main features of this library are:

  • High level API (just two lines to create NN)
  • 4 models architectures for binary and multi class segmentation (including legendary Unet)
  • 25 available backbones for each architecture
  • All backbones have pre-trained weights for faster and better convergence

Quick start

Since the library is built on the Keras framework, created segmentation model is just a Keras Model, which can be created as easy as:

from segmentation_models import Unet

model = Unet()

Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:

model = Unet('resnet34', encoder_weights='imagenet')

Change number of output classes in the model:

model = Unet('resnet34', classes=3, activation='softmax')

Change input shape of the model:

model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)

Simple training pipeline

from segmentation_models import Unet
from segmentation_models import get_preprocessing
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score

BACKBONE = 'resnet34'
preprocess_input = get_preprocessing(BACKBONE)

# load your data
x_train, y_train, x_val, y_val = load_data(...)

# preprocess input
x_train = preprocess_input(x_train)
x_val = preprocess_input(x_val)

# define model
model = Unet(BACKBONE, encoder_weights='imagenet')
model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score])

# fit model
    validation_data=(x_val, y_val),

Same manimulations can be done with Linknet, PSPNet and FPN. For more detailed information about models API and use cases Read the Docs.

Models and Backbones


Unet Linknet
unet_image linknet_image
psp_image fpn_image


All backbones have weights trained on 2012 ILSVRC ImageNet dataset (encoder_weights='imagenet').

Fine tuning

Some times, it is useful to train only randomly initialized decoder in order not to damage weights of properly trained encoder with huge gradients during first steps of training. In this case, all you need is just pass encoder_freeze = True argument while initializing the model.

from segmentation_models import Unet
from segmentation_models.utils import set_trainable

model = Unet(backbone_name='resnet34', encoder_weights='imagenet', encoder_freeze=True)
model.compile('Adam', 'binary_crossentropy', ['binary_accuracy'])

# pretrain model decoder
model.fit(x, y, epochs=2)

# release all layers for training
set_trainable(model) # set all layers trainable and recompile model

# continue training
model.fit(x, y, epochs=100)

Training with non-RGB data

In case you have non RGB images (e.g. grayscale or some medical/remote sensing data) you have few different options:

  1. Train network from scratch with randomly initialized weights
from segmentation_models import Unet

# read/scale/preprocess data
x, y = ...

# define number of channels
N = x.shape[-1]

# define model
model = Unet(backbone_name='resnet34', encoder_weights=None, input_shape=(None, None, N))

# continue with usual steps: compile, fit, etc..
  1. Add extra convolution layer to map N -> 3 channels data and train with pretrained weights
from segmentation_models import Unet
from keras.layers import Input, Conv2D
from keras.models import Model

# read/scale/preprocess data
x, y = ...

# define number of channels
N = x.shape[-1]

base_model = Unet(backbone_name='resnet34', encoder_weights='imagenet')

inp = Input(shape=(None, None, N))
l1 = Conv2D(3, (1, 1))(inp) # map N channels data to 3 channels
out = base_model(l1)

model = Model(inp, out, name=base_model.name)

# continue with usual steps: compile, fit, etc..