Skip to content

Commit 0b80bcd

Browse files
committed
Add check to automatically catch common bugs in ncon definitions
1 parent 2e259f7 commit 0b80bcd

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

peps_ad/contractions/definitions.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Definitions for contractions in this module.
33
"""
44

5+
from collections import Counter
6+
57
import tensornetwork as tn
68

79
from typing import Dict, Optional, Union, List, Tuple, Sequence
@@ -66,6 +68,7 @@ class Definitions:
6668
@staticmethod
6769
def _create_filter_and_network(
6870
contraction: Definition,
71+
name: str,
6972
) -> Tuple[
7073
List[Sequence[str]],
7174
List[str],
@@ -81,7 +84,7 @@ def _create_filter_and_network(
8184
for t in contraction["tensors"]:
8285
if isinstance(t, (list, tuple)):
8386
if len(filter_additional_tensors) != 0:
84-
raise ValueError("Invalid specification for contraction.")
87+
raise ValueError(f'Invalid specification for contraction "{name}".')
8588

8689
filter_peps_tensors.append(t)
8790
else:
@@ -92,19 +95,19 @@ def _create_filter_and_network(
9295
isinstance(ni, (list, tuple)) for ni in n
9396
):
9497
if len(network_additional_tensors) != 0:
95-
raise ValueError("Invalid specification for contraction.")
98+
raise ValueError(f'Invalid specification for contraction "{name}".')
9699

97100
network_peps_tensors.append(n) # type: ignore
98101
elif isinstance(n, (list, tuple)) and all(isinstance(ni, int) for ni in n):
99102
network_additional_tensors.append(n) # type: ignore
100103
else:
101-
raise ValueError("Invalid specification for contraction.")
104+
raise ValueError(f'Invalid specification for contraction "{name}".')
102105

103106
if len(network_peps_tensors) != len(filter_peps_tensors) or not all(
104107
len(network_peps_tensors[i]) == len(filter_peps_tensors[i])
105108
for i in range(len(filter_peps_tensors))
106109
):
107-
raise ValueError("Invalid specification for contraction.")
110+
raise ValueError(f'Invalid specification for contraction "{name}".')
108111

109112
return (
110113
filter_peps_tensors,
@@ -114,17 +117,31 @@ def _create_filter_and_network(
114117
)
115118

116119
@classmethod
117-
def _process_def(cls, e):
120+
def _process_def(cls, e, name):
118121
(
119122
filter_peps_tensors,
120123
filter_additional_tensors,
121124
network_peps_tensors,
122125
network_additional_tensors,
123-
) = cls._create_filter_and_network(e)
126+
) = cls._create_filter_and_network(e, name)
124127

125128
ncon_network = [
126129
j for i in network_peps_tensors for j in i
127130
] + network_additional_tensors
131+
132+
flatted_ncon_list = [j for i in ncon_network for j in i]
133+
counter_ncon_list = Counter(flatted_ncon_list)
134+
for ind, c in counter_ncon_list.items():
135+
if (ind > 0 and c != 2) or (ind < 0 and c != 1) or ind == 0:
136+
raise ValueError(
137+
f'Invalid definition found for "{name}": Element {ind:d} has counter {c:d}.'
138+
)
139+
sorted_ncon_list = sorted(c for c in counter_ncon_list if c > 0)
140+
if len(sorted_ncon_list) != sorted_ncon_list[-1]:
141+
raise ValueError(
142+
f'Non-monotonous indices in definition "{name}". Please check!'
143+
)
144+
128145
(
129146
mapped_ncon_network,
130147
mapping,
@@ -150,13 +167,13 @@ def _process_def(cls, e):
150167

151168
@classmethod
152169
def _prepare_defs(cls):
153-
for e in dir(cls):
154-
if e.startswith("_"):
170+
for name in dir(cls):
171+
if name.startswith("_"):
155172
continue
156173

157-
e = getattr(cls, e)
174+
e = getattr(cls, name)
158175

159-
cls._process_def(e)
176+
cls._process_def(e, name)
160177

161178
density_matrix_one_site: Definition = {
162179
"tensors": [

0 commit comments

Comments
 (0)