新闻文本分类之旅 BERT

新闻文本分类之旅 BERT
新闻文本分类之旅 BERT

?代码全部放在GitHub
预训练BERT以及相关代码下载地址:
链接: https://pan.baidu.com/s/1zd6wN7elGgp1NyuzYKpvGQ 提取码: tmp5

?我们知道BERT模型的输入有三部分:token embeddingsegment embedding以及position embedding

词向量的后续处理

先生成Segment Embeddings 和 Position Embeddings,
再相加,即Input = Token Embeddings + Segment Embeddings + Position Embeddings

新闻文本分类之旅 BERT
新闻文本分类之旅 BERT
BERT源码分析
transformers库
Self-Attention与Transformer
?模型创建

class BERTClass(torch.nn.Module):
def __init__(self):
super(BERTClass, self).__init__()
self.config = BertConfig.from_pretrained('../emb/bert-mini/bert_config.json', output_hidden_states=True)
self.l1 = BertModel.from_pretrained('../emb/bert-mini/pytorch_model.bin', config=self.config)
self.bilstm1 = torch.nn.LSTM(512, 64, 1, bidirectional=True)
self.l2 = torch.nn.Linear(128, 64)
self.a1 = torch.nn.ReLU()
self.l3 = torch.nn.Dropout(0.3)
self.l4 = torch.nn.Linear(64, 14)

def forward(self, ids, mask, token_type_ids):
sequence_output, pooler_output, hidden_states= self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids)
# [bs, 200, 256] [bs,256]
bs = len(sequence_output)
h12 = hidden_states[-1][:,0].view(1,bs,256)
h11 = hidden_states[-2][:,0].view(1,bs,256)
concat_hidden = torch.cat((h12,h11),2)
x, _ = self.bilstm1(concat_hidden)
x = self.l2(x.view(bs,128))
x = self.a1(x)
x = self.l3(x)
output = self.l4(x)
return output

net = BERTClass()
net.to(device)

?训练模型

def train(epoch,train_iter, test_iter, criterion, num_epochs, optimizer, device):
print('training on', device)
net.to(device)
best_test_acc = 0
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) # 设置学习率下降策略
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=2e-06) # 余弦退火
for epoch in range(num_epochs):
train_l_sum = torch.tensor([0.0], dtype=torch.float32, device=device)
train_acc_sum = torch.tensor([0.0], dtype=torch.float32, device=device)
n, start = 0, time.time()
for data in tqdm(train_iter):
net.train()
optimizer.zero_grad()
ids = data['ids'].to(device, dtype=torch.long)
mask = data['mask'].to(device, dtype=torch.long)
token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
targets = data['targets'].to(device, dtype = torch.float)
y_hat = net(ids, mask, token_type_ids)
loss = criterion(y_hat, targets.long())
loss.backward()
optimizer.step()

with torch.no_grad():
targets = targets.long()
train_l_sum += loss.float()
train_acc_sum += (torch.sum((torch.argmax(y_hat, dim=1) == targets))).float()
n += targets.shape[0]
valid_acc = evaluate_accuracy(test_iter, net, device)
train_acc = train_acc_sum / n
print('epoch %d, loss %.4f, train acc %.3f, valid acc %.3f, '
'time %.1f sec'
% (epoch + 1, train_l_sum / n, train_acc, valid_acc,
time.time() - start))
if valid_acc > best_test_acc:
print('find best! save at model/best.pth')
best_test_acc = valid_acc
torch.save(net.state_dict(), 'model/best.pth')
scheduler.step() # 更新学习率

原创:https://www.panoramacn.com
源码网提供WordPress源码,帝国CMS源码discuz源码,微信小程序,小说源码,杰奇源码,thinkphp源码,ecshop模板源码,微擎模板源码,dede源码,织梦源码等。

专业搭建小说网站,小说程序,杰奇系列,微信小说系列,app系列小说

新闻文本分类之旅 BERT

免责声明,若由于商用引起版权纠纷,一切责任均由使用者承担。

您必须遵守我们的协议,如果您下载了该资源行为将被视为对《免责声明》全部内容的认可-> 联系客服 投诉资源
www.panoramacn.com资源全部来自互联网收集,仅供用于学习和交流,请勿用于商业用途。如有侵权、不妥之处,请联系站长并出示版权证明以便删除。 敬请谅解! 侵权删帖/违法举报/投稿等事物联系邮箱:2640602276@qq.com
未经允许不得转载:书荒源码源码网每日更新网站源码模板! » 新闻文本分类之旅 BERT
关注我们小说电影免费看
关注我们,获取更多的全网素材资源,有趣有料!
120000+人已关注
分享到:
赞(0) 打赏

评论抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址

您的打赏就是我分享的动力!

支付宝扫一扫打赏

微信扫一扫打赏