|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "c429ee79", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# Part I: Preparing the Dataset\n", |
| 9 | + "\n", |
| 10 | + "This notebook showcases transforming a dataset for finetuning an embedding model with NeMo Microservices.\n", |
| 11 | + "\n", |
| 12 | + "\n", |
| 13 | + "It covers the following -\n", |
| 14 | + "1. Download the SPECTER dataset\n", |
| 15 | + "2. Prepare data for embedding fine-tuning.\n", |
| 16 | + "\n", |
| 17 | + "*Dataset Disclaimer: Each user is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use.*" |
| 18 | + ] |
| 19 | + }, |
| 20 | + { |
| 21 | + "cell_type": "code", |
| 22 | + "execution_count": 1, |
| 23 | + "id": "a8c70985", |
| 24 | + "metadata": {}, |
| 25 | + "outputs": [], |
| 26 | + "source": [ |
| 27 | + "import os\n", |
| 28 | + "import json\n", |
| 29 | + "import random\n", |
| 30 | + "from datasets import load_dataset" |
| 31 | + ] |
| 32 | + }, |
| 33 | + { |
| 34 | + "cell_type": "markdown", |
| 35 | + "id": "d9b165f5", |
| 36 | + "metadata": {}, |
| 37 | + "source": [ |
| 38 | + "The following code cell sets a random seed for reproducibility, and sets data path.\n", |
| 39 | + "It also configures the fraction of training data to use for demonstration purposes as training with the whole [SPECTER](https://huggingface.co/datasets/embedding-data/SPECTER) data may take several hours." |
| 40 | + ] |
| 41 | + }, |
| 42 | + { |
| 43 | + "cell_type": "code", |
| 44 | + "execution_count": 2, |
| 45 | + "id": "403c5cf7", |
| 46 | + "metadata": {}, |
| 47 | + "outputs": [], |
| 48 | + "source": [ |
| 49 | + "SEED = 42\n", |
| 50 | + "random.seed(SEED)\n", |
| 51 | + "\n", |
| 52 | + "DATA_SAVE_PATH = \"./data\"\n", |
| 53 | + "\n", |
| 54 | + "# Configuration for data fraction\n", |
| 55 | + "USE_FRACTION = True # Set to False to use full dataset\n", |
| 56 | + "FRACTION = 0.1 # Use 10% of the dataset (0.1 = 10%, 0.01 = 1%, etc.)" |
| 57 | + ] |
| 58 | + }, |
| 59 | + { |
| 60 | + "cell_type": "markdown", |
| 61 | + "id": "c9cf17cc", |
| 62 | + "metadata": {}, |
| 63 | + "source": [ |
| 64 | + "## Step 1: Download the SPECTER Dataset\n", |
| 65 | + "\n", |
| 66 | + "The SPECTER dataset contains scientific paper triples for training embedding models. This step loads the dataset from Hugging Face." |
| 67 | + ] |
| 68 | + }, |
| 69 | + { |
| 70 | + "cell_type": "code", |
| 71 | + "execution_count": 3, |
| 72 | + "id": "e76c59e3", |
| 73 | + "metadata": {}, |
| 74 | + "outputs": [], |
| 75 | + "source": [ |
| 76 | + "from config import HF_TOKEN\n", |
| 77 | + "\n", |
| 78 | + "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", |
| 79 | + "os.environ[\"HF_ENDPOINT\"] = \"https://huggingface.co\"" |
| 80 | + ] |
| 81 | + }, |
| 82 | + { |
| 83 | + "cell_type": "code", |
| 84 | + "execution_count": 4, |
| 85 | + "id": "20d9650c", |
| 86 | + "metadata": {}, |
| 87 | + "outputs": [ |
| 88 | + { |
| 89 | + "name": "stdout", |
| 90 | + "output_type": "stream", |
| 91 | + "text": [ |
| 92 | + "Dataset info: DatasetDict({\n", |
| 93 | + " train: Dataset({\n", |
| 94 | + " features: ['set'],\n", |
| 95 | + " num_rows: 684100\n", |
| 96 | + " })\n", |
| 97 | + "})\n" |
| 98 | + ] |
| 99 | + } |
| 100 | + ], |
| 101 | + "source": [ |
| 102 | + "# Load the dataset directly from Hugging Face\n", |
| 103 | + "dataset = load_dataset(\"embedding-data/SPECTER\")\n", |
| 104 | + "print(f\"Dataset info: {dataset}\")" |
| 105 | + ] |
| 106 | + }, |
| 107 | + { |
| 108 | + "cell_type": "code", |
| 109 | + "execution_count": 5, |
| 110 | + "id": "189a3558", |
| 111 | + "metadata": {}, |
| 112 | + "outputs": [ |
| 113 | + { |
| 114 | + "data": { |
| 115 | + "text/plain": [ |
| 116 | + "[['Millimeter-wave CMOS digital controlled artificial dielectric differential mode transmission lines for reconfigurable ICs',\n", |
| 117 | + " 'CMP network-on-chip overlaid with multi-band RF-interconnect',\n", |
| 118 | + " 'Route packets, not wires: on-chip interconnection networks'],\n", |
| 119 | + " ['Millimeter-wave CMOS digital controlled artificial dielectric differential mode transmission lines for reconfigurable ICs',\n", |
| 120 | + " 'CMP network-on-chip overlaid with multi-band RF-interconnect',\n", |
| 121 | + " 'Entheses: tendon and ligament attachment sites'],\n", |
| 122 | + " ['Millimeter-wave CMOS digital controlled artificial dielectric differential mode transmission lines for reconfigurable ICs',\n", |
| 123 | + " 'CMP network-on-chip overlaid with multi-band RF-interconnect',\n", |
| 124 | + " 'Packet leashes: a defense against wormhole attacks in wireless networks']]" |
| 125 | + ] |
| 126 | + }, |
| 127 | + "execution_count": 5, |
| 128 | + "metadata": {}, |
| 129 | + "output_type": "execute_result" |
| 130 | + } |
| 131 | + ], |
| 132 | + "source": [ |
| 133 | + "# Inspect the first 3 rows\n", |
| 134 | + "dataset[\"train\"][:3][\"set\"]" |
| 135 | + ] |
| 136 | + }, |
| 137 | + { |
| 138 | + "cell_type": "markdown", |
| 139 | + "id": "08700f5d", |
| 140 | + "metadata": {}, |
| 141 | + "source": [ |
| 142 | + "Each row in the dataset contains three sentences (or triplets): query, positive passage, and negative passage, in order.\n", |
| 143 | + "\n", |
| 144 | + "During training of the embedding model, contrastive learning is used to maximize the similarity between the query and the passage that contains the answer, while minimizing the similarity between the query and sampled negative passage not useful to answer the question." |
| 145 | + ] |
| 146 | + }, |
| 147 | + { |
| 148 | + "cell_type": "markdown", |
| 149 | + "id": "ba489337", |
| 150 | + "metadata": {}, |
| 151 | + "source": [ |
| 152 | + "## Step 2: Prepare Data for Customization" |
| 153 | + ] |
| 154 | + }, |
| 155 | + { |
| 156 | + "cell_type": "markdown", |
| 157 | + "id": "a8c498d7", |
| 158 | + "metadata": {}, |
| 159 | + "source": [ |
| 160 | + "For customizing embedding models, the NeMo Microservices platform leverages a JSONL format, where each row is:\n", |
| 161 | + "```\n", |
| 162 | + "{\n", |
| 163 | + " \"query\": \"query text\",\n", |
| 164 | + " \"pos_doc\": \"positive document text\",\n", |
| 165 | + " \"neg_doc\": [\"negative document text 1\", \"negative document text 2\", ...]\n", |
| 166 | + "}\n", |
| 167 | + "```\n", |
| 168 | + "\n", |
| 169 | + "The following code cell -\n", |
| 170 | + "1. Defines a helper for data splitting\n", |
| 171 | + "2. Uses a fraction of the data, and converts each row to the required format\n", |
| 172 | + "3. Saves the data splits to jsonl files" |
| 173 | + ] |
| 174 | + }, |
| 175 | + { |
| 176 | + "cell_type": "code", |
| 177 | + "execution_count": 6, |
| 178 | + "id": "a96197d1", |
| 179 | + "metadata": {}, |
| 180 | + "outputs": [ |
| 181 | + { |
| 182 | + "name": "stdout", |
| 183 | + "output_type": "stream", |
| 184 | + "text": [ |
| 185 | + "Total examples in dataset: 684100\n", |
| 186 | + "Using fraction of dataset: 68410/684100 examples (10.0%)\n", |
| 187 | + "Formatted 68410 examples\n", |
| 188 | + "\n", |
| 189 | + "Train set: 61569 examples\n", |
| 190 | + "Validation set: 3420 examples\n", |
| 191 | + "Test set: 3421 examples\n", |
| 192 | + "Saving data to: ./data/specter_10pct\n", |
| 193 | + "Saved 61569 examples to ./data/specter_10pct/training/training.jsonl\n", |
| 194 | + "Saved 3420 examples to ./data/specter_10pct/validation/validation.jsonl\n", |
| 195 | + "Saved 3421 examples to ./data/specter_10pct/testing/testing.jsonl\n", |
| 196 | + "\n", |
| 197 | + "First few examples from training set:\n", |
| 198 | + "Example 1:\n", |
| 199 | + " Query: Rhythm, Metrics, and the Link to Phonology\n", |
| 200 | + " Positive: Rhythm, Timing and the Timing of Rhythm\n", |
| 201 | + " Negative: ['Social software and participatory learning: Pedagogical choices with technology affordances in the Web 2.0 era']\n", |
| 202 | + "\n", |
| 203 | + "Example 2:\n", |
| 204 | + " Query: underwater image processing : state of the art of restoration and image enhancement methods .\n", |
| 205 | + " Positive: Image quality assessment: from error visibility to structural similarity\n", |
| 206 | + " Negative: ['An overview of home automation systems']\n", |
| 207 | + "\n", |
| 208 | + "Example 3:\n", |
| 209 | + " Query: Marginal Space Deep Learning: Efficient Architecture for Volumetric Image Parsing.\n", |
| 210 | + " Positive: Mitosis Detection in Breast Cancer Histology Images with Deep Neural Networks\n", |
| 211 | + " Negative: ['Regularized multi--task learning']\n", |
| 212 | + "\n" |
| 213 | + ] |
| 214 | + } |
| 215 | + ], |
| 216 | + "source": [ |
| 217 | + "def split_data(data, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):\n", |
| 218 | + " \"\"\"\n", |
| 219 | + " Splits the data into training, validation, and test sets.\n", |
| 220 | + " \"\"\"\n", |
| 221 | + " assert train_ratio + val_ratio + test_ratio == 1.0, \"Ratios must sum to 1\"\n", |
| 222 | + " \n", |
| 223 | + " # Compute split indices\n", |
| 224 | + " train_end = int(len(data) * train_ratio)\n", |
| 225 | + " val_end = train_end + int(len(data) * val_ratio)\n", |
| 226 | + " \n", |
| 227 | + " # Split the data\n", |
| 228 | + " train_set = data[:train_end]\n", |
| 229 | + " val_set = data[train_end:val_end]\n", |
| 230 | + " test_set = data[val_end:]\n", |
| 231 | + " \n", |
| 232 | + " return train_set, val_set, test_set\n", |
| 233 | + "\n", |
| 234 | + "\n", |
| 235 | + "try:\n", |
| 236 | + " # Get the raw data\n", |
| 237 | + " raw_data = dataset['train']['set']\n", |
| 238 | + " print(f\"Total examples in dataset: {len(raw_data)}\")\n", |
| 239 | + " \n", |
| 240 | + " # Shuffle the data once at the beginning\n", |
| 241 | + " raw_data_list = list(raw_data)\n", |
| 242 | + " random.shuffle(raw_data_list)\n", |
| 243 | + " \n", |
| 244 | + " # Apply fraction if specified (after shuffling)\n", |
| 245 | + " if USE_FRACTION:\n", |
| 246 | + " original_size = len(raw_data_list)\n", |
| 247 | + " fraction_size = int(len(raw_data_list) * FRACTION)\n", |
| 248 | + " raw_data_list = raw_data_list[:fraction_size]\n", |
| 249 | + " print(f\"Using fraction of dataset: {len(raw_data_list)}/{original_size} examples ({FRACTION*100:.1f}%)\")\n", |
| 250 | + " else:\n", |
| 251 | + " print(f\"Using full dataset: {len(raw_data_list)} examples\")\n", |
| 252 | + " \n", |
| 253 | + " # Format the data\n", |
| 254 | + " data = []\n", |
| 255 | + " for example in raw_data_list:\n", |
| 256 | + " data.append({\n", |
| 257 | + " \"query\": example[0],\n", |
| 258 | + " \"pos_doc\": example[1], \n", |
| 259 | + " \"neg_doc\": [example[2]] # neg_doc as a list of strings\n", |
| 260 | + " })\n", |
| 261 | + " print(f\"Formatted {len(data)} examples\")\n", |
| 262 | + " \n", |
| 263 | + " # Split the data\n", |
| 264 | + " train, val, test = split_data(data, train_ratio=0.90, val_ratio=0.05, test_ratio=0.05)\n", |
| 265 | + " \n", |
| 266 | + " print(f\"\\nTrain set: {len(train)} examples\")\n", |
| 267 | + " print(f\"Validation set: {len(val)} examples\")\n", |
| 268 | + " print(f\"Test set: {len(test)} examples\")\n", |
| 269 | + " \n", |
| 270 | + " # Generate save path with fraction suffix if using a fraction of the dataset\n", |
| 271 | + " if USE_FRACTION:\n", |
| 272 | + " # Convert fraction to percentage for folder name (e.g., 0.1 -> 10pct, 0.01 -> 1pct)\n", |
| 273 | + " fraction_pct = int(FRACTION * 100)\n", |
| 274 | + " folder_name = f\"specter_{fraction_pct}pct\"\n", |
| 275 | + " else:\n", |
| 276 | + " folder_name = \"specter_full\"\n", |
| 277 | + " \n", |
| 278 | + " save_path = os.path.join(DATA_SAVE_PATH, folder_name)\n", |
| 279 | + " print(f\"Saving data to: {save_path}\")\n", |
| 280 | + " \n", |
| 281 | + " # Create directories for each split\n", |
| 282 | + " for split_name in [\"training\", \"validation\", \"testing\"]:\n", |
| 283 | + " split_dir = os.path.join(save_path, split_name)\n", |
| 284 | + " os.makedirs(split_dir, exist_ok=True)\n", |
| 285 | + " \n", |
| 286 | + " # Save to JSONL files in respective folders\n", |
| 287 | + " for fname, ds, folder in ((\"training.jsonl\", train, \"training\"), \n", |
| 288 | + " (\"validation.jsonl\", val, \"validation\"), \n", |
| 289 | + " (\"testing.jsonl\", test, \"testing\")):\n", |
| 290 | + " file_path = os.path.join(save_path, folder, fname)\n", |
| 291 | + " with open(file_path, \"w\") as out:\n", |
| 292 | + " for obj in ds:\n", |
| 293 | + " out.write(json.dumps(obj) + \"\\n\")\n", |
| 294 | + " print(f\"Saved {len(ds)} examples to {file_path}\")\n", |
| 295 | + " \n", |
| 296 | + " # Display first few examples from training set\n", |
| 297 | + " print(\"\\nFirst few examples from training set:\")\n", |
| 298 | + " for i, example in enumerate(train[:3]):\n", |
| 299 | + " print(f\"Example {i+1}:\")\n", |
| 300 | + " print(f\" Query: {example['query']}\")\n", |
| 301 | + " print(f\" Positive: {example['pos_doc']}\")\n", |
| 302 | + " print(f\" Negative: {example['neg_doc']}\")\n", |
| 303 | + " print()\n", |
| 304 | + " \n", |
| 305 | + "except Exception as e:\n", |
| 306 | + " print(f\"Error loading dataset: {e}\")" |
| 307 | + ] |
| 308 | + } |
| 309 | + ], |
| 310 | + "metadata": { |
| 311 | + "kernelspec": { |
| 312 | + "display_name": "Python 3 (ipykernel)", |
| 313 | + "language": "python", |
| 314 | + "name": "python3" |
| 315 | + }, |
| 316 | + "language_info": { |
| 317 | + "codemirror_mode": { |
| 318 | + "name": "ipython", |
| 319 | + "version": 3 |
| 320 | + }, |
| 321 | + "file_extension": ".py", |
| 322 | + "mimetype": "text/x-python", |
| 323 | + "name": "python", |
| 324 | + "nbconvert_exporter": "python", |
| 325 | + "pygments_lexer": "ipython3", |
| 326 | + "version": "3.11.12" |
| 327 | + } |
| 328 | + }, |
| 329 | + "nbformat": 4, |
| 330 | + "nbformat_minor": 5 |
| 331 | +} |
0 commit comments