Skip to content

Commit bc10e47

Browse files
authored
Merge pull request #130 from CCPBioSim/127-add-averaging-over-timesteps
Refactor Entropy Calculations with Averaging Over Timesteps, Robust State Handling, and Enhanced Eigenvalue Filtering
2 parents 998f8cb + 700f6bb commit bc10e47

File tree

6 files changed

+266
-88
lines changed

6 files changed

+266
-88
lines changed

CodeEntropy/config/arg_config_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
"grouping": {
6868
"type": str,
6969
"help": "How to group molecules for averaging",
70-
"default": "each",
70+
"default": "molecules",
7171
},
7272
}
7373

CodeEntropy/entropy.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _get_trajectory_bounds(self):
247247
Tuple of (start, end, step) frame indices.
248248
"""
249249
start = self._args.start or 0
250-
end = self._args.end or -1
250+
end = len(self._universe.trajectory) if self._args.end == -1 else self._args.end
251251
step = self._args.step or 1
252252

253253
return start, end, step
@@ -343,11 +343,9 @@ def _process_united_atom_entropy(
343343

344344
f_matrix = force_matrix[key]
345345
f_matrix = self._level_manager.filter_zero_rows_columns(f_matrix)
346-
f_matrix = f_matrix / number_frames
347346

348347
t_matrix = torque_matrix[key]
349348
t_matrix = self._level_manager.filter_zero_rows_columns(t_matrix)
350-
t_matrix = t_matrix / number_frames
351349

352350
S_trans_res = ve.vibrational_entropy_calculation(
353351
f_matrix, "force", self._args.temperature, highest
@@ -356,8 +354,16 @@ def _process_united_atom_entropy(
356354
t_matrix, "torque", self._args.temperature, highest
357355
)
358356

359-
S_conf_res = ce.conformational_entropy_calculation(
360-
states[key], number_frames
357+
values = states[key]
358+
359+
contains_non_empty_states = (
360+
np.any(values) if isinstance(values, np.ndarray) else any(values)
361+
)
362+
363+
S_conf_res = (
364+
ce.conformational_entropy_calculation(values, number_frames)
365+
if contains_non_empty_states
366+
else 0
361367
)
362368

363369
S_trans += S_trans_res
@@ -395,10 +401,8 @@ def _process_vibrational_entropy(
395401
level.
396402
"""
397403
force_matrix = self._level_manager.filter_zero_rows_columns(force_matrix)
398-
force_matrix = force_matrix / number_frames
399404

400405
torque_matrix = self._level_manager.filter_zero_rows_columns(torque_matrix)
401-
torque_matrix = torque_matrix / number_frames
402406

403407
S_trans = ve.vibrational_entropy_calculation(
404408
force_matrix, "force", self._args.temperature, highest
@@ -425,7 +429,22 @@ def _process_conformational_entropy(
425429
start, end, step (int): Frame bounds.
426430
n_frames (int): Number of frames used.
427431
"""
428-
S_conf = ce.conformational_entropy_calculation(states[group_id], number_frames)
432+
group_states = states[group_id] if group_id < len(states) else None
433+
434+
if group_states is not None:
435+
contains_state_data = (
436+
group_states.any()
437+
if isinstance(group_states, np.ndarray)
438+
else any(group_states)
439+
)
440+
else:
441+
contains_state_data = False
442+
443+
S_conf = (
444+
ce.conformational_entropy_calculation(group_states, number_frames)
445+
if contains_state_data
446+
else 0
447+
)
429448

430449
self._data_logger.add_results_data(group_id, level, "Conformational", S_conf)
431450

@@ -619,13 +638,19 @@ def frequency_calculation(self, lambdas, temp):
619638
lambdas = np.array(lambdas) # Ensure input is a NumPy array
620639
logger.debug(f"Eigenvalues (lambdas): {lambdas}")
621640

622-
# Check for negatives and raise an error if any are found
623-
if np.any(lambdas < 0):
624-
logger.error(f"Negative eigenvalues encountered: {lambdas[lambdas < 0]}")
625-
raise ValueError(
626-
f"Negative eigenvalues encountered: {lambdas[lambdas < 0]}"
641+
lambdas = np.real_if_close(lambdas, tol=1000)
642+
valid_mask = (
643+
np.isreal(lambdas) & (lambdas > 0) & (~np.isclose(lambdas, 0, atol=1e-07))
644+
)
645+
646+
if len(lambdas) > np.count_nonzero(valid_mask):
647+
logger.warning(
648+
f"{len(lambdas) - np.count_nonzero(valid_mask)} "
649+
f"invalid eigenvalues excluded (complex, non-positive, or near-zero)."
627650
)
628651

652+
lambdas = lambdas[valid_mask].real
653+
629654
# Compute frequencies safely
630655
frequencies = 1 / (2 * pi) * np.sqrt(lambdas / kT)
631656
logger.debug(f"Calculated frequencies: {frequencies}")
@@ -748,8 +773,11 @@ def assign_conformation(
748773

749774
# get the values of the angle for the dihedral
750775
# dihedral angle values have a range from -180 to 180
751-
for timestep in data_container.trajectory[start:end:step]:
752-
timestep_index = timestep.frame - start
776+
indices = list(range(start, end, step))
777+
for timestep_index, _ in zip(
778+
indices, data_container.trajectory[start:end:step]
779+
):
780+
timestep_index = timestep_index - start
753781
value = dihedral.value()
754782
# we want postive values in range 0 to 360 to make the peak assignment
755783
# work using the fact that dihedrals have circular symetry

CodeEntropy/levels.py

Lines changed: 138 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def get_matrices(
171171
f"{force_matrix.shape}, new {force_block.shape}"
172172
)
173173
else:
174-
force_matrix += force_block
174+
force_matrix = force_block
175175

176176
if torque_matrix is None:
177177
torque_matrix = np.zeros_like(torque_block)
@@ -181,7 +181,7 @@ def get_matrices(
181181
f"{torque_matrix.shape}, new {torque_block.shape}"
182182
)
183183
else:
184-
torque_matrix += torque_block
184+
torque_matrix = torque_block
185185

186186
return force_matrix, torque_matrix
187187

@@ -290,18 +290,27 @@ def compute_dihedral_conformations(
290290
- dihedrals (list): List of dihedral angle definitions.
291291
"""
292292
dihedrals = self.get_dihedrals(selector, level)
293-
num_dihedrals = len(dihedrals)
294293

295-
conformation = np.zeros((num_dihedrals, number_frames))
296-
for i, dihedral in enumerate(dihedrals):
297-
conformation[i] = ce.assign_conformation(
298-
selector, dihedral, number_frames, bin_width, start, end, step
299-
)
294+
if len(dihedrals) == 0:
295+
logger.debug("No dihedrals found; skipping conformation assignment.")
296+
states = []
297+
else:
298+
num_dihedrals = len(dihedrals)
299+
conformation = np.zeros((num_dihedrals, number_frames))
300300

301-
states = [
302-
"".join(str(int(conformation[d][f])) for d in range(num_dihedrals))
303-
for f in range(number_frames)
304-
]
301+
for i, dihedral in enumerate(dihedrals):
302+
conformation[i] = ce.assign_conformation(
303+
selector, dihedral, number_frames, bin_width, start, end, step
304+
)
305+
306+
states = [
307+
state
308+
for state in (
309+
"".join(str(int(conformation[d][f])) for d in range(num_dihedrals))
310+
for f in range(number_frames)
311+
)
312+
if state
313+
]
305314

306315
return states
307316

@@ -733,40 +742,58 @@ def build_covariance_matrices(
733742
number_frames,
734743
):
735744
"""
736-
Construct force and torque covariance matrices for all molecules and levels.
745+
Construct average force and torque covariance matrices for all molecules and
746+
entropy levels.
737747
738-
Parameters:
739-
entropy_manager (EntropyManager): Instance of the EntropyManager
740-
reduced_atom (Universe): The reduced atom selection.
741-
number_molecules (int): Number of molecules in the system.
742-
levels (list): List of entropy levels per molecule.
743-
start (int): Start frame index.
744-
end (int): End frame index.
745-
step (int): Step size for frame iteration.
746-
number_frames (int): Total number of frames to process.
748+
Parameters
749+
----------
750+
entropy_manager : EntropyManager
751+
Instance of the EntropyManager.
752+
reduced_atom : Universe
753+
The reduced atom selection.
754+
levels : dict
755+
Dictionary mapping molecule IDs to lists of entropy levels.
756+
groups : dict
757+
Dictionary mapping group IDs to lists of molecule IDs.
758+
start : int
759+
Start frame index.
760+
end : int
761+
End frame index.
762+
step : int
763+
Step size for frame iteration.
764+
number_frames : int
765+
Total number of frames to process.
747766
748-
Returns:
749-
tuple: A tuple containing:
750-
- force_matrices (dict): Force covariance matrices by level.
751-
- torque_matrices (dict): Torque covariance matrices by level.
767+
Returns
768+
-------
769+
tuple
770+
force_avg : dict
771+
Averaged force covariance matrices by entropy level.
772+
torque_avg : dict
773+
Averaged torque covariance matrices by entropy level.
752774
"""
753775
number_groups = len(groups)
754-
force_matrices = {
776+
777+
force_avg = {
755778
"ua": {},
756779
"res": [None] * number_groups,
757780
"poly": [None] * number_groups,
758781
}
759-
torque_matrices = {
782+
torque_avg = {
760783
"ua": {},
761784
"res": [None] * number_groups,
762785
"poly": [None] * number_groups,
763786
}
787+
frame_counts = {
788+
"ua": {},
789+
"res": np.zeros(number_groups, dtype=int),
790+
"poly": np.zeros(number_groups, dtype=int),
791+
}
764792

765-
for timestep in reduced_atom.trajectory[start:end:step]:
766-
time_index = timestep.frame - start
793+
indices = list(range(start, end, step))
794+
for time_index, _ in zip(indices, reduced_atom.trajectory[start:end:step]):
767795

768-
for group_id in groups.keys():
769-
molecules = groups[group_id]
796+
for group_id, molecules in groups.items():
770797
for mol_id in molecules:
771798
mol = entropy_manager._get_molecule_container(reduced_atom, mol_id)
772799
for level in levels[mol_id]:
@@ -776,13 +803,14 @@ def build_covariance_matrices(
776803
group_id,
777804
level,
778805
levels[mol_id],
779-
time_index,
806+
time_index - start,
780807
number_frames,
781-
force_matrices,
782-
torque_matrices,
808+
force_avg,
809+
torque_avg,
810+
frame_counts,
783811
)
784812

785-
return force_matrices, torque_matrices
813+
return force_avg, torque_avg
786814

787815
def update_force_torque_matrices(
788816
self,
@@ -793,22 +821,55 @@ def update_force_torque_matrices(
793821
level_list,
794822
time_index,
795823
num_frames,
796-
force_matrices,
797-
torque_matrices,
824+
force_avg,
825+
torque_avg,
826+
frame_counts,
798827
):
799828
"""
800-
Update force and torque matrices for a given molecule and entropy level.
829+
Update the running averages of force and torque covariance matrices
830+
for a given molecule and entropy level.
801831
802-
Parameters:
803-
entropy_manager (EntropyManager): Instance of the EntropyManager
804-
mol (AtomGroup): The molecule to process.
805-
group_id (int): Index of the group.
806-
level (str): Current entropy level ("united_atom", "residue", or "polymer").
807-
level_list (list): List of levels for the molecule.
808-
time_index (int): Index of the current frame.
809-
num_frames (int): Total number of frames.
810-
force_matrices (dict): Dictionary of force matrices to update.
811-
torque_matrices (dict): Dictionary of torque matrices to update.
832+
This function computes the force and torque covariance matrices for the
833+
current frame and updates the existing averages in-place using the incremental
834+
mean formula:
835+
836+
new_avg = old_avg + (value - old_avg) / n
837+
838+
where n is the number of frames processed so far for that molecule/level
839+
combination. This ensures that the averages are maintained without storing
840+
all previous frame data.
841+
842+
Parameters
843+
----------
844+
entropy_manager : EntropyManager
845+
Instance of the EntropyManager.
846+
mol : AtomGroup
847+
The molecule to process.
848+
group_id : int
849+
Index of the group to which the molecule belongs.
850+
level : str
851+
Current entropy level ("united_atom", "residue", or "polymer").
852+
level_list : list
853+
List of entropy levels for the molecule.
854+
time_index : int
855+
Index of the current frame relative to the start of the trajectory slice.
856+
num_frames : int
857+
Total number of frames to process.
858+
force_avg : dict
859+
Dictionary holding the running average force matrices, keyed by entropy
860+
level.
861+
torque_avg : dict
862+
Dictionary holding the running average torque matrices, keyed by entropy
863+
level.
864+
frame_counts : dict
865+
Dictionary holding the count of frames processed for each molecule/level
866+
combination.
867+
868+
Returns
869+
-------
870+
None
871+
Updates are performed in-place on `force_avg`, `torque_avg`, and
872+
`frame_counts`.
812873
"""
813874
highest = level == level_list[-1]
814875

@@ -825,11 +886,19 @@ def update_force_torque_matrices(
825886
level,
826887
num_frames,
827888
highest,
828-
force_matrices["ua"].get(key),
829-
torque_matrices["ua"].get(key),
889+
None if key not in force_avg["ua"] else force_avg["ua"][key],
890+
None if key not in torque_avg["ua"] else torque_avg["ua"][key],
830891
)
831-
force_matrices["ua"][key] = f_mat
832-
torque_matrices["ua"][key] = t_mat
892+
893+
if key not in force_avg["ua"]:
894+
force_avg["ua"][key] = f_mat.copy()
895+
torque_avg["ua"][key] = t_mat.copy()
896+
frame_counts["ua"][key] = 1
897+
else:
898+
frame_counts["ua"][key] += 1
899+
n = frame_counts["ua"][key]
900+
force_avg["ua"][key] += (f_mat - force_avg["ua"][key]) / n
901+
torque_avg["ua"][key] += (t_mat - torque_avg["ua"][key]) / n
833902

834903
elif level in ["residue", "polymer"]:
835904
mol.trajectory[time_index]
@@ -839,11 +908,23 @@ def update_force_torque_matrices(
839908
level,
840909
num_frames,
841910
highest,
842-
force_matrices[key][group_id],
843-
torque_matrices[key][group_id],
911+
None if force_avg[key][group_id] is None else force_avg[key][group_id],
912+
(
913+
None
914+
if torque_avg[key][group_id] is None
915+
else torque_avg[key][group_id]
916+
),
844917
)
845-
force_matrices[key][group_id] = f_mat
846-
torque_matrices[key][group_id] = t_mat
918+
919+
if force_avg[key][group_id] is None:
920+
force_avg[key][group_id] = f_mat.copy()
921+
torque_avg[key][group_id] = t_mat.copy()
922+
frame_counts[key][group_id] = 1
923+
else:
924+
frame_counts[key][group_id] += 1
925+
n = frame_counts[key][group_id]
926+
force_avg[key][group_id] += (f_mat - force_avg[key][group_id]) / n
927+
torque_avg[key][group_id] += (t_mat - torque_avg[key][group_id]) / n
847928

848929
def filter_zero_rows_columns(self, arg_matrix):
849930
"""

0 commit comments

Comments
 (0)