资源简介

mmdetection在2019年12月13号进行了新版本的更新,其中对api/train.py增加torch.distributed,这块在windows下不支持,所以要在windows中训练的话需要把v1.0rc1的版本的train与新版本的train进行合并,主要是去除torch.distributed以及_non_dist_train的修改为主。

资源截图

代码片段和文件信息

from __future__ import division
import logging
import random
import numpy as np
import re
from collections import OrderedDict

import torch
from mmcv.runner import Runner DistSamplerSeedHook obj_from_dict
from mmcv.parallel import MMDataParallel MMDistributedDataParallel

from mmdet import datasets
from mmdet.core import (DistOptimizerHook DistEvalmAPHook
                        CocoDistEvalRecallHook CocoDistEvalmAPHook
                        Fp16OptimizerHook)
from mmdet.datasets import build_dataloader DATASETS
from mmdet.models import RPN
# from .env import get_root_logger

def get_root_logger(log_file=None log_level=logging.INFO):
    logger = logging.getLogger(‘mmdet‘)
    # if the logger has been initialized just return it
    if logger.hasHandlers():
        return logger

    logging.basicConfig(
        format=‘%(asctime)s - %(levelname)s - %(message)s‘ level=log_level)
    # rank _ = get_dist_info()
    # if rank != 0:
    #     logger.setLevel(‘ERROR‘)
    # elif log_file is not None:
    #     file_handler = logging.FileHandler(log_file ‘w‘)
    #     file_handler.setFormatter(
    #         logging.Formatter(‘%(asctime)s - %(levelname)s - %(message)s‘))
    #     file_handler.setLevel(log_level)
    #     logger.addHandler(file_handler)

    return logger

def set_random_seed(seed deterministic=False):
    “““Set random seed.
    Args:
        seed (int): Seed to be used.
        deterministic (bool): Whether to set the deterministic option for
            CUDNN backend i.e. set ‘torch.backends.cudnn.deterministic‘
            to True and ‘torch.backends.cudnn.benchmark‘ to False.
            Default: False.
    “““
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name loss_value in losses.items():
        if isinstance(loss_value torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                ‘{} is not a tensor or list of tensors‘.format(loss_name))

    loss = sum(_value for _key _value in log_vars.items() if ‘loss‘ in _key)

    log_vars[‘loss‘] = loss
    for name in log_vars:
        log_vars[name] = log_vars[name].item()

    return loss log_vars


def batch_processor(model data train_mode):
    losses = model(**data)
    loss log_vars = parse_losses(losses)

    outputs = dict(
        loss=loss log_vars=log_vars num_samples=len(data[‘img‘].data))

    return outputs


def train_detector(model
                   dataset
                   cfg
                   distributed=False
  

评论

共有 条评论