diff --git a/colicoords/cell.py b/colicoords/cell.py index 16796b4..a0afcea 100644 --- a/colicoords/cell.py +++ b/colicoords/cell.py @@ -736,7 +736,7 @@ def sub_par(self, par_dict): def calc_xc(self, xp, yp): """ Calculates the coordinate xc on p(x) closest to xp, yp. - + All coordinates are cartesian. Solutions are found by solving the cubic equation. Parameters @@ -757,7 +757,7 @@ def calc_xc(self, xp, yp): a0, a1, a2 = self.coeff # xp, yp = xp.astype('float32'), yp.astype('float32') # Converting of cell spine polynomial coefficients to coefficients of polynomial giving distance r - a, b, c, d = 4 * a2 ** 2, 6 * a1 * a2, 4 * a0 * a2 + 2 * a1 ** 2 - 4 * a2 * yp + 2, 2 * a0 * a1 - 2 * a1 * yp - 2 * xp + a, b, c, d = 2 * a2 ** 2, 3 * a1 * a2, 2 * a0 * a2 + 1 * a1 ** 2 - 2 * a2 * yp + 1, 1 * a0 * a1 - 1 * a1 * yp - 1 * xp # a: float, b: float, c: array, d: array discr = 18 * a * b * c * d - 4 * b ** 3 * d + b ** 2 * c ** 2 - 4 * a * c ** 3 - 27 * a ** 2 * d ** 2 @@ -773,11 +773,50 @@ def calc_xc(self, xp, yp): general_part = solve_general(a, b, c[mask], d[mask]) trig_part = solve_trig(a, b, c[~mask], d[~mask]) + # Trig_part returns 3 roots. Evaluate each and pick the one that corresponds to the minimum distance + trig_part = self._pick_root(trig_part, xp[~mask], yp[~mask]) + x_c[mask] = general_part x_c[~mask] = trig_part return x_c + def _pick_root(self, roots, xp, yp): + """ + Evaluate the expression for r^2 for each of 3 roots and pick the smallest - which corresponds to the + minimum distance from midline. + + Parameters + ---------- + roots : :class:`~numpy.ndarray` + ndarray of shape (3,n), containing 3 roots each of n total polynomials + xp : :class:`~numpy.ndarray` + ndarray of shape (n,) with x-coordinate + yp : :class:`~numpy.ndarray` + ndarray of shape (n,) with y-coordinate + + Returns + ------- + xc : :class:`~numpy.ndarray` + ndarray of shape (n,), containing n roots which minimize the r^2 + """ + + (_, n) = roots.shape + + a0, a1, a2 = self.coeff + + # Broadcast to common shape + xp = np.broadcast_to(xp, (3, n)) + yp = np.broadcast_to(yp, (3, n)) + + # Calculate r^2 for all roots + rsquare = (roots - xp) ** 2 + (a0 + a1 * roots + a2 * roots ** 2 - yp) ** 2 + + # Find roots that minimize r^2 and return + mask = np.argmin(rsquare, axis=0) + + return roots[mask, np.arange(0, n, 1)] + @allow_scalars def calc_xc_mask(self, xp, yp): """ @@ -1934,7 +1973,7 @@ def solve_trig(a, b, c, d): Returns ------- array : array_like - First real root solution. + All 3 real root solutions. .. [1] https://en.wikipedia.org/wiki/Cubic_function#Trigonometric_solution_for_three_real_roots @@ -1943,19 +1982,18 @@ def solve_trig(a, b, c, d): p = (3. * a * c - b ** 2.) / (3. * a ** 2.) q = (2. * b ** 3. - 9. * a * b * c + 27. * a ** 2. * d) / (27. * a ** 3.) assert (np.all(p < 0)) - k = 0. - t_k = 2. * np.sqrt(-p / 3.) * np.cos( - (1 / 3.) * np.arccos(((3. * q) / (2. * p)) * np.sqrt(-3. / p)) - (2 * np.pi * k) / 3.) - x_r = t_k - (b / (3 * a)) - try: - assert (np.all( - x_r > 0)) # dont know if this is guaranteed otherwise boundaries need to be passed and choosing from 3 slns - except AssertionError: - pass - # todo find out if this is bad or not - # raise ValueError - return x_r + #Find all 3 real roots + x_r_array = [np.zeros(p.shape), np.zeros(p.shape), np.zeros(p.shape)] + for i, k in enumerate([0., 1., 2.]): #TODO faster solution available if you vectorize via numpy, but this is fast enough + + t_k = 2. * np.sqrt(-p / 3.) * np.cos( + (1 / 3.) * np.arccos(((3. * q) / (2. * p)) * np.sqrt(-3. / p)) - (2 * np.pi * k) / 3.) + x_r = t_k - (b / (3 * a)) + + x_r_array[i] = x_r + + return np.asarray(x_r_array) def calc_lc(xl, xr, coeff): """