Automatic Postoperative Brain Tumor Segmentation with Limited Data using Transfer Learning and Triplet Attention

Accurate brain tumor segmentation is clinically important for diagnosis and treatment planning. Convolutional neural networks (CNNs) have achieved promising performance in various visual recognition tasks. Training such networks usually requires large amount of labeled data, which is often challenging for medical applications. In this work, we address the segmentation problem by applying transfer learning to downstream segmentation tasks. Speciﬁcally, we explore how knowledge acquired from a large preoperative dataset can be transferred to postoperative tumor segmentation on a smaller dataset. To this end, we have developed a 3D CNN for brain tumor segmentation, and ﬁne-tuned the pretrained models on the target domain data. To better exploit the inter-channel and spatial information, triplet attention has been incorporated and extended into existing segmentation network. Extensive experiments on our dataset demonstrate the eﬀectiveness of transfer learning and attention modules for improved post-operative tumor segmentation performance when only limited amount of annotated data is available.


Introduction
Glioblastoma is the most aggressive brain tumor and is commonly treated with surgery and chemoradiotherapy [5,7,19].Accurate diagnosis and segmentation of glioblastoma is essential for treatment planning and postoperative analysis.Magnetic resonance imaging (MRI) provides high soft tissue contrast and is the modality of choice for structural brain analysis [1,4].Manual tumor segmentation is challenging and time-consuming due to complex tumor structure and high interrater variability [6,7,19,20], making automatic segmentation methods increasingly popular.Deep neural networks such as Convolutional Neural networks (CNNs) have achieved state-of-the-art performance in a range of vision recognition tasks, showing great potential for improved brain tumor segmentation performance.However, training such neural networks usually requires large amount of labeled data, which limits its application in medical imaging field.
Transfer learning leverages knowledge gained in a source domain to improve learning in a target domain without training the network from scratch.Zoetmulder et al. [21] assessed the transfer learning performance on multiple medical segmentation tasks by investigating various combinations of domains and tasks.Wacker et al. [19] applied fullyconvolutional networks with pretrained encoders on ImageNet dataset for the task of brain tumor segmentation and achieved improved and more robust segmentation results.Ghaffari et al. [6] developed a 3D densely-connected U-net and transfer the knowledge acquired on preoperative brain tumor dataset to a target postoperative dataset.
Recent advances in deep learning, notably attention mechanism, have been shown to achieve performance gain in a variety of tasks, due to its capability of learning more discriminative representations by leveraging the semantic correlations among image regions.One of the most promising method is squeeze-and-excite networks (SENet) [9].Triplet attention [14], a successor of SENet, learns 1 more robust representation by exploiting the interdependencies among channels and spatial locations and provides computationally affordable and effective performance gains.
In this paper, we develop a popular 3D U-net variant architecture for brain tumor segmentation, and perform extensive experiments on various finetuning strategies on the target domain.Concretely, the network is first trained a large public preoperative dataset, the Brain Tumor Segmentation Challenge (BraTS) [13,2,3,12], and then fine-tuned the model using transfer learning on an in-house target dataset of postoperative gliomas.Inspired by the success of attention-based methods, we propose to add the triplet attention module to the baseline network, introducing an additional branch of attention and adapting the attention module in the ResNet [8] backbone to fit into our network.

Network Architecture
We use a 3D U-Net variant as our baseline network, which follows an encoder-decoder architecture with asymmetrically larger encoding pathways [16,15].The encoder structure comprises four stages of ResNet blocks.Each block consists of two convolutions with Instance Normalization [17] and Rectified Linear Unit (ReLU) activation, followed by additive identity skip connection.The input patches are progressively downsampled by convolutions with stride of 2. Each decoder level consists of a single ResNet block, and transpose convolutions with stride of 2 are used to double the spatial dimension and reduce the number of features.The endpoint of the decoder has the same spatial dimension as the input image with its channel dimension reduced to three after a 1 × 1 × 1 convolution, followed by a sigmoid function.The network architecture is shown in Figure 1.

Loss
The Dice Similarity Coefficient (DSC) is used to measure the similarity between the predicted segmentation P and ground truth G: The associated soft Dice loss can be expressed as: where p i and q i are the predicted probability and ground truth label of the i-th voxel, respectively.

