Skip to content

Commit 4c3b1fc

Browse files
Fabio SalernoFabio Salerno
authored andcommitted
add data/ directory
1 parent 99ac466 commit 4c3b1fc

File tree

4 files changed

+682
-0
lines changed

4 files changed

+682
-0
lines changed

data/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Data Directory Documentation
2+
3+
This repository contains scripts and tools for dataset filtering and sample creation, organized into two main directories.
4+
5+
## Directory Structure
6+
7+
```
8+
data/
9+
├── filtering/
10+
└── samples-creation/
11+
```
12+
13+
## Filtering Directory
14+
15+
Contains the script for dataset filtering and sampling to create the fine-tuning dataset.
16+
17+
### Datasets
18+
- **Fine-tuning Dataset**: Available on HuggingFace at [AISE-TUDelft/memtune-tuning_data](https://huggingface.co/datasets/AISE-TUDelft/memtune-tuning_data)
19+
- **Source Dataset**: The original dataset that underwent filtering is available at [AISE-TUDelft/the-heap](https://huggingface.co/datasets/AISE-TUDelft/the-heap)
20+
21+
## Sample Creation Directory
22+
23+
Contains scripts for generating data extraction benchmarks used in data extraction attacks.
24+
25+
### Scripts
26+
- `sample_identification_mem.py`: Generates benchmark dataset for the **fine-tuning code attack**
27+
- `sample_identification_forget.py`: Generates benchmark dataset for the **pre-training code attack**
28+
- Note: Requires downloading the [TheStackV2 Java subset](https://huggingface.co/datasets/bigcode/the-stack-v2-dedup/viewer/Java)
29+
30+
### Benchmarks
31+
The generated data extraction benchmarks are available at [AISE-TUDelft/memtune-data_attack](https://huggingface.co/datasets/AISE-TUDelft/memtune-data_attack)

data/filtering/filter.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import torch
2+
from datasets import load_dataset
3+
from transformers import AutoTokenizer
4+
5+
# parallel processing
6+
from pandarallel import pandarallel
7+
pandarallel.initialize(progress_bar=True, nb_workers=16)
8+
from tqdm import tqdm
9+
tqdm.pandas()
10+
11+
# utility
12+
import pandas as pd
13+
import numpy as np
14+
import matplotlib.pyplot as plt
15+
import re
16+
import os
17+
18+
seed = 42
19+
n_train = 20000
20+
n_valid = 1000
21+
n_test = 10000
22+
n_samples = n_train + n_valid + n_test # 31000
23+
hf_dir = "AISE-TUDelft/the-heap"
24+
25+
# Filter 0, filtering out files with less than 300 tokens
26+
# Load the tokenizer from HF-hub
27+
checkpoint = "bigcode/starcoder2-3b"
28+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
29+
# Tokenize the sequences
30+
# keep in mind of truncation
31+
# in this case avoid padding! Since we don't want to sample place holders!
32+
def tokenize_input(batch):
33+
return tokenizer(batch['content'], padding = 'do_not_pad', return_tensors='pt')
34+
35+
# Filters:
36+
def longline_filter(file):
37+
"""
38+
input: a file string
39+
output: a boolean, True if the file passes the filter, False if not
40+
"""
41+
# remove double space and split the file into a list of lines
42+
lines = file.replace('\n\n', '\n').split('\n')
43+
# if the number of lines overcome 100k filter out
44+
if len(lines) >= 100000:
45+
return False
46+
# if the n lines is below 100k we have the second filter:
47+
# remove files with maximum line length of more than 1000 characters and
48+
# an average line length of more than 100 characters
49+
line_count = [len(line)for _,line in enumerate(lines) if len(line) <= 1000]
50+
if line_count == []:
51+
return False
52+
else:
53+
return np.mean(line_count) <= 100
54+
55+
def alpha_filter(file):
56+
"""
57+
input: a file string
58+
output: the percentage of alphabetic characters
59+
"""
60+
total_chars = len(file)
61+
if total_chars == 0:
62+
return 0
63+
alpha_chars = sum(c.isalpha() for c in file)
64+
return (alpha_chars / total_chars) * 100
65+
66+
base64_regex = re.compile(r'[a-zA-Z0-9+/\n=]{64,}')
67+
hex_regex = re.compile(r'(?:\b(?:0x|\\x)?[0-9a-fA-F]{2}(?:,|\b\s*)){8,}')
68+
unicode_regex = re.compile(r'(?:\\u[0-9a-fA-F]{4}){8,}')
69+
70+
def encoded_data_filter (file):
71+
total_length = len(file)
72+
"""
73+
input: a file string
74+
output: a boolean, True if the file passes the filter, False if not
75+
"""
76+
# Find all matches of the regex patterns
77+
base64_matches = base64_regex.findall(file)
78+
hex_matches = hex_regex.findall(file)
79+
unicode_matches = unicode_regex.findall(file)
80+
81+
# Concatenate all matches into one list
82+
all_matches = base64_matches + hex_matches + unicode_matches
83+
84+
# Calculate the total length of all matched strings
85+
matched_length = sum(len(match) for match in all_matches)
86+
87+
# Check if any match exceeds 1024 characters or if matched fraction is more than 50%
88+
if any(len(match) > 1024 for match in all_matches) or (matched_length / total_length > 0.5):
89+
return False
90+
91+
return True
92+
93+
# following TheStackV2 paper. Remove files classified as auto-generated by the is_generated function of go-enry
94+
# go-enry java regexes
95+
auto_gen1 = re.compile(r'Generated by the protocol buffer compiler\. DO NOT EDIT!')
96+
auto_gen2 = re.compile(r'Autogenerated by Thrift Compiler')
97+
auto_gen3 = re.compile(r'/* The following code was generated by JFlex ')
98+
auto_gen4 = re.compile(r'// This is a generated file\. Not intended for manual editing.')
99+
auto_gen5 = re.compile(r'Generated by Haxe')
100+
auto_gen6 = re.compile(r'This file is generated by jOOQ.')
101+
# additional regex (implemented by TheStackV2)
102+
auto_gen7 = re.compile(r'auto-?generated|automatically\s*generated|generated\s*automatically|this\s*file\s*is\s*generated')
103+
# pattern for repetitive lines (added by me)
104+
auto_hen8 = re.compile(r'(.*)\1{3,}')
105+
106+
def autogen(file):
107+
match1 = auto_gen1.findall(file)
108+
match2 = auto_gen2.findall(file)
109+
match3 = auto_gen3.findall(file)
110+
match4 = auto_gen4.findall(file)
111+
match5 = auto_gen5.findall(file)
112+
match6 = auto_gen6.findall(file)
113+
match7 = auto_gen7.findall(file)
114+
115+
all_matches = match1 + match2 + match3 + match4 + match5 + match6 + match7
116+
117+
if all_matches == []:
118+
return True
119+
else: False
120+
121+
# load the dataset
122+
dataset = load_dataset(hf_dir, split= 'train')
123+
# Filter out the files near duplicate with at least one file from TheStackV2.
124+
dataset = dataset.filter(lambda sample: len(sample["near_dups_stkv2_idx"]) == 0)
125+
# tokenize
126+
dataset = dataset.map(tokenize_input, batched=False, num_proc=64)
127+
128+
# Filter 0: remove files with less than 300 tokens
129+
# convert to pandas df
130+
df = dataset.to_pandas()
131+
# I have a list of lists and I want to get rid of that
132+
df['input_ids'] = df['input_ids'].progress_apply(lambda x: x[0])
133+
# filter out the input with a number of tokens > 300
134+
df['n_tok'] = df['input_ids'].progress_apply(len)
135+
df = df[df['n_tok'] > 300 ]
136+
137+
# Longline filter
138+
df_try = df.copy()
139+
df_try['longline'] = df_try.progress_apply(lambda x: longline_filter(x['content']), axis=1)
140+
# Alpha filter
141+
df_try['alpha'] = df_try.progress_apply(lambda x: alpha_filter(x['content']), axis=1)
142+
# Encoded data filter
143+
df_try['encoded'] = df_try.progress_apply(lambda x: encoded_data_filter(x['content']), axis=1)
144+
# Autogen filter
145+
df_try['autogen'] = df_try.progress_apply(lambda x: autogen(x['content']), axis=1)
146+
147+
# Filter out the files that pass all the filters
148+
df_filtered = df_try[(df_try['longline'] == True) | (df_try['alpha'] > 25) | (df_try['encoded'] == True) | (df_try['autogen'] == True)]
149+
150+
# Sampling
151+
df_sampled = df_filtered.sample(n = n_samples, replace = False, random_state = seed)
152+
df_train = df_sampled.iloc[0:n_train]
153+
df_valid = df_sampled.iloc[n_train:n_train+n_valid]
154+
df_test = df_sampled.iloc[n_train+n_valid:]
155+
156+
print(f" train: {df_train.shape}\n valid:{df_valid.shape}\n test:{df_test.shape} ")
157+
158+
# Saving the files
159+
# You can download these files from this link: https://huggingface.co/datasets/AISE-TUDelft/memtune-tuning_data
160+
df_train.to_parquet('./train_java.parquet', index = False)
161+
df_valid.to_parquet('./valid_java.parquet', index = False)
162+
df_test.to_parquet('./test_java.parquet', index = False)

0 commit comments

Comments
 (0)