diff --git a/CHANGELOG.md b/CHANGELOG.md index ddd6a7011..766daa3f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ - Speed up MatrixExpr.add.reduce via quicksum - Speed up np.ndarray(..., dtype=np.float64) @ MatrixExpr - MatrixExpr and MatrixExprCons use `__array_ufunc__` protocol to control all numpy.ufunc inputs and outputs +- Speed up `Expr.__neg__`, `SumExpr.__neg__`, `ProdExpr.__neg__` and `Constant.__neg__` via C-level API - Set `__array_priority__` for MatrixExpr and MatrixExprCons - changed addConsNode() and addConsLocal() to mirror addCons() and accept ExprCons instead of Constraint - Improved `chgReoptObjective()` performance diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index 07d6ab031..79fb392ba 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -45,9 +45,10 @@ import math from typing import TYPE_CHECKING -from pyscipopt.scip cimport Variable, Solution -from cpython.dict cimport PyDict_Next +from cpython.dict cimport PyDict_Next, PyDict_SetItem +from cpython.object cimport Py_TYPE from cpython.ref cimport PyObject +from pyscipopt.scip cimport Variable, Solution import numpy as np @@ -308,8 +309,15 @@ cdef class Expr: else: raise TypeError(f"Unsupported base type {type(other)} for exponentiation.") - def __neg__(self): - return Expr({v:-c for v,c in self.terms.items()}) + def __neg__(self) -> Expr: + cdef dict res = {} + cdef Py_ssize_t pos = 0 + cdef PyObject* key_ptr + cdef PyObject* val_ptr + + while PyDict_Next(self.terms, &pos, &key_ptr, &val_ptr): + PyDict_SetItem(res, key_ptr, -(val_ptr)) + return Expr(res) def __sub__(self, other): return self + (-other) @@ -659,6 +667,23 @@ cdef class SumExpr(GenExpr): self.coefs = [] self.children = [] self._op = Operator.add + + def __neg__(self) -> SumExpr: + cdef int i = 0, n = len(self.coefs) + cdef list coefs = [0.0] * n + cdef double[:] dest_view = coefs + cdef double[:] src_view = self.coefs + + for i in range(n): + dest_view[i] = -src_view[i] + + cdef SumExpr res = SumExpr.__new__(SumExpr) + res.coefs = coefs + res.children = self.children.copy() + res.constant = -self.constant + res._op = Operator.add + return res + def __repr__(self): return self._op + "(" + str(self.constant) + "," + ",".join(map(lambda child : child.__repr__(), self.children)) + ")" @@ -666,7 +691,7 @@ cdef class SumExpr(GenExpr): cdef double res = self.constant cdef int i = 0, n = len(self.children) cdef list children = self.children - cdef list coefs = self.coefs + cdef double[:] coefs = self.coefs for i in range(n): res += coefs[i] * (children[i])._evaluate(sol) return res @@ -682,6 +707,13 @@ cdef class ProdExpr(GenExpr): self.children = [] self._op = Operator.prod + def __neg__(self) -> ProdExpr: + cdef ProdExpr res = ProdExpr.__new__(ProdExpr) + res.constant = -res.constant + self.children = self.children.copy() + res._op = Operator.prod + return res + def __repr__(self): return self._op + "(" + str(self.constant) + "," + ",".join(map(lambda child : child.__repr__(), self.children)) + ")" @@ -746,11 +778,16 @@ cdef class UnaryExpr(GenExpr): # class for constant expressions cdef class Constant(GenExpr): + cdef public number + def __init__(self,number): self.number = number self._op = Operator.const + def __neg__(self): + return Constant(-self.number) + def __repr__(self): return str(self.number) diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index ccd2028ef..620caa162 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -343,7 +343,7 @@ class Expr: def __lt__(self, other: object) -> bool: ... def __mul__(self, other: Incomplete) -> Incomplete: ... def __ne__(self, other: object) -> bool: ... - def __neg__(self) -> Incomplete: ... + def __neg__(self) -> Expr: ... def __pow__(self, other: Incomplete, modulo: Incomplete = ...) -> Incomplete: ... def __radd__(self, other: Incomplete) -> Incomplete: ... def __rmul__(self, other: Incomplete) -> Incomplete: ... @@ -386,7 +386,7 @@ class GenExpr: def __lt__(self, other: object) -> bool: ... def __mul__(self, other: Incomplete) -> Incomplete: ... def __ne__(self, other: object) -> bool: ... - def __neg__(self) -> Incomplete: ... + def __neg__(self) -> GenExpr: ... def __pow__(self, other: Incomplete, modulo: Incomplete = ...) -> Incomplete: ... def __radd__(self, other: Incomplete) -> Incomplete: ... def __rmul__(self, other: Incomplete) -> Incomplete: ... diff --git a/tests/test_expr.py b/tests/test_expr.py index c9135d2fa..13a167e13 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -2,8 +2,8 @@ import pytest -from pyscipopt import Model, sqrt, log, exp, sin, cos -from pyscipopt.scip import Expr, GenExpr, ExprCons, Term +from pyscipopt import Model, cos, exp, log, sin, sqrt +from pyscipopt.scip import Constant, Expr, ExprCons, GenExpr, ProdExpr, SumExpr, Term @pytest.fixture(scope="module") @@ -218,3 +218,33 @@ def test_getVal_with_GenExpr(): with pytest.raises(ZeroDivisionError): m.getVal(1 / z) + + +def test_neg(): + m = Model() + x = m.addVar(name="x") + + expr = (x + 1) ** 3 + neg_expr = -expr + assert isinstance(expr, Expr) + assert isinstance(neg_expr, Expr) + assert ( + str(neg_expr) + == "Expr({Term(x, x, x): -1.0, Term(x, x): -3.0, Term(x): -3.0, Term(): -1.0})" + ) + + base = sqrt(x) + expr = base * -1 + neg_expr = -expr + assert isinstance(expr, ProdExpr) + assert isinstance(neg_expr, ProdExpr) + assert str(neg_expr) == "prod(1.0,sqrt(sum(0.0,prod(1.0,x))))" + + expr = base + x - 1 + neg_expr = -expr + assert isinstance(expr, SumExpr) + assert isinstance(neg_expr, SumExpr) + assert str(neg_expr) == "sum(1.0,sqrt(sum(0.0,prod(1.0,x))),prod(1.0,x))" + assert list(neg_expr.coefs) == [-1, -1] + + assert str(-Constant(3.0)) == "-3.0"