BERT文本分类 情感分类 BertForSequenceClassification的使用

avatar 2024年04月27日22:42:37 0 107 views
博主分享免费Java教学视频,B站账号:Java刘哥

直接贴代码

from torch.utils.data import DataLoader, Dataset
import torch
from transformers import BertForSequenceClassification
from torch.optim import AdamW
from transformers import BertTokenizer
# from sklearn.metrics import accuracy_score, classification_report  # pip3 install scikit-learn

BERT_PATH = 'K:/workspace-sync/models/bert-base-chinese'


# 1. 定义数据集类
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=512, return_tensors="pt")
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


# 2. 微调模型
# 分词器
tokenizer = BertTokenizer.from_pretrained(BERT_PATH)
# 将模型放到GPU上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型, BertForSequenceClassification和BertModel的区别是前者是用于分类任务的模型,后者是用于其他任务的模型
model = BertForSequenceClassification.from_pretrained(BERT_PATH, num_labels=2).to(device)
# 使用AdamW优化器
optimizer = AdamW(model.parameters(), lr=1e-5)


# 3. 定义训练函数
def train(model, train_dataloader, device, optimizer):
    model.train()  # 设置模型为训练模式
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()


# 4 定义测试函数
def test(model, test_dataloader, device):
    model.eval()
    with torch.no_grad():
        corrent = 0
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            _, pred = torch.max(outputs.logits, dim=1)  # _是概率,pred是预测结果
            print(f'Predicted: {pred.item()}, Actual: {labels.item()}')  # 相当于 .format()

            if pred.item() == labels.item():
                corrent += 1
        print(f'Accuracy: {corrent / len(test_dataloader)}')


# 5 开始训练模型
train_texts = ["I love programming", "I love java", "I love c", "I love everything",
               "I hate error", "I hate bug", "I hate bad boy", "I hate easy girl"]
train_labels = [1, 1, 1, 1, 0, 0, 0, 0]
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=8)  # batch_size=8, 8个样本一组, 加快训练速度
train_num_epochs = 3
for epoch in range(train_num_epochs):
    print(f"Train Epoch {epoch + 1}/{train_num_epochs} ")
    train(model, train_dataloader, device, optimizer)

# 6 测试模型
test_texts = ["I love programming", "I love redis", "I hate rbr", "I hate japanese", "I love mysql"]
test_labels = [1, 1, 0, 0, 1]
test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=1)  # batch_size=1, 1个样本一组, 逐个样本测试,方便我看
test_num_epochs = 3
for epoch in range(test_num_epochs):
    print(f"\nTest Epoch {epoch + 1}/{train_num_epochs} ")
    test(model, test_dataloader, device)

 

运行结果

Train Epoch 1/3 
Train Epoch 2/3 
Train Epoch 3/3 

Test Epoch 1/3 
Predicted: 1, Actual: 1
Predicted: 1, Actual: 1
Predicted: 0, Actual: 0
Predicted: 1, Actual: 0
Predicted: 1, Actual: 1
Accuracy: 0.8

Test Epoch 2/3 
Predicted: 1, Actual: 1
Predicted: 1, Actual: 1
Predicted: 0, Actual: 0
Predicted: 1, Actual: 0
Predicted: 1, Actual: 1
Accuracy: 0.8

Test Epoch 3/3 
Predicted: 1, Actual: 1
Predicted: 1, Actual: 1
Predicted: 0, Actual: 0
Predicted: 1, Actual: 0
Predicted: 1, Actual: 1
Accuracy: 0.8

暂时直接用点自己造的数据集,明天尝试用一些比较标准的电影评论数据集

  • 微信
  • 交流学习,有偿服务
  • weinxin
  • 博客/Java交流群
  • 资源分享,问题解决,技术交流。群号:590480292
  • weinxin
avatar

发表评论

avatar 登录者:匿名
匿名评论,评论回复后会有邮件通知

  

已通过评论:0   待审核评论数:0