diff --git a/.venv/Scripts/python.exe b/.venv/Scripts/python.exe new file mode 100644 index 00000000..6b1b8d9e Binary files /dev/null and b/.venv/Scripts/python.exe differ diff --git a/.venv/Scripts/pythonw.exe b/.venv/Scripts/pythonw.exe new file mode 100644 index 00000000..ea5d460a Binary files /dev/null and b/.venv/Scripts/pythonw.exe differ diff --git a/.venv/pyvenv.cfg b/.venv/pyvenv.cfg new file mode 100644 index 00000000..07aca260 --- /dev/null +++ b/.venv/pyvenv.cfg @@ -0,0 +1,3 @@ +home = d:\Anaconda\envs\jupyter +include-system-site-packages = false +version = 3.9.16 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..6c2ff60b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "githubPullRequests.ignoredPullRequestBranches": [ + "master" + ] +} \ No newline at end of file diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 00000000..1170b2ca Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz new file mode 100644 index 00000000..5ace8ea9 Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 00000000..d1c3a970 Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 00000000..a7e14154 Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte b/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte new file mode 100644 index 00000000..bbce2765 Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte.gz b/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte.gz new file mode 100644 index 00000000..b50e4b6b Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte.gz differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte b/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 00000000..d6b4c5db Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte.gz b/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 00000000..707a576b Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte.gz differ diff --git a/ML/Pytorch/Basics/pytorch_simple_fullynet.py b/ML/Pytorch/Basics/pytorch_simple_fullynet.py index 36a399d8..d2da0153 100644 --- a/ML/Pytorch/Basics/pytorch_simple_fullynet.py +++ b/ML/Pytorch/Basics/pytorch_simple_fullynet.py @@ -95,6 +95,8 @@ def forward(self, x): for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(tqdm(train_loader)): # Get data to cuda if possible + print( data.shape) + print( targets.shape) data = data.to(device=device) targets = targets.to(device=device) @@ -102,9 +104,9 @@ def forward(self, x): data = data.reshape(data.shape[0], -1) # Forward - scores = model(data) + scores = model.forward(data) loss = criterion(scores, targets) - + print(f"Loss at epoch {epoch}, batch {batch_idx}: {loss.item()}") # Backward optimizer.zero_grad() loss.backward() @@ -131,7 +133,7 @@ def check_accuracy(loader, model): num_correct = 0 num_samples = 0 - model.eval() + model.eval()#评估模式,这会关闭dropout等 # We don't need to keep track of gradients here so we wrap it in torch.no_grad() with torch.no_grad(): diff --git a/ML/Pytorch/Basics/pytorch_tensorbasics.py b/ML/Pytorch/Basics/pytorch_tensorbasics.py index 3686dfa6..c1ac6a27 100644 --- a/ML/Pytorch/Basics/pytorch_tensorbasics.py +++ b/ML/Pytorch/Basics/pytorch_tensorbasics.py @@ -163,11 +163,12 @@ values, indices = torch.min(x, dim=0) # Can also do x.min(dim=0) abs_x = torch.abs(x) # Returns x where abs function has been applied to every element z = torch.argmax(x, dim=0) # Gets index of the maximum value -z = torch.argmin(x, dim=0) # Gets index of the minimum value +z = torch.argmin(x, dim=0) +print(z)# Gets index of the minimum value mean_x = torch.mean(x.float(), dim=0) # mean requires x to be float z = torch.eq(x, y) # Element wise comparison, in this case z = [False, False, False] sorted_y, indices = torch.sort(y, dim=0, descending=False) - +print(indices) z = torch.clamp(x, min=0) # All values < 0 set to 0 and values > 0 unchanged (this is exactly ReLU function) # If you want to values over max_val to be clamped, do torch.clamp(x, min=min_val, max=max_val) @@ -207,7 +208,7 @@ rows = torch.tensor([1, 0]) cols = torch.tensor([4, 0]) print(x[rows, cols]) # Gets second row fifth column and first row first column - +# which is same as doing: [x[1,4], x[0,0]] 高级索引 # More advanced indexing x = torch.arange(10) print(x[(x < 2) | (x > 8)]) # will be [0, 1, 9] @@ -216,7 +217,9 @@ # Useful operations for indexing print( torch.where(x > 5, x, x * 2) -) # gives [0, 2, 4, 6, 8, 10, 6, 7, 8, 9], all values x > 5 yield x, else x*2 +) +#满足第一个条件执行第二个 反之执行第三个 +# gives [0, 2, 4, 6, 8, 10, 6, 7, 8, 9], all values x > 5 yield x, else x*2 x = torch.tensor([0, 0, 1, 2, 2, 3, 4]).unique() # x = [0, 1, 2, 3, 4] print( x.ndimension() @@ -231,7 +234,7 @@ # ============================================================= # x = torch.arange(9) - +print(x.shape) # Shape is [9] # Let's say we want to reshape it to be 3x3 x_3x3 = x.view(3, 3) @@ -256,7 +259,7 @@ # using pointers to construct these matrices). This is a bit complicated and I need to explore this more # as well, at least you know it's a problem to be cautious of! A solution is to do the following print(y.contiguous().view(9)) # Calling .contiguous() before view and it works - +#内存跳动 # Moving on to another operation, let's say we want to add two tensors dimensions togethor x1 = torch.rand(2, 5) x2 = torch.rand(2, 5) @@ -284,7 +287,7 @@ z = torch.chunk(x, chunks=2, dim=1) print(z[0].shape) print(z[1].shape) - +#分成若干个子张量 # Let's say we want to add an additional dimension x = torch.arange( 10 diff --git a/ML/Pytorch/GANs/1. SimpleGAN/fc_gan.py b/ML/Pytorch/GANs/1. SimpleGAN/fc_gan.py index 57aa419a..1aa68f78 100644 --- a/ML/Pytorch/GANs/1. SimpleGAN/fc_gan.py +++ b/ML/Pytorch/GANs/1. SimpleGAN/fc_gan.py @@ -81,7 +81,8 @@ def forward(self, x): noise = torch.randn(batch_size, z_dim).to(device) fake = gen(noise) disc_real = disc(real).view(-1) - lossD_real = criterion(disc_real, torch.ones_like(disc_real)) + lossD_real = criterion(disc_real, to + rch.ones_like(disc_real)) disc_fake = disc(fake).view(-1) lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) lossD = (lossD_real + lossD_fake) / 2 diff --git a/ML/Pytorch/GANs/2. DCGAN/model.py b/ML/Pytorch/GANs/2. DCGAN/model.py index 04b52d9d..7011f1d7 100644 --- a/ML/Pytorch/GANs/2. DCGAN/model.py +++ b/ML/Pytorch/GANs/2. DCGAN/model.py @@ -6,6 +6,14 @@ * 2022-12-20: Small revision of code, checked that it works with latest PyTorch version """ +import torch""" +Discriminator and Generator implementation from DCGAN paper + +Programmed by Aladdin Persson +* 2020-11-01: Initial coding +* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version +""" + import torch import torch.nn as nn diff --git a/ML/Pytorch/README_ZH.md b/ML/Pytorch/README_ZH.md new file mode 100644 index 00000000..2351cdae --- /dev/null +++ b/ML/Pytorch/README_ZH.md @@ -0,0 +1,301 @@ +# PyTorch 学习资源库 📚 + +这个文件夹包含了全面的 PyTorch 深度学习学习资源,涵盖从基础到高级的各种主题。按照难度和应用领域组织,适合从初学者到中级开发者的学习路径。 + +--- + +## 📁 文件夹结构总览 + +``` +Pytorch/ +├── Basics/ # ⭐ 入门基础 +├── CNN_architectures/ # 卷积神经网络架构 +├── GANs/ # 生成对抗网络 +├── huggingface/ # HuggingFace 自然语言处理 +├── image_segmentation/ # 图像分割 +├── more_advanced/ # 高级技术 +├── object_detection/ # 目标检测 +├── others/ # 其他工具和设置 +├── pytorch_lightning/ # PyTorch Lightning 框架 +└── recommender_systems/ # 推荐系统 +``` + +--- + +## 📖 详细内容说明 + +### 1️⃣ **Basics/** - PyTorch 基础入门 ⭐ + +这是最重要的入门文件夹,包含所有必须掌握的基础概念。 + +| 文件名 | 功能说明 | +|-------|---------| +| **pytorch_tensorbasics.py** | 张量基础操作(创建、索引、切片等) | +| **pytorch_simple_fullynet.py** | 简单的全连接神经网络实现 | +| **pytorch_simple_CNN.py** | 基础卷积神经网络(CNN)实现 | +| **pytorch_rnn_gru_lstm.py** | 循环神经网络(RNN、LSTM、GRU)讲解 | +| **pytorch_bidirectional_lstm.py** | 双向LSTM实现 | +| **pytorch_transforms.py** | 数据转换和预处理技巧 | +| **pytorch_tensorboard_.py** | TensorBoard 可视化工具使用 | +| **pytorch_loadsave.py** | 模型的保存和加载方法 | +| **pytorch_init_weights.py** | 权重初始化方法 | +| **pytorch_std_mean.py** | 标准化和均值计算 | +| **pytorch_lr_ratescheduler.py** | 学习率调度器(动态调整学习率) | +| **pytorch_mixed_precision_example.py** | 混合精度训练(加速训练,节省内存) | +| **pytorch_pretrain_finetune.py** | 预训练模型的微调方法 | +| **pytorch_progress_bar.py** | 进度条显示工具 | +| **lightning_simple_CNN.py** | 使用 PyTorch Lightning 实现简单CNN | + +#### 子文件夹: + +| 文件夹 | 说明 | +|-------|------| +| **custom_dataset/** | 自定义数据集加载器(CSV格式) | +| **custom_dataset_txt/** | 自定义数据集加载器(文本格式) | +| **albumentations_tutorial/** | 数据增强库 Albumentations 的使用教程 | +| **set_deterministic_behavior/** | 设置随机种子以确保结果可复现 | +| **Imbalanced_classes/** | 处理不平衡数据集的方法 | +| **dataset/** | 存放下载的 MNIST 数据集 | + +--- + +### 2️⃣ **CNN_architectures/** - 卷积神经网络架构 + +实现和学习各种经典的CNN架构。 + +| 文件名 | 架构 | 说明 | +|-------|------|------| +| **lenet5_pytorch.py** | LeNet-5 | 最早的CNN,用于手写数字识别 | +| **pytorch_vgg_implementation.py** | VGG | 深层网络,使用小卷积核 | +| **pytorch_resnet.py** | ResNet | 残差网络,解决梯度消失问题 | +| **pytorch_efficientnet.py** | EfficientNet | 高效的网络,性能与速度的平衡 | +| **pytorch_inceptionet.py** | Inception | 使用多个卷积核尺寸的并行结构 | + +**应用场景**:图像分类、特征提取等。 + +--- + +### 3️⃣ **GANs/** - 生成对抗网络(生成模型) + +从简单到复杂的各种GAN实现,用于生成逼真的图像。 + +| 文件夹 | 说明 | 复杂度 | +|-------|------|--------| +| **1. SimpleGAN/** | 最简单的GAN实现(全连接层) | ⭐ 入门 | +| **2. DCGAN/** | 深度卷积GAN(用卷积层替代全连接) | ⭐⭐ 初级 | +| **3. WGAN/** | Wasserstein GAN(改进的损失函数) | ⭐⭐⭐ 中级 | +| **4. WGAN-GP/** | WGAN + 梯度惩罚(更稳定的训练) | ⭐⭐⭐ 中级 | +| **CycleGAN/** | 无配对图像翻译(如照片↔绘画) | ⭐⭐⭐⭐ 高级 | +| **Pix2Pix/** | 条件GAN,配对图像到图像翻译 | ⭐⭐⭐ 中级 | +| **SRGAN/** | 超分辨率GAN(低分辨率→高分辨率) | ⭐⭐⭐ 中级 | +| **ESRGAN/** | 增强型SRGAN(改进的超分辨率) | ⭐⭐⭐⭐ 高级 | +| **StyleGAN/** | 风格生成网络(高质量人脸生成) | ⭐⭐⭐⭐⭐ 专家级 | +| **ProGAN/** | 渐进式GAN(逐步增加分辨率) | ⭐⭐⭐⭐ 高级 | + +**应用场景**:图像生成、图像翻译、超分辨率、数据增强等。 + +--- + +### 4️⃣ **image_segmentation/** - 图像分割 + +用于像素级的图像分析。 + +| 文件夹 | 说明 | +|-------|------| +| **semantic_segmentation_unet/** | U-Net 网络实现(医学图像分割等) | + +**应用场景**:医学影像分析、自动驾驶、卫星图像处理等。 + +--- + +### 5️⃣ **object_detection/** - 目标检测 + +检测和定位图像中的物体。 + +| 文件夹 | 说明 | +|-------|------| +| **metrics/** | 目标检测评估指标(IoU、mAP等) | +| **YOLO/** | YOLO 检测算法 | +| **YOLOv3/** | YOLO v3 版本实现 | + +**应用场景**:安全监控、人脸识别、自动驾驶等。 + +--- + +### 6️⃣ **more_advanced/** - 高级技术 + +各种高级和前沿的深度学习技术。 + +| 文件夹 | 说明 | +|-------|------| +| **Seq2Seq/** | 序列到序列模型(机器翻译基础) | +| **Seq2Seq_attention/** | 带注意力机制的Seq2Seq(改进版) | +| **seq2seq_transformer/** | 使用Transformer的Seq2Seq | +| **transformer_from_scratch/** | 从零实现Transformer模型 | +| **VAE/** | 变分自编码器(无监督学习) | +| **image_captioning/** | 图像标题生成(视觉+语言结合) | +| **neuralstyle/** | 神经风格迁移(艺术风格转换) | +| **torchtext/** | PyTorch 文本处理库教程 | +| **finetuning_whisper/** | Whisper 语音识别模型微调 | + +**应用场景**:机器翻译、语音识别、图像理解、艺术生成等。 + +--- + +### 7️⃣ **huggingface/** - 自然语言处理(NLP) + +使用 HuggingFace Transformers 库进行NLP任务。 + +| 文件名 | 功能 | +|-------|------| +| **learninghugg.py** | HuggingFace 基础学习 | +| **model.py** | 模型定义 | +| **train.py** | 模型训练脚本 | +| **test.py** | 模型测试脚本 | +| **dataset.py** | 数据集加载器 | +| **finetuning_t5_lightning.ipynb** | T5 模型微调(使用Lightning) | +| **finetune_t5_small_cnndaily.ipynb** | T5-small 在CNN/DailyMail数据集微调 | +| **cnndaily_t5_lightning_customdataloading.ipynb** | 自定义数据加载方式 | +| **learning.ipynb** | 学习笔记本 | + +**应用场景**:文本分类、机器翻译、文本摘要、问答系统等。 + +--- + +### 8️⃣ **pytorch_lightning/** - PyTorch Lightning 框架 + +使用 PyTorch Lightning 简化模型训练(类似Keras对TensorFlow的作用)。 + +| 文件夹 | 说明 | +|-------|------| +| **1. start code/** | 开始使用 Lightning 的基础代码 | +| **2. LightningModule/** | Lightning 模块化组件 | +| **3. Lightning Trainer/** | 训练器配置和使用 | +| **4. Metrics/** | 评估指标计算 | +| **5. DataModule/** | 数据模块(数据处理管道) | +| **6. Restructuring/** | 代码重构和组织 | +| **7. Callbacks/** | 回调函数(早停、保存等) | +| **8. Logging Tensorboard/** | TensorBoard 日志记录 | +| **9. Profiler/** | 性能分析工具 | +| **10. Multi-GPU/** | 多GPU训练 | + +**优势**: +- 代码简洁,专注模型逻辑 +- 自动处理GPU/TPU +- 内置多GPU分布式训练 +- 集成许多最佳实践 + +--- + +### 9️⃣ **recommender_systems/** - 推荐系统 + +构建推荐系统的模型和算法。 + +| 文件夹 | 说明 | +|-------|------| +| **neural_collaborative_filtering/** | 神经协同过滤(用户-物品交互建模) | + +**应用场景**:电商推荐、视频推荐、音乐推荐等。 + +--- + +### 🔟 **others/** - 其他工具和设置 + +| 文件夹 | 说明 | +|-------|------| +| **default_setups/** | 默认配置和项目模板 | + +--- + +## 🎯 学习路线建议 + +### 对于初学者: +1. **从Basics开始** - 按文件名顺序学习,理解PyTorch的基本概念 +2. **pytorch_tensorbasics.py** → **pytorch_simple_fullynet.py** → **pytorch_simple_CNN.py** +3. 学习 **custom_dataset** 如何加载数据 +4. 尝试 **pytorch_loadsave.py** 保存和加载模型 + +### 对于进阶学习者: +1. 学习 **CNN_architectures** 中的各种网络架构 +2. 尝试 **GANs/1. SimpleGAN** 理解生成模型 +3. 探索 **pytorch_lightning** 来优化代码结构 + +### 对于高级应用: +1. **GANs** 文件夹中的各种GAN变体 +2. **more_advanced** 中的Transformer、VAE等 +3. **huggingface** 进行自然语言处理任务 +4. **object_detection** 和 **image_segmentation** 进行计算机视觉任务 + +--- + +## 🛠️ 环境和依赖 + +运行这些代码需要安装以下库: + +```bash +# 核心库 +pip install torch torchvision torchaudio + +# 高级工具 +pip install pytorch-lightning tensorboard transformers + +# 数据处理 +pip install numpy pandas scikit-learn albumentations + +# 可视化 +pip install matplotlib seaborn + +# NLP工具 +pip install torchtext huggingface-hub +``` + +--- + +## 💡 核心概念速查表 + +| 概念 | 所在位置 | 文件示例 | +|-----|--------|----------| +| **张量操作** | Basics | pytorch_tensorbasics.py | +| **神经网络构建** | Basics | pytorch_simple_fullynet.py | +| **CNN** | Basics, CNN_architectures | pytorch_simple_CNN.py, lenet5_pytorch.py | +| **RNN/LSTM** | Basics | pytorch_rnn_gru_lstm.py | +| **数据加载** | Basics | custom_dataset/custom_dataset.py | +| **模型保存/加载** | Basics | pytorch_loadsave.py | +| **生成模型** | GANs | 1. SimpleGAN/fc_gan.py | +| **迁移学习** | Basics | pytorch_pretrain_finetune.py | +| **Transformer** | more_advanced | transformer_from_scratch/ | +| **图像分割** | image_segmentation | semantic_segmentation_unet/ | +| **目标检测** | object_detection | YOLO/ | +| **框架简化** | pytorch_lightning | 1. start code/ | + +--- + +## 📚 推荐学习资源 + +- **官方文档**:https://pytorch.org/docs/ +- **PyTorch Tutorials**:https://pytorch.org/tutorials/ +- **PyTorch Lightning**:https://www.pytorchlightning.ai/ +- **HuggingFace Docs**:https://huggingface.co/docs/ + +--- + +## ✅ 快速开始 + +### 运行你的第一个PyTorch程序: + +```bash +# 进入项目目录 +cd d:\Machine-Learning-Collection\ML\Pytorch\Basics + +# 运行基础教程 +python pytorch_tensorbasics.py + +# 或者运行一个简单的CNN +python pytorch_simple_CNN.py +``` + +--- + +**祝你学习愉快!🚀** + +如有问题,欢迎查看各个文件的源代码注释,它们都有详细的说明。 diff --git a/Test/tensor.py b/Test/tensor.py new file mode 100644 index 00000000..eaa05df1 --- /dev/null +++ b/Test/tensor.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn + +my_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) +x=torch.empty(size=(3,3)).uniform_(0,1) +y=torch.diag(torch.ones(3)) +z=torch.ones(3) +print(x) +print(y) +print(z) +import numpy as np +a = np.array([1, 2, 3]) +b = torch.from_numpy(a)#这里是将numpy数组转换为tensor +print(a) +print(b) +c=b.numpy()#这里是将tensor转换为numpy数组 +print(c.dtype) +import torch +x = torch.tensor([1, 2, 3]) +print(torch.diag(x)) +# 输出: +# tensor([[1, 0, 0], +# [0, 2, 0], +# [0, 0, 3]]) +A = torch.tensor([[1, 2], [3, 4]]) +print(torch.diag(A)) +# 输出: tensor([1, 4]) +p=torch.rand(3, 4) +print(p) +q=torch.eye(4) +print(q) +z=torch.empty(3,4).normal_(mean=0,std=1) +print(z) +j=torch.arange(1,10,2) +print(j) +k=torch.empty(3,4) +print(k) \ No newline at end of file diff --git a/_downloads/c195adbae0504b6504c93e0fd18235ce/mario_rl_tutorial.ipynb b/_downloads/c195adbae0504b6504c93e0fd18235ce/mario_rl_tutorial.ipynb new file mode 100644 index 00000000..5151c183 --- /dev/null +++ b/_downloads/c195adbae0504b6504c93e0fd18235ce/mario_rl_tutorial.ipynb @@ -0,0 +1,1178 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "id": "-uf9wy7_D2qF" + }, + "outputs": [], + "source": [ + "# For tips on running notebooks in Google Colab, see\n", + "# https://docs.pytorch.org/tutorials/beginner/colab\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2LYmEix3D2qH" + }, + "source": [ + "Train a Mario-playing RL Agent\n", + "==============================\n", + "\n", + "**Authors:** [Yuansong Feng](https://github.com/YuansongFeng), [Suraj\n", + "Subramanian](https://github.com/suraj813), [Howard\n", + "Wang](https://github.com/hw26), [Steven\n", + "Guo](https://github.com/GuoYuzhang).\n", + "\n", + "This tutorial walks you through the fundamentals of Deep Reinforcement\n", + "Learning. At the end, you will implement an AI-powered Mario (using\n", + "[Double Deep Q-Networks](https://arxiv.org/pdf/1509.06461.pdf)) that can\n", + "play the game by itself.\n", + "\n", + "Although no prior knowledge of RL is necessary for this tutorial, you\n", + "can familiarize yourself with these RL\n", + "[concepts](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html),\n", + "and have this handy\n", + "[cheatsheet](https://colab.research.google.com/drive/1eN33dPVtdPViiS1njTW_-r-IYCDTFU7N)\n", + "as your companion. The full code is available\n", + "[here](https://github.com/yuansongFeng/MadMario/).\n", + "\n", + "![](https://pytorch.org/tutorials/_static/img/mario.gif)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gIa9lvU7D2qJ" + }, + "source": [ + "``` {.bash}\n", + "%%bash\n", + "pip install gym-super-mario-bros==7.4.0\n", + "pip install tensordict==0.3.0\n", + "pip install torchrl==0.3.0\n", + "```\n" + ] + }, + { + "cell_type": "code", + "source": [ + "%%bash\n", + "pip install gym==0.26.2 gym-super-mario-bros==7.3.0 nes-py==8.1.0\n", + "pip install tensordict\n", + "pip install torchrl" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aYk3GjmsEmAJ", + "outputId": "ddd2971f-f92f-4fae-f9e9-f20015cb8fff" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: gym==0.25.2 in /usr/local/lib/python3.12/dist-packages (0.25.2)\n", + "Collecting gym-super-mario-bros==7.3.0\n", + " Downloading gym_super_mario_bros-7.3.0-py2.py3-none-any.whl.metadata (9.4 kB)\n", + "Collecting nes-py==8.1.0\n", + " Downloading nes_py-8.1.0.tar.gz (73 kB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 73.1/73.1 kB 3.6 MB/s eta 0:00:00\n", + " Preparing metadata (setup.py): started\n", + " Preparing metadata (setup.py): finished with status 'done'\n", + "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.12/dist-packages (from gym==0.25.2) (2.0.2)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from gym==0.25.2) (3.1.1)\n", + "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.12/dist-packages (from gym==0.25.2) (0.1.0)\n", + "Collecting pyglet>=1.3.2 (from nes-py==8.1.0)\n", + " Downloading pyglet-2.1.9-py3-none-any.whl.metadata (7.7 kB)\n", + "Requirement already satisfied: tqdm>=4.19.5 in /usr/local/lib/python3.12/dist-packages (from nes-py==8.1.0) (4.67.1)\n", + "Downloading gym_super_mario_bros-7.3.0-py2.py3-none-any.whl (198 kB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 198.6/198.6 kB 10.6 MB/s eta 0:00:00\n", + "Downloading pyglet-2.1.9-py3-none-any.whl (1.0 MB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 36.1 MB/s eta 0:00:00\n", + "Building wheels for collected packages: nes-py\n", + " Building wheel for nes-py (setup.py): started\n", + " Building wheel for nes-py (setup.py): finished with status 'done'\n", + " Created wheel for nes-py: filename=nes_py-8.1.0-cp312-cp312-linux_x86_64.whl size=504502 sha256=7d20998ab7c44177d003f2459ee1bb2136527fcd190496ab3e06445e79e88fd5\n", + " Stored in directory: /root/.cache/pip/wheels/a7/83/d9/f251e11d21aa7223824a74c79b52f63a3f5175ac20e9bac221\n", + "Successfully built nes-py\n", + "Installing collected packages: pyglet, nes-py, gym-super-mario-bros\n", + "Successfully installed gym-super-mario-bros-7.3.0 nes-py-8.1.0 pyglet-2.1.9\n", + "Collecting tensordict\n", + " Downloading tensordict-0.10.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (9.3 kB)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (from tensordict) (2.8.0+cu126)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from tensordict) (2.0.2)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.12/dist-packages (from tensordict) (3.1.1)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from tensordict) (25.0)\n", + "Requirement already satisfied: importlib_metadata in /usr/local/lib/python3.12/dist-packages (from tensordict) (8.7.0)\n", + "Requirement already satisfied: orjson in /usr/local/lib/python3.12/dist-packages (from tensordict) (3.11.3)\n", + "Collecting pyvers<0.2.0,>=0.1.0 (from tensordict)\n", + " Downloading pyvers-0.1.0-py3-none-any.whl.metadata (5.4 kB)\n", + "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.12/dist-packages (from importlib_metadata->tensordict) (3.23.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (4.15.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (75.2.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (3.5)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (3.1.6)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (2025.3.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.80)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (9.10.2.21)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (11.3.0.4)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (10.3.7.77)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (11.7.1.2)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.5.4.2)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (0.7.1)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (2.27.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.77)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.85)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (1.11.1.6)\n", + "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (3.4.0)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch->tensordict) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch->tensordict) (3.0.3)\n", + "Downloading tensordict-0.10.0-cp312-cp312-manylinux_2_28_x86_64.whl (449 kB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 450.0/450.0 kB 12.5 MB/s eta 0:00:00\n", + "Downloading pyvers-0.1.0-py3-none-any.whl (10 kB)\n", + "Installing collected packages: pyvers, tensordict\n", + "Successfully installed pyvers-0.1.0 tensordict-0.10.0\n", + "Collecting torchrl\n", + " Downloading torchrl-0.10.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (48 kB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 48.4/48.4 kB 2.8 MB/s eta 0:00:00\n", + "Requirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.12/dist-packages (from torchrl) (2.8.0+cu126)\n", + "Requirement already satisfied: pyvers in /usr/local/lib/python3.12/dist-packages (from torchrl) (0.1.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from torchrl) (2.0.2)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from torchrl) (25.0)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.12/dist-packages (from torchrl) (3.1.1)\n", + "Requirement already satisfied: tensordict<0.11.0,>=0.10.0 in /usr/local/lib/python3.12/dist-packages (from torchrl) (0.10.0)\n", + "Requirement already satisfied: importlib_metadata in /usr/local/lib/python3.12/dist-packages (from tensordict<0.11.0,>=0.10.0->torchrl) (8.7.0)\n", + "Requirement already satisfied: orjson in /usr/local/lib/python3.12/dist-packages (from tensordict<0.11.0,>=0.10.0->torchrl) (3.11.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (4.15.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (75.2.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (3.5)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (3.1.6)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (2025.3.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.80)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (9.10.2.21)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (11.3.0.4)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (10.3.7.77)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (11.7.1.2)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.5.4.2)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (0.7.1)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (2.27.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.77)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.85)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (1.11.1.6)\n", + "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (3.4.0)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.1.0->torchrl) (1.3.0)\n", + "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.12/dist-packages (from importlib_metadata->tensordict<0.11.0,>=0.10.0->torchrl) (3.23.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.1.0->torchrl) (3.0.3)\n", + "Downloading torchrl-0.10.0-cp312-cp312-manylinux_2_28_x86_64.whl (1.8 MB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 42.0 MB/s eta 0:00:00\n", + "Installing collected packages: torchrl\n", + "Successfully installed torchrl-0.10.0\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Vp25M-ClD2qJ", + "outputId": "462e5ecc-ef04-4258-9ca6-5bb923c0027b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.\n", + "Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.\n", + "See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] + } + ], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torchvision import transforms as T\n", + "from PIL import Image\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from collections import deque\n", + "import random, datetime, os\n", + "\n", + "# Gym is an OpenAI toolkit for RL\n", + "import gym\n", + "from gym.spaces import Box\n", + "from gym.wrappers import FrameStack\n", + "\n", + "# NES Emulator for OpenAI Gym\n", + "from nes_py.wrappers import JoypadSpace\n", + "\n", + "# Super Mario environment for OpenAI Gym\n", + "import gym_super_mario_bros\n", + "\n", + "from tensordict import TensorDict\n", + "from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_c-_S9p_D2qJ" + }, + "source": [ + "RL Definitions\n", + "==============\n", + "\n", + "**Environment** The world that an agent interacts with and learns from.\n", + "\n", + "**Action** $a$ : How the Agent responds to the Environment. The set of\n", + "all possible Actions is called *action-space*.\n", + "\n", + "**State** $s$ : The current characteristic of the Environment. The set\n", + "of all possible States the Environment can be in is called\n", + "*state-space*.\n", + "\n", + "**Reward** $r$ : Reward is the key feedback from Environment to Agent.\n", + "It is what drives the Agent to learn and to change its future action. An\n", + "aggregation of rewards over multiple time steps is called **Return**.\n", + "\n", + "**Optimal Action-Value function** $Q^*(s,a)$ : Gives the expected return\n", + "if you start in state $s$, take an arbitrary action $a$, and then for\n", + "each future time step take the action that maximizes returns. $Q$ can be\n", + "said to stand for the \"quality\" of the action in a state. We try to\n", + "approximate this function.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s1MRtEQkD2qJ" + }, + "source": [ + "Environment\n", + "===========\n", + "\n", + "Initialize Environment\n", + "----------------------\n", + "\n", + "In Mario, the environment consists of tubes, mushrooms and other\n", + "components.\n", + "\n", + "When Mario makes an action, the environment responds with the changed\n", + "(next) state, reward and other info.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 530 + }, + "id": "ptVynUw2D2qK", + "outputId": "355001d4-10e8-4f4b-d01d-5ac97dd815d5" + }, + "outputs": [ + { + "output_type": "error", + "ename": "TypeError", + "evalue": "SuperMarioBrosEnv.__init__() got an unexpected keyword argument 'render_mode'", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipython-input-4228824977.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mgym_super_mario_bros\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactions\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mactions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m env = gym_super_mario_bros.make(\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;34m\"SuperMarioBros-1-1-v0\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mrender_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'rgb_array'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# Changed render_mode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, new_step_api, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m 672\u001b[0m )\n\u001b[1;32m 673\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 674\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 675\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 676\u001b[0m \u001b[0;31m# Copies the environment creation specification and kwargs to add to the environment specification details\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, new_step_api, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 661\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 662\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv_creator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0m_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 663\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 664\u001b[0m if (\n", + "\u001b[0;31mTypeError\u001b[0m: SuperMarioBrosEnv.__init__() got an unexpected keyword argument 'render_mode'" + ] + } + ], + "source": [ + "# Initialize Super Mario environment\n", + "import gym_super_mario_bros\n", + "from nes_py.wrappers import JoypadSpace\n", + "import gym_super_mario_bros.actions as actions\n", + "\n", + "env = gym_super_mario_bros.make(\n", + " \"SuperMarioBros-1-1-v0\",\n", + " render_mode='rgb_array', # Changed render_mode\n", + " apply_api_compatibility=True\n", + ")\n", + "\n", + "# Limit the action-space to\n", + "# 0. walk right\n", + "# 1. jump right\n", + "env = JoypadSpace(env, actions.SIMPLE_MOVEMENT)\n", + "\n", + "env.reset()\n", + "next_state, reward, done, trunc, info = env.step(action=0)\n", + "print(f\"{next_state.shape},\\n {reward},\\n {done},\\n {info}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_1tsdX8ND2qK" + }, + "source": [ + "Preprocess Environment\n", + "======================\n", + "\n", + "Environment data is returned to the agent in `next_state`. As you saw\n", + "above, each state is represented by a `[3, 240, 256]` size array. Often\n", + "that is more information than our agent needs; for instance, Mario's\n", + "actions do not depend on the color of the pipes or the sky!\n", + "\n", + "We use **Wrappers** to preprocess environment data before sending it to\n", + "the agent.\n", + "\n", + "`GrayScaleObservation` is a common wrapper to transform an RGB image to\n", + "grayscale; doing so reduces the size of the state representation without\n", + "losing useful information. Now the size of each state: `[1, 240, 256]`\n", + "\n", + "`ResizeObservation` downsamples each observation into a square image.\n", + "New size: `[1, 84, 84]`\n", + "\n", + "`SkipFrame` is a custom wrapper that inherits from `gym.Wrapper` and\n", + "implements the `step()` function. Because consecutive frames don't vary\n", + "much, we can skip n-intermediate frames without losing much information.\n", + "The n-th frame aggregates rewards accumulated over each skipped frame.\n", + "\n", + "`FrameStack` is a wrapper that allows us to squash consecutive frames of\n", + "the environment into a single observation point to feed to our learning\n", + "model. This way, we can identify if Mario was landing or jumping based\n", + "on the direction of his movement in the previous several frames.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 211 + }, + "id": "OTdrBI4XD2qK", + "outputId": "ffc7e4d9-c2e3-4108-e8d6-a0957523b62f" + }, + "outputs": [ + { + "output_type": "error", + "ename": "NameError", + "evalue": "name 'env' is not defined", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipython-input-3910090054.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;31m# Apply Wrappers to environment\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSkipFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mskip\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGrayScaleObservation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mResizeObservation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m84\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'env' is not defined" + ] + } + ], + "source": [ + "class SkipFrame(gym.Wrapper):\n", + " def __init__(self, env, skip):\n", + " \"\"\"Return only every `skip`-th frame\"\"\"\n", + " super().__init__(env)\n", + " self._skip = skip\n", + "\n", + " def step(self, action):\n", + " \"\"\"Repeat action, and sum reward\"\"\"\n", + " total_reward = 0.0\n", + " for i in range(self._skip):\n", + " # Accumulate reward and repeat the same action\n", + " obs, reward, done, trunk, info = self.env.step(action)\n", + " total_reward += reward\n", + " if done:\n", + " break\n", + " return obs, total_reward, done, trunk, info\n", + "\n", + "\n", + "class GrayScaleObservation(gym.ObservationWrapper):\n", + " def __init__(self, env):\n", + " super().__init__(env)\n", + " obs_shape = self.observation_space.shape[:2]\n", + " self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)\n", + "\n", + " def permute_orientation(self, observation):\n", + " # permute [H, W, C] array to [C, H, W] tensor\n", + " observation = np.transpose(observation, (2, 0, 1))\n", + " observation = torch.tensor(observation.copy(), dtype=torch.float)\n", + " return observation\n", + "\n", + " def observation(self, observation):\n", + " observation = self.permute_orientation(observation)\n", + " transform = T.Grayscale()\n", + " observation = transform(observation)\n", + " return observation\n", + "\n", + "\n", + "class ResizeObservation(gym.ObservationWrapper):\n", + " def __init__(self, env, shape):\n", + " super().__init__(env)\n", + " if isinstance(shape, int):\n", + " self.shape = (shape, shape)\n", + " else:\n", + " self.shape = tuple(shape)\n", + "\n", + " obs_shape = self.shape + self.observation_space.shape[2:]\n", + " self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)\n", + "\n", + " def observation(self, observation):\n", + " transforms = T.Compose(\n", + " [T.Resize(self.shape, antialias=True), T.Normalize(0, 255)]\n", + " )\n", + " observation = transforms(observation).squeeze(0)\n", + " return observation\n", + "\n", + "\n", + "# Apply Wrappers to environment\n", + "env = SkipFrame(env, skip=4)\n", + "env = GrayScaleObservation(env)\n", + "env = ResizeObservation(env, shape=84)\n", + "if gym.__version__ < '0.26':\n", + " env = FrameStack(env, num_stack=4, new_step_api=True)\n", + "else:\n", + " env = FrameStack(env, num_stack=4)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zdFsFd9OD2qL" + }, + "source": [ + "After applying the above wrappers to the environment, the final wrapped\n", + "state consists of 4 gray-scaled consecutive frames stacked together, as\n", + "shown above in the image on the left. Each time Mario makes an action,\n", + "the environment responds with a state of this structure. The structure\n", + "is represented by a 3-D array of size `[4, 84, 84]`.\n", + "\n", + "![](https://pytorch.org/tutorials/_static/img/mario_env.png)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nWAooKxLD2qL" + }, + "source": [ + "Agent\n", + "=====\n", + "\n", + "We create a class `Mario` to represent our agent in the game. Mario\n", + "should be able to:\n", + "\n", + "- **Act** according to the optimal action policy based on the current\n", + " state (of the environment).\n", + "- **Remember** experiences. Experience = (current state, current\n", + " action, reward, next state). Mario *caches* and later *recalls* his\n", + " experiences to update his action policy.\n", + "- **Learn** a better action policy over time\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pi2mirDYD2qL" + }, + "outputs": [], + "source": [ + "class Mario:\n", + " def __init__():\n", + " pass\n", + "\n", + " def act(self, state):\n", + " \"\"\"Given a state, choose an epsilon-greedy action\"\"\"\n", + " pass\n", + "\n", + " def cache(self, experience):\n", + " \"\"\"Add the experience to memory\"\"\"\n", + " pass\n", + "\n", + " def recall(self):\n", + " \"\"\"Sample experiences from memory\"\"\"\n", + " pass\n", + "\n", + " def learn(self):\n", + " \"\"\"Update online action value (Q) function with a batch of experiences\"\"\"\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PMIFF0aBD2qL" + }, + "source": [ + "In the following sections, we will populate Mario's parameters and\n", + "define his functions.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jklk8voTD2qL" + }, + "source": [ + "Act\n", + "===\n", + "\n", + "For any given state, an agent can choose to do the most optimal action\n", + "(**exploit**) or a random action (**explore**).\n", + "\n", + "Mario randomly explores with a chance of `self.exploration_rate`; when\n", + "he chooses to exploit, he relies on `MarioNet` (implemented in `Learn`\n", + "section) to provide the most optimal action.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-kYLEcR8D2qL" + }, + "outputs": [], + "source": [ + "class Mario:\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " self.state_dim = state_dim\n", + " self.action_dim = action_dim\n", + " self.save_dir = save_dir\n", + "\n", + " self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + " # Mario's DNN to predict the most optimal action - we implement this in the Learn section\n", + " self.net = MarioNet(self.state_dim, self.action_dim).float()\n", + " self.net = self.net.to(device=self.device)\n", + "\n", + " self.exploration_rate = 1\n", + " self.exploration_rate_decay = 0.99999975\n", + " self.exploration_rate_min = 0.1\n", + " self.curr_step = 0\n", + "\n", + " self.save_every = 5e5 # no. of experiences between saving Mario Net\n", + "\n", + " def act(self, state):\n", + " \"\"\"\n", + " Given a state, choose an epsilon-greedy action and update value of step.\n", + "\n", + " Inputs:\n", + " state(``LazyFrame``): A single observation of the current state, dimension is (state_dim)\n", + " Outputs:\n", + " ``action_idx`` (``int``): An integer representing which action Mario will perform\n", + " \"\"\"\n", + " # EXPLORE\n", + " if np.random.rand() < self.exploration_rate:\n", + " action_idx = np.random.randint(self.action_dim)\n", + "\n", + " # EXPLOIT\n", + " else:\n", + " state = state[0].__array__() if isinstance(state, tuple) else state.__array__()\n", + " state = torch.tensor(state, device=self.device).unsqueeze(0)\n", + " action_values = self.net(state, model=\"online\")\n", + " action_idx = torch.argmax(action_values, axis=1).item()\n", + "\n", + " # decrease exploration_rate\n", + " self.exploration_rate *= self.exploration_rate_decay\n", + " self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)\n", + "\n", + " # increment step\n", + " self.curr_step += 1\n", + " return action_idx" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z_ud-4bvD2qM" + }, + "source": [ + "Cache and Recall\n", + "================\n", + "\n", + "These two functions serve as Mario's \"memory\" process.\n", + "\n", + "`cache()`: Each time Mario performs an action, he stores the\n", + "`experience` to his memory. His experience includes the current *state*,\n", + "*action* performed, *reward* from the action, the *next state*, and\n", + "whether the game is *done*.\n", + "\n", + "`recall()`: Mario randomly samples a batch of experiences from his\n", + "memory, and uses that to learn the game.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YHeLPsd4D2qM" + }, + "outputs": [], + "source": [ + "class Mario(Mario): # subclassing for continuity\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " super().__init__(state_dim, action_dim, save_dir)\n", + " self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000, device=torch.device(\"cpu\")))\n", + " self.batch_size = 32\n", + "\n", + " def cache(self, state, next_state, action, reward, done):\n", + " \"\"\"\n", + " Store the experience to self.memory (replay buffer)\n", + "\n", + " Inputs:\n", + " state (``LazyFrame``),\n", + " next_state (``LazyFrame``),\n", + " action (``int``),\n", + " reward (``float``),\n", + " done(``bool``))\n", + " \"\"\"\n", + " def first_if_tuple(x):\n", + " return x[0] if isinstance(x, tuple) else x\n", + " state = first_if_tuple(state).__array__()\n", + " next_state = first_if_tuple(next_state).__array__()\n", + "\n", + " state = torch.tensor(state)\n", + " next_state = torch.tensor(next_state)\n", + " action = torch.tensor([action])\n", + " reward = torch.tensor([reward])\n", + " done = torch.tensor([done])\n", + "\n", + " # self.memory.append((state, next_state, action, reward, done,))\n", + " self.memory.add(TensorDict({\"state\": state, \"next_state\": next_state, \"action\": action, \"reward\": reward, \"done\": done}, batch_size=[]))\n", + "\n", + " def recall(self):\n", + " \"\"\"\n", + " Retrieve a batch of experiences from memory\n", + " \"\"\"\n", + " batch = self.memory.sample(self.batch_size).to(self.device)\n", + " state, next_state, action, reward, done = (batch.get(key) for key in (\"state\", \"next_state\", \"action\", \"reward\", \"done\"))\n", + " return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qrnlw-hYD2qM" + }, + "source": [ + "Learn\n", + "=====\n", + "\n", + "Mario uses the [DDQN algorithm](https://arxiv.org/pdf/1509.06461) under\n", + "the hood. DDQN uses two ConvNets - $Q_{online}$ and $Q_{target}$ - that\n", + "independently approximate the optimal action-value function.\n", + "\n", + "In our implementation, we share feature generator `features` across\n", + "$Q_{online}$ and $Q_{target}$, but maintain separate FC classifiers for\n", + "each. $\\theta_{target}$ (the parameters of $Q_{target}$) is frozen to\n", + "prevent updating by backprop. Instead, it is periodically synced with\n", + "$\\theta_{online}$ (more on this later).\n", + "\n", + "Neural Network\n", + "--------------\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9_aBKXrRD2qM" + }, + "outputs": [], + "source": [ + "class MarioNet(nn.Module):\n", + " \"\"\"mini CNN structure\n", + " input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, output_dim):\n", + " super().__init__()\n", + " c, h, w = input_dim\n", + "\n", + " if h != 84:\n", + " raise ValueError(f\"Expecting input height: 84, got: {h}\")\n", + " if w != 84:\n", + " raise ValueError(f\"Expecting input width: 84, got: {w}\")\n", + "\n", + " self.online = self.__build_cnn(c, output_dim)\n", + "\n", + " self.target = self.__build_cnn(c, output_dim)\n", + " self.target.load_state_dict(self.online.state_dict())\n", + "\n", + " # Q_target parameters are frozen.\n", + " for p in self.target.parameters():\n", + " p.requires_grad = False\n", + "\n", + " def forward(self, input, model):\n", + " if model == \"online\":\n", + " return self.online(input)\n", + " elif model == \"target\":\n", + " return self.target(input)\n", + "\n", + " def __build_cnn(self, c, output_dim):\n", + " return nn.Sequential(\n", + " nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),\n", + " nn.ReLU(),\n", + " nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),\n", + " nn.ReLU(),\n", + " nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),\n", + " nn.ReLU(),\n", + " nn.Flatten(),\n", + " nn.Linear(3136, 512),\n", + " nn.ReLU(),\n", + " nn.Linear(512, output_dim),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3givTueiD2qM" + }, + "source": [ + "TD Estimate & TD Target\n", + "=======================\n", + "\n", + "Two values are involved in learning:\n", + "\n", + "**TD Estimate** - the predicted optimal $Q^*$ for a given state $s$\n", + "\n", + "$${TD}_e = Q_{online}^*(s,a)$$\n", + "\n", + "**TD Target** - aggregation of current reward and the estimated $Q^*$ in\n", + "the next state $s'$\n", + "\n", + "$$a' = argmax_{a} Q_{online}(s', a)$$\n", + "\n", + "$${TD}_t = r + \\gamma Q_{target}^*(s',a')$$\n", + "\n", + "Because we don't know what next action $a'$ will be, we use the action\n", + "$a'$ maximizes $Q_{online}$ in the next state $s'$.\n", + "\n", + "Notice we use the\n", + "[\\@torch.no\\_grad()](https://pytorch.org/docs/stable/generated/torch.no_grad.html#no-grad)\n", + "decorator on `td_target()` to disable gradient calculations here\n", + "(because we don't need to backpropagate on $\\theta_{target}$).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-2dL1LkWD2qN" + }, + "outputs": [], + "source": [ + "class Mario(Mario):\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " super().__init__(state_dim, action_dim, save_dir)\n", + " self.gamma = 0.9\n", + "\n", + " def td_estimate(self, state, action):\n", + " current_Q = self.net(state, model=\"online\")[\n", + " np.arange(0, self.batch_size), action\n", + " ] # Q_online(s,a)\n", + " return current_Q\n", + "\n", + " @torch.no_grad()\n", + " def td_target(self, reward, next_state, done):\n", + " next_state_Q = self.net(next_state, model=\"online\")\n", + " best_action = torch.argmax(next_state_Q, axis=1)\n", + " next_Q = self.net(next_state, model=\"target\")[\n", + " np.arange(0, self.batch_size), best_action\n", + " ]\n", + " return (reward + (1 - done.float()) * self.gamma * next_Q).float()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8u_-CcTID2qN" + }, + "source": [ + "Updating the model\n", + "==================\n", + "\n", + "As Mario samples inputs from his replay buffer, we compute $TD_t$ and\n", + "$TD_e$ and backpropagate this loss down $Q_{online}$ to update its\n", + "parameters $\\theta_{online}$ ($\\alpha$ is the learning rate `lr` passed\n", + "to the `optimizer`)\n", + "\n", + "$$\\theta_{online} \\leftarrow \\theta_{online} + \\alpha \\nabla(TD_e - TD_t)$$\n", + "\n", + "$\\theta_{target}$ does not update through backpropagation. Instead, we\n", + "periodically copy $\\theta_{online}$ to $\\theta_{target}$\n", + "\n", + "$$\\theta_{target} \\leftarrow \\theta_{online}$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UKttGHnvD2qN" + }, + "outputs": [], + "source": [ + "class Mario(Mario):\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " super().__init__(state_dim, action_dim, save_dir)\n", + " self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)\n", + " self.loss_fn = torch.nn.SmoothL1Loss()\n", + "\n", + " def update_Q_online(self, td_estimate, td_target):\n", + " loss = self.loss_fn(td_estimate, td_target)\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " return loss.item()\n", + "\n", + " def sync_Q_target(self):\n", + " self.net.target.load_state_dict(self.net.online.state_dict())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mj0D6D1TD2qN" + }, + "source": [ + "Save checkpoint\n", + "===============\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-20sIwqQD2qN" + }, + "outputs": [], + "source": [ + "class Mario(Mario):\n", + " def save(self):\n", + " save_path = (\n", + " self.save_dir / f\"mario_net_{int(self.curr_step // self.save_every)}.chkpt\"\n", + " )\n", + " torch.save(\n", + " dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),\n", + " save_path,\n", + " )\n", + " print(f\"MarioNet saved to {save_path} at step {self.curr_step}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AM-1bAUkD2qN" + }, + "source": [ + "Putting it all together\n", + "=======================\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qXqnlMixD2qN" + }, + "outputs": [], + "source": [ + "class Mario(Mario):\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " super().__init__(state_dim, action_dim, save_dir)\n", + " self.burnin = 1e4 # min. experiences before training\n", + " self.learn_every = 3 # no. of experiences between updates to Q_online\n", + " self.sync_every = 1e4 # no. of experiences between Q_target & Q_online sync\n", + "\n", + " def learn(self):\n", + " if self.curr_step % self.sync_every == 0:\n", + " self.sync_Q_target()\n", + "\n", + " if self.curr_step % self.save_every == 0:\n", + " self.save()\n", + "\n", + " if self.curr_step < self.burnin:\n", + " return None, None\n", + "\n", + " if self.curr_step % self.learn_every != 0:\n", + " return None, None\n", + "\n", + " # Sample from memory\n", + " state, next_state, action, reward, done = self.recall()\n", + "\n", + " # Get TD Estimate\n", + " td_est = self.td_estimate(state, action)\n", + "\n", + " # Get TD Target\n", + " td_tgt = self.td_target(reward, next_state, done)\n", + "\n", + " # Backpropagate loss through Q_online\n", + " loss = self.update_Q_online(td_est, td_tgt)\n", + "\n", + " return (td_est.mean().item(), loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nPxSO3n3D2qN" + }, + "source": [ + "Logging\n", + "=======\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5Vhp225CD2qN" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import time, datetime\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "class MetricLogger:\n", + " def __init__(self, save_dir):\n", + " self.save_log = save_dir / \"log\"\n", + " with open(self.save_log, \"w\") as f:\n", + " f.write(\n", + " f\"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}\"\n", + " f\"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}\"\n", + " f\"{'TimeDelta':>15}{'Time':>20}\\n\"\n", + " )\n", + " self.ep_rewards_plot = save_dir / \"reward_plot.jpg\"\n", + " self.ep_lengths_plot = save_dir / \"length_plot.jpg\"\n", + " self.ep_avg_losses_plot = save_dir / \"loss_plot.jpg\"\n", + " self.ep_avg_qs_plot = save_dir / \"q_plot.jpg\"\n", + "\n", + " # History metrics\n", + " self.ep_rewards = []\n", + " self.ep_lengths = []\n", + " self.ep_avg_losses = []\n", + " self.ep_avg_qs = []\n", + "\n", + " # Moving averages, added for every call to record()\n", + " self.moving_avg_ep_rewards = []\n", + " self.moving_avg_ep_lengths = []\n", + " self.moving_avg_ep_avg_losses = []\n", + " self.moving_avg_ep_avg_qs = []\n", + "\n", + " # Current episode metric\n", + " self.init_episode()\n", + "\n", + " # Timing\n", + " self.record_time = time.time()\n", + "\n", + " def log_step(self, reward, loss, q):\n", + " self.curr_ep_reward += reward\n", + " self.curr_ep_length += 1\n", + " if loss:\n", + " self.curr_ep_loss += loss\n", + " self.curr_ep_q += q\n", + " self.curr_ep_loss_length += 1\n", + "\n", + " def log_episode(self):\n", + " \"Mark end of episode\"\n", + " self.ep_rewards.append(self.curr_ep_reward)\n", + " self.ep_lengths.append(self.curr_ep_length)\n", + " if self.curr_ep_loss_length == 0:\n", + " ep_avg_loss = 0\n", + " ep_avg_q = 0\n", + " else:\n", + " ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)\n", + " ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)\n", + " self.ep_avg_losses.append(ep_avg_loss)\n", + " self.ep_avg_qs.append(ep_avg_q)\n", + "\n", + " self.init_episode()\n", + "\n", + " def init_episode(self):\n", + " self.curr_ep_reward = 0.0\n", + " self.curr_ep_length = 0\n", + " self.curr_ep_loss = 0.0\n", + " self.curr_ep_q = 0.0\n", + " self.curr_ep_loss_length = 0\n", + "\n", + " def record(self, episode, epsilon, step):\n", + " mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)\n", + " mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)\n", + " mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)\n", + " mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)\n", + " self.moving_avg_ep_rewards.append(mean_ep_reward)\n", + " self.moving_avg_ep_lengths.append(mean_ep_length)\n", + " self.moving_avg_ep_avg_losses.append(mean_ep_loss)\n", + " self.moving_avg_ep_avg_qs.append(mean_ep_q)\n", + "\n", + " last_record_time = self.record_time\n", + " self.record_time = time.time()\n", + " time_since_last_record = np.round(self.record_time - last_record_time, 3)\n", + "\n", + " print(\n", + " f\"Episode {episode} - \"\n", + " f\"Step {step} - \"\n", + " f\"Epsilon {epsilon} - \"\n", + " f\"Mean Reward {mean_ep_reward} - \"\n", + " f\"Mean Length {mean_ep_length} - \"\n", + " f\"Mean Loss {mean_ep_loss} - \"\n", + " f\"Mean Q Value {mean_ep_q} - \"\n", + " f\"Time Delta {time_since_last_record} - \"\n", + " f\"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}\"\n", + " )\n", + "\n", + " with open(self.save_log, \"a\") as f:\n", + " f.write(\n", + " f\"{episode:8d}{step:8d}{epsilon:10.3f}\"\n", + " f\"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}\"\n", + " f\"{time_since_last_record:15.3f}\"\n", + " f\"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\\n\"\n", + " )\n", + "\n", + " for metric in [\"ep_lengths\", \"ep_avg_losses\", \"ep_avg_qs\", \"ep_rewards\"]:\n", + " plt.clf()\n", + " plt.plot(getattr(self, f\"moving_avg_{metric}\"), label=f\"moving_avg_{metric}\")\n", + " plt.legend()\n", + " plt.savefig(getattr(self, f\"{metric}_plot\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N_qG20ecD2qO" + }, + "source": [ + "Let's play!\n", + "===========\n", + "\n", + "In this example we run the training loop for 40 episodes, but for Mario\n", + "to truly learn the ways of his world, we suggest running the loop for at\n", + "least 40,000 episodes!\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wtG81GdfD2qO" + }, + "outputs": [], + "source": [ + "use_cuda = torch.cuda.is_available()\n", + "print(f\"Using CUDA: {use_cuda}\")\n", + "print()\n", + "\n", + "save_dir = Path(\"checkpoints\") / datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n", + "save_dir.mkdir(parents=True)\n", + "\n", + "mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)\n", + "\n", + "logger = MetricLogger(save_dir)\n", + "\n", + "episodes = 40\n", + "for e in range(episodes):\n", + "\n", + " state = env.reset()\n", + "\n", + " # Play the game!\n", + " while True:\n", + "\n", + " # Run agent on the state\n", + " action = mario.act(state)\n", + "\n", + " # Agent performs action\n", + " next_state, reward, done, trunc, info = env.step(action)\n", + "\n", + " # Remember\n", + " mario.cache(state, next_state, action, reward, done)\n", + "\n", + " # Learn\n", + " q, loss = mario.learn()\n", + "\n", + " # Logging\n", + " logger.log_step(reward, loss, q)\n", + "\n", + " # Update state\n", + " state = next_state\n", + "\n", + " # Check if end of game\n", + " if done or info[\"flag_get\"]:\n", + " break\n", + "\n", + " logger.log_episode()\n", + "\n", + " if (e % 20 == 0) or (e == episodes - 1):\n", + " logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "piaFDm7vD2qO" + }, + "source": [ + "Conclusion\n", + "==========\n", + "\n", + "In this tutorial, we saw how we can use PyTorch to train a game-playing\n", + "AI. You can use the same methods to train an AI to play any of the games\n", + "at the [OpenAI gym](https://gym.openai.com/). Hope you enjoyed this\n", + "tutorial, feel free to reach us at [our\n", + "github](https://github.com/yuansongFeng/MadMario/)!\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/isaacsim b/isaacsim new file mode 160000 index 00000000..aa503a9b --- /dev/null +++ b/isaacsim @@ -0,0 +1 @@ +Subproject commit aa503a9bbf92405bbbcfe5361e1c4a74fe10d689