Skip to content

Commit 83d3224

Browse files
committed
transformer solver
1 parent b68d75f commit 83d3224

File tree

11 files changed

+2596
-0
lines changed

11 files changed

+2596
-0
lines changed

solvers/enumerative/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Enumerative puzzle solvers
2+
3+
This folder contains the code for the enumerative models used in our Programming Puzzles paper.
4+
We used python 3.8.0 and the libraries in the `requirements.txt` file.
5+
6+
In a linux machine with python3.8.0 installed, the following commands will set up the environment:
7+
```
8+
virtualenv -p /usr/bin/python3.8 env_solvers
9+
source env_solvers/bin/activate
10+
pip install -r requirements.txt
11+
```
12+
13+
## Uniform solver
14+
```
15+
bash run_uniform.sh
16+
```
17+
This will run the uniform solver for a maximum of 10k trials per puzzle. This is required before training the other parameterized solvers.
18+
19+
To run the uniform with 1M trials per puzzle, simply change the `max_n_progs` argument in the bash script.
20+
21+
## Bigram random forest solver
22+
```
23+
bash run_bigram.sh
24+
```
25+
This will first train a parameterized model with self-bootsrapping (first iteration is based on the unifrom solutions). The last command will train a model without self-bootsrapping.
26+
27+
## Transformers solver
28+
```
29+
bash download_pretrained_roberta.sh
30+
bash run_transformer.sh
31+
```
32+
The first script will download the RoBERTa-Base model that we trained on Python code.
33+
34+
The second script will first train a parameterized model with self-bootsrapping (first iteration is based on the unifrom solutions). The last command will train a model without self-bootsrapping.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
from challenges.challenge import *
22
from challenges.solutions import *
3+
4+
def contains_node(root, x_node):
5+
return root is x_node or (hasattr(root, "children") and any(contains_node(k, x_node) for k in root.children))
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#! /bin/bash
2+
3+
# Linux commands to download our Roberta model pretrained on Python code.
4+
# Newer vesrions of huggingface transformers don't require this but we need to adjust the rest of the code for them.
5+
6+
set -ex
7+
8+
mkdir tals
9+
mkdir tals/roberta_python
10+
11+
cd tals/roberta_python
12+
13+
wget https://huggingface.co/tals/roberta_python/resolve/main/config.json
14+
wget https://huggingface.co/tals/roberta_python/resolve/main/merges.txt
15+
wget https://huggingface.co/tals/roberta_python/resolve/main/pytorch_model.bin
16+
wget https://huggingface.co/tals/roberta_python/resolve/main/special_tokens_map.json
17+
wget https://huggingface.co/tals/roberta_python/resolve/main/tokenizer_config.json
18+
wget https://huggingface.co/tals/roberta_python/resolve/main/training_args.bin
19+
wget https://huggingface.co/tals/roberta_python/resolve/main/vocab.json
20+
21+
cd ../..

solvers/enumerative/models/transformers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)