MLP分类效果一般好于线性分类器,即将特征输入MLP中再经过softmax来进行分类。
具体实现为将原先线性分类模块:
1
|
self .classifier = nn.Linear(config.hidden_size, num_labels) |
替换为:
1
|
self .classifier = MLP(config.hidden_size, num_labels) |
并且添加MLP模块:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
class MLP(nn.Module): def __init__( self , input_size, common_size): super (MLP, self ).__init__() self .linear = nn.Sequential( nn.Linear(input_size, input_size / / 2 ), nn.ReLU(inplace = True ), nn.Linear(input_size / / 2 , input_size / / 4 ), nn.ReLU(inplace = True ), nn.Linear(input_size / / 4 , common_size) ) def forward( self , x): out = self .linear(x) return out |
看一下模块结构:
1
2
|
mlp = MLP( 1000 , 3 ) print (mlp) |
以上这篇关于Pytorch的MLP模块实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_33373858/article/details/88108153