实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测

未分类2个月前发布 tree
38 0 0
↑ 点击蓝字 关注极市平台
实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测
作者丨Ctrl CV
来源丨笑傲算法江湖
编辑丨极市平台

极市导读

 

本文基于PyTorch框架,实现通过Faster RCNN算法检测图像中的小麦麦穗。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

本文基于PyTorch框架,实现通过Faster RCNN算法检测图像中的小麦麦穗。当然,用YOLO算法也同样能够完成。本文最终实现的效果如下:

实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测
麦穗检测示例

一、数据下载

数据集名:Global Wheat Head Dataset

下载地址www.kaggle.com/c/global-wheat-detection

更多深度学习数据集:https://www.cvmart.net/dataSets

相关论文:Global Wheat Head Detection (GWHD) Dataset: A Large and Diverse Dataset of High-Resolution RGB-Labelled Images to Develop and Benchmark Wheat Head Detection Methods

数据描述:全球麦穗数据集由来自7个国家的9个研究机构领导,东京大学、国家农业、营养和环境研究所、Arvalis、ETHZ、萨斯喀彻温大学、昆士兰大学、南京农业大学和洛桑研究所。包括全球粮食安全研究所、DigitAg、Kubota和Hiphen在内的许多机构都加入了这些机构的行列,致力于精确的小麦麦穗检测。

实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测
数据集贡献机构

数据集为室外小麦植物图像,包括来自全球各地不同平台采集的4698张RGB图像,标记了193,634个小麦麦穗,1024×1024像素,每张图像含有20~70个麦穗。2020年通过Kaggle举办了相关比赛,并在2021年更新了数据集。该数据集可以用于麦穗检测,评估穗数和大小。研究成果有助于准确估计不同品种小麦麦穗的密度和大小。

实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测
数据集示例

二、代码实战

2.1 导入所需要的包

# 导入所需要的包  
import numpy as np  
import pandas as pd  
import matplotlib.pyplot as plt  
import torch  
  
  
import torch.nn as nn  
import albumentations as A   # pip install albumentations==1.1.0  
from albumentations.pytorch import ToTensorV2  
import torchvision  
from torchvision import datasets,transforms  
from tqdm import tqdm  
import cv2  
from torch.utils.data import Dataset,DataLoader  
import torch.optim as optim  
from PIL import Image  
import os  
import torch.nn.functional as F  
import ast

2.2 参数配置

# 定义参数  
LR = 1e-4  
SPLIT = 0.2  
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"  
BATCH_SIZE = 4  
EPOCHS = 2  
DATAPATH = '../global-wheat-detection' 

2.3 读取数据

# 读取 train.csv文件  
df = pd.read_csv(DATAPATH + '/train.csv')  
df.bbox = df.bbox.apply(ast.literal_eval)   # # 将string of list 转成list数据  
  
#  # 利用groupby 将同一个image_id的数据进行聚合,方式为list进行,并且用reset_index直接转变成dataframe  
df = df.groupby("image_id")["bbox"].apply(list).reset_index(name="bboxes")  

2.4 划分数据

# # 划分数据集  
def train_test_split(dataFrame,split):  
    len_tot = len(dataFrame)  
    val_len = int(split*len_tot)  
    train_len = len_tot-val_len  
    train_data,val_data = dataFrame.iloc[:train_len][:],dataFrame.iloc[train_len:][:]  
    return train_data,val_data  
  
len(df)  
  
train_data_df,val_data_df = train_test_split(df,SPLIT)  # 划分 train val 8:2  
len(train_data_df), len(val_data_df)  
  
# 查看数据  
train_data_df  

2.5 构建Dataset类

# 定义WheatDataset 返回 图片,标签  
class WheatDataset(Dataset):  
    def __init__(self,data,root_dir,transform=None,train=True):  
        self.data = data  
        self.root_dir = root_dir  
        self.image_names = self.data.image_id.values  
        self.bboxes = self.data.bboxes.values  
        self.transform = transform  
        self.isTrain = train  
    def __len__(self):  
        return len(self.data)  
    def __getitem__(self,index):  
