From 9938634e340c2895238a5f2602dfb78396e0bae4 Mon Sep 17 00:00:00 2001 From: zyfone <478756030@qq.com> Date: Tue, 7 Dec 2021 19:28:16 +0800 Subject: [PATCH] add new demo_global.py --- demo_global.py | 69 ++++++++++++++++---------------------------------- 1 file changed, 22 insertions(+), 47 deletions(-) diff --git a/demo_global.py b/demo_global.py index 25eb052..7206226 100644 --- a/demo_global.py +++ b/demo_global.py @@ -1,8 +1,3 @@ -# -------------------------------------------------------- -# Tensorflow Faster R-CNN -# Licensed under The MIT License [see LICENSE for details] -# Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick -# -------------------------------------------------------- from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -19,7 +14,7 @@ from torch.autograd import Variable import torch.nn as nn import torch.optim as optim - +import _init_paths import torchvision.transforms as transforms import torchvision.datasets as dset from scipy.misc import imread @@ -31,8 +26,8 @@ from model.rpn.bbox_transform import bbox_transform_inv from model.utils.net_utils import save_net, load_net, vis_detections from model.utils.blob import im_list_to_blob -# from model.faster_rcnn.vgg16_global import vgg16 -# from model.faster_rcnn.resnet_global import resnet +from model.faster_rcnn.vgg16_MEAA import vgg16 +from model.faster_rcnn.resnet_MEAA import resnet import pdb try: @@ -178,29 +173,19 @@ def _get_image_blob(im): 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']) - pascal_classes = np.asarray(['__background__', # always index 0 - 'aeroplane', 'bicycle', 'bus', 'car', - 'horse', 'knife', 'motorcycle', 'person', - 'plant', 'skateboard', 'train', 'truck']) - foggy_classes = ('__background__', - 'bus', 'bicycle', 'car', 'motorcycle', 'person', 'rider', 'train', 'truck') + # pascal_classes = np.asarray(['__background__', # always index 0 + # 'aeroplane', 'bicycle', 'bus', 'car', + # 'horse', 'knife', 'motorcycle', 'person', + # 'plant', 'skateboard', 'train', 'truck']) # initilize the network here. #pascal_classes = np.asarray(['__background__', 'car']) #pascal_classes = ('__background__', # always index 0 # 'bus', 'bicycle', 'car', 'motorcycle', 'person', 'rider', 'train', 'truck') # initilize the network here. - from model.faster_rcnn.vgg16_SCL import vgg16 - if args.net == 'vgg16': - fasterRCNN = vgg16(foggy_classes, pretrained=True) - - # fasterRCNN = vgg16(pascal_classes, pretrained=False, class_agnostic=args.class_agnostic,lc=args.lc,gc=args.gc) + fasterRCNN = vgg16(pascal_classes, pretrained=False, class_agnostic=args.class_agnostic,lc=args.lc,gc=args.gc) elif args.net == 'res101': - fasterRCNN = resnet(pascal_classes, 101, pretrained=False, class_agnostic=args.class_agnostic,context=args.context) - elif args.net == 'res50': - fasterRCNN = resnet(pascal_classes, 50, pretrained=False, class_agnostic=args.class_agnostic) - elif args.net == 'res152': - fasterRCNN = resnet(pascal_classes, 152, pretrained=False, class_agnostic=args.class_agnostic) + fasterRCNN = resnet(pascal_classes, 101, pretrained=False, class_agnostic=args.class_agnostic,gc1 = False, gc2=False, gc3 = False) else: print("network is not defined") pdb.set_trace() @@ -236,12 +221,11 @@ def _get_image_blob(im): num_boxes = num_boxes.cuda() gt_boxes = gt_boxes.cuda() - with torch.no_grad(): - # make variable - im_data = Variable(im_data)#, volatile=True) - im_info = Variable(im_info)#, volatile=True) - num_boxes = Variable(num_boxes)#, volatile=True) - gt_boxes = Variable(gt_boxes)#, volatile=True) + # make variable + im_data = Variable(im_data, volatile=True) + im_info = Variable(im_info, volatile=True) + num_boxes = Variable(num_boxes, volatile=True) + gt_boxes = Variable(gt_boxes, volatile=True) if args.cuda > 0: cfg.CUDA = True @@ -281,7 +265,6 @@ def _get_image_blob(im): raise RuntimeError("Webcam could not open. Please check connection.") ret, frame = cap.read() im_in = np.array(frame) - # Load the demo image else: im_file = os.path.join(args.image_dir, imglist[num_images]) @@ -310,15 +293,13 @@ def _get_image_blob(im): # pdb.set_trace() det_tic = time.time() - # rois, cls_prob, bbox_pred, \ - # rpn_loss_cls, rpn_loss_box, \ - # RCNN_loss_cls, RCNN_loss_bbox, \ - # rois_label, _, = fasterRCNN(im_data, im_info, gt_boxes, num_boxes) + rois, cls_prob, bbox_pred, \ rpn_loss_cls, rpn_loss_box, \ RCNN_loss_cls, RCNN_loss_bbox, \ rois_label, d_pred, domain_p1, domain_p2, domain_p3,\ - out_d11, out_d12, out_d13 = fasterRCNN(im_data, im_info, gt_boxes, num_boxes) + out_d11, out_d12, out_d13 = fasterRCNN(im_data, im_info, gt_boxes, num_boxes) + scores = cls_prob.data boxes = rois.data[:, :, 1:5] @@ -344,7 +325,7 @@ def _get_image_blob(im): else: box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS) \ + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS) - box_deltas = box_deltas.view(1, -1, 4 * len(foggy_classes)) + box_deltas = box_deltas.view(1, -1, 4 * len(pascal_classes)) pred_boxes = bbox_transform_inv(boxes, box_deltas, 1) pred_boxes = clip_boxes(pred_boxes, im_info.data, 1) @@ -361,7 +342,7 @@ def _get_image_blob(im): misc_tic = time.time() if vis: im2show = np.copy(im) - for j in xrange(1, len(foggy_classes)): + for j in xrange(1, len(pascal_classes)): inds = torch.nonzero(scores[:,j]>thresh).view(-1) # if there is det if inds.numel() > 0: @@ -378,28 +359,22 @@ def _get_image_blob(im): keep = nms(cls_dets, cfg.TEST.NMS, force_cpu=not cfg.USE_GPU_NMS) cls_dets = cls_dets[keep.view(-1).long()] if vis: - im2show = vis_detections(im2show, foggy_classes[j], cls_dets.cpu().numpy(), 0.8) + im2show = vis_detections(im2show, pascal_classes[j], cls_dets.cpu().numpy(), 0.8) misc_toc = time.time() nms_time = misc_toc - misc_tic if webcam_num == -1: - print('\nHello from web ==-1!') - sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s \r' \ .format(num_images + 1, len(imglist), detect_time, nms_time)) sys.stdout.flush() if vis and webcam_num == -1: - print('\nHello from vis + web ==-1!') - # cv2.imshow('test', im2show) # cv2.waitKey(0) - result_path = os.path.join('/home/basic/mm20-may10/images_det',imglist[num_images][:-4] + "_det2.jpg") - print('\result_path: ', result_path) + result_path = os.path.join('images_det',imglist[num_images][:-4] + "_det2.jpg") cv2.imwrite(result_path, im2show) else: - print('\nHello!') im2showRGB = cv2.cvtColor(im2show, cv2.COLOR_BGR2RGB) cv2.imshow("frame", im2showRGB) total_toc = time.time() @@ -410,4 +385,4 @@ def _get_image_blob(im): break if webcam_num >= 0: cap.release() - cv2.destroyAllWindows() + cv2.destroyAllWindows() \ No newline at end of file