cityscapes_semsegm_test_enet.py 4.56 KB
import numpy as np
import sys
import os
import fnmatch
import argparse

try:
    import cv2 as cv
except ImportError:
    raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
                      'configure environemnt variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
try:
    import torch
except ImportError:
    raise ImportError('Can\'t find pytorch. Please intall it by following instructions on the official site')

from torch.utils.serialization import load_lua
from pascal_semsegm_test_fcn import eval_segm_result, get_conf_mat, get_metrics, DatasetImageFetch, SemSegmEvaluation
from imagenet_cls_test_alexnet import Framework, DnnCaffeModel


class NormalizePreproc:
    def __init__(self):
        pass

    @staticmethod
    def process(img):
        image_data = np.array(img).transpose(2, 0, 1).astype(np.float32)
        image_data = np.expand_dims(image_data, 0)
        image_data /= 255.0
        return image_data


class CityscapesDataFetch(DatasetImageFetch):
    img_dir = ''
    segm_dir = ''
    segm_files = []
    colors = []
    i = 0

    def __init__(self, img_dir, segm_dir, preproc):
        self.img_dir = img_dir
        self.segm_dir = segm_dir
        self.segm_files = sorted([img for img in self.locate('*_color.png', segm_dir)])
        self.colors = self.get_colors()
        self.data_prepoc = preproc
        self.i = 0

    @staticmethod
    def get_colors():
        result = []
        colors_list = (
         (0, 0, 0), (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153),
         (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0),
         (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32))

        for c in colors_list:
            result.append(DatasetImageFetch.pix_to_c(c))
        return result

    def __iter__(self):
        return self

    def next(self):
        if self.i < len(self.segm_files):
            segm_file = self.segm_files[self.i]
            segm = cv.imread(segm_file, cv.IMREAD_COLOR)[:, :, ::-1]
            segm = cv.resize(segm, (1024, 512), interpolation=cv.INTER_NEAREST)

            img_file = self.rreplace(self.img_dir + segm_file[len(self.segm_dir):], 'gtFine_color', 'leftImg8bit')
            assert os.path.exists(img_file)
            img = cv.imread(img_file, cv.IMREAD_COLOR)[:, :, ::-1]
            img = cv.resize(img, (1024, 512))

            self.i += 1
            gt = self.color_to_gt(segm, self.colors)
            img = self.data_prepoc.process(img)
            return img, gt
        else:
            self.i = 0
            raise StopIteration

    def get_num_classes(self):
        return len(self.colors)

    @staticmethod
    def locate(pattern, root_path):
        for path, dirs, files in os.walk(os.path.abspath(root_path)):
            for filename in fnmatch.filter(files, pattern):
                yield os.path.join(path, filename)

    @staticmethod
    def rreplace(s, old, new, occurrence=1):
        li = s.rsplit(old, occurrence)
        return new.join(li)


class TorchModel(Framework):
    net = object

    def __init__(self, model_file):
        self.net = load_lua(model_file)

    def get_name(self):
        return 'Torch'

    def get_output(self, input_blob):
        tensor = torch.FloatTensor(input_blob)
        out = self.net.forward(tensor).numpy()
        return out


class DnnTorchModel(DnnCaffeModel):
    net = cv.dnn.Net()

    def __init__(self, model_file):
        self.net = cv.dnn.readNetFromTorch(model_file)

    def get_output(self, input_blob):
        self.net.setBlob("", input_blob)
        self.net.forward()
        return self.net.getBlob(self.net.getLayerNames()[-1])

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--imgs_dir", help="path to Cityscapes validation images dir, imgsfine/leftImg8bit/val")
    parser.add_argument("--segm_dir", help="path to Cityscapes dir with segmentation, gtfine/gtFine/val")
    parser.add_argument("--model", help="path to torch model, download it here: "
                        "https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa")
    parser.add_argument("--log", help="path to logging file")
    args = parser.parse_args()

    prep = NormalizePreproc()
    df = CityscapesDataFetch(args.imgs_dir, args.segm_dir, prep)

    fw = [TorchModel(args.model),
          DnnTorchModel(args.model)]

    segm_eval = SemSegmEvaluation(args.log)
    segm_eval.process(fw, df)