Skip to content

Commit bc39e93

Browse files
committed
feat: examples/train_pointcloud.py
```bash python examples/train_pointcloud.py --cuda --dataset="datasets/modelnet40" ```
1 parent da4b717 commit bc39e93

File tree

1 file changed

+373
-0
lines changed

1 file changed

+373
-0
lines changed

examples/train_pointcloud.py

Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
# Copyright (c) 2021-2022, InterDigital Communications, Inc
2+
# All rights reserved.
3+
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted (subject to the limitations in the disclaimer
6+
# below) provided that the following conditions are met:
7+
8+
# * Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
# * Neither the name of InterDigital Communications, Inc nor the names of its
14+
# contributors may be used to endorse or promote products derived from this
15+
# software without specific prior written permission.
16+
17+
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
18+
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
19+
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
20+
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
21+
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
25+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
26+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
27+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
28+
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
30+
import argparse
31+
import random
32+
import shutil
33+
import sys
34+
35+
import torch
36+
import torch.nn as nn
37+
import torch.optim as optim
38+
39+
from torch.utils.data import DataLoader
40+
from torchvision.transforms import Compose
41+
42+
import compressai.transforms as transforms
43+
44+
from compressai.datasets import ModelNetDataset
45+
from compressai.losses import ChamferPccRateDistortionLoss
46+
from compressai.optimizers import net_aux_optimizer
47+
from compressai.registry import MODELS
48+
from compressai.zoo import pointcloud_models
49+
50+
51+
class AverageMeter:
52+
"""Compute running average."""
53+
54+
def __init__(self):
55+
self.val = 0
56+
self.avg = 0
57+
self.sum = 0
58+
self.count = 0
59+
60+
def update(self, val, n=1):
61+
self.val = val
62+
self.sum += val * n
63+
self.count += n
64+
self.avg = self.sum / self.count
65+
66+
67+
class CustomDataParallel(nn.DataParallel):
68+
"""Custom DataParallel to access the module methods."""
69+
70+
def __getattr__(self, key):
71+
try:
72+
return super().__getattr__(key)
73+
except AttributeError:
74+
return getattr(self.module, key)
75+
76+
77+
def configure_optimizers(net, args):
78+
"""Separate parameters for the main optimizer and the auxiliary optimizer.
79+
Return two optimizers"""
80+
conf = {
81+
"net": {"type": "Adam", "lr": args.learning_rate},
82+
"aux": {"type": "Adam", "lr": args.aux_learning_rate},
83+
}
84+
optimizer = net_aux_optimizer(net, conf)
85+
return optimizer["net"], optimizer["aux"]
86+
87+
88+
def train_one_epoch(
89+
model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm
90+
):
91+
model.train()
92+
device = next(model.parameters()).device
93+
94+
for i, d in enumerate(train_dataloader):
95+
d = {k: v.to(device) for k, v in d.items()}
96+
97+
optimizer.zero_grad()
98+
aux_optimizer.zero_grad()
99+
100+
out_net = model(d)
101+
102+
out_criterion = criterion(out_net, d)
103+
out_criterion["loss"].backward()
104+
if clip_max_norm > 0:
105+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
106+
optimizer.step()
107+
108+
aux_loss = model.aux_loss()
109+
aux_loss.backward()
110+
aux_optimizer.step()
111+
112+
if i % 10 == 0:
113+
print(
114+
f"Train epoch {epoch}: ["
115+
f"{i*len(d)}/{len(train_dataloader.dataset)} "
116+
f"({100. * i / len(train_dataloader):.0f}%)] "
117+
f'Loss: {out_criterion["loss"].item():.3f} | '
118+
f'Bpp loss: {out_criterion["bpp_loss"].item():.3f} | '
119+
f'Rec loss: {out_criterion["rec_loss"].item():.4f} | '
120+
# f'Aux loss: {aux_loss.item():.0f} | '
121+
"\n"
122+
)
123+
124+
125+
def test_epoch(epoch, test_dataloader, model, criterion):
126+
model.eval()
127+
model.update(force=True, update_quantiles=True)
128+
device = next(model.parameters()).device
129+
130+
meter_keys = ["loss", "bpp_loss", "rec_loss", "aux_loss"]
131+
meters = {key: AverageMeter() for key in meter_keys}
132+
133+
with torch.no_grad():
134+
for d in test_dataloader:
135+
d = {k: v.to(device) for k, v in d.items()}
136+
137+
out_net = model(d)
138+
out_criterion = criterion(out_net, d)
139+
out_criterion["aux_loss"] = model.aux_loss()
140+
141+
for key in meters:
142+
if key in out_criterion:
143+
meters[key].update(out_criterion[key])
144+
145+
print(
146+
f"Test epoch {epoch}: Average losses: "
147+
f'Loss: {meters["loss"].avg:.3f} | '
148+
f'Bpp loss: {meters["bpp_loss"].avg:.3f} | '
149+
f'Rec loss: {meters["rec_loss"].avg:.4f} | '
150+
# f'Aux loss: {meters["aux_loss"].avg:.0f} | '
151+
"\n"
152+
)
153+
154+
return meters["loss"].avg
155+
156+
157+
def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
158+
torch.save(state, filename)
159+
if is_best:
160+
shutil.copyfile(filename, "checkpoint_best_loss.pth.tar")
161+
162+
163+
def parse_args(argv):
164+
parser = argparse.ArgumentParser(description="Example training script.")
165+
parser.add_argument(
166+
"-m",
167+
"--model",
168+
default="sfu2023-pcc-rec-pointnet",
169+
choices=pointcloud_models.keys(),
170+
help="Model architecture (default: %(default)s)",
171+
)
172+
parser.add_argument(
173+
"-d", "--dataset", type=str, required=True, help="Training dataset"
174+
)
175+
parser.add_argument(
176+
"-e",
177+
"--epochs",
178+
default=100,
179+
type=int,
180+
help="Number of epochs (default: %(default)s)",
181+
)
182+
parser.add_argument(
183+
"-lr",
184+
"--learning-rate",
185+
default=1e-4,
186+
type=float,
187+
help="Learning rate (default: %(default)s)",
188+
)
189+
parser.add_argument(
190+
"-n",
191+
"--num-workers",
192+
type=int,
193+
default=4,
194+
help="Dataloaders threads (default: %(default)s)",
195+
)
196+
parser.add_argument(
197+
"--lambda",
198+
dest="lmbda",
199+
type=float,
200+
default=100,
201+
help="Bit-rate distortion parameter (default: %(default)s)",
202+
)
203+
parser.add_argument(
204+
"--batch-size", type=int, default=16, help="Batch size (default: %(default)s)"
205+
)
206+
parser.add_argument(
207+
"--test-batch-size",
208+
type=int,
209+
default=64,
210+
help="Test batch size (default: %(default)s)",
211+
)
212+
parser.add_argument(
213+
"--aux-learning-rate",
214+
type=float,
215+
default=1e-3,
216+
help="Auxiliary loss learning rate (default: %(default)s)",
217+
)
218+
parser.add_argument(
219+
"--patch-size",
220+
type=int,
221+
nargs=2,
222+
default=(256, 256),
223+
help="Size of the patches to be cropped (default: %(default)s)",
224+
)
225+
parser.add_argument("--cuda", action="store_true", help="Use cuda")
226+
parser.add_argument(
227+
"--save", action="store_true", default=True, help="Save model to disk"
228+
)
229+
parser.add_argument("--seed", type=int, help="Set random seed for reproducibility")
230+
parser.add_argument(
231+
"--clip_max_norm",
232+
default=1.0,
233+
type=float,
234+
help="gradient clipping max norm (default: %(default)s",
235+
)
236+
parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint")
237+
args = parser.parse_args(argv)
238+
return args
239+
240+
241+
def main(argv):
242+
args = parse_args(argv)
243+
244+
if args.seed is not None:
245+
torch.manual_seed(args.seed)
246+
random.seed(args.seed)
247+
248+
num_points = 1024
249+
250+
train_dataset = ModelNetDataset(
251+
args.dataset,
252+
split="train",
253+
pre_transform=Compose(
254+
[
255+
transforms.ToDict(wrapper="torch_geometric.data.Data"),
256+
transforms.SamplePointsV2(
257+
num=8192, remove_faces=True, include_normals=True, static_seed=1234
258+
),
259+
transforms.NormalizeScaleV2(center=True, scale_method="l2"),
260+
transforms.ToDict(wrapper="dict"),
261+
]
262+
),
263+
transform=Compose(
264+
[
265+
transforms.ToDict(wrapper="torch_geometric.data.Data"),
266+
transforms.RandomSample(num=num_points, attrs=["pos", "normal"]),
267+
transforms.ToDict(wrapper="dict"),
268+
]
269+
),
270+
)
271+
272+
test_dataset = ModelNetDataset(
273+
args.dataset,
274+
split="test",
275+
pre_transform=Compose(
276+
[
277+
transforms.ToDict(wrapper="torch_geometric.data.Data"),
278+
transforms.SamplePointsV2(
279+
num=8192, remove_faces=True, include_normals=True, static_seed=1234
280+
),
281+
transforms.NormalizeScaleV2(center=True, scale_method="l2"),
282+
transforms.ToDict(wrapper="dict"),
283+
]
284+
),
285+
transform=Compose(
286+
[
287+
transforms.ToDict(wrapper="torch_geometric.data.Data"),
288+
transforms.RandomSample(
289+
num=num_points, attrs=["pos", "normal"], static_seed=1234
290+
),
291+
transforms.ToDict(wrapper="dict"),
292+
]
293+
),
294+
)
295+
296+
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
297+
298+
train_dataloader = DataLoader(
299+
train_dataset,
300+
batch_size=args.batch_size,
301+
num_workers=args.num_workers,
302+
shuffle=True,
303+
pin_memory=(device == "cuda"),
304+
)
305+
306+
test_dataloader = DataLoader(
307+
test_dataset,
308+
batch_size=args.test_batch_size,
309+
num_workers=args.num_workers,
310+
shuffle=False,
311+
pin_memory=(device == "cuda"),
312+
)
313+
314+
net = MODELS[args.model]()
315+
net = net.to(device)
316+
317+
if args.cuda and torch.cuda.device_count() > 1:
318+
net = CustomDataParallel(net)
319+
320+
optimizer, aux_optimizer = configure_optimizers(net, args)
321+
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
322+
criterion = ChamferPccRateDistortionLoss(lmbda={"bpp": 1.0, "rec": args.lmbda})
323+
324+
last_epoch = 0
325+
if args.checkpoint: # load from previous checkpoint
326+
print("Loading", args.checkpoint)
327+
checkpoint = torch.load(args.checkpoint, map_location=device)
328+
last_epoch = checkpoint["epoch"] + 1
329+
net.load_state_dict(checkpoint["state_dict"])
330+
optimizer.load_state_dict(checkpoint["optimizer"])
331+
aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
332+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
333+
334+
best_loss = float("inf")
335+
for epoch in range(last_epoch, args.epochs):
336+
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
337+
train_one_epoch(
338+
net,
339+
criterion,
340+
train_dataloader,
341+
optimizer,
342+
aux_optimizer,
343+
epoch,
344+
args.clip_max_norm,
345+
)
346+
loss = test_epoch(epoch, test_dataloader, net, criterion)
347+
lr_scheduler.step(loss)
348+
349+
is_best = loss < best_loss
350+
best_loss = min(loss, best_loss)
351+
352+
if args.save:
353+
save_checkpoint(
354+
{
355+
"epoch": epoch,
356+
"state_dict": net.state_dict(),
357+
"loss": loss,
358+
"optimizer": optimizer.state_dict(),
359+
"aux_optimizer": aux_optimizer.state_dict(),
360+
"lr_scheduler": lr_scheduler.state_dict(),
361+
},
362+
is_best,
363+
)
364+
365+
366+
if __name__ == "__main__":
367+
main(sys.argv[1:])
368+
369+
370+
# NOTE: A more complete trainer with experiment tracking, visualizations, etc
371+
# that uses CompressAI Trainer can be found at:
372+
#
373+
# https://github.com/multimedialabsfu/learned-point-cloud-compression-for-classification

0 commit comments

Comments
 (0)