博客
关于我
Pytorch深度学习框架YOLOv3目标检测学习笔记(二)——创建网络框架的层
阅读量:521 次
发布时间:2019-03-07

本文共 5413 字,大约阅读时间需要 18 分钟。

YOLO网络结构代码实现

1. 配置文件解析

YOLO网络的配置文件包含了网络的各个层级信息,以 Keyword 为基础,使用

标识符分隔各个参数。

1.1 配置文件示例

convolutionalbatch_normalize=1filters=64size=3stride=2pad=1activation=leaky

1.2 解析过程

  • 读取文件:将(cfgfile中的内容读取为一个列表。
  • 处理多空行:过滤掉空白行,并去除注释。
  • 解析块:遍历每一行,检查是否标记为新块(以‘[’开头),生成对应层的配置字典。

1.3 工具函数

from __future__ import divisionimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.autograd import Variableimport numpy as npdef parse_cfg(cfgfile):    """解析配置文件,返回网络结构块列表。"""    f = open(cfgfile, 'r')    lines = f.read().split('\n')    lines = [x.strip() for x in lines if x.strip() != '']    lines = [x for x in lines if x[0] != '#']    blocks = []    current_block = {}    for line in lines:        if line.startswith('['):            if current_block:                blocks.append(current_block)                current_block = {}            block_type = line[1:-1].strip()            current_block['type'] = block_type        else:            if '=' in line:                key, value = line.split('=', 1)                key = key.strip()                value = value.strip()                current_block[key] = value            else:                line = line.split()[0]                if line:                    current_block[line] = 'True'    blocks.append(current_block)    return blocks

1.4 模块生成

遍历解析得到的块列表,创建对应的PyTorch模块。使用Sequential和ModuleList来组织模块,确保层级结构清晰。

def create_modules(blocks):    net_info = blocks[0]    parent_module = nn.ModuleList()    prev_filters = 3    output_filters = []        for index, block in enumerate(blocks[1:]):        module = nn.Sequential()        module_type = block['type']                if module_type == 'convolutional':            activation = block.get('activation', 'leaky')            batch_norm = block.get('batch_normalize', 0)            filters = int(block['filters'])            padding = block.get('pad', 0)            kernel_size = int(block['size'])            stride = int(block['stride'])            if padding != 0:                padding = (kernel_size - 1) // 2                pad = padding            else:                pad = 0                        conv = nn.Conv2d(prev_filters, filters, kernel_size, stride=stride, padding=pad,                                bias=block.get('bias', True))            module.add_module(f'conv_{index}', conv)                        if batch_norm:                bn = nn.BatchNorm2d(filters)                module.add_module(f'batch_norm_{index}', bn)                            if activation == 'leaky':                activn = nn.LeakyReLU(0.1, inplace=True)                module.add_module(f'leaky_{index}', activn)                    elif module_type == 'upsample':            stride = int(block['stride'])            upsample = nn.Upsample(scale_factor=2, mode='bilinear')            module.add_module(f'upsample_{index}', upsample)                    elif module_type == 'route':            layers = block['layers'].split(',')            start = int(layers[0])            end = int(layers[1]) if len(layers) > 1 else -1                        start_idx = start - index            end_idx = end - index if end != -1 else -1                        route = EmptyLayer()            module.add_module(f'route_{index}', route)                        if end != -1:                filters = output_filters[index + start] + output_filters[index + end]            else:                filters = output_filters[index + start]                        elif module_type == 'shortcut':            shortcut = EmptyLayer()            module.add_module(f'shortcut_{index}', shortcut)                    elif module_type == 'yolo':            mask = block['mask'].split(',')            mask = [int(x) for x in mask]            anchors = block['anchors'].split(',')            anchors = [int(a) for a in anchors]            anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors), 2)]            selected_anchors = [a for a in anchors if a in mask]                        detection = DetectionLayer(selected_anchors)            module.add_module(f'Detection_{index}', detection)                    module_list.append(module)        module_list.add_module(module, parent_module)        prev_filters = filters        output_filters.append(filters)        return (net_info, parent_module)

1.5 重要模块定义

class EmptyLayer(nn.Module):    def __init__(self):        super(EmptyLayer, self).__init__()class DetectionLayer(nn.Module):    def __init__(self, anchors):        super(DetectionLayer, self).__init__()        self.anchors = anchors

2. 模型构建

最终将解析得到的模块列表和网络信息整合,返回完整的PyTorch模型结构。

def build_net(net_info):    net = nn.Sequential()    with net_info['convolutional'] as conv:        net.add_module('卷积层', conv)    # 其他层的添加类似    return net

2.1 模型返回

完整模型可以通过将create_modules函数得到的module_listnet_info结合起来构建。

if __name__ == '__main__':    import os    os.environ['DETATCH'] = '1'    blocks = parse_cfg("cfg/yolov3.cfg")    net_info, modules = create_modules(blocks)    print("模块数量:", len(modules))    print(modules[0])

2.2 模型输出

通过上述方法,模型将包含所有配置文件中定义的网络层,并通过PyTorch的模块系统正确连接。可以为YOLO模型添加输入和输出层,完成完整的网络定义。

# 输入层input_layer = nn.Input(3, 320, 320)# 前向传播流程outputs = modules.input_layer(input)= modules.conv_0(outputs)= modules.upsample_1(outputs)= modules.batch_norm_2(outputs)= modules.leaky_3(outputs)...= modules.Detection_8(outputs)# 最后的输出y_pred = detection(outputs)

3. 模型训练与使用

完成模型构建后,可以使用PyTorch的训练 utilities进行训练,使用预定义的数据集和优化器进行优化。

转载地址:http://nscnz.baihongyu.com/

你可能感兴趣的文章
MySQL 中日志的面试题总结
查看>>
MySQL 中随机抽样:order by rand limit 的替代方案
查看>>
MySQL 为什么需要两阶段提交?
查看>>
mysql 为某个字段的值加前缀、去掉前缀
查看>>
mysql 主从 lock_mysql 主从同步权限mysql 行锁的实现
查看>>
mysql 主从互备份_mysql互为主从实战设置详解及自动化备份(Centos7.2)
查看>>
mysql 主键重复则覆盖_数据库主键不能重复
查看>>
mysql 优化器 key_mysql – 选择*和查询优化器
查看>>
MySQL 优化:Explain 执行计划详解
查看>>
Mysql 会导致锁表的语法
查看>>
mysql 使用sql文件恢复数据库
查看>>
mysql 修改默认字符集为utf8
查看>>
Mysql 共享锁
查看>>
MySQL 内核深度优化
查看>>
mysql 内连接、自然连接、外连接的区别
查看>>
mysql 写入慢优化
查看>>
mysql 分组统计SQL语句
查看>>
Mysql 分页语句 Limit原理
查看>>
MySQL 创建新用户及授予权限的完整流程
查看>>
mysql 创建表,不能包含关键字values 以及 表id自增问题
查看>>