Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 22 additions & 47 deletions demo_global.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand All @@ -19,7 +14,7 @@
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim

import _init_paths
import torchvision.transforms as transforms
import torchvision.datasets as dset
from scipy.misc import imread
Expand All @@ -31,8 +26,8 @@
from model.rpn.bbox_transform import bbox_transform_inv
from model.utils.net_utils import save_net, load_net, vis_detections
from model.utils.blob import im_list_to_blob
# from model.faster_rcnn.vgg16_global import vgg16
# from model.faster_rcnn.resnet_global import resnet
from model.faster_rcnn.vgg16_MEAA import vgg16
from model.faster_rcnn.resnet_MEAA import resnet
import pdb

try:
Expand Down Expand Up @@ -178,29 +173,19 @@ def _get_image_blob(im):
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor'])
pascal_classes = np.asarray(['__background__', # always index 0
'aeroplane', 'bicycle', 'bus', 'car',
'horse', 'knife', 'motorcycle', 'person',
'plant', 'skateboard', 'train', 'truck'])
foggy_classes = ('__background__',
'bus', 'bicycle', 'car', 'motorcycle', 'person', 'rider', 'train', 'truck')
# pascal_classes = np.asarray(['__background__', # always index 0
# 'aeroplane', 'bicycle', 'bus', 'car',
# 'horse', 'knife', 'motorcycle', 'person',
# 'plant', 'skateboard', 'train', 'truck'])
# initilize the network here.
#pascal_classes = np.asarray(['__background__', 'car'])
#pascal_classes = ('__background__', # always index 0
# 'bus', 'bicycle', 'car', 'motorcycle', 'person', 'rider', 'train', 'truck')
# initilize the network here.
from model.faster_rcnn.vgg16_SCL import vgg16

if args.net == 'vgg16':
fasterRCNN = vgg16(foggy_classes, pretrained=True)

# fasterRCNN = vgg16(pascal_classes, pretrained=False, class_agnostic=args.class_agnostic,lc=args.lc,gc=args.gc)
fasterRCNN = vgg16(pascal_classes, pretrained=False, class_agnostic=args.class_agnostic,lc=args.lc,gc=args.gc)
elif args.net == 'res101':
fasterRCNN = resnet(pascal_classes, 101, pretrained=False, class_agnostic=args.class_agnostic,context=args.context)
elif args.net == 'res50':
fasterRCNN = resnet(pascal_classes, 50, pretrained=False, class_agnostic=args.class_agnostic)
elif args.net == 'res152':
fasterRCNN = resnet(pascal_classes, 152, pretrained=False, class_agnostic=args.class_agnostic)
fasterRCNN = resnet(pascal_classes, 101, pretrained=False, class_agnostic=args.class_agnostic,gc1 = False, gc2=False, gc3 = False)
else:
print("network is not defined")
pdb.set_trace()
Expand Down Expand Up @@ -236,12 +221,11 @@ def _get_image_blob(im):
num_boxes = num_boxes.cuda()
gt_boxes = gt_boxes.cuda()

with torch.no_grad():
# make variable
im_data = Variable(im_data)#, volatile=True)
im_info = Variable(im_info)#, volatile=True)
num_boxes = Variable(num_boxes)#, volatile=True)
gt_boxes = Variable(gt_boxes)#, volatile=True)
# make variable
im_data = Variable(im_data, volatile=True)
im_info = Variable(im_info, volatile=True)
num_boxes = Variable(num_boxes, volatile=True)
gt_boxes = Variable(gt_boxes, volatile=True)

if args.cuda > 0:
cfg.CUDA = True
Expand Down Expand Up @@ -281,7 +265,6 @@ def _get_image_blob(im):
raise RuntimeError("Webcam could not open. Please check connection.")
ret, frame = cap.read()
im_in = np.array(frame)

# Load the demo image
else:
im_file = os.path.join(args.image_dir, imglist[num_images])
Expand Down Expand Up @@ -310,15 +293,13 @@ def _get_image_blob(im):
# pdb.set_trace()
det_tic = time.time()

# rois, cls_prob, bbox_pred, \
# rpn_loss_cls, rpn_loss_box, \
# RCNN_loss_cls, RCNN_loss_bbox, \
# rois_label, _, = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)

rois, cls_prob, bbox_pred, \
rpn_loss_cls, rpn_loss_box, \
RCNN_loss_cls, RCNN_loss_bbox, \
rois_label, d_pred, domain_p1, domain_p2, domain_p3,\
out_d11, out_d12, out_d13 = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)
out_d11, out_d12, out_d13 = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)


scores = cls_prob.data
boxes = rois.data[:, :, 1:5]
Expand All @@ -344,7 +325,7 @@ def _get_image_blob(im):
else:
box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS) \
+ torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS)
box_deltas = box_deltas.view(1, -1, 4 * len(foggy_classes))
box_deltas = box_deltas.view(1, -1, 4 * len(pascal_classes))

pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
Expand All @@ -361,7 +342,7 @@ def _get_image_blob(im):
misc_tic = time.time()
if vis:
im2show = np.copy(im)
for j in xrange(1, len(foggy_classes)):
for j in xrange(1, len(pascal_classes)):
inds = torch.nonzero(scores[:,j]>thresh).view(-1)
# if there is det
if inds.numel() > 0:
Expand All @@ -378,28 +359,22 @@ def _get_image_blob(im):
keep = nms(cls_dets, cfg.TEST.NMS, force_cpu=not cfg.USE_GPU_NMS)
cls_dets = cls_dets[keep.view(-1).long()]
if vis:
im2show = vis_detections(im2show, foggy_classes[j], cls_dets.cpu().numpy(), 0.8)
im2show = vis_detections(im2show, pascal_classes[j], cls_dets.cpu().numpy(), 0.8)

misc_toc = time.time()
nms_time = misc_toc - misc_tic

if webcam_num == -1:
print('\nHello from web ==-1!')

sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s \r' \
.format(num_images + 1, len(imglist), detect_time, nms_time))
sys.stdout.flush()

if vis and webcam_num == -1:
print('\nHello from vis + web ==-1!')

# cv2.imshow('test', im2show)
# cv2.waitKey(0)
result_path = os.path.join('/home/basic/mm20-may10/images_det',imglist[num_images][:-4] + "_det2.jpg")
print('\result_path: ', result_path)
result_path = os.path.join('images_det',imglist[num_images][:-4] + "_det2.jpg")
cv2.imwrite(result_path, im2show)
else:
print('\nHello!')
im2showRGB = cv2.cvtColor(im2show, cv2.COLOR_BGR2RGB)
cv2.imshow("frame", im2showRGB)
total_toc = time.time()
Expand All @@ -410,4 +385,4 @@ def _get_image_blob(im):
break
if webcam_num >= 0:
cap.release()
cv2.destroyAllWindows()
cv2.destroyAllWindows()