#         print(self.image_names)  
#         print(self.bboxes)  
        img_path = os.path.join(self.root_dir,self.image_names[index]+".jpg")  # 拼接路径  
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)   # 读取图片  
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)  # BGR2RGB  
        image /= 255.0    # 归一化  
        bboxes = torch.tensor(self.bboxes[index],dtype=torch.float64)  
#         print(bboxes)  
        """  
            As per the docs of torchvision  
            we need bboxes in format (xmin,ymin,xmax,ymax)  
            Currently we have them in format (xmin,ymin,width,height)  
        "
""  
        bboxes[:,2] = bboxes[:,0]+bboxes[:,2]   # 格式转换 (xmin,ymin,width,height)-----> (xmin,ymin,xmax,ymax)  
        bboxes[:,3] = bboxes[:,1]+bboxes[:,3]  
#         print(image.size,type(image))  
        """  
            we need to return image and a target dictionary  
            target:  
                boxes,labels,image_id,area,iscrowd  
        "
""  
        area = (bboxes[:,3]-bboxes[:,1])*(bboxes[:,2]-bboxes[:,0])   # 计算面积  
        area = torch.as_tensor(area,dtype=torch.float32)  
          
        # there is only one class  
        labels = torch.ones((len(bboxes),),dtype=torch.int64)   # 标签  
          
        # suppose all instances are not crowded  
        iscrowd = torch.zeros((len(bboxes),),dtype=torch.int64)  
          
        target = {}   # target是个字典 里面 包括 boxes,labels,image_id,area,iscrowd  
        target['boxes'] = bboxes  
        target['labels']= labels  
        target['image_id'] = torch.tensor([index])  
        target["area"] = area  
        target['iscrowd'] = iscrowd  
          
        if self.transform is not None:  
            sample = {  
                'image': image,  
                'bboxes': target['boxes'],  
                'labels': labels  
            }  
            sample = self.transform(**sample)  
            image = sample['image']  
              
            # 沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状,   
#             把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠  
            target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)  
              
        return image,target  

2.6 数据增强

