Shortcuts

Source code for eisen.models.segmentation.vnet

from torch import nn

# ATTRIBUTION: this implementation has been obtained from https://github.com/JunMa11/SegWithDistMap
# We thank the authors for open sourcing this implementation. Furthermore we thank Lequan Yu

# This code was originally distributed under Apache 2 license. The terms of this license can be found here.
# https://github.com/JunMa11/SegWithDistMap/blob/master/LICENSE
# This implementation is also closely related to related Lequan Yu's https://github.com/yulequan/UA-MT

# This code has been adapted to work within Eisen, and therefore modified without changes of its functionality
# the interface of the constructor has been slightly modified. Features related to dropout have been removed.
# Normalization schemes include groupnorm, instancenorm, batchnorm, and none.


class ConvBlock(nn.Module):
    def __init__(self, n_stages, n_filters_in, n_filters_out, filter_size, normalization="none"):
        super(ConvBlock, self).__init__()

        ops = []
        for i in range(n_stages):
            if i == 0:
                input_channel = n_filters_in
            else:
                input_channel = n_filters_out

            ops.append(nn.Conv3d(input_channel, n_filters_out, filter_size, padding=1))
            if normalization == "batchnorm":
                ops.append(nn.BatchNorm3d(n_filters_out))
            elif normalization == "groupnorm":
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
            elif normalization == "instancenorm":
                ops.append(nn.InstanceNorm3d(n_filters_out))
            elif normalization != "none":
                assert False
            ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x


class ResidualConvBlock(nn.Module):
    def __init__(self, n_stages, n_filters_in, n_filters_out, filter_size, normalization="none"):
        super(ResidualConvBlock, self).__init__()

        ops = []
        for i in range(n_stages):
            if i == 0:
                input_channel = n_filters_in
            else:
                input_channel = n_filters_out

            ops.append(nn.Conv3d(input_channel, n_filters_out, filter_size, padding=1))
            if normalization == "batchnorm":
                ops.append(nn.BatchNorm3d(n_filters_out))
            elif normalization == "groupnorm":
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
            elif normalization == "instancenorm":
                ops.append(nn.InstanceNorm3d(n_filters_out))
            elif normalization != "none":
                assert False

            if i != n_stages - 1:
                ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x) + x
        x = self.relu(x)
        return x


class DownsamplingConvBlock(nn.Module):
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization="none"):
        super(DownsamplingConvBlock, self).__init__()

        ops = []
        if normalization != "none":
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
            if normalization == "batchnorm":
                ops.append(nn.BatchNorm3d(n_filters_out))
            elif normalization == "groupnorm":
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
            elif normalization == "instancenorm":
                ops.append(nn.InstanceNorm3d(n_filters_out))
            else:
                assert False
        else:
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))

        ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x


class UpsamplingDeconvBlock(nn.Module):
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization="none"):
        super(UpsamplingDeconvBlock, self).__init__()

        ops = []
        if normalization != "none":
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
            if normalization == "batchnorm":
                ops.append(nn.BatchNorm3d(n_filters_out))
            elif normalization == "groupnorm":
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
            elif normalization == "instancenorm":
                ops.append(nn.InstanceNorm3d(n_filters_out))
            else:
                assert False
        else:
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))

        ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x


class Upsampling(nn.Module):
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization="none"):
        super(Upsampling, self).__init__()

        ops = []
        ops.append(nn.Upsample(scale_factor=stride, mode="trilinear", align_corners=False))
        ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
        if normalization == "batchnorm":
            ops.append(nn.BatchNorm3d(n_filters_out))
        elif normalization == "groupnorm":
            ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
        elif normalization == "instancenorm":
            ops.append(nn.InstanceNorm3d(n_filters_out))
        elif normalization != "none":
            assert False
        ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x


[docs]class VNet(nn.Module):
[docs] def __init__( self, input_channels=3, output_channels=2, n_filters=16, filter_size=3, normalization="none", outputs_activation="sigmoid", ): """ :param input_channels: number of input channels :type input_channels: int :param output_channels: number of output channels :type output_channels: int :param n_filters: number of filters :type n_filters: int :param filter_size: spatial size of the filters :type filter_size: int :param normalization: normalization either groupnorm, batchnorm, instancenorm or none :type normalization: str :param outputs_activation: output activation. either sigmoid, softmax or none :type outputs_activation: str <json> [ {"name": "input_names", "type": "list:string", "value": "['images']"}, {"name": "output_names", "type": "list:string", "value": "['output']"}, {"name": "input_channels", "type": "int", "value": ""}, {"name": "output_channels", "type": "int", "value": ""}, {"name": "n_filters", "type": "int", "value": "16"}, {"name": "filter_size", "type": "int", "value": "3"}, {"name": "normalization", "type": "string", "value": ["groupnorm", "batchnorm", "instancenorm", "none"]}, {"name": "outputs_activation", "type": "string", "value": ["sigmoid", "softmax", "none"]} ] </json> """ super(VNet, self).__init__() self.block_one = ConvBlock(1, input_channels, n_filters, filter_size, normalization=normalization) self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, filter_size, normalization=normalization) self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, filter_size, normalization=normalization) self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, filter_size, normalization=normalization) self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, filter_size, normalization=normalization) self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, filter_size, normalization=normalization) self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, filter_size, normalization=normalization) self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, filter_size, normalization=normalization) self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) self.block_nine = ConvBlock(1, n_filters, n_filters, filter_size, normalization=normalization) self.out_conv = nn.Conv3d(n_filters, output_channels, 1, padding=0) if outputs_activation == "sigmoid": self.activation = nn.Sigmoid() elif outputs_activation == "softmax": self.activation = nn.Softmax() elif outputs_activation == "none": self.activation = nn.Identity()
def encoder(self, input): x1 = self.block_one(input) x1_dw = self.block_one_dw(x1) x2 = self.block_two(x1_dw) x2_dw = self.block_two_dw(x2) x3 = self.block_three(x2_dw) x3_dw = self.block_three_dw(x3) x4 = self.block_four(x3_dw) x4_dw = self.block_four_dw(x4) x5 = self.block_five(x4_dw) res = [x1, x2, x3, x4, x5] return res def decoder(self, features): x1 = features[0] x2 = features[1] x3 = features[2] x4 = features[3] x5 = features[4] x5_up = self.block_five_up(x5) x5_up = x5_up + x4 x6 = self.block_six(x5_up) x6_up = self.block_six_up(x6) x6_up = x6_up + x3 x7 = self.block_seven(x6_up) x7_up = self.block_seven_up(x7) x7_up = x7_up + x2 x8 = self.block_eight(x7_up) x8_up = self.block_eight_up(x8) x8_up = x8_up + x1 x9 = self.block_nine(x8_up) out = self.out_conv(x9) return out
[docs] def forward(self, x): features = self.encoder(x) outputs = self.decoder(features) ret = self.activation(outputs) return ret

Docs

Access comprehensive developer documentation for Eisen

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources