设计神经网络的一般步骤:
1. 设计框架
2. 设计骨干网络
Unet网络设计的步骤:
1. 设计Unet网络工厂模式
2. 设计编解码结构
3. 设计卷积模块
4. unet实例模块
Unet网络最重要的特征:
1. 编解码结构。
2. 解码结构,比FCN更加完善,采用连接方式。
3. 本质是一个框架,编码部分可以使用很多图像分类网络。
示例代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
import torch import torch.nn as nn class Unet(nn.Module): #初始化参数:Encoder,Decoder,bridge #bridge默认值为无,如果有参数传入,则用该参数替换None def __init__( self ,Encoder,Decoder,bridge = None ): super (Unet, self ).__init__() self .encoder = Encoder(encoder_blocks) self .decoder = Decoder(decoder_blocks) self .bridge = bridge def forward( self ,x): res = self .encoder(x) out,skip = res[ 0 ],res[ 1 ,:] if bridge is not None : out = bridge(out) out = self .decoder(out,skip) return out #设计编码模块 class Encoder(nn.Module): def __init__( self ,blocks): super (Encoder, self ).__init__() #assert:断言函数,避免出现参数错误 assert len (blocks) > 0 #nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数 self .blocks = nn.Modulelist(blocks) def forward( self ,x): skip = [] for i in range ( len ( self .blocks) - 1 ): x = self .blocks[i](x) skip.append(x) res = [ self .block[i + 1 ](x)] #列表之间可以通过+号拼接 res + = skip return res #设计Decoder模块 class Decoder(nn.Module): def __init__( self ,blocks): super (Decoder, self ).__init__() assert len (blocks) > 0 self .blocks = nn.Modulelist(blocks) def ceter_crop( self ,skips,x): _,_,height1,width1 = skips.shape() _,_,height2,width2 = x.shape() #对图像进行剪切处理,拼接的时候保持对应size参数一致 ht,wt = min (height1,height2), min (width1,width2) dh1 = (height1 - height2) / / 2 if height1 > height2 else 0 dw1 = (width1 - width2) / / 2 if width1 > width2 else 0 dh2 = (height2 - height1) / / 2 if height2 > height1 else 0 dw2 = (width2 - width1) / / 2 if width2 > width1 else 0 return skips[:,:,dh1:(dh1 + ht),dw1:(dw1 + wt)],\ x[:,:,dh2:(dh2 + ht),dw2 : (dw2 + wt)] def forward( self , skips,x,reverse_skips = True ): assert len (skips) = = len (blocks) - 1 if reverse_skips is True : skips = skips[: : - 1 ] x = self .blocks[ 0 ](x) for i in range ( 1 , len ( self .blocks)): skip = skips[i - 1 ] x = torch.cat(skip,x, 1 ) x = self .blocks[i](x) return x #定义了一个卷积block def unet_convs(in_channels,out_channels,padding = 0 ): #nn.Sequential:与Modulelist相比,包含了forward函数 return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernal_size = 3 , padding = padding, bias = False ), nn.BatchNorm2d(outchannels), nn.ReLU(inplace = True ), nn.Conv2d(in_channels, out_channels, kernal_size = 3 , padding = padding, bias = False ), nn.BatchNorm2d(outchannels), nn.ReLU(inplace = True ), ) #实例化Unet模型 def unet(in_channels,out_channels): encoder_blocks = [unet_convs(in_channels, 64 ),\ nn.Sequential(nn.Maxpool2d(kernal_size = 2 , stride = 2 , ceil_mode = True ),\ unet_convs( 64 , 128 )), \ nn.Sequential(nn.Maxpool2d(kernal_size = 2 , stride = 2 , ceil_mode = True ), \ unet_convs( 128 , 256 )), nn.Sequential(nn.Maxpool2d(kernal_size = 2 , stride = 2 , ceil_mode = True ), \ unet_convs( 256 , 512 )), ] bridge = nn.Sequential(unet_convs( 512 , 1024 )) decoder_blocks = [nn.conTranpose2d( 1024 , 512 ), \ nn.Sequential(unet_convs( 1024 , 512 ), nn.conTranpose2d( 512 , 256 )),\ nn.Sequential(unet_convs( 512 , 256 ), nn.conTranpose2d( 256 , 128 )), \ nn.Sequential(unet_convs( 512 , 256 ), nn.conTranpose2d( 256 , 128 )), \ nn.Sequential(unet_convs( 256 , 128 ), nn.conTranpose2d( 128 , 64 )) ] return Unet(encoder_blocks,decoder_blocks,bridge) |
补充知识:Pytorch搭建U-Net网络
U-Net: Convolutional Networks for Biomedical Image Segmentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
import torch.nn as nn import torch from torch import autograd from torchsummary import summary class DoubleConv(nn.Module): def __init__( self , in_ch, out_ch): super (DoubleConv, self ).__init__() self .conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3 , padding = 0 ), nn.BatchNorm2d(out_ch), nn.ReLU(inplace = True ), nn.Conv2d(out_ch, out_ch, 3 , padding = 0 ), nn.BatchNorm2d(out_ch), nn.ReLU(inplace = True ) ) def forward( self , input ): return self .conv( input ) class Unet(nn.Module): def __init__( self , in_ch, out_ch): super (Unet, self ).__init__() self .conv1 = DoubleConv(in_ch, 64 ) self .pool1 = nn.MaxPool2d( 2 ) self .conv2 = DoubleConv( 64 , 128 ) self .pool2 = nn.MaxPool2d( 2 ) self .conv3 = DoubleConv( 128 , 256 ) self .pool3 = nn.MaxPool2d( 2 ) self .conv4 = DoubleConv( 256 , 512 ) self .pool4 = nn.MaxPool2d( 2 ) self .conv5 = DoubleConv( 512 , 1024 ) # 逆卷积,也可以使用上采样 self .up6 = nn.ConvTranspose2d( 1024 , 512 , 2 , stride = 2 ) self .conv6 = DoubleConv( 1024 , 512 ) self .up7 = nn.ConvTranspose2d( 512 , 256 , 2 , stride = 2 ) self .conv7 = DoubleConv( 512 , 256 ) self .up8 = nn.ConvTranspose2d( 256 , 128 , 2 , stride = 2 ) self .conv8 = DoubleConv( 256 , 128 ) self .up9 = nn.ConvTranspose2d( 128 , 64 , 2 , stride = 2 ) self .conv9 = DoubleConv( 128 , 64 ) self .conv10 = nn.Conv2d( 64 , out_ch, 1 ) def forward( self , x): c1 = self .conv1(x) crop1 = c1[:,:, 88 : 480 , 88 : 480 ] p1 = self .pool1(c1) c2 = self .conv2(p1) crop2 = c2[:,:, 40 : 240 , 40 : 240 ] p2 = self .pool2(c2) c3 = self .conv3(p2) crop3 = c3[:,:, 16 : 120 , 16 : 120 ] p3 = self .pool3(c3) c4 = self .conv4(p3) crop4 = c4[:,:, 4 : 60 , 4 : 60 ] p4 = self .pool4(c4) c5 = self .conv5(p4) up_6 = self .up6(c5) merge6 = torch.cat([up_6, crop4], dim = 1 ) c6 = self .conv6(merge6) up_7 = self .up7(c6) merge7 = torch.cat([up_7, crop3], dim = 1 ) c7 = self .conv7(merge7) up_8 = self .up8(c7) merge8 = torch.cat([up_8, crop2], dim = 1 ) c8 = self .conv8(merge8) up_9 = self .up9(c8) merge9 = torch.cat([up_9, crop1], dim = 1 ) c9 = self .conv9(merge9) c10 = self .conv10(c9) out = nn.Sigmoid()(c10) return out if __name__ = = "__main__" : test_input = torch.rand( 1 , 1 , 572 , 572 ) model = Unet(in_ch = 1 , out_ch = 2 ) summary(model, ( 1 , 572 , 572 )) ouput = model(test_input) print (ouput.size()) |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Layer ( type ) Output Shape Param # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = Conv2d - 1 [ - 1 , 64 , 570 , 570 ] 640 BatchNorm2d - 2 [ - 1 , 64 , 570 , 570 ] 128 ReLU - 3 [ - 1 , 64 , 570 , 570 ] 0 Conv2d - 4 [ - 1 , 64 , 568 , 568 ] 36 , 928 BatchNorm2d - 5 [ - 1 , 64 , 568 , 568 ] 128 ReLU - 6 [ - 1 , 64 , 568 , 568 ] 0 DoubleConv - 7 [ - 1 , 64 , 568 , 568 ] 0 MaxPool2d - 8 [ - 1 , 64 , 284 , 284 ] 0 Conv2d - 9 [ - 1 , 128 , 282 , 282 ] 73 , 856 BatchNorm2d - 10 [ - 1 , 128 , 282 , 282 ] 256 ReLU - 11 [ - 1 , 128 , 282 , 282 ] 0 Conv2d - 12 [ - 1 , 128 , 280 , 280 ] 147 , 584 BatchNorm2d - 13 [ - 1 , 128 , 280 , 280 ] 256 ReLU - 14 [ - 1 , 128 , 280 , 280 ] 0 DoubleConv - 15 [ - 1 , 128 , 280 , 280 ] 0 MaxPool2d - 16 [ - 1 , 128 , 140 , 140 ] 0 Conv2d - 17 [ - 1 , 256 , 138 , 138 ] 295 , 168 BatchNorm2d - 18 [ - 1 , 256 , 138 , 138 ] 512 ReLU - 19 [ - 1 , 256 , 138 , 138 ] 0 Conv2d - 20 [ - 1 , 256 , 136 , 136 ] 590 , 080 BatchNorm2d - 21 [ - 1 , 256 , 136 , 136 ] 512 ReLU - 22 [ - 1 , 256 , 136 , 136 ] 0 DoubleConv - 23 [ - 1 , 256 , 136 , 136 ] 0 MaxPool2d - 24 [ - 1 , 256 , 68 , 68 ] 0 Conv2d - 25 [ - 1 , 512 , 66 , 66 ] 1 , 180 , 160 BatchNorm2d - 26 [ - 1 , 512 , 66 , 66 ] 1 , 024 ReLU - 27 [ - 1 , 512 , 66 , 66 ] 0 Conv2d - 28 [ - 1 , 512 , 64 , 64 ] 2 , 359 , 808 BatchNorm2d - 29 [ - 1 , 512 , 64 , 64 ] 1 , 024 ReLU - 30 [ - 1 , 512 , 64 , 64 ] 0 DoubleConv - 31 [ - 1 , 512 , 64 , 64 ] 0 MaxPool2d - 32 [ - 1 , 512 , 32 , 32 ] 0 Conv2d - 33 [ - 1 , 1024 , 30 , 30 ] 4 , 719 , 616 BatchNorm2d - 34 [ - 1 , 1024 , 30 , 30 ] 2 , 048 ReLU - 35 [ - 1 , 1024 , 30 , 30 ] 0 Conv2d - 36 [ - 1 , 1024 , 28 , 28 ] 9 , 438 , 208 BatchNorm2d - 37 [ - 1 , 1024 , 28 , 28 ] 2 , 048 ReLU - 38 [ - 1 , 1024 , 28 , 28 ] 0 DoubleConv - 39 [ - 1 , 1024 , 28 , 28 ] 0 ConvTranspose2d - 40 [ - 1 , 512 , 56 , 56 ] 2 , 097 , 664 Conv2d - 41 [ - 1 , 512 , 54 , 54 ] 4 , 719 , 104 BatchNorm2d - 42 [ - 1 , 512 , 54 , 54 ] 1 , 024 ReLU - 43 [ - 1 , 512 , 54 , 54 ] 0 Conv2d - 44 [ - 1 , 512 , 52 , 52 ] 2 , 359 , 808 BatchNorm2d - 45 [ - 1 , 512 , 52 , 52 ] 1 , 024 ReLU - 46 [ - 1 , 512 , 52 , 52 ] 0 DoubleConv - 47 [ - 1 , 512 , 52 , 52 ] 0 ConvTranspose2d - 48 [ - 1 , 256 , 104 , 104 ] 524 , 544 Conv2d - 49 [ - 1 , 256 , 102 , 102 ] 1 , 179 , 904 BatchNorm2d - 50 [ - 1 , 256 , 102 , 102 ] 512 ReLU - 51 [ - 1 , 256 , 102 , 102 ] 0 Conv2d - 52 [ - 1 , 256 , 100 , 100 ] 590 , 080 BatchNorm2d - 53 [ - 1 , 256 , 100 , 100 ] 512 ReLU - 54 [ - 1 , 256 , 100 , 100 ] 0 DoubleConv - 55 [ - 1 , 256 , 100 , 100 ] 0 ConvTranspose2d - 56 [ - 1 , 128 , 200 , 200 ] 131 , 200 Conv2d - 57 [ - 1 , 128 , 198 , 198 ] 295 , 040 BatchNorm2d - 58 [ - 1 , 128 , 198 , 198 ] 256 ReLU - 59 [ - 1 , 128 , 198 , 198 ] 0 Conv2d - 60 [ - 1 , 128 , 196 , 196 ] 147 , 584 BatchNorm2d - 61 [ - 1 , 128 , 196 , 196 ] 256 ReLU - 62 [ - 1 , 128 , 196 , 196 ] 0 DoubleConv - 63 [ - 1 , 128 , 196 , 196 ] 0 ConvTranspose2d - 64 [ - 1 , 64 , 392 , 392 ] 32 , 832 Conv2d - 65 [ - 1 , 64 , 390 , 390 ] 73 , 792 BatchNorm2d - 66 [ - 1 , 64 , 390 , 390 ] 128 ReLU - 67 [ - 1 , 64 , 390 , 390 ] 0 Conv2d - 68 [ - 1 , 64 , 388 , 388 ] 36 , 928 BatchNorm2d - 69 [ - 1 , 64 , 388 , 388 ] 128 ReLU - 70 [ - 1 , 64 , 388 , 388 ] 0 DoubleConv - 71 [ - 1 , 64 , 388 , 388 ] 0 Conv2d - 72 [ - 1 , 2 , 388 , 388 ] 130 = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = Total params: 31 , 042 , 434 Trainable params: 31 , 042 , 434 Non - trainable params: 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Input size (MB): 1.25 Forward / backward pass size (MB): 3280.59 Params size (MB): 118.42 Estimated Total Size (MB): 3400.26 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - torch.Size([ 1 , 2 , 388 , 388 ]) |
以上这篇使用pytorch实现论文中的unet网络就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/weixin_38410551/article/details/104294545