Skip to content

Commit 7c57ec8

Browse files
committed
Added GPT3 benchmark solver
1 parent 0588c30 commit 7c57ec8

File tree

13 files changed

+3023
-0
lines changed

13 files changed

+3023
-0
lines changed

solvers/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Solvers
2+
3+
This folder contains two subfolders for recreating the benchmarks in the [paper](https://arxiv.org/abs/2106.05784).
4+
* [gpt3](/benchmarks/gpt3) The GPT-3 experiments.
5+
* [enumerative](/benchmarks/enumerative) The enumerative top-down search solvers.
6+
7+
Each folder has a separate README explaining how to run the experiments.

solvers/gpt3/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Running GPT-3 experiments
2+
3+
These are instructions for re-running the GPT-3 experiments. The results will be slightly different than those in
4+
the paper because the API is non-deterministic.
5+
6+
The requirements can be installed with `pip3 install -r requirements.txt`.
7+
8+
This script runs the GPT-3 experiments and prints the results to stdout.
9+
10+
## Installation and execution.
11+
You will need an open-ai GPT-3 access key which can be signed up for [here](https://openai.com/join/).
12+
You will then need to set it as the `OPENAI_API_KEY` environmnet variable.
13+
14+
The requirements can be installed with `pip3 install -r requirements.txt`.
15+
16+
It was run with Python 3.6.9, sys.version = '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]', but should
17+
be compatible with later versions as well.
18+
19+
Then you simply run
20+
`python run_gpt3_experiments.py` and the results are written to stdout. It uses cacheing mechanisms with the first run
21+
being quite slow and verbose, querying the API. However you can subsequently run it again and it will be
22+
much faster and just output the results. The cacheing makes it deterministic so it should give the same
23+
exact results when re-run.
24+
25+
26+
## Contact
27+
28+
If you are interested in reproducing the exact results of the paper, please contact the authors to ensure the exact
29+
same query results.

solvers/gpt3/ezlog.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
import logging
3+
import inspect
4+
import io
5+
6+
my_path = os.path.dirname(__file__)
7+
8+
9+
def color_str(obj, code="\033[0;36m"):
10+
return code + str(obj) + '\033[0m'
11+
12+
13+
_configured = False
14+
15+
16+
def configure_logging(stdio_level=logging.INFO,
17+
file_level=logging.DEBUG,
18+
filename=".easy.log",
19+
filepath=os.path.join(my_path, "logs")):
20+
os.makedirs(filepath, exist_ok=True)
21+
filename = os.path.join(filepath, filename)
22+
global _configured
23+
if _configured:
24+
warning("Re-configuring logging")
25+
stdio_handler = logging.StreamHandler()
26+
stdio_handler.setLevel(stdio_level)
27+
file_hanlder = logging.FileHandler(filename)
28+
file_hanlder.setLevel(file_level)
29+
30+
logging.basicConfig(
31+
format="%(asctime)s - %(levelname)s - %(name)s - %(message).200s",
32+
datefmt="%m/%d/%Y %H:%M:%S",
33+
level=min(stdio_level, file_level),
34+
handlers=[stdio_handler, file_hanlder]
35+
)
36+
#
37+
# fh = logging.FileHandler('spam.log')
38+
# fh.setLevel(logging.DEBUG)
39+
# # create console handler with a higher log level
40+
# ch = logging.StreamHandler()
41+
# ch.setLevel(logging.ERROR)
42+
# # create formatter and add it to the handlers
43+
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
44+
# fh.setFormatter(formatter)
45+
# ch.setFormatter(formatter)
46+
# # add the handlers to the logger
47+
# logger.addHandler(fh)
48+
# logger.addHandler(ch)
49+
50+
_configured = True
51+
_get_or_create_logger().debug("Configured logging")
52+
53+
54+
_loggers = {}
55+
56+
57+
def _get_or_create_logger():
58+
global _configured, _loggers
59+
if not _configured:
60+
configure_logging()
61+
try:
62+
for frame in inspect.stack():
63+
name = inspect.getmodule(frame[0]).__name__
64+
if name != __name__:
65+
break
66+
except:
67+
name = "_"
68+
if name not in _loggers:
69+
_loggers[name] = logging.getLogger(name)
70+
return _loggers[name]
71+
72+
73+
def print_to_string(*args, end="", **kwargs):
74+
with io.StringIO() as buf:
75+
print(*args, file=buf, end=end, **kwargs)
76+
return buf.getvalue()
77+
78+
79+
def debug(*args, **kwargs):
80+
_get_or_create_logger().debug(print_to_string(*args, **kwargs))
81+
82+
83+
def info(*args, **kwargs):
84+
_get_or_create_logger().info(print_to_string(*args, **kwargs))
85+
86+
87+
log = info
88+
89+
90+
def warning(*args, **kwargs):
91+
_get_or_create_logger().warning(print_to_string(*args, **kwargs))
92+
93+
94+
warn = warning
95+
96+
97+
def error(*args, **kwargs):
98+
_get_or_create_logger().error(print_to_string(*args, **kwargs))

solvers/gpt3/lm_solve/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from lm_solve.run import *

solvers/gpt3/lm_solve/gpt3_lib.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import os
2+
import json
3+
import openai
4+
import ezlog
5+
import time
6+
import datetime
7+
8+
assert 'OPENAI_API_KEY' in os.environ, "Need to set environment variable `OPENAI_API_KEY`"
9+
openai.api_key = os.environ['OPENAI_API_KEY']
10+
11+
12+
_CACHE_PATH = os.path.join(os.path.dirname(__file__), "../.cache")
13+
_CACHE_FILENAME = os.path.join(_CACHE_PATH, "gpt3.cache")
14+
_ENCODING = "utf-8"
15+
16+
_cache = None
17+
18+
19+
# the cache file is just a list of (query params dictionary encoded as a string but without n, result list)
20+
# multiple queries with the same params (except for n) are merged into a single big list
21+
def _save_line(item, comment=None):
22+
global _cache
23+
assert _cache is not None
24+
with open(_CACHE_FILENAME, "a", encoding=_ENCODING) as f:
25+
f.write(str(item)+ ((" # " + comment + "\n") if comment else "\n"))
26+
27+
def _load_cache():
28+
global _cache
29+
30+
assert _cache is None, "gpt3 cache already loaded"
31+
32+
if not os.path.exists(_CACHE_PATH):
33+
ezlog.warn("Creating cache path")
34+
os.makedirs(_CACHE_PATH)
35+
36+
_cache = {}
37+
38+
if os.path.exists(_CACHE_FILENAME):
39+
time0 = time.perf_counter()
40+
with open(_CACHE_FILENAME, "r", encoding=_ENCODING) as f:
41+
for k, v in [eval(line) for line in f.readlines()]:
42+
if k not in _cache:
43+
_cache[k] = v
44+
else:
45+
_cache[k].extend(v)
46+
ezlog.info(f"Loaded gpt3 cache in {time.perf_counter()-time0:.1f}s")
47+
else:
48+
ezlog.warn("No gpt3 cache yet")
49+
50+
51+
52+
def query(prompt, n=10, max_tokens=150, temp=1.0, max_batch=32, stop=None, notes=None, cache_only=False, verbose=True):
53+
"""Query gpt3
54+
55+
:param prompt: Up to 2048 tokens (about 3-4k chars)
56+
:param n: number of answers, None returns all cached answers
57+
:param max_tokens:
58+
:param temp: 0.9 seems to work well
59+
:param max_batch: max to query at once
60+
:param stop: string to stop at or '' if not to stop
61+
:param notes: notes you want to save or change in case you want to run the same query more than once!
62+
:return: list of answers and then the response items
63+
"""
64+
global _cache
65+
if _cache is None:
66+
_load_cache()
67+
68+
if temp == 0 and n > 1:
69+
ezlog.debug("Temp 0: no point in running more than one query")
70+
n = 1
71+
72+
key = str(dict(prompt=prompt, max_tokens=max_tokens, temp=temp, max_batch=max_batch, stop=stop, rep=notes))
73+
cached = _cache.get(key, [])
74+
if n is None:
75+
return cached[:]
76+
77+
if len(cached) >= n:
78+
return cached[:n]
79+
80+
if cache_only:
81+
pass
82+
1/0
83+
assert not cache_only, "Entry not found in cache"
84+
if verbose:
85+
print("/"*100)
86+
print("Querying GPT3 with prompt:")
87+
print(prompt)
88+
s = stop and stop.replace('\n', '\\n')
89+
print(f"/// n={n} ({n-len(cached)} new) max_tokens={max_tokens} temp={temp} max_batch={max_batch} stop={s}")
90+
print("/"*100)
91+
92+
time0 = time.perf_counter()
93+
94+
new = []
95+
n -= len(cached)
96+
97+
while n > 0:
98+
m = min(n, max_batch)
99+
100+
res = openai.Completion.create(
101+
engine="davinci-msft",
102+
prompt=prompt,
103+
max_tokens=max_tokens,
104+
temperature=temp,
105+
n=m,
106+
stop=stop or None
107+
)
108+
109+
new += [c["text"] for c in res["choices"]]
110+
n -= m
111+
112+
_save_line((key, new), f"{time.perf_counter() - time0:.1f}s {datetime.datetime.now()}")
113+
ans = _cache[key] = cached + new
114+
return ans[:]
115+
116+
# old code
117+
# # to persist calls to the API...
118+
# _disk_cache = joblib.Memory(os.path.join(os.path.dirname(__file__), ".cache"), verbose=1).cache
119+
#
120+
#
121+
# @_disk_cache
122+
# def query(prompt, n=10, max_tokens=150, temperature=1.0, max_batch=32):
123+
# """Query gpt3
124+
#
125+
# :param prompt: Up to 2048 tokens (about 3-4k chars)
126+
# :param n: number of answers
127+
# :param max_tokens:
128+
# :param temperature:
129+
# :param max_batch: max to query at once
130+
# :return: list of answers and then the response items
131+
# """
132+
# if temperature == 0 and n > 1:
133+
# ezlog.debug("Temp 0: no point in running more than one query")
134+
# n = 1
135+
#
136+
# responses = []
137+
# while n > 0:
138+
# m = min(n, max_batch)
139+
# prompt_summary = prompt if len(prompt) < 80 else f"{prompt[:40]}...{prompt[-40:]}"
140+
# ezlog.warn(f"**** Running GPT3 query: temp {temperature}, n={m}, prompt={prompt_summary}")
141+
# time0 = time.perf_counter()
142+
# responses.append(openai.Completion.create(
143+
# engine="davinci-msft",
144+
# prompt=prompt,
145+
# max_tokens=max_tokens,
146+
# temperature=temperature,
147+
# n=m
148+
# ))
149+
# ezlog.info(f"**** Got response in {time.perf_counter()-time0}s...")
150+
# n -= m
151+
#
152+
# return [c["text"] for r in responses for c in r["choices"]], responses
153+
154+

0 commit comments

Comments
 (0)