Skip to content

Commit 68aacce

Browse files
authored
Merge pull request #80 from Huangzizhou/slim
smooth_step after line search for slim
2 parents 0cd21ba + 5e8eb3d commit 68aacce

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

src/polysolve/nonlinear/Problem.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ namespace polysolve::nonlinear
9090
/// @param new_x New solution.
9191
virtual void solution_changed(const TVector &new_x) {}
9292

93+
virtual bool after_line_search_custom_operation(const TVector &x0, const TVector &x1) { return false; }
94+
9395
/// @brief Callback function used to determine if the solver should stop.
9496
/// @param state Current state of the solver.
9597
/// @param x Current solution.

src/polysolve/nonlinear/Solver.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,13 @@ namespace polysolve::nonlinear
427427
continue;
428428
}
429429

430-
x += rate * delta_x;
430+
{
431+
TVector x1 = x + rate * delta_x;
432+
if (objFunc.after_line_search_custom_operation(x, x1))
433+
objFunc.solution_changed(x1);
434+
x = x1;
435+
}
436+
431437
old_energy = energy;
432438

433439
// Reset this for the next iterations
@@ -574,6 +580,7 @@ namespace polysolve::nonlinear
574580
void Solver::verify_gradient(Problem &objFunc, const TVector &x, const TVector &grad)
575581
{
576582
bool match = false;
583+
double J = objFunc(x);
577584

578585
switch (gradient_fd_strategy)
579586
{
@@ -591,16 +598,18 @@ namespace polysolve::nonlinear
591598
objFunc.solution_changed(x1);
592599
double J1 = objFunc(x1);
593600

594-
double fd = (J2 - J1) / 2 / gradient_fd_eps;
601+
double fd_centered = (J2 - J1) / 2 / gradient_fd_eps;
602+
double fd_right = (J2 - J) / gradient_fd_eps;
603+
double fd_left = (J - J1) / gradient_fd_eps;
595604
double analytic = direc.dot(grad);
596605

597-
match = abs(fd - analytic) < 1e-8 || abs(fd - analytic) < 1e-4 * abs(analytic);
606+
match = abs(fd_centered - analytic) < 1e-8 || abs(fd_centered - analytic) < 1e-4 * abs(analytic);
598607

599608
// Log error in either case to make it more visible in the logs.
600609
if (match)
601-
m_logger.debug("step size: {}, finite difference: {}, derivative: {}", gradient_fd_eps, fd, analytic);
610+
m_logger.debug("step size: {}, finite difference: {} {} {}, derivative: {}", gradient_fd_eps, fd_centered, fd_left, fd_right, analytic);
602611
else
603-
m_logger.error("step size: {}, finite difference: {}, derivative: {}", gradient_fd_eps, fd, analytic);
612+
m_logger.error("step size: {}, finite difference: {} {} {}, derivative: {}", gradient_fd_eps, fd_centered, fd_left, fd_right, analytic);
604613
}
605614
break;
606615
case FiniteDiffStrategy::FULL_FINITE_DIFF:

0 commit comments

Comments
 (0)