|
7 | 7 | #include "form/expression_util.hpp" |
8 | 8 | #include "form/formula_util.hpp" |
9 | 9 |
|
| 10 | +// Helper function to replace all function calls by name with a replacement |
| 11 | +// expression |
| 12 | +void replaceFunctionByName(Expression& expr, const std::string& funcName, |
| 13 | + const Expression& replacement) { |
| 14 | + // First recurse into children |
| 15 | + for (auto& child : expr.children) { |
| 16 | + replaceFunctionByName(child, funcName, replacement); |
| 17 | + } |
| 18 | + // Then check if this expression matches |
| 19 | + if (expr.type == Expression::Type::FUNCTION && expr.name == funcName) { |
| 20 | + expr = replacement; |
| 21 | + } |
| 22 | +} |
| 23 | + |
10 | 24 | // Helper function to collect simple functions from formula |
11 | 25 | std::set<std::string> collectSimpleFunctions(const Formula& formula) { |
12 | 26 | std::set<std::string> funcs; |
@@ -560,3 +574,87 @@ void FormulaSimplify::replaceSimpleRecursiveRefs(Formula& formula) { |
560 | 574 | processedRecursiveFuncs.insert(funcName); |
561 | 575 | } |
562 | 576 | } |
| 577 | + |
| 578 | +// Helper function to check if a function is a constant identity (f(n) = f(n-k)) |
| 579 | +bool isConstantIdentityFunction(const Formula& formula, |
| 580 | + const std::string& funcName) { |
| 581 | + // Find the general definition (the one with a parameter, not a constant) |
| 582 | + Expression funcKey = ExpressionUtil::newFunction(funcName); |
| 583 | + auto it = formula.entries.find(funcKey); |
| 584 | + if (it == formula.entries.end()) { |
| 585 | + return false; // No general definition found |
| 586 | + } |
| 587 | + |
| 588 | + const auto& rhs = it->second; |
| 589 | + |
| 590 | + // Check if RHS is f(n-k) where k > 0 |
| 591 | + if (rhs.type != Expression::Type::FUNCTION || rhs.children.size() != 1) { |
| 592 | + return false; |
| 593 | + } |
| 594 | + |
| 595 | + // Must reference itself |
| 596 | + if (rhs.name != funcName) { |
| 597 | + return false; |
| 598 | + } |
| 599 | + |
| 600 | + const auto& arg = rhs.children.front(); |
| 601 | + |
| 602 | + // Check if argument is n-k where k > 0 |
| 603 | + Number offset; |
| 604 | + if (!extractArgumentOffset(arg, offset)) { |
| 605 | + return false; |
| 606 | + } |
| 607 | + |
| 608 | + // Must be a backwards reference (offset < 0) |
| 609 | + if (offset >= Number::ZERO) { |
| 610 | + return false; |
| 611 | + } |
| 612 | + |
| 613 | + // Check that there are no initial terms (base cases) for this function |
| 614 | + for (const auto& entry : formula.entries) { |
| 615 | + if (entry.first.type == Expression::Type::FUNCTION && |
| 616 | + entry.first.name == funcName && entry.first.children.size() == 1 && |
| 617 | + entry.first.children.front().type == Expression::Type::CONSTANT) { |
| 618 | + // Found an initial term - this is not a constant identity function |
| 619 | + return false; |
| 620 | + } |
| 621 | + } |
| 622 | + |
| 623 | + return true; |
| 624 | +} |
| 625 | + |
| 626 | +bool FormulaSimplify::replaceConstantIdentityFunctions(Formula& formula) { |
| 627 | + // Collect all function names |
| 628 | + auto funcs = FormulaUtil::getDefinitions(formula, Expression::Type::FUNCTION); |
| 629 | + |
| 630 | + // Find constant identity functions |
| 631 | + std::set<std::string> constantFuncs; |
| 632 | + for (const auto& funcName : funcs) { |
| 633 | + if (isConstantIdentityFunction(formula, funcName)) { |
| 634 | + constantFuncs.insert(funcName); |
| 635 | + } |
| 636 | + } |
| 637 | + |
| 638 | + if (constantFuncs.empty()) { |
| 639 | + return false; |
| 640 | + } |
| 641 | + |
| 642 | + // Replace references to constant identity functions with 0 |
| 643 | + Expression zero = ExpressionUtil::newConstant(0); |
| 644 | + for (const auto& funcName : constantFuncs) { |
| 645 | + // Remove all entries for this function |
| 646 | + FormulaUtil::removeFunctionEntries(formula, funcName); |
| 647 | + |
| 648 | + // Replace all references to funcName(...) with 0 in other entries |
| 649 | + for (auto& entry : formula.entries) { |
| 650 | + replaceFunctionByName(entry.second, funcName, zero); |
| 651 | + } |
| 652 | + } |
| 653 | + |
| 654 | + // Normalize after replacements |
| 655 | + for (auto& entry : formula.entries) { |
| 656 | + ExpressionUtil::normalize(entry.second); |
| 657 | + } |
| 658 | + |
| 659 | + return true; |
| 660 | +} |
0 commit comments