The main features of this library are:
- High level API (just two lines to create NN)
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
- 25 available backbones for each architecture
- All backbones have pre-trained weights for faster and better convergence
Since the library is built on the Keras framework, created segmentation model is just a Keras Model, which can be created as easy as:
from segmentation_models import Unet model = Unet()
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
model = Unet('resnet34', encoder_weights='imagenet')
Change number of output classes in the model:
model = Unet('resnet34', classes=3, activation='softmax')
Change input shape of the model:
model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)
Simple training pipeline¶
from segmentation_models import Unet from segmentation_models import get_preprocessing from segmentation_models.losses import bce_jaccard_loss from segmentation_models.metrics import iou_score BACKBONE = 'resnet34' preprocess_input = get_preprocessing(BACKBONE) # load your data x_train, y_train, x_val, y_val = load_data(...) # preprocess input x_train = preprocess_input(x_train) x_val = preprocess_input(x_val) # define model model = Unet(BACKBONE, encoder_weights='imagenet') model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score]) # fit model model.fit( x=x_train, y=y_train, batch_size=16, epochs=100, validation_data=(x_val, y_val), )
Same manimulations can be done with
FPN. For more detailed information about models API and use cases Read the Docs.
Models and Backbones¶
All backbones have weights trained on 2012 ILSVRC ImageNet dataset (
Some times, it is useful to train only randomly initialized
decoder in order not to damage weights of properly trained
encoder with huge gradients during first steps of training.
In this case, all you need is just pass
encoder_freeze = True argument
while initializing the model.
from segmentation_models import Unet from segmentation_models.utils import set_trainable model = Unet(backbone_name='resnet34', encoder_weights='imagenet', encoder_freeze=True) model.compile('Adam', 'binary_crossentropy', ['binary_accuracy']) # pretrain model decoder model.fit(x, y, epochs=2) # release all layers for training set_trainable(model) # set all layers trainable and recompile model # continue training model.fit(x, y, epochs=100)
Training with non-RGB data¶
In case you have non RGB images (e.g. grayscale or some medical/remote sensing data) you have few different options:
- Train network from scratch with randomly initialized weights
from segmentation_models import Unet # read/scale/preprocess data x, y = ... # define number of channels N = x.shape[-1] # define model model = Unet(backbone_name='resnet34', encoder_weights=None, input_shape=(None, None, N)) # continue with usual steps: compile, fit, etc..
- Add extra convolution layer to map
N -> 3channels data and train with pretrained weights
from segmentation_models import Unet from keras.layers import Input, Conv2D from keras.models import Model # read/scale/preprocess data x, y = ... # define number of channels N = x.shape[-1] base_model = Unet(backbone_name='resnet34', encoder_weights='imagenet') inp = Input(shape=(None, None, N)) l1 = Conv2D(3, (1, 1))(inp) # map N channels data to 3 channels out = base_model(l1) model = Model(inp, out, name=base_model.name) # continue with usual steps: compile, fit, etc..