|
| 1 | +import torch |
| 2 | +from datasets import load_dataset |
| 3 | +from transformers import AutoTokenizer |
| 4 | + |
| 5 | +# parallel processing |
| 6 | +from pandarallel import pandarallel |
| 7 | +pandarallel.initialize(progress_bar=True, nb_workers=16) |
| 8 | +from tqdm import tqdm |
| 9 | +tqdm.pandas() |
| 10 | + |
| 11 | +# utility |
| 12 | +import pandas as pd |
| 13 | +import numpy as np |
| 14 | +import matplotlib.pyplot as plt |
| 15 | +import re |
| 16 | +import os |
| 17 | + |
| 18 | +seed = 42 |
| 19 | +n_train = 20000 |
| 20 | +n_valid = 1000 |
| 21 | +n_test = 10000 |
| 22 | +n_samples = n_train + n_valid + n_test # 31000 |
| 23 | +hf_dir = "AISE-TUDelft/the-heap" |
| 24 | + |
| 25 | +# Filter 0, filtering out files with less than 300 tokens |
| 26 | +# Load the tokenizer from HF-hub |
| 27 | +checkpoint = "bigcode/starcoder2-3b" |
| 28 | +tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
| 29 | +# Tokenize the sequences |
| 30 | +# keep in mind of truncation |
| 31 | +# in this case avoid padding! Since we don't want to sample place holders! |
| 32 | +def tokenize_input(batch): |
| 33 | + return tokenizer(batch['content'], padding = 'do_not_pad', return_tensors='pt') |
| 34 | + |
| 35 | +# Filters: |
| 36 | +def longline_filter(file): |
| 37 | + """ |
| 38 | + input: a file string |
| 39 | + output: a boolean, True if the file passes the filter, False if not |
| 40 | + """ |
| 41 | + # remove double space and split the file into a list of lines |
| 42 | + lines = file.replace('\n\n', '\n').split('\n') |
| 43 | + # if the number of lines overcome 100k filter out |
| 44 | + if len(lines) >= 100000: |
| 45 | + return False |
| 46 | + # if the n lines is below 100k we have the second filter: |
| 47 | + # remove files with maximum line length of more than 1000 characters and |
| 48 | + # an average line length of more than 100 characters |
| 49 | + line_count = [len(line)for _,line in enumerate(lines) if len(line) <= 1000] |
| 50 | + if line_count == []: |
| 51 | + return False |
| 52 | + else: |
| 53 | + return np.mean(line_count) <= 100 |
| 54 | + |
| 55 | +def alpha_filter(file): |
| 56 | + """ |
| 57 | + input: a file string |
| 58 | + output: the percentage of alphabetic characters |
| 59 | + """ |
| 60 | + total_chars = len(file) |
| 61 | + if total_chars == 0: |
| 62 | + return 0 |
| 63 | + alpha_chars = sum(c.isalpha() for c in file) |
| 64 | + return (alpha_chars / total_chars) * 100 |
| 65 | + |
| 66 | +base64_regex = re.compile(r'[a-zA-Z0-9+/\n=]{64,}') |
| 67 | +hex_regex = re.compile(r'(?:\b(?:0x|\\x)?[0-9a-fA-F]{2}(?:,|\b\s*)){8,}') |
| 68 | +unicode_regex = re.compile(r'(?:\\u[0-9a-fA-F]{4}){8,}') |
| 69 | + |
| 70 | +def encoded_data_filter (file): |
| 71 | + total_length = len(file) |
| 72 | + """ |
| 73 | + input: a file string |
| 74 | + output: a boolean, True if the file passes the filter, False if not |
| 75 | + """ |
| 76 | + # Find all matches of the regex patterns |
| 77 | + base64_matches = base64_regex.findall(file) |
| 78 | + hex_matches = hex_regex.findall(file) |
| 79 | + unicode_matches = unicode_regex.findall(file) |
| 80 | + |
| 81 | + # Concatenate all matches into one list |
| 82 | + all_matches = base64_matches + hex_matches + unicode_matches |
| 83 | + |
| 84 | + # Calculate the total length of all matched strings |
| 85 | + matched_length = sum(len(match) for match in all_matches) |
| 86 | + |
| 87 | + # Check if any match exceeds 1024 characters or if matched fraction is more than 50% |
| 88 | + if any(len(match) > 1024 for match in all_matches) or (matched_length / total_length > 0.5): |
| 89 | + return False |
| 90 | + |
| 91 | + return True |
| 92 | + |
| 93 | +# following TheStackV2 paper. Remove files classified as auto-generated by the is_generated function of go-enry |
| 94 | +# go-enry java regexes |
| 95 | +auto_gen1 = re.compile(r'Generated by the protocol buffer compiler\. DO NOT EDIT!') |
| 96 | +auto_gen2 = re.compile(r'Autogenerated by Thrift Compiler') |
| 97 | +auto_gen3 = re.compile(r'/* The following code was generated by JFlex ') |
| 98 | +auto_gen4 = re.compile(r'// This is a generated file\. Not intended for manual editing.') |
| 99 | +auto_gen5 = re.compile(r'Generated by Haxe') |
| 100 | +auto_gen6 = re.compile(r'This file is generated by jOOQ.') |
| 101 | +# additional regex (implemented by TheStackV2) |
| 102 | +auto_gen7 = re.compile(r'auto-?generated|automatically\s*generated|generated\s*automatically|this\s*file\s*is\s*generated') |
| 103 | +# pattern for repetitive lines (added by me) |
| 104 | +auto_hen8 = re.compile(r'(.*)\1{3,}') |
| 105 | + |
| 106 | +def autogen(file): |
| 107 | + match1 = auto_gen1.findall(file) |
| 108 | + match2 = auto_gen2.findall(file) |
| 109 | + match3 = auto_gen3.findall(file) |
| 110 | + match4 = auto_gen4.findall(file) |
| 111 | + match5 = auto_gen5.findall(file) |
| 112 | + match6 = auto_gen6.findall(file) |
| 113 | + match7 = auto_gen7.findall(file) |
| 114 | + |
| 115 | + all_matches = match1 + match2 + match3 + match4 + match5 + match6 + match7 |
| 116 | + |
| 117 | + if all_matches == []: |
| 118 | + return True |
| 119 | + else: False |
| 120 | + |
| 121 | +# load the dataset |
| 122 | +dataset = load_dataset(hf_dir, split= 'train') |
| 123 | +# Filter out the files near duplicate with at least one file from TheStackV2. |
| 124 | +dataset = dataset.filter(lambda sample: len(sample["near_dups_stkv2_idx"]) == 0) |
| 125 | +# tokenize |
| 126 | +dataset = dataset.map(tokenize_input, batched=False, num_proc=64) |
| 127 | + |
| 128 | +# Filter 0: remove files with less than 300 tokens |
| 129 | +# convert to pandas df |
| 130 | +df = dataset.to_pandas() |
| 131 | +# I have a list of lists and I want to get rid of that |
| 132 | +df['input_ids'] = df['input_ids'].progress_apply(lambda x: x[0]) |
| 133 | +# filter out the input with a number of tokens > 300 |
| 134 | +df['n_tok'] = df['input_ids'].progress_apply(len) |
| 135 | +df = df[df['n_tok'] > 300 ] |
| 136 | + |
| 137 | +# Longline filter |
| 138 | +df_try = df.copy() |
| 139 | +df_try['longline'] = df_try.progress_apply(lambda x: longline_filter(x['content']), axis=1) |
| 140 | +# Alpha filter |
| 141 | +df_try['alpha'] = df_try.progress_apply(lambda x: alpha_filter(x['content']), axis=1) |
| 142 | +# Encoded data filter |
| 143 | +df_try['encoded'] = df_try.progress_apply(lambda x: encoded_data_filter(x['content']), axis=1) |
| 144 | +# Autogen filter |
| 145 | +df_try['autogen'] = df_try.progress_apply(lambda x: autogen(x['content']), axis=1) |
| 146 | + |
| 147 | +# Filter out the files that pass all the filters |
| 148 | +df_filtered = df_try[(df_try['longline'] == True) | (df_try['alpha'] > 25) | (df_try['encoded'] == True) | (df_try['autogen'] == True)] |
| 149 | + |
| 150 | +# Sampling |
| 151 | +df_sampled = df_filtered.sample(n = n_samples, replace = False, random_state = seed) |
| 152 | +df_train = df_sampled.iloc[0:n_train] |
| 153 | +df_valid = df_sampled.iloc[n_train:n_train+n_valid] |
| 154 | +df_test = df_sampled.iloc[n_train+n_valid:] |
| 155 | + |
| 156 | +print(f" train: {df_train.shape}\n valid:{df_valid.shape}\n test:{df_test.shape} ") |
| 157 | + |
| 158 | +# Saving the files |
| 159 | +# You can download these files from this link: https://huggingface.co/datasets/AISE-TUDelft/memtune-tuning_data |
| 160 | +df_train.to_parquet('./train_java.parquet', index = False) |
| 161 | +df_valid.to_parquet('./valid_java.parquet', index = False) |
| 162 | +df_test.to_parquet('./test_java.parquet', index = False) |
0 commit comments