|
| 1 | +import os |
| 2 | +import json |
| 3 | +import openai |
| 4 | +import ezlog |
| 5 | +import time |
| 6 | +import datetime |
| 7 | + |
| 8 | +assert 'OPENAI_API_KEY' in os.environ, "Need to set environment variable `OPENAI_API_KEY`" |
| 9 | +openai.api_key = os.environ['OPENAI_API_KEY'] |
| 10 | + |
| 11 | + |
| 12 | +_CACHE_PATH = os.path.join(os.path.dirname(__file__), "../.cache") |
| 13 | +_CACHE_FILENAME = os.path.join(_CACHE_PATH, "gpt3.cache") |
| 14 | +_ENCODING = "utf-8" |
| 15 | + |
| 16 | +_cache = None |
| 17 | + |
| 18 | + |
| 19 | +# the cache file is just a list of (query params dictionary encoded as a string but without n, result list) |
| 20 | +# multiple queries with the same params (except for n) are merged into a single big list |
| 21 | +def _save_line(item, comment=None): |
| 22 | + global _cache |
| 23 | + assert _cache is not None |
| 24 | + with open(_CACHE_FILENAME, "a", encoding=_ENCODING) as f: |
| 25 | + f.write(str(item)+ ((" # " + comment + "\n") if comment else "\n")) |
| 26 | + |
| 27 | +def _load_cache(): |
| 28 | + global _cache |
| 29 | + |
| 30 | + assert _cache is None, "gpt3 cache already loaded" |
| 31 | + |
| 32 | + if not os.path.exists(_CACHE_PATH): |
| 33 | + ezlog.warn("Creating cache path") |
| 34 | + os.makedirs(_CACHE_PATH) |
| 35 | + |
| 36 | + _cache = {} |
| 37 | + |
| 38 | + if os.path.exists(_CACHE_FILENAME): |
| 39 | + time0 = time.perf_counter() |
| 40 | + with open(_CACHE_FILENAME, "r", encoding=_ENCODING) as f: |
| 41 | + for k, v in [eval(line) for line in f.readlines()]: |
| 42 | + if k not in _cache: |
| 43 | + _cache[k] = v |
| 44 | + else: |
| 45 | + _cache[k].extend(v) |
| 46 | + ezlog.info(f"Loaded gpt3 cache in {time.perf_counter()-time0:.1f}s") |
| 47 | + else: |
| 48 | + ezlog.warn("No gpt3 cache yet") |
| 49 | + |
| 50 | + |
| 51 | + |
| 52 | +def query(prompt, n=10, max_tokens=150, temp=1.0, max_batch=32, stop=None, notes=None, cache_only=False, verbose=True): |
| 53 | + """Query gpt3 |
| 54 | +
|
| 55 | + :param prompt: Up to 2048 tokens (about 3-4k chars) |
| 56 | + :param n: number of answers, None returns all cached answers |
| 57 | + :param max_tokens: |
| 58 | + :param temp: 0.9 seems to work well |
| 59 | + :param max_batch: max to query at once |
| 60 | + :param stop: string to stop at or '' if not to stop |
| 61 | + :param notes: notes you want to save or change in case you want to run the same query more than once! |
| 62 | + :return: list of answers and then the response items |
| 63 | + """ |
| 64 | + global _cache |
| 65 | + if _cache is None: |
| 66 | + _load_cache() |
| 67 | + |
| 68 | + if temp == 0 and n > 1: |
| 69 | + ezlog.debug("Temp 0: no point in running more than one query") |
| 70 | + n = 1 |
| 71 | + |
| 72 | + key = str(dict(prompt=prompt, max_tokens=max_tokens, temp=temp, max_batch=max_batch, stop=stop, rep=notes)) |
| 73 | + cached = _cache.get(key, []) |
| 74 | + if n is None: |
| 75 | + return cached[:] |
| 76 | + |
| 77 | + if len(cached) >= n: |
| 78 | + return cached[:n] |
| 79 | + |
| 80 | + if cache_only: |
| 81 | + pass |
| 82 | + 1/0 |
| 83 | + assert not cache_only, "Entry not found in cache" |
| 84 | + if verbose: |
| 85 | + print("/"*100) |
| 86 | + print("Querying GPT3 with prompt:") |
| 87 | + print(prompt) |
| 88 | + s = stop and stop.replace('\n', '\\n') |
| 89 | + print(f"/// n={n} ({n-len(cached)} new) max_tokens={max_tokens} temp={temp} max_batch={max_batch} stop={s}") |
| 90 | + print("/"*100) |
| 91 | + |
| 92 | + time0 = time.perf_counter() |
| 93 | + |
| 94 | + new = [] |
| 95 | + n -= len(cached) |
| 96 | + |
| 97 | + while n > 0: |
| 98 | + m = min(n, max_batch) |
| 99 | + |
| 100 | + res = openai.Completion.create( |
| 101 | + engine="davinci-msft", |
| 102 | + prompt=prompt, |
| 103 | + max_tokens=max_tokens, |
| 104 | + temperature=temp, |
| 105 | + n=m, |
| 106 | + stop=stop or None |
| 107 | + ) |
| 108 | + |
| 109 | + new += [c["text"] for c in res["choices"]] |
| 110 | + n -= m |
| 111 | + |
| 112 | + _save_line((key, new), f"{time.perf_counter() - time0:.1f}s {datetime.datetime.now()}") |
| 113 | + ans = _cache[key] = cached + new |
| 114 | + return ans[:] |
| 115 | + |
| 116 | +# old code |
| 117 | +# # to persist calls to the API... |
| 118 | +# _disk_cache = joblib.Memory(os.path.join(os.path.dirname(__file__), ".cache"), verbose=1).cache |
| 119 | +# |
| 120 | +# |
| 121 | +# @_disk_cache |
| 122 | +# def query(prompt, n=10, max_tokens=150, temperature=1.0, max_batch=32): |
| 123 | +# """Query gpt3 |
| 124 | +# |
| 125 | +# :param prompt: Up to 2048 tokens (about 3-4k chars) |
| 126 | +# :param n: number of answers |
| 127 | +# :param max_tokens: |
| 128 | +# :param temperature: |
| 129 | +# :param max_batch: max to query at once |
| 130 | +# :return: list of answers and then the response items |
| 131 | +# """ |
| 132 | +# if temperature == 0 and n > 1: |
| 133 | +# ezlog.debug("Temp 0: no point in running more than one query") |
| 134 | +# n = 1 |
| 135 | +# |
| 136 | +# responses = [] |
| 137 | +# while n > 0: |
| 138 | +# m = min(n, max_batch) |
| 139 | +# prompt_summary = prompt if len(prompt) < 80 else f"{prompt[:40]}...{prompt[-40:]}" |
| 140 | +# ezlog.warn(f"**** Running GPT3 query: temp {temperature}, n={m}, prompt={prompt_summary}") |
| 141 | +# time0 = time.perf_counter() |
| 142 | +# responses.append(openai.Completion.create( |
| 143 | +# engine="davinci-msft", |
| 144 | +# prompt=prompt, |
| 145 | +# max_tokens=max_tokens, |
| 146 | +# temperature=temperature, |
| 147 | +# n=m |
| 148 | +# )) |
| 149 | +# ezlog.info(f"**** Got response in {time.perf_counter()-time0}s...") |
| 150 | +# n -= m |
| 151 | +# |
| 152 | +# return [c["text"] for r in responses for c in r["choices"]], responses |
| 153 | + |
| 154 | + |
0 commit comments