- CVPR 2019
- arXiv https://arxiv.org/abs/1903.04197
- The need of neural netoworks with small model size, light computation cost and high segmentation accuracy for applications on mobile devices.
- Deep neural networks have achieved significant improvement in segmentation accuracy.
- Knowledge distillation has been verified valid in classification tasks.
- Knowledge distillation for accurate compact semantic segmentation network training.
- Two structured knowledge distillation strategies for semantic segmentation, pair-wise distillation and holistic distillation.
- Experimental validation on three datasets with different netowrk configurations.
T for teacher network (PSPNet with ResNet101), S for Student network (ResNet18, MobileNetV2Plus, ESPNet-C, and ESPNet).
- Overall architecture
Consists of three parts: (a) Pair-wise distillation; (b) Pixel-wise distillation; (c) Holistic distillation.
- Structured knowledge distillation
- Pixel-wise distillation
Treat segmentation task as a collection of separate pixel classification problem. Direct align the class probability of each pixel produced by S with that produced by T.
Pixel-wise distillation loss:
qis is the class probabilities from S. qit is the class probabilities from T. KL() is the Kullback-Leibler divergence. R refers to all the pixels.
- Pair-wise distillation
Pair-wise distillation is inspired by the pair-wise Markov random field framework. Instead of the pixel probabilities, pair-wise similarities among pixels are transfered.
Similarity between two pixels is defined as:
aijs refers to the similarity between the ith pixel and jth pixel produced by S. aijt refers to the similarity between the ith pixel and jth pixel produced by T.
In implementation, the similarity between two pixels is simplified as:
fi and fj are two feature maps.
- Holistic distillation
Conditional generative adversarial learning is employed. S is treated as the generator conditioned on the input image I. The segmentation map Qs is a fake sample. Qt is regarded as a real sample. Qs needs to be as similar as possible to Qt. Wasserstein distance is used to evaluate the difference between Qs and Qt.
E is the expectation operator. D() is an embedding network including five convolutions with two self-attention modules inserted between the final three layers. Qs and I are concatenated and input into D().
- Pixel-wise distillation
- Optimization and training
Overall loss function:
Optimization in two steps:
- Train the discriminator by minimizing lho(S,D)
- Train the segmentation network S by minimizing
- Effectiveness of the three distillation strategies
- A new strategy to incorporate GAN into segmentation.Directly enforce the alignment between the segmentation map and the ground truth may limite the success of the discriminator of GAN as there is mismatch between the generator’s continuous output and the discrete true labels. Here, the segmentation map is compared with the continuous output of the teacher network.
- The pixel-wise and holistic distillations were applied on the final score maps and the pair-wise distillation was applied on the feature maps of the last layer. For the comparison, it seems that attention transfer was only applied on the score maps. The full capacity of attention transfer may not be utilized.