Transfer learning
Training a deep neural network from scratch with a small labeled dataset is challenging.The model is often not able to learn meaningful information when the training dataset is small, a problem called overfitting.One popular strategy to cope with limited data size and reduce overfitting is to apply transfer learning.The model is first trained on a large dataset, which allows the model to learn more general-purposed features.When directly apply the trained model on the local data, it might not work very well as the model weights are optimized based on the large dataset.The trained model is then tuned on the local data to adjust its weights to better fit the new training data.The fine-tuning step can stabilize the training process and improve its predictive performance where the model is less prone to overfit to the local data.
During transfer learning, we usually prevent certain part of the network from being trained so that the learned weights can be reused, where the rest of the network is trained as usual.The frozen part typically corresponds to the first shallow layers of the neural network, which tend to capture low-level features that are shared across domains, while deeper layers can learn high-level features that are more task-specific [18].We only retrain those deep layers, thus avoiding losing those shared low-level information learned in the shallow layers.In other words, transfer learning adapts the pretrained model to the target domain, retaining shared information and fine-tuning domain-specific knowledge.More specifically, we partially transfer the pretrained model weights to the target domain by freezing the weights of initial shallow layers and fine-tuning the remaining layers.

Triplet attention
The triplet attention module was initially applied to ResNet backbone networks for classification and object detection [14].We adapt the attention module to our 3D segmentation network by adding one additional branch capturing cross-dimensional interaction between channel and depth dimensions, as illustrated in Figure 2. We also replace the batch normalization with instance normalization in accordance with the baseline network.
The input feature x is passed to four branches, where three of them are responsible for computing attention weights across channel C and spatial dimensions H, W and D, and the last one captures spatial dependencies (H, W and D).In each channel-spatial branch, the feature is passed through a rotation operation followed by a residual transformation block, which consists of a Z-pool layer shrinking the depth dimension and a convolution layer, and the feature is rotated back afterwards to retain the same shape as input x.The last branch performs similar residual transformation, where rotation is not involved.The resulting outputs of all branches are averaged to generate the refined output of the attention module x.The process can be represented as follows: where ω 1 , ω 2 , ω 3 and ω 4 are the cross-dimensional attention weights, x1 , x2 and x3 are the rotated and Z-pooled features.Following the original paper, the triplet attention module was appended to the bottleneck of the encoder-decoder network.

Source Dataset
The publicly available BraTS dataset provides 3D multimodal MRI data with ground truth segmentation annotated by domain experts.The BraTS 2020 dataset comprises multi-parametric MRI (mp-MRI) scans with 369 cases for training and 125 cases for validation.The MRI scans were collected with different clinical protocols and from multiple institutions.Each MRI scan contains four modalities: native T1-weighted (T1), post-contrast T1weighted (T1ce), T2-weighted (T2), and T2 Fluid Attenuated Inversion Recovery (FLAIR).Each 3D volume is skull-stripped, rigidly co-registered, and resampled to 1 mm 3 isotropic voxel resolution.Three tumor sub-regions were manually annotated by one to four raters: the Gd-enhancing tumor, peritumoral edema, and necrotic and nonenhancing tumor core.The annotations were combined into overlapping sub-regions for evaluation: enhancing tumor (ET), tumor core (TC), and whole tumor (WT).

Target Dataset
The target dataset is a subset of a glioblatoma study ongoing at our institution and consists of postoperative longitudinal brain MRI from 13 glioblastoma patients, with one to three time points per patient (see Table 1).The dataset was preprocessed similar to [12], including converting DICOM images to NIfTI format, skull-stripping using [10], and resampling to 1mm×1mm×1mm voxel resolution, rendering an image size of 256×256×190.The same segmentation labels as in BraTS were generated.Here, an initial automated segmentation algorithm was applied to a larger dataset whereby a subset of this (n = 13 patients) was selected based on the review and segmentation quality of enhancing tumor and edema rating by a radiologist.Necrotic tumor was not specifically qualified by radiologist.

Experimental Settings
All experiments were implemented in PyTorch, and the network was trained on NVIDIA A100 GPUs.We used five-fold cross-validation (CV) on 80% of the whole data and the remaining 20% as test set.All partitions were performed randomly at patientlevel.In each experiment, there were roughly 8 patients in the training set, 2 in the validation set, and 3 in the test set.

