88from peps_ad .ctmrg import calc_ctmrg_env , calc_ctmrg_env_custom_rule
99from peps_ad .mapping import Map_To_PEPS_Model
1010
11- from typing import Sequence , Tuple , cast , Optional , Callable
11+ from typing import Sequence , Tuple , cast , Optional , Callable , Dict
1212
1313
1414def _map_tensors (
@@ -65,6 +65,7 @@ def calc_ctmrg_expectation(
6565 unitcell : PEPS_Unit_Cell ,
6666 expectation_func : Expectation_Model ,
6767 convert_to_unitcell_func : Optional [Map_To_PEPS_Model ],
68+ additional_input : Dict [str , jnp .ndarray ] = dict (),
6869 * ,
6970 enforce_elementwise_convergence : Optional [bool ] = None ,
7071) -> Tuple [jnp .ndarray , PEPS_Unit_Cell ]:
@@ -82,6 +83,9 @@ def calc_ctmrg_expectation(
8283 convert_to_unitcell_func (:obj:`~peps_ad.mapping.Map_To_PEPS_Model`):
8384 Function to convert the `input_tensors` to a PEPS unitcell. If ommited,
8485 it is assumed that a PEPS unitcell is the input.
86+ additional_input (:obj:`dict` of :obj:`str` to :obj:`jax.numpy.ndarray` mapping):
87+ Optional dict with additional inputs which should be considered in the
88+ calculation of the expectation value.
8589 Keyword args:
8690 enforce_elementwise_convergence (obj:`bool`):
8791 Enforce elementwise convergence of the CTM tensors instead of only
@@ -90,7 +94,8 @@ def calc_ctmrg_expectation(
9094 :obj:`tuple`\ (:obj:`jax.numpy.ndarray`, :obj:`~peps_ad.peps.PEPS_Unit_Cell`):
9195 Tuple consisting of the calculated expectation value and the new unitcell.
9296 """
93- if expectation_func .is_spiral_peps :
97+ spiral_vectors = additional_input .get ("spiral_vectors" )
98+ if expectation_func .is_spiral_peps and spiral_vectors is None :
9499 peps_tensors , unitcell , spiral_vectors = _map_tensors (
95100 input_tensors , unitcell , convert_to_unitcell_func , True
96101 )
@@ -128,6 +133,7 @@ def calc_preconverged_ctmrg_value_and_grad(
128133 unitcell : PEPS_Unit_Cell ,
129134 expectation_func : Expectation_Model ,
130135 convert_to_unitcell_func : Optional [Map_To_PEPS_Model ],
136+ additional_input : Dict [str , jnp .ndarray ] = dict (),
131137 * ,
132138 calc_preconverged : bool = True ,
133139) -> Tuple [Tuple [jnp .ndarray , PEPS_Unit_Cell ], Sequence [jnp .ndarray ]]:
@@ -150,6 +156,9 @@ def calc_preconverged_ctmrg_value_and_grad(
150156 convert_to_unitcell_func (:obj:`~peps_ad.mapping.Map_To_PEPS_Model`):
151157 Function to convert the `input_tensors` to a PEPS unitcell. If ommited,
152158 it is assumed that a PEPS unitcell is the input.
159+ additional_input (:obj:`dict` of :obj:`str` to :obj:`jax.numpy.ndarray` mapping):
160+ Optional dict with additional inputs which should be considered in the
161+ calculation of the expectation value.
153162 Keyword args:
154163 calc_preconverged (:obj:`bool`):
155164 Flag if the above described procedure to calculate a pre-converged
@@ -161,7 +170,8 @@ def calc_preconverged_ctmrg_value_and_grad(
161170 unitcell.
162171 2. The calculated gradient.
163172 """
164- if expectation_func .is_spiral_peps :
173+ spiral_vectors = additional_input .get ("spiral_vectors" )
174+ if expectation_func .is_spiral_peps and spiral_vectors is None :
165175 peps_tensors , unitcell , spiral_vectors = _map_tensors (
166176 input_tensors , unitcell , convert_to_unitcell_func , True
167177 )
@@ -196,6 +206,7 @@ def calc_ctmrg_expectation_custom(
196206 unitcell : PEPS_Unit_Cell ,
197207 expectation_func : Expectation_Model ,
198208 convert_to_unitcell_func : Optional [Map_To_PEPS_Model ],
209+ additional_input : Dict [str , jnp .ndarray ] = dict (),
199210) -> Tuple [jnp .ndarray , PEPS_Unit_Cell ]:
200211 """
201212 Calculate the CTMRG environment and the (energy) expectation value for a
@@ -211,11 +222,15 @@ def calc_ctmrg_expectation_custom(
211222 convert_to_unitcell_func (:obj:`~peps_ad.mapping.Map_To_PEPS_Model`):
212223 Function to convert the `input_tensors` to a PEPS unitcell. If ommited,
213224 it is assumed that a PEPS unitcell is the input.
225+ additional_input (:obj:`dict` of :obj:`str` to :obj:`jax.numpy.ndarray` mapping):
226+ Dict with additional inputs which should be considered in the
227+ calculation of the expectation value.
214228 Returns:
215229 :obj:`tuple`\ (:obj:`jax.numpy.ndarray`, :obj:`~peps_ad.peps.PEPS_Unit_Cell`):
216230 Tuple consisting of the calculated expectation value and the new unitcell.
217231 """
218- if expectation_func .is_spiral_peps :
232+ spiral_vectors = additional_input .get ("spiral_vectors" )
233+ if expectation_func .is_spiral_peps and spiral_vectors is None :
219234 peps_tensors , unitcell , spiral_vectors = _map_tensors (
220235 input_tensors , unitcell , convert_to_unitcell_func , True
221236 )
0 commit comments