Skip to content
Advertisement

How to generate accurate masks for an image from Mask R-CNN prediction in PyTorch?

I have trained a Mask RCNN network for instance segmentation of apples. I am able to load the weights and generate predictions for my test images. The masks being generated seem to be in the correct location, but the mask itself has no real form.. it just looks like a bunch of pixels

Training is done based on the dataset from this paper, and here is the github link to code being used to train and generate weights

code for prediction is as follows. (i have omitted the parts where i create path variables and assign the paths)

import os
import glob
import numpy as np
import pandas as pd
import cv2 as cv
import fileinput

import torch
import torch.utils.data
import torchvision

from data.apple_dataset import AppleDataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

import utility.utils as utils
import utility.transforms as T

from PIL import Image
from matplotlib import pyplot as plt
%matplotlib inline


def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

def get_maskrcnn_model_instance(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

num_classes = 2
device = torch.device('cpu')

model = get_maskrcnn_model_instance(num_classes)
checkpoint = torch.load('model_49.pth', map_location=device)
model.load_state_dict(checkpoint['model'], strict=False)

dataset_test = AppleDataset(test_image_files_path, get_transform(train=False))
img, _ = dataset_test[1]
model.eval()

with torch.no_grad():
    prediction = model([img.to(device)])

prediction

Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())

(unable to load image here since its over 2MB.  

Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())

Here is an Imgur link to the original image.. below is the predicted mask for one of the instances

Mask output for one instance

Also, could you please help me understand the data structure of the generated prediction matrix shown below.. How do i access the masks so as to generate a single image with all masks displayed???

[{'boxes': tensor([[ 966.8143, 1633.7491, 1106.7389, 1787.6367],
          [1418.7872, 1467.0619, 1732.0828, 1796.1527],
          [1608.0396, 2064.6482, 1710.7534, 2206.5535],
          [2326.3750, 1690.3418, 2542.2112, 1883.2626],
          [2213.2024, 1864.3657, 2299.8933, 1963.0178],
          [1112.9083, 1732.5953, 1236.7600, 1823.0170],
          [1150.8256,  614.0334, 1218.8584,  711.4094],
          [ 942.7086,  794.6043, 1138.2318, 1008.0430],
          [1065.4371,  723.0493, 1192.7570,  870.3763],
          [1002.3103,  883.4616, 1146.9994, 1006.6841],
          [1315.2816, 1680.8625, 1531.3210, 1989.3317],
          [1244.5769, 1925.0903, 1459.5417, 2175.3252],
          [1725.2191, 2082.6187, 1934.0227, 2274.2952],
          [ 936.3065, 1554.3765, 1014.2722, 1659.4229],
          [ 934.8851, 1541.3331, 1090.4736, 1657.3751],
          [2486.0120,  776.4577, 2547.2329,  847.9725],
          [2336.1675,  698.6327, 2508.6492,  921.4550],
          [2368.4077, 1954.1102, 2448.4004, 2049.5796],
          [1899.1403, 1775.2371, 2035.7561, 1962.6923],
          [2176.0664, 1075.1553, 2398.6084, 1267.2555],
          [2274.8899,  641.6769, 2395.9634,  791.3353],
          [2535.1580,  874.4780, 2642.8213,  966.4614],
          [2183.4236,  619.9688, 2288.5676,  758.6825],
          [2183.9832, 1122.9382, 2334.9583, 1263.3226],
          [1135.7822,  779.0529, 1225.9871,  890.0135],
          [ 317.3954, 1328.6995,  397.3900, 1467.7740],
          [ 945.4811, 1833.3708,  997.2318, 1878.8607],
          [1992.4447,  679.4969, 2134.6667,  835.8701],
          [1098.5416, 1452.7799, 1429.1808, 1771.4460],
          [1657.3193, 1405.5405, 1781.6273, 1574.6780],
          [1443.8911, 1747.1544, 1739.0361, 2076.9724],
          [1092.6003, 1165.3340, 1206.0881, 1383.8314],
          [2466.4170, 1945.5931, 2555.1931, 2039.8368],
          [2561.8508, 1616.2659, 2672.1033, 1742.2332],
          [1894.4806,  907.9214, 2097.1875, 1182.6473],
          [2321.5005, 1701.3344, 2368.3699, 1865.3914],
          [2180.0781,  567.5969, 2344.6357,  763.4360],
          [1845.7612,  668.6808, 2045.2688,  899.8501],
          [1858.9216, 2145.7097, 1961.8870, 2273.5088],
          [ 261.4607, 1314.0154,  396.9288, 1486.9498],
          [2488.1682, 1585.2357, 2669.0178, 1794.9926],
          [2696.9548,  936.0087, 2802.7961, 1025.2294],
          [1593.6837, 1489.8641, 1720.3124, 1627.8135],
          [2517.9468,  857.1713, 2567.1125,  929.4335],
          [1943.2167,  636.3422, 2151.4419,  853.8924],
          [2143.5664, 1100.0521, 2308.1570, 1290.7125],
          [2140.9231, 1947.9692, 2238.6956, 2000.6249],
          [1461.6316, 2105.2593, 1559.7675, 2189.0264],
          [2114.0781,  374.8153, 2222.8838,  559.9851],
          [2350.5320,  726.5779, 2466.8140,  878.2617]]),
  'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1]),
  'scores': tensor([0.9916, 0.9841, 0.9669, 0.9337, 0.9118, 0.7729, 0.7202, 0.7193, 0.6928,
          0.6872, 0.6690, 0.5913, 0.4877, 0.4683, 0.3781, 0.3327, 0.3164, 0.2364,
          0.1696, 0.1692, 0.1502, 0.1365, 0.1316, 0.1171, 0.1119, 0.1094, 0.1041,
          0.0865, 0.0853, 0.0835, 0.0822, 0.0816, 0.0797, 0.0796, 0.0788, 0.0780,
          0.0757, 0.0736, 0.0736, 0.0689, 0.0681, 0.0644, 0.0642, 0.0630, 0.0612,
          0.0598, 0.0563, 0.0531, 0.0525, 0.0522]),
  'masks': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          ...,


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]]])}]

Advertisement

Answer

The prediction from the Mask R-CNN has the following structure:

During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows:

boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between 0 and H and 0 and W  
labels (Int64Tensor[N]): the predicted labels for each image  
scores (Tensor[N]): the scores or each prediction  
masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range.

You can use OpenCV’s findContours and drawContours functions to draw masks as follows:

img_cv = cv2.imread('input.jpg', cv2.COLOR_BGR2RGB)

for i in range(len(prediction[0]['masks'])):
    # iterate over masks
    mask = prediction[0]['masks'][i, 0]
    mask = mask.mul(255).byte().cpu().numpy()
    contours, _ = cv2.findContours(
            mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
    cv2.drawContours(img_cv, contours, -1, (255, 0, 0), 2, cv2.LINE_AA)

cv2.imshow('img output', img_cv)

Sample output:

sample output

User contributions licensed under: CC BY-SA
7 People found this is helpful
Advertisement