Preprocessing
For both pretraining and fine-tuning, we used the same preprocessing steps.We randomly cropped the MRI images to a fixed size of 192×192×144 and concatenated the four MRI modalities of each patient into a four-channel input.We independently normalized each channel by subtracting the mean intensity and dividing by the standard deviation of intensities within the brain region.Data augmentation techniques have shown to effectively reduce overfitting.We randomly flipped each spatial axis with a probability of 0.5 and applied a random intensity shift within [−0.1, 0.1] of the standard deviation of each input channel, followed by a random intensity scaling in the range of [0.9, 1.1].

Training
In both training cases, the Adam optimizer and a polynomial learning rate decay schedule, L2 regu-larization with a decay rate of 1e−5 were used.A batch size of 1 was used to compromise the large crop size, and the maximal number of epochs was set to 300.The learning rate were individually set to 1e−4 and 1e−5 for pretraining and fine-tuning.
The model with best performance on the validation set was chosen.In our experiments, we found that freezing the first layer (level) achieved the best performance, which was discussed in the Appendix.

Evaluation and Comparisons
All experiments were evaluated as the average Dice score from the 5-fold cross validation results.To evaluate the segmentation performance using transfer learning, we trained the baseline network on the BraTS 2020 dataset, and tested its fine-tuned model on the target postoperative dataset.We then compared the performance of the baseline model trained on the target set from scratch with the fine-tuning approach on the pretrained model.

Results and Discussion
Table 2 presents the results of the baseline network trained and evaluated on BraTS 2020 (validation) dataset, showing performance close to the state-ofthe-art in the BraTS challenge.The competitive results for the related preoperative brain tumor segmentation task convinced us that our baseline network is a good choice for the evaluation of the postoperative segmentation task.Table 3 compares the average Dice scores between the baseline and transfer learning, where fine-tuning achieves the highest value of 0.8018, with an absolute improvement of 0.04 from the baseline.It is observed that the Dices scores of tumor core increase most, compared to whole tumor and enhancing tumor, which It also shows that the Dices scores of the whole tumor increase most, which can be explained by the fact that the task of segmenting whole tumor is fairly easy in both domains and therefore the adaptation of the pretrained model weights to the target domain is more straightforward.By incorporating the triplet attention into the baseline network, further performance gain is obtained.An example case of the segmentation result is provided in Figure 3.

Conclusion
Accurate segmentation of the different pathological components in postoperative glioblastomas from MRI is clinically important yet technically challenging.Starting with a network that achieves close to state-of-the-art performance in the preoperative brain tumor segmentation challenge, we here show that with limited labeled data transfer learning from pretrained model can be applied to improve segmentation performance in postoperative MRIs.
We add an attention module, triplet attention, to a 3D segmentation network.Triplet attention captures more descriminative features across channel and spatial dimensions, thus can enhance the model accuracy.
To note that transfer learning and triplet attention were independently experimented.For future work, we plan to integrate the two techniques and evaluate how attention mechanisms and transfer learning could potentially benefit each other.

Appendix
We experimented on the number of shallow layers to be frozen.As illustrated in Figure 4, fine-tuning with freezing the first layer achieved the highest Dice score on the target dataset.Also to note that the performance were very close when freezing 0, 1, and 2 layers.

Figure 1 :
Figure 1: Schematic visualization of the baseline network architecture.

Figure 2 :
Figure 2: Illustration of the triplet attention module.The first three branches compute the attention weights between the channel dimension C and each of the three spatial dimensions (H, W, and D).The last branch captures the spatial dependencies (H, W and D).The output is obtained by averaging the computed weights from each branch.The attention module is added to the bottleneck of the encoderencoder network.

Figure 3 :
Figure 3: Visualization of the segmentation results of two sample cases.From left to right: FLAIR, T1, T1ce, ground truth overlaid on T1ce, predicted segmentation overlaid on T1ce.

Figure 4 :
Figure 4: Average Dice scores corresponding to freezing different number of layers.In total, there are seven levels of convolutional layers in additional to the output endpoint.

Table 1 :
Statistics of the target postoperative dataset.

Table 2 :
The baseline network evaluated on BraTS 2020 validation set, in comparison with the state-ofthe-art results, as well as the ensemble of best performing variant models.All results were provided by the BraTS evaluation platform.

Table 3 :
Quantitative results of the different methods on the target dataset.Pretrained model is trained on BraTS data.Baseline is trained from scratch on the target dataset.Transfer Learning (TL) fine-tunes the pretrained network with its first layer being frozen.Triplet attention (TA) is trained with added triplet attention to existing network.Note that TL and TA are experimented independently.