Skip to content

Commit 0588c30

Browse files
authored
Merge pull request #14 from TalSchuster/baselines
Enumerative solvers
2 parents 7946dae + 83d3224 commit 0588c30

31 files changed

+6767
-0
lines changed

solvers/enumerative/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Enumerative puzzle solvers
2+
3+
This folder contains the code for the enumerative models used in our Programming Puzzles paper.
4+
We used python 3.8.0 and the libraries in the `requirements.txt` file.
5+
6+
In a linux machine with python3.8.0 installed, the following commands will set up the environment:
7+
```
8+
virtualenv -p /usr/bin/python3.8 env_solvers
9+
source env_solvers/bin/activate
10+
pip install -r requirements.txt
11+
```
12+
13+
## Uniform solver
14+
```
15+
bash run_uniform.sh
16+
```
17+
This will run the uniform solver for a maximum of 10k trials per puzzle. This is required before training the other parameterized solvers.
18+
19+
To run the uniform with 1M trials per puzzle, simply change the `max_n_progs` argument in the bash script.
20+
21+
## Bigram random forest solver
22+
```
23+
bash run_bigram.sh
24+
```
25+
This will first train a parameterized model with self-bootsrapping (first iteration is based on the unifrom solutions). The last command will train a model without self-bootsrapping.
26+
27+
## Transformers solver
28+
```
29+
bash download_pretrained_roberta.sh
30+
bash run_transformer.sh
31+
```
32+
The first script will download the RoBERTa-Base model that we trained on Python code.
33+
34+
The second script will first train a parameterized model with self-bootsrapping (first iteration is based on the unifrom solutions). The last command will train a model without self-bootsrapping.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from challenges.challenge import *
2+
from challenges.solutions import *
3+
4+
def contains_node(root, x_node):
5+
return root is x_node or (hasattr(root, "children") and any(contains_node(k, x_node) for k in root.children))
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import utils
2+
import tython
3+
import logging
4+
from typing import List, Dict, Callable, Tuple, Generator, Set, Sequence
5+
from tython import Program, nt
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def extract_constants(prog) -> Dict:
11+
'''
12+
Extract all constants from program. Does not (yet) allow copying of comprehensions, e.g., '[i*i for i in range(10)]'
13+
'''
14+
15+
from collections import defaultdict
16+
consts = defaultdict(list)
17+
18+
def handle_args(args_node):
19+
20+
if args_node.rule.name == 'cast:ARGS':
21+
handle_args(args_node.children[0])
22+
else:
23+
if len(args_node.children) >= 3 and args_node.children[1].nt == nt.TYPE:
24+
annotation_node = args_node.children[1]
25+
t = nt.type2nt(eval(annotation_node.src()))
26+
consts[t].append(args_node.children[0])
27+
if args_node.children and args_node.children[-1].nt in {nt.ARGS, nt.DEFAULT_ARGS}:
28+
handle_args(args_node.children[-1])
29+
30+
def helper(node):
31+
if node.rule.name == 'def': # it's a function
32+
name_node, args_node, body_node = node.children
33+
if name_node.src() == 'sat':
34+
handle_args(args_node.children[-1]) # skip first arg for `def sat`
35+
else:
36+
handle_args(args_node)
37+
helper(body_node)
38+
return False
39+
elif node.nt in {nt.NAME}:
40+
return False
41+
elif node.nt in {nt.STMT}:
42+
for c in node.children:
43+
helper(c)
44+
return False
45+
if node.rule.name not in {"int-const", "str-const"} and not all([helper(c) for c in node.children]):
46+
return False
47+
if node.nt.isa(nt.LIST, nt.SET, nt.DICT, nt.TUPLE, nt.RANGE,
48+
nt.INT, nt.FLOAT, nt.BOOL, nt.STR):
49+
consts[node.nt].append(node)
50+
return True
51+
52+
if prog is not None:
53+
helper(prog.tree)
54+
55+
return dict(consts)
56+
57+
#
58+
# q = Program("""
59+
# def sat(i: List[str], a=5):
60+
# return i==['5']
61+
# """)
62+
#
63+
# extract_constants(q)
64+
#
65+
#
66+
# %%
67+
class Solution():
68+
def __init__(self, string=None, prog=None, likelihood=None, time=None, count=None):
69+
self.string = string
70+
self.prog = prog
71+
self.likelihood = likelihood
72+
self.time = time
73+
self.count = count
74+
75+
76+
class SolverSolution(Solution):
77+
def __init__(self, string=None, prog=None, likelihood=None, time=None, count=None):
78+
super().__init__(string=string, prog=prog, likelihood=likelihood)
79+
self.time = time
80+
self.count = count
81+
82+
83+
def get_arg_type_str(sat_str):
84+
assert sat_str.startswith("def sat(") and ":" in sat_str
85+
depth = 0
86+
for i, c in enumerate(sat_str):
87+
if c == '[':
88+
depth += 1
89+
elif c == ']':
90+
depth -= 1
91+
elif c in ")," and depth == 0:
92+
return sat_str[sat_str.index(":") + 1:i].lstrip()
93+
assert False
94+
95+
96+
class Challenge():
97+
def __init__(self, challenge_config, max_ticks=100000000):
98+
self.name = challenge_config["name"]
99+
self.f_str = challenge_config["sat"]
100+
self.type_str = get_arg_type_str(challenge_config["sat"])
101+
self.type = eval(self.type_str)
102+
self.gold_solutions = []
103+
self.solver_solutions = []
104+
for sol in challenge_config["sols"]:
105+
self.gold_solutions.append(Solution(string=sol))
106+
if "sol_tries" in challenge_config:
107+
for i, x in enumerate(challenge_config["sol_tries"]):
108+
self.gold_solutions[i].count = x
109+
110+
if "sol_time" in challenge_config:
111+
for i, x in enumerate(challenge_config["sol_time"]):
112+
self.gold_solutions[i].time = x
113+
114+
self.solution_strs = challenge_config["sols"]
115+
self.max_ticks = max_ticks
116+
117+
self._parse_challenge()
118+
119+
def _parse_challenge(self):
120+
'''
121+
Converts the challenge string to a tython program.
122+
'''
123+
self.sol_kind = tython.nt.type2nt(self.type)
124+
self.prog = None
125+
self.f = None
126+
try:
127+
self.prog = tython.Program(
128+
self.f_str)
129+
self.f = self.prog.run(max_ticks=self.max_ticks)
130+
except Program.EvalException as e:
131+
logger.warning(f"Exception evaluating {self.name} '{self.f_str}': {e}")
132+
except Exception as e:
133+
logger.warning(f"Exception parsing {self.name} '{self.f_str}': {e}")
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import List, Set, Dict, Callable, Tuple
2+
import logging
3+
from challenges import extract_constants
4+
5+
from tython import Program, TastNode, _RULES_BY_KIND, RULES, Rule, str2name
6+
from tython.rules import DEF_RULE
7+
8+
logger = logging.getLogger(__name__)
9+
logger.setLevel(logging.INFO)
10+
11+
12+
def generatable_answer(q: Program, a: Program):
13+
def get_src(node):
14+
return Program(node).src(safe=False, simplify=False)
15+
16+
consts = extract_constants(q)
17+
const_srcs = {k: [get_src(c) for c in consts[k]] for k in consts}
18+
19+
def helper(anode):
20+
if anode.rule.name == "COPY":
21+
return get_src(anode.children[0]) in const_srcs[anode.rule.nt]
22+
return anode.rule.name != "literal" and all(helper(n) for n in anode.children)
23+
24+
return helper(a.tree)
25+
26+
27+
def verify_solutions(challenges):
28+
'''
29+
Verify all provided solutions to the given challenges and store the parsed solution
30+
program in the challenge object.
31+
'''
32+
successes = 0
33+
all_correct = True
34+
for ch in challenges:
35+
verified_sols = []
36+
f = ch.prog.run(max_ticks=ch.max_ticks)['sat']
37+
args_node = ch.prog.tree.children[0].children[1].children[-1]
38+
for sol_p in ch.gold_solutions:
39+
s = sol_p.string
40+
if s == '':
41+
continue
42+
# Verify solution is correct.
43+
if True:
44+
assert s.startswith('def ')
45+
try:
46+
sol_prog = Program(s)
47+
except Exception as e:
48+
logger.error(f"Exception parsing solution for {ch.name} '{s}': {e}")
49+
continue
50+
51+
# Inject the value assignments for the variables to the call of the sol func.
52+
p_body = sol_prog.tree.children[0].children[2]
53+
sol_prog.tree.children[0].children[1].children = args_node.children
54+
sol_prog = Program(TastNode(DEF_RULE, [str2name("sol"), args_node, p_body]))
55+
56+
a_safe = sol_prog.src(simplify=False)
57+
x = sol_prog.run(max_ticks=ch.max_ticks)["sol"]()
58+
ch.prog.reset_clock()
59+
60+
v = f(x)
61+
assert isinstance(v, bool)
62+
63+
if not generatable_answer(ch.prog, sol_prog):
64+
logger.error(f'Challenge "{ch.name}" cannot be used to automatically generate solution "{s}"')
65+
66+
# TODO
67+
# if type(y) != ch.type:
68+
# print(f'Challenge "{ch.name}" has wrong solution type: "{type(y)}"')
69+
# all_correct = False
70+
if v is not True: # checks both False and None
71+
logger.error(f'Challenge "{ch.name}" not satisfied by solution "{s}"')
72+
else:
73+
sol_p.prog = sol_prog
74+
verified_sols.append(sol_p)
75+
successes += 1
76+
77+
ch.gold_solutions = verified_sols
78+
79+
logger.info(
80+
f"Tython confirmed {successes:,} solutions to {len(challenges)} challenges."
81+
)
82+
return
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#! /bin/bash
2+
3+
# Linux commands to download our Roberta model pretrained on Python code.
4+
# Newer vesrions of huggingface transformers don't require this but we need to adjust the rest of the code for them.
5+
6+
set -ex
7+
8+
mkdir tals
9+
mkdir tals/roberta_python
10+
11+
cd tals/roberta_python
12+
13+
wget https://huggingface.co/tals/roberta_python/resolve/main/config.json
14+
wget https://huggingface.co/tals/roberta_python/resolve/main/merges.txt
15+
wget https://huggingface.co/tals/roberta_python/resolve/main/pytorch_model.bin
16+
wget https://huggingface.co/tals/roberta_python/resolve/main/special_tokens_map.json
17+
wget https://huggingface.co/tals/roberta_python/resolve/main/tokenizer_config.json
18+
wget https://huggingface.co/tals/roberta_python/resolve/main/training_args.bin
19+
wget https://huggingface.co/tals/roberta_python/resolve/main/vocab.json
20+
21+
cd ../..
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import json
2+
import sys
3+
import os
4+
5+
6+
inp_file = sys.argv[1]
7+
if len(sys.argv) > 2:
8+
shift = int(sys.argv[2])
9+
else:
10+
shift = 0
11+
out_file = os.path.splitext(inp_file)[0]
12+
13+
14+
with open(inp_file, 'r') as f:
15+
data = json.load(f)
16+
17+
thresholds = [100, 1000, 10000, 100000, 1000000]
18+
for t in thresholds:
19+
t = t
20+
out = []
21+
suc = 0
22+
for p in data:
23+
#if not p["name"].startswith("Study"):
24+
# continue
25+
if p["sols"][-1] != "" and p["sol_tries"][-1] + shift <= t:
26+
out.append(p)
27+
suc += 1
28+
else:
29+
out.append(dict(name=p["name"], sat=p["sat"], sols=[]))
30+
31+
print(f"t={t}: solutions: {suc}/ {len(out)}")
32+
33+
with open(out_file + f"_{t}.json", "w") as fw:
34+
json.dump(out, fw, indent=4)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
MODEL_REGISTRY = {}
2+
def RegisterModel(model_name):
3+
def decorator(m):
4+
MODEL_REGISTRY[model_name] = m
5+
return m
6+
7+
return decorator
8+
9+
from models.uniform import *
10+
#from models.bigram import *
11+
#from models.ml_bow_unigram import *
12+
from models.ml_bow_bigram import *

0 commit comments

Comments
 (0)