diff --git a/ipsae.py b/ipsae.py index 45f64de..b32f3c4 100644 --- a/ipsae.py +++ b/ipsae.py @@ -694,7 +694,7 @@ def classify_chains(chains, residue_types): ptm_matrix_d0chn=ptm_func_vec(pae_matrix,d0chn[chain1][chain2]) valid_pairs_iptm = (chains == chain2) - valid_pairs_matrix = (chains == chain2) & (pae_matrix < pae_cutoff) + valid_pairs_matrix = np.outer(chains == chain1, chains == chain2) & (pae_matrix < pae_cutoff) for i in range(numres): @@ -740,7 +740,7 @@ def classify_chains(chains, residue_types): ptm_matrix_d0dom = np.zeros((numres,numres)) ptm_matrix_d0dom = ptm_func_vec(pae_matrix,d0dom[chain1][chain2]) - valid_pairs_matrix = (chains == chain2) & (pae_matrix < pae_cutoff) + valid_pairs_matrix = np.outer(chains == chain1, chains == chain2) & (pae_matrix < pae_cutoff) # Assuming valid_pairs_matrix is already defined n0res_byres_all = np.sum(valid_pairs_matrix, axis=1)