Quantcast
Channel: Recent Questions - Stack Overflow
Viewing all articles
Browse latest Browse all 22544

Training custom unet model

$
0
0

I'm working on segmentation project. when I tried both custom unet or pretrained full convolution network from pytorch I got theses strange results whatever what are the hyperparameters I use but when I used deeplab pretrained model I got quite reasonable results. Does anyone have an explanation? Thanks

losses graph

dice scores graph

this is the custom model

    import torch    from torch import nn    def create_conv_block(in_channels, out_channels):        conv_block = nn.Sequential(            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),            nn.BatchNorm2d(out_channels),            nn.ReLU(inplace=True),            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),            nn.BatchNorm2d(out_channels),            nn.ReLU(inplace=True))        return conv_block    class Encoder(nn.Module):        def __init__(self, channels) -> None:            super().__init__()            blocks = []            for i in range(len(channels)-1):                conv_block = create_conv_block(channels[i], channels[i+1])                blocks.append(conv_block)            self.model = nn.ModuleList(blocks)            self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)        def forward(self, x):            features = []            for module in self.model:                x = module(x)                features.append(x)                x = self.maxpool(x)            return features    class Decoder(nn.Module):        def __init__(self, channels) -> None:            super().__init__()            blocks = []            for i in range(len(channels)-2):                up_sampling_layer = nn.Sequential(nn.ConvTranspose2d(            in_channels=channels[i], out_channels=channels[i+1], kernel_size=2, stride=2))                blocks.append(up_sampling_layer)                conv_block = create_conv_block(channels[i+1]*2, channels[i+1])                blocks.append(conv_block)            blocks.append(nn.Sequential(nn.Conv2d(in_channels=channels[-2], out_channels=21, kernel_size=1)))            self.model = nn.ModuleList(blocks)        def forward(self, x, encoder_featuers):            for i, module in enumerate(self.model):                if i % 2:                    x = torch.cat([encoder_featuers.pop(), x], dim=1)                x = module(x)            return x    class UNet(nn.Module):        def __init__(self, channels):            super().__init__()            self.encoder = Encoder(channels)            self.decoder = Decoder(channels[::-1])        def forward(self, x):            encoder_features = self.encoder(x)            x = self.decoder(encoder_features[-1], encoder_features[:-1])    return x

And this is the training function

    def train(loader, model, loss_fn, optimizer):        model.train()        epoch_loss = 0        for images, masks in loader:            images = images.to("cuda")            masks = masks.type(torch.LongTensor).to("cuda")            optimizer.zero_grad()            outputs = model(images)["out"]            loss = loss_fn(outputs, masks)            epoch_loss += loss.item()            loss.backward()            optimizer.step()        return epoch_loss / len(loader)

And this is the validation function

    def validate(loader, model, loss_fn):        model.eval()        epoch_loss = 0        dice_score = 0        with torch.no_grad():            for images, masks in loader:                images = images.to("cuda")                masks = masks.type(torch.LongTensor).to("cuda")                outputs = model(images)["out"]                loss = loss_fn(outputs, masks)                epoch_loss += loss.item()                outputs = torch.argmax(outputs, dim=1)                dice_score += calc_dice_score(outputs, masks)            return epoch_loss / len(loader), dice_score / len(loader)

And this dice score function

    def calc_dice_score(preds, targets):        intersection = ((preds == targets) & (targets != 0)).sum()        union = (preds > 0).sum() + (targets > 0).sum()        dice_score = 2.0 * intersection / union        return dice_score.item()

Viewing all articles
Browse latest Browse all 22544

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>