Skip to content

Commit 6f31092

Browse files
committed
Add feature to supply additional input to the optimizer routine
1 parent a6cb9f2 commit 6f31092

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

peps_ad/optimization/inner_function.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from peps_ad.ctmrg import calc_ctmrg_env, calc_ctmrg_env_custom_rule
99
from peps_ad.mapping import Map_To_PEPS_Model
1010

11-
from typing import Sequence, Tuple, cast, Optional, Callable
11+
from typing import Sequence, Tuple, cast, Optional, Callable, Dict
1212

1313

1414
def _map_tensors(
@@ -65,6 +65,7 @@ def calc_ctmrg_expectation(
6565
unitcell: PEPS_Unit_Cell,
6666
expectation_func: Expectation_Model,
6767
convert_to_unitcell_func: Optional[Map_To_PEPS_Model],
68+
additional_input: Dict[str, jnp.ndarray] = dict(),
6869
*,
6970
enforce_elementwise_convergence: Optional[bool] = None,
7071
) -> Tuple[jnp.ndarray, PEPS_Unit_Cell]:
@@ -82,6 +83,9 @@ def calc_ctmrg_expectation(
8283
convert_to_unitcell_func (:obj:`~peps_ad.mapping.Map_To_PEPS_Model`):
8384
Function to convert the `input_tensors` to a PEPS unitcell. If ommited,
8485
it is assumed that a PEPS unitcell is the input.
86+
additional_input (:obj:`dict` of :obj:`str` to :obj:`jax.numpy.ndarray` mapping):
87+
Optional dict with additional inputs which should be considered in the
88+
calculation of the expectation value.
8589
Keyword args:
8690
enforce_elementwise_convergence (obj:`bool`):
8791
Enforce elementwise convergence of the CTM tensors instead of only
@@ -90,7 +94,8 @@ def calc_ctmrg_expectation(
9094
:obj:`tuple`\ (:obj:`jax.numpy.ndarray`, :obj:`~peps_ad.peps.PEPS_Unit_Cell`):
9195
Tuple consisting of the calculated expectation value and the new unitcell.
9296
"""
93-
if expectation_func.is_spiral_peps:
97+
spiral_vectors = additional_input.get("spiral_vectors")
98+
if expectation_func.is_spiral_peps and spiral_vectors is None:
9499
peps_tensors, unitcell, spiral_vectors = _map_tensors(
95100
input_tensors, unitcell, convert_to_unitcell_func, True
96101
)
@@ -128,6 +133,7 @@ def calc_preconverged_ctmrg_value_and_grad(
128133
unitcell: PEPS_Unit_Cell,
129134
expectation_func: Expectation_Model,
130135
convert_to_unitcell_func: Optional[Map_To_PEPS_Model],
136+
additional_input: Dict[str, jnp.ndarray] = dict(),
131137
*,
132138
calc_preconverged: bool = True,
133139
) -> Tuple[Tuple[jnp.ndarray, PEPS_Unit_Cell], Sequence[jnp.ndarray]]:
@@ -150,6 +156,9 @@ def calc_preconverged_ctmrg_value_and_grad(
150156
convert_to_unitcell_func (:obj:`~peps_ad.mapping.Map_To_PEPS_Model`):
151157
Function to convert the `input_tensors` to a PEPS unitcell. If ommited,
152158
it is assumed that a PEPS unitcell is the input.
159+
additional_input (:obj:`dict` of :obj:`str` to :obj:`jax.numpy.ndarray` mapping):
160+
Optional dict with additional inputs which should be considered in the
161+
calculation of the expectation value.
153162
Keyword args:
154163
calc_preconverged (:obj:`bool`):
155164
Flag if the above described procedure to calculate a pre-converged
@@ -161,7 +170,8 @@ def calc_preconverged_ctmrg_value_and_grad(
161170
unitcell.
162171
2. The calculated gradient.
163172
"""
164-
if expectation_func.is_spiral_peps:
173+
spiral_vectors = additional_input.get("spiral_vectors")
174+
if expectation_func.is_spiral_peps and spiral_vectors is None:
165175
peps_tensors, unitcell, spiral_vectors = _map_tensors(
166176
input_tensors, unitcell, convert_to_unitcell_func, True
167177
)
@@ -196,6 +206,7 @@ def calc_ctmrg_expectation_custom(
196206
unitcell: PEPS_Unit_Cell,
197207
expectation_func: Expectation_Model,
198208
convert_to_unitcell_func: Optional[Map_To_PEPS_Model],
209+
additional_input: Dict[str, jnp.ndarray] = dict(),
199210
) -> Tuple[jnp.ndarray, PEPS_Unit_Cell]:
200211
"""
201212
Calculate the CTMRG environment and the (energy) expectation value for a
@@ -211,11 +222,15 @@ def calc_ctmrg_expectation_custom(
211222
convert_to_unitcell_func (:obj:`~peps_ad.mapping.Map_To_PEPS_Model`):
212223
Function to convert the `input_tensors` to a PEPS unitcell. If ommited,
213224
it is assumed that a PEPS unitcell is the input.
225+
additional_input (:obj:`dict` of :obj:`str` to :obj:`jax.numpy.ndarray` mapping):
226+
Dict with additional inputs which should be considered in the
227+
calculation of the expectation value.
214228
Returns:
215229
:obj:`tuple`\ (:obj:`jax.numpy.ndarray`, :obj:`~peps_ad.peps.PEPS_Unit_Cell`):
216230
Tuple consisting of the calculated expectation value and the new unitcell.
217231
"""
218-
if expectation_func.is_spiral_peps:
232+
spiral_vectors = additional_input.get("spiral_vectors")
233+
if expectation_func.is_spiral_peps and spiral_vectors is None:
219234
peps_tensors, unitcell, spiral_vectors = _map_tensors(
220235
input_tensors, unitcell, convert_to_unitcell_func, True
221236
)

peps_ad/optimization/line_search.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
calc_ctmrg_expectation_custom_value_and_grad,
1919
)
2020

21-
from typing import Sequence, Tuple, List, Union, Optional
21+
from typing import Sequence, Tuple, List, Union, Optional, Dict
2222

2323

2424
@jit
@@ -146,6 +146,7 @@ def line_search(
146146
convert_to_unitcell_func: Optional[Map_To_PEPS_Model] = None,
147147
generate_unitcell: bool = False,
148148
spiral_indices: Optional[Sequence[int]] = None,
149+
additional_input: Dict[str, jnp.ndarray] = {},
149150
) -> Tuple[
150151
List[jnp.ndarray],
151152
PEPS_Unit_Cell,
@@ -260,6 +261,7 @@ def line_search(
260261
new_unitcell,
261262
expectation_func,
262263
convert_to_unitcell_func,
264+
additional_input,
263265
enforce_elementwise_convergence=enforce_elementwise_convergence,
264266
)
265267

@@ -296,6 +298,7 @@ def line_search(
296298
unitcell,
297299
expectation_func,
298300
convert_to_unitcell_func,
301+
additional_input,
299302
)
300303
else:
301304
(
@@ -306,6 +309,7 @@ def line_search(
306309
unitcell,
307310
expectation_func,
308311
convert_to_unitcell_func,
312+
additional_input,
309313
calc_preconverged=True,
310314
)
311315
gradient = [elem.conj() for elem in tmp_gradient_seq]
@@ -334,6 +338,7 @@ def line_search(
334338
new_unitcell,
335339
expectation_func,
336340
convert_to_unitcell_func,
341+
additional_input,
337342
)
338343
else:
339344
(
@@ -344,6 +349,7 @@ def line_search(
344349
new_unitcell,
345350
expectation_func,
346351
convert_to_unitcell_func,
352+
additional_input,
347353
calc_preconverged=True,
348354
)
349355
new_gradient = [elem.conj() for elem in new_gradient_seq]
@@ -381,6 +387,7 @@ def line_search(
381387
unitcell,
382388
expectation_func,
383389
convert_to_unitcell_func,
390+
additional_input,
384391
)
385392
else:
386393
(
@@ -391,6 +398,7 @@ def line_search(
391398
unitcell,
392399
expectation_func,
393400
convert_to_unitcell_func,
401+
additional_input,
394402
calc_preconverged=True,
395403
)
396404
gradient = [elem.conj() for elem in tmp_gradient_seq]

peps_ad/optimization/optimizer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def optimize_peps_network(
205205
autosave_func: Callable[
206206
[PathLike, Sequence[jnp.ndarray], PEPS_Unit_Cell], None
207207
] = autosave_function,
208+
additional_input: Dict[str, jnp.ndarray] = {},
208209
) -> Tuple[Sequence[jnp.ndarray], PEPS_Unit_Cell, Union[float, jnp.ndarray]]:
209210
"""
210211
Optimize a PEPS unitcell using a variational method.
@@ -279,7 +280,11 @@ def random_noise(a):
279280
signal_reset_descent_dir = False
280281

281282
spiral_indices = None
282-
if hasattr(expectation_func, "is_spiral_peps") and expectation_func.is_spiral_peps:
283+
if (
284+
hasattr(expectation_func, "is_spiral_peps")
285+
and expectation_func.is_spiral_peps
286+
and additional_input.get("spiral_vectors") is None
287+
):
283288
if isinstance(input_tensors, collections.abc.Sequence) and isinstance(
284289
input_tensors[0], PEPS_Unit_Cell
285290
):
@@ -321,6 +326,7 @@ def random_noise(a):
321326
working_unitcell,
322327
expectation_func,
323328
convert_to_unitcell_func,
329+
additional_input,
324330
)
325331
else:
326332
(
@@ -331,6 +337,7 @@ def random_noise(a):
331337
working_unitcell,
332338
expectation_func,
333339
convert_to_unitcell_func,
340+
additional_input,
334341
calc_preconverged=(count == 0),
335342
)
336343
except (CTMRGNotConvergedError, CTMRGGradientNotConvergedError) as e:
@@ -427,6 +434,7 @@ def random_noise(a):
427434
convert_to_unitcell_func,
428435
generate_unitcell,
429436
spiral_indices,
437+
additional_input,
430438
)
431439
except NoSuitableStepSizeError:
432440
if peps_ad_config.optimizer_fail_if_no_step_size_found:
@@ -521,6 +529,7 @@ def random_noise(a):
521529
working_unitcell,
522530
expectation_func,
523531
convert_to_unitcell_func,
532+
additional_input,
524533
enforce_elementwise_convergence=peps_ad_config.ad_use_custom_vjp,
525534
)
526535
descent_dir = None

0 commit comments

Comments
 (0)