|
| 1 | +# Copyright (c) 2021-2022, InterDigital Communications, Inc |
| 2 | +# All rights reserved. |
| 3 | + |
| 4 | +# Redistribution and use in source and binary forms, with or without |
| 5 | +# modification, are permitted (subject to the limitations in the disclaimer |
| 6 | +# below) provided that the following conditions are met: |
| 7 | + |
| 8 | +# * Redistributions of source code must retain the above copyright notice, |
| 9 | +# this list of conditions and the following disclaimer. |
| 10 | +# * Redistributions in binary form must reproduce the above copyright notice, |
| 11 | +# this list of conditions and the following disclaimer in the documentation |
| 12 | +# and/or other materials provided with the distribution. |
| 13 | +# * Neither the name of InterDigital Communications, Inc nor the names of its |
| 14 | +# contributors may be used to endorse or promote products derived from this |
| 15 | +# software without specific prior written permission. |
| 16 | + |
| 17 | +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY |
| 18 | +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND |
| 19 | +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT |
| 20 | +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A |
| 21 | +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR |
| 22 | +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, |
| 23 | +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, |
| 24 | +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; |
| 25 | +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, |
| 26 | +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR |
| 27 | +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF |
| 28 | +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 29 | + |
| 30 | +import argparse |
| 31 | +import random |
| 32 | +import shutil |
| 33 | +import sys |
| 34 | + |
| 35 | +import torch |
| 36 | +import torch.nn as nn |
| 37 | +import torch.optim as optim |
| 38 | + |
| 39 | +from torch.utils.data import DataLoader |
| 40 | +from torchvision.transforms import Compose |
| 41 | + |
| 42 | +import compressai.transforms as transforms |
| 43 | + |
| 44 | +from compressai.datasets import ModelNetDataset |
| 45 | +from compressai.losses import ChamferPccRateDistortionLoss |
| 46 | +from compressai.optimizers import net_aux_optimizer |
| 47 | +from compressai.registry import MODELS |
| 48 | +from compressai.zoo import pointcloud_models |
| 49 | + |
| 50 | + |
| 51 | +class AverageMeter: |
| 52 | + """Compute running average.""" |
| 53 | + |
| 54 | + def __init__(self): |
| 55 | + self.val = 0 |
| 56 | + self.avg = 0 |
| 57 | + self.sum = 0 |
| 58 | + self.count = 0 |
| 59 | + |
| 60 | + def update(self, val, n=1): |
| 61 | + self.val = val |
| 62 | + self.sum += val * n |
| 63 | + self.count += n |
| 64 | + self.avg = self.sum / self.count |
| 65 | + |
| 66 | + |
| 67 | +class CustomDataParallel(nn.DataParallel): |
| 68 | + """Custom DataParallel to access the module methods.""" |
| 69 | + |
| 70 | + def __getattr__(self, key): |
| 71 | + try: |
| 72 | + return super().__getattr__(key) |
| 73 | + except AttributeError: |
| 74 | + return getattr(self.module, key) |
| 75 | + |
| 76 | + |
| 77 | +def configure_optimizers(net, args): |
| 78 | + """Separate parameters for the main optimizer and the auxiliary optimizer. |
| 79 | + Return two optimizers""" |
| 80 | + conf = { |
| 81 | + "net": {"type": "Adam", "lr": args.learning_rate}, |
| 82 | + "aux": {"type": "Adam", "lr": args.aux_learning_rate}, |
| 83 | + } |
| 84 | + optimizer = net_aux_optimizer(net, conf) |
| 85 | + return optimizer["net"], optimizer["aux"] |
| 86 | + |
| 87 | + |
| 88 | +def train_one_epoch( |
| 89 | + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm |
| 90 | +): |
| 91 | + model.train() |
| 92 | + device = next(model.parameters()).device |
| 93 | + |
| 94 | + for i, d in enumerate(train_dataloader): |
| 95 | + d = {k: v.to(device) for k, v in d.items()} |
| 96 | + |
| 97 | + optimizer.zero_grad() |
| 98 | + aux_optimizer.zero_grad() |
| 99 | + |
| 100 | + out_net = model(d) |
| 101 | + |
| 102 | + out_criterion = criterion(out_net, d) |
| 103 | + out_criterion["loss"].backward() |
| 104 | + if clip_max_norm > 0: |
| 105 | + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) |
| 106 | + optimizer.step() |
| 107 | + |
| 108 | + aux_loss = model.aux_loss() |
| 109 | + aux_loss.backward() |
| 110 | + aux_optimizer.step() |
| 111 | + |
| 112 | + if i % 10 == 0: |
| 113 | + print( |
| 114 | + f"Train epoch {epoch}: [" |
| 115 | + f"{i*len(d)}/{len(train_dataloader.dataset)} " |
| 116 | + f"({100. * i / len(train_dataloader):.0f}%)] " |
| 117 | + f'Loss: {out_criterion["loss"].item():.3f} | ' |
| 118 | + f'Bpp loss: {out_criterion["bpp_loss"].item():.3f} | ' |
| 119 | + f'Rec loss: {out_criterion["rec_loss"].item():.4f} | ' |
| 120 | + # f'Aux loss: {aux_loss.item():.0f} | ' |
| 121 | + "\n" |
| 122 | + ) |
| 123 | + |
| 124 | + |
| 125 | +def test_epoch(epoch, test_dataloader, model, criterion): |
| 126 | + model.eval() |
| 127 | + model.update(force=True, update_quantiles=True) |
| 128 | + device = next(model.parameters()).device |
| 129 | + |
| 130 | + meter_keys = ["loss", "bpp_loss", "rec_loss", "aux_loss"] |
| 131 | + meters = {key: AverageMeter() for key in meter_keys} |
| 132 | + |
| 133 | + with torch.no_grad(): |
| 134 | + for d in test_dataloader: |
| 135 | + d = {k: v.to(device) for k, v in d.items()} |
| 136 | + |
| 137 | + out_net = model(d) |
| 138 | + out_criterion = criterion(out_net, d) |
| 139 | + out_criterion["aux_loss"] = model.aux_loss() |
| 140 | + |
| 141 | + for key in meters: |
| 142 | + if key in out_criterion: |
| 143 | + meters[key].update(out_criterion[key]) |
| 144 | + |
| 145 | + print( |
| 146 | + f"Test epoch {epoch}: Average losses: " |
| 147 | + f'Loss: {meters["loss"].avg:.3f} | ' |
| 148 | + f'Bpp loss: {meters["bpp_loss"].avg:.3f} | ' |
| 149 | + f'Rec loss: {meters["rec_loss"].avg:.4f} | ' |
| 150 | + # f'Aux loss: {meters["aux_loss"].avg:.0f} | ' |
| 151 | + "\n" |
| 152 | + ) |
| 153 | + |
| 154 | + return meters["loss"].avg |
| 155 | + |
| 156 | + |
| 157 | +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): |
| 158 | + torch.save(state, filename) |
| 159 | + if is_best: |
| 160 | + shutil.copyfile(filename, "checkpoint_best_loss.pth.tar") |
| 161 | + |
| 162 | + |
| 163 | +def parse_args(argv): |
| 164 | + parser = argparse.ArgumentParser(description="Example training script.") |
| 165 | + parser.add_argument( |
| 166 | + "-m", |
| 167 | + "--model", |
| 168 | + default="sfu2023-pcc-rec-pointnet", |
| 169 | + choices=pointcloud_models.keys(), |
| 170 | + help="Model architecture (default: %(default)s)", |
| 171 | + ) |
| 172 | + parser.add_argument( |
| 173 | + "-d", "--dataset", type=str, required=True, help="Training dataset" |
| 174 | + ) |
| 175 | + parser.add_argument( |
| 176 | + "-e", |
| 177 | + "--epochs", |
| 178 | + default=100, |
| 179 | + type=int, |
| 180 | + help="Number of epochs (default: %(default)s)", |
| 181 | + ) |
| 182 | + parser.add_argument( |
| 183 | + "-lr", |
| 184 | + "--learning-rate", |
| 185 | + default=1e-4, |
| 186 | + type=float, |
| 187 | + help="Learning rate (default: %(default)s)", |
| 188 | + ) |
| 189 | + parser.add_argument( |
| 190 | + "-n", |
| 191 | + "--num-workers", |
| 192 | + type=int, |
| 193 | + default=4, |
| 194 | + help="Dataloaders threads (default: %(default)s)", |
| 195 | + ) |
| 196 | + parser.add_argument( |
| 197 | + "--lambda", |
| 198 | + dest="lmbda", |
| 199 | + type=float, |
| 200 | + default=100, |
| 201 | + help="Bit-rate distortion parameter (default: %(default)s)", |
| 202 | + ) |
| 203 | + parser.add_argument( |
| 204 | + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)" |
| 205 | + ) |
| 206 | + parser.add_argument( |
| 207 | + "--test-batch-size", |
| 208 | + type=int, |
| 209 | + default=64, |
| 210 | + help="Test batch size (default: %(default)s)", |
| 211 | + ) |
| 212 | + parser.add_argument( |
| 213 | + "--aux-learning-rate", |
| 214 | + type=float, |
| 215 | + default=1e-3, |
| 216 | + help="Auxiliary loss learning rate (default: %(default)s)", |
| 217 | + ) |
| 218 | + parser.add_argument( |
| 219 | + "--patch-size", |
| 220 | + type=int, |
| 221 | + nargs=2, |
| 222 | + default=(256, 256), |
| 223 | + help="Size of the patches to be cropped (default: %(default)s)", |
| 224 | + ) |
| 225 | + parser.add_argument("--cuda", action="store_true", help="Use cuda") |
| 226 | + parser.add_argument( |
| 227 | + "--save", action="store_true", default=True, help="Save model to disk" |
| 228 | + ) |
| 229 | + parser.add_argument("--seed", type=int, help="Set random seed for reproducibility") |
| 230 | + parser.add_argument( |
| 231 | + "--clip_max_norm", |
| 232 | + default=1.0, |
| 233 | + type=float, |
| 234 | + help="gradient clipping max norm (default: %(default)s", |
| 235 | + ) |
| 236 | + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint") |
| 237 | + args = parser.parse_args(argv) |
| 238 | + return args |
| 239 | + |
| 240 | + |
| 241 | +def main(argv): |
| 242 | + args = parse_args(argv) |
| 243 | + |
| 244 | + if args.seed is not None: |
| 245 | + torch.manual_seed(args.seed) |
| 246 | + random.seed(args.seed) |
| 247 | + |
| 248 | + num_points = 1024 |
| 249 | + |
| 250 | + train_dataset = ModelNetDataset( |
| 251 | + args.dataset, |
| 252 | + split="train", |
| 253 | + pre_transform=Compose( |
| 254 | + [ |
| 255 | + transforms.ToDict(wrapper="torch_geometric.data.Data"), |
| 256 | + transforms.SamplePointsV2( |
| 257 | + num=8192, remove_faces=True, include_normals=True, static_seed=1234 |
| 258 | + ), |
| 259 | + transforms.NormalizeScaleV2(center=True, scale_method="l2"), |
| 260 | + transforms.ToDict(wrapper="dict"), |
| 261 | + ] |
| 262 | + ), |
| 263 | + transform=Compose( |
| 264 | + [ |
| 265 | + transforms.ToDict(wrapper="torch_geometric.data.Data"), |
| 266 | + transforms.RandomSample(num=num_points, attrs=["pos", "normal"]), |
| 267 | + transforms.ToDict(wrapper="dict"), |
| 268 | + ] |
| 269 | + ), |
| 270 | + ) |
| 271 | + |
| 272 | + test_dataset = ModelNetDataset( |
| 273 | + args.dataset, |
| 274 | + split="test", |
| 275 | + pre_transform=Compose( |
| 276 | + [ |
| 277 | + transforms.ToDict(wrapper="torch_geometric.data.Data"), |
| 278 | + transforms.SamplePointsV2( |
| 279 | + num=8192, remove_faces=True, include_normals=True, static_seed=1234 |
| 280 | + ), |
| 281 | + transforms.NormalizeScaleV2(center=True, scale_method="l2"), |
| 282 | + transforms.ToDict(wrapper="dict"), |
| 283 | + ] |
| 284 | + ), |
| 285 | + transform=Compose( |
| 286 | + [ |
| 287 | + transforms.ToDict(wrapper="torch_geometric.data.Data"), |
| 288 | + transforms.RandomSample( |
| 289 | + num=num_points, attrs=["pos", "normal"], static_seed=1234 |
| 290 | + ), |
| 291 | + transforms.ToDict(wrapper="dict"), |
| 292 | + ] |
| 293 | + ), |
| 294 | + ) |
| 295 | + |
| 296 | + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" |
| 297 | + |
| 298 | + train_dataloader = DataLoader( |
| 299 | + train_dataset, |
| 300 | + batch_size=args.batch_size, |
| 301 | + num_workers=args.num_workers, |
| 302 | + shuffle=True, |
| 303 | + pin_memory=(device == "cuda"), |
| 304 | + ) |
| 305 | + |
| 306 | + test_dataloader = DataLoader( |
| 307 | + test_dataset, |
| 308 | + batch_size=args.test_batch_size, |
| 309 | + num_workers=args.num_workers, |
| 310 | + shuffle=False, |
| 311 | + pin_memory=(device == "cuda"), |
| 312 | + ) |
| 313 | + |
| 314 | + net = MODELS[args.model]() |
| 315 | + net = net.to(device) |
| 316 | + |
| 317 | + if args.cuda and torch.cuda.device_count() > 1: |
| 318 | + net = CustomDataParallel(net) |
| 319 | + |
| 320 | + optimizer, aux_optimizer = configure_optimizers(net, args) |
| 321 | + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") |
| 322 | + criterion = ChamferPccRateDistortionLoss(lmbda={"bpp": 1.0, "rec": args.lmbda}) |
| 323 | + |
| 324 | + last_epoch = 0 |
| 325 | + if args.checkpoint: # load from previous checkpoint |
| 326 | + print("Loading", args.checkpoint) |
| 327 | + checkpoint = torch.load(args.checkpoint, map_location=device) |
| 328 | + last_epoch = checkpoint["epoch"] + 1 |
| 329 | + net.load_state_dict(checkpoint["state_dict"]) |
| 330 | + optimizer.load_state_dict(checkpoint["optimizer"]) |
| 331 | + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) |
| 332 | + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
| 333 | + |
| 334 | + best_loss = float("inf") |
| 335 | + for epoch in range(last_epoch, args.epochs): |
| 336 | + print(f"Learning rate: {optimizer.param_groups[0]['lr']}") |
| 337 | + train_one_epoch( |
| 338 | + net, |
| 339 | + criterion, |
| 340 | + train_dataloader, |
| 341 | + optimizer, |
| 342 | + aux_optimizer, |
| 343 | + epoch, |
| 344 | + args.clip_max_norm, |
| 345 | + ) |
| 346 | + loss = test_epoch(epoch, test_dataloader, net, criterion) |
| 347 | + lr_scheduler.step(loss) |
| 348 | + |
| 349 | + is_best = loss < best_loss |
| 350 | + best_loss = min(loss, best_loss) |
| 351 | + |
| 352 | + if args.save: |
| 353 | + save_checkpoint( |
| 354 | + { |
| 355 | + "epoch": epoch, |
| 356 | + "state_dict": net.state_dict(), |
| 357 | + "loss": loss, |
| 358 | + "optimizer": optimizer.state_dict(), |
| 359 | + "aux_optimizer": aux_optimizer.state_dict(), |
| 360 | + "lr_scheduler": lr_scheduler.state_dict(), |
| 361 | + }, |
| 362 | + is_best, |
| 363 | + ) |
| 364 | + |
| 365 | + |
| 366 | +if __name__ == "__main__": |
| 367 | + main(sys.argv[1:]) |
| 368 | + |
| 369 | + |
| 370 | +# NOTE: A more complete trainer with experiment tracking, visualizations, etc |
| 371 | +# that uses CompressAI Trainer can be found at: |
| 372 | +# |
| 373 | +# https://github.com/multimedialabsfu/learned-point-cloud-compression-for-classification |
0 commit comments