# 训练与验证数据增强,利用albumentations  随机翻转转换,随机图片处理  
# 对象检测的增强与正常增强不同,因为在这里需要确保 bbox 在转换后仍然正确与对象对齐  
train_transform = A.Compose([  
    A.Flip(0.5),  
    ToTensorV2(p=1.0)  
],bbox_params = {'format':"pascal_voc",'label_fields': ['labels']})  
val_transform = A.Compose([  
      ToTensorV2(p=1.0)  
],bbox_params = {'format':"pascal_voc","label_fields":['labels']})  
`

### 2.7 数据整理

`"""  
collate_fn默认是对数据(图片)通过torch.stack()进行简单的拼接。对于分类网络来说,默认方法是可以的(因为传入的就是数据的图片),  
但是对于目标检测来说,train_dataset返回的是一个tuple,即(image, target)。  
如果我们还是采用默认的合并方法,那么就会出错。  
所以我们需要自定义一个方法,即collate_fn=train_dataset.collate_fn  
"
""  
def collate_fn(batch):  
    return tuple(zip(*batch))  

2.8 创建数据加载器

# 创建数据加载器  
  
train_data = WheatDataset(train_data_df,DATAPATH+"/train",transform=train_transform)  
valid_data = WheatDataset(val_data_df,DATAPATH+"/train",transform=val_transform)  

2.9 查看数据

# 查看一个训练集中的数据  
image,target = train_data.__getitem__(0)  
plt.imshow(image.numpy().transpose(1,2,0))  
print(image.shape)  
实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测
训练集示例

2.10 定义模型

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor  
  
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)  
num_classes = 2  
in_features = model.roi_heads.box_predictor.cls_score.in_features  
model.roi_heads.box_predictor = FastRCNNPredictor(in_features,num_classes)  

2.11 定义Averager类

# 这一个类来保存对应的loss  
class Averager:  
    def __init__(self):  
        self.current_total = 0.0  
        self.iterations = 0.0  
  
    def send(self, value):  
        self.current_total += value  
        self.iterations += 1  
  
    @property  
    def value(self):  
        if self.iterations == 0:  
            return 0  
        else:  
            return 1.0 * self.current_total / self.iterations  
  
    def reset(self):  
        self.current_total = 0.0  
        self.iterations = 0.0  

2.12 构建训练和测试 dataloader

# 构建训练和测试 dataloader  
train_dataloader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_fn)  
val_dataloader = DataLoader(valid_data,batch_size=BATCH_SIZE,shuffle=False,collate_fn=collate_fn)  

2.13 定义模型参数

# 定义模型, 优化器,损失, 迭代,以及 学习率  
train_loss = []  
# val_loss = []  
model = model.to(DEVICE)  
params =[p for p in model.parameters() if p.requires_grad]  
optimizer = optim.Adam(params,lr=LR)  
loss_hist = Averager()  
itr = 1  
lr_scheduler=None  
  
loss_hist = Averager()  
itr = 1  

2.14 模型训练

if __name__ == '__main__':  
  
    for epoch in range(EPOCHS):  
        loss_hist.reset()  
  
        for images, targets in train_dataloader:  
  
            # print(images)  
            # print(targets)  
  
            # for image in images:  
            #     print(image.dtype)  # torch.float32  
  
            # for t in targets:  
            #     for k, v in t.items():  
            #         print(k ,v.dtype)  
  
            images = list(image.to(DEVICE) for image in images)  
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]  
  
  
            loss_dict = model(images, targets)  
  
            # for loss in loss_dict.values():  
            #     print(loss.dtype)  
  
            losses = sum(loss for loss in loss_dict.values())  
            loss_value = losses.item()  
  
            loss_hist.send(loss_value)  
  
            optimizer.zero_grad()  
            losses.backward()  
            optimizer.step()  
  
            if itr % 50 == 0:  
                print(f"Iteration #{itr} loss: {loss_value}")  
  
            itr += 1  
  
        # update the learning rate  
        if lr_scheduler is not None:  
            lr_scheduler.step()  
  
        print(f"Epoch #{epoch} loss: {loss_hist.value}")  

2.15 模型保存

# 模型保存  
torch.save(model.state_dict(), 'fasterrcnn_resnet50_fpn.pth')  
实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测
训练好的模型

2.16 加载模型进行预测

images, targets = next(iter(val_dataloader))  
images = list(img.to(DEVICE) for img in images)  
# print(images[0].shape)  
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]  
boxes = targets[1]['boxes'].cpu().numpy().astype(np.int32)  
sample = images[1].permute(1, 2, 0).cpu().numpy()  
  
model.eval()  
cpu_device = torch.device("cpu")  
# print(images[0].shape)  
  
  
outputs = model(images)  
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]  
# print(outputs[1]['boxes'].detach().numpy().astype(np.int32))  
  
pred_boxes = outputs[1]['boxes'].detach().numpy().astype(np.int32)  
  
fig, ax = plt.subplots(1, 1, figsize=(16, 8))  
  
for b, box in zip(boxes, pred_boxes):  
    # 绘制预测边框 红色表示  
    cv2.rectangle(sample,  
                  (box[0], box[1]),  
                  (box[2], box[3]),  
                  (220, 0, 0), 3)  
    # 绘制实际边框  绿色表示  
    cv2.rectangle(sample,  
                  (b[0], b[1]),  
                  (b[2], b[3]),  
                  (0, 220, 0), 3)  
  
ax.set_axis_off()  
ax.imshow(sample)  
plt.show()  
实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测
检测结果

对比预测框与实际框,可以看出模型能够很好的预测出麦穗。可以尝试测试不同的麦穗图片,来进行测试查看效果。

实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测

公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列大视觉模型 (LVM) 解读扩散模型系列极市直播
技术综述:小目标检测那点事大模型面试八股含答案万字长文!人体姿态估计(HPE)入门教程

实践教程|PyTorch框架Faster RCNN算法实现小麦麦穗检测

点击阅读原文进入CV社区

收获更多技术干货

© 版权声明

相关文章

暂无评论

暂无评论...