@@ -806,22 +806,22 @@ void check_phase_transitions(Raid* env) {
806806 // Track when claws go down and apply imbalance penalty
807807 if (olm -> left_claw_hp <= 0 && olm -> left_claw_down_tick == -1 ) {
808808 olm -> left_claw_down_tick = env -> tick ;
809- // Penalty proportional to right claw's remaining HP
810- float imbalance = ( float ) olm -> right_claw_hp / CLAW_MAX_HP ;
811- float penalty = imbalance * env -> reward_claw_imbalance ;
812- for ( int i = 0 ; i < env -> num_players ; i ++ ) {
813- env -> rewards [i ] -= penalty ;
814- env -> players [ i ]. episode_return -= penalty ;
809+ // Penalty of -1 if other claw has > 50 HP remaining
810+ if ( olm -> right_claw_hp > 50 ) {
811+ for ( int i = 0 ; i < env -> num_players ; i ++ ) {
812+ env -> rewards [ i ] -= 1.0f ;
813+ env -> players [i ]. episode_return -= 1.0f ;
814+ }
815815 }
816816 }
817817 if (olm -> right_claw_hp <= 0 && olm -> right_claw_down_tick == -1 ) {
818818 olm -> right_claw_down_tick = env -> tick ;
819- // Penalty proportional to left claw's remaining HP
820- float imbalance = ( float ) olm -> left_claw_hp / CLAW_MAX_HP ;
821- float penalty = imbalance * env -> reward_claw_imbalance ;
822- for ( int i = 0 ; i < env -> num_players ; i ++ ) {
823- env -> rewards [i ] -= penalty ;
824- env -> players [ i ]. episode_return -= penalty ;
819+ // Penalty of -1 if other claw has > 50 HP remaining
820+ if ( olm -> left_claw_hp > 50 ) {
821+ for ( int i = 0 ; i < env -> num_players ; i ++ ) {
822+ env -> rewards [ i ] -= 1.0f ;
823+ env -> players [i ]. episode_return -= 1.0f ;
824+ }
825825 }
826826 }
827827
0 commit comments