@@ -100,7 +100,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState:
100100 )
101101 assert_dtype (sim .state .object .data , cdtype )
102102 assert_dtype (sim .state .probe .data , cdtype )
103-
103+ ## FIXME: can this apply iter constraints be moved after the position update?
104104 sim = sim .apply_iter_constraints ()
105105
106106 if iter_update_positions :
@@ -109,12 +109,12 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState:
109109
110110 # subtract mean position update
111111 pos_update -= xp .mean (pos_update , tuple (range (pos_update .ndim - 1 )))
112- pos_update , position_solver_state = position_solver .perform_update (sim .state .scan , pos_update , position_solver_state )
112+ pos_update , position_solver_state = position_solver .perform_update (sim .state .scan . data , pos_update , position_solver_state )
113113 # subtract mean again (this can change with momentum)
114114 pos_update -= xp .mean (pos_update , tuple (range (pos_update .ndim - 1 )))
115115 pos_update_rms = float (xp .mean (xp .linalg .norm (pos_update , axis = - 1 , keepdims = True )))
116116 logger .info (f"Position update: mean { pos_update_rms } " )
117- sim .state .scan += pos_update
117+ sim .state .scan . data += pos_update
118118 assert_dtype (sim .state .scan , dtype )
119119
120120 # check positions are at least overlapping object
0 commit comments