diff --git a/Project.toml b/Project.toml index 7fdfcb0..4f4d295 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SourceCodeMcCormick" uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960" authors = ["Robert Gottlieb "] -version = "0.5.0" +version = "0.5.1" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/kernel_writer/kernel_write.jl b/src/kernel_writer/kernel_write.jl index 3e8eb46..8a49b0a 100644 --- a/src/kernel_writer/kernel_write.jl +++ b/src/kernel_writer/kernel_write.jl @@ -10,7 +10,7 @@ kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwr kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic) function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, constants::Vector{Num}, overwrite::Bool, splitting::Symbol, affine_quadratic::Bool) # Create a hash of the expression and check if the function already exists - expr_hash = string(hash(num+sum(gradlist)), base=62) + expr_hash = string(hash(string(num)*string(gradlist)), base=62) if (overwrite==false) && (isfile(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))) try func_name = eval(Meta.parse("f_"*expr_hash)) return func_name @@ -102,9 +102,6 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons elseif splitting==:high # Formerly default split_point = 1500 max_size = 2000 - # elseif splitting==:high # More splitting - # split_point = 1000 - # max_size = 1200 elseif splitting==:max # Extremely small split_point = 500 max_size = 750 @@ -116,7 +113,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons sparsity = detect_sparsity(factored, gradlist) # Decide if the kernel needs to be split - if (n_vars[end] < 31) && (n_lines[end] <= max_size) + if (n_vars[end] < 31) && ((n_lines[end] <= max_size) || (findfirst(x -> x > split_point, n_lines)==length(n_lines))) # Complexity is fairly low; only a single kernel needed create_kernel!(expr_hash, 1, num, get_name.(gradlist), func_outputs, constants, factored, sparsity) push!(kernel_nums, 1) @@ -130,7 +127,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons while !complete # Determine which line to break at line_ID = findfirst(x -> x > split_point, n_lines) - vars_ID = findfirst(x -> x == 31, n_vars) + vars_ID = findfirst(x -> (x == 30) || (x == 31), n_vars) if isnothing(vars_ID) new_ID = line_ID elseif isnothing(line_ID) @@ -188,7 +185,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons n_lines = complexity(factored) n_vars = var_counts(factored) - # If the total number of lines (not including the final line) is below 2000 + # If the total number of lines (not including the final line) is below the max size # and the number of variables is below 32, we can make the final kernel and be done if (n_vars[end] < 32) && (all(n_lines[1:end-1] .<= max_size)) create_kernel!(expr_hash, kernel_count, extract(factored), get_name.(gradlist), func_outputs, constants, factored, sparsity) @@ -328,7 +325,12 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a") # Put in the preamble. - write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist))) + if isempty(vars) + write(file, preamble_string(expr_hash, ["OUT";], 1, 1, length(gradlist))) + else + write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist))) + end + # Depending on the format of the expression, compose the kernel differently if typeof(expr) <: Real @@ -360,9 +362,9 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num end end else # There must be two elements in the dictionary - binary_vars = string.(get_name.(keys(key.dict))) + binary_vars = string.(get_name.(keys(expr.dict))) binary_vars = binary_vars[sort_vars(binary_vars)] - write(file, SCMC_quadaff_binary(vars..., expr.coeff, varlist)) + write(file, SCMC_quadaff_binary(binary_vars..., expr.coeff, varlist)) end elseif exprtype(expr)==ADD @@ -394,7 +396,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num # EAGO already does this and bypasses the need to calculate relaxations. # But, for compatibility with McCormick-style relaxations in ParBB, # it's easier to simply calculate what ParBB is expecting.) - write(file, postamble_quadaff(string.(vars), varlist)) + if isempty(varlist) + write(file, postamble_quadaff(String[], String[])) + elseif isempty(vars) + write(file, postamble_quadaff(String[], varlist)) + else + write(file, postamble_quadaff(string.(vars), varlist)) + end close(file) # Include this kernel so SCMC knows what it is @@ -403,7 +411,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num # Add onto the file the "main" CPU function that calls the kernel blocks = Int32(CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)) file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a") - write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist))) + if isempty(gradlist) + write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, Symbol[])) + elseif isempty(vars) + write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, get_name.(gradlist))) + else + write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist))) + end close(file) # Include the file again to get the final kernel @@ -731,6 +745,7 @@ end # 7) log(inv(x1)) = -log(x1) [EAGO paper] # 8) CONST1*CONST2*x1 = (CONST1*CONST2)*x1 # 9) 1 / (1 + exp(-x)) = Sigmoid(x) +# 10) sin(x) = cos(x - pi/2) # # Forms that aren't relevant yet: # 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers) @@ -826,7 +841,7 @@ function perform_substitutions(old_factored::Vector{Equation}) end end # Create a factorization of this new expr - new_factorization = factor(new_expr) + new_factorization = factor(new_expr, split_div=true) # Scan through the new factorization to see if we can merge elements # with the original factored list done = false @@ -1191,7 +1206,7 @@ function perform_substitutions(old_factored::Vector{Equation}) new_expr *= arg end # Create a factorization of this new expr - new_factorization = factor(new_expr) + new_factorization = factor(new_expr, split_div=true) # Scan through the new factorization to see if we can merge elements @@ -1315,6 +1330,38 @@ function perform_substitutions(old_factored::Vector{Equation}) end end end + + # 10) sin(x) = cos(x - pi/2) + if exprtype(factored[index0].rhs)==TERM + if factored[index0].rhs.f==sin + # We found sin(arg). Check if (arg - pi/2) exists, + # and if so, also check if cos(arg - pi/2) exists. + scan_flag = true + index1 = findfirst(x -> isequal(x.rhs, arguments(factored[index0].rhs)[] - pi/2), factored) + if !isnothing(index1) + index2 = findfirst(x -> isequal(x.rhs, cos(factored[index1].lhs)), factored) + if !isnothing(index2) + # cos(arg - pi/2) exists already (index2). Remove all reference to index0 and replace with index2 + for i in eachindex(factored) + @eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $factored[$index2].lhs)) + end + deleteat!(factored, index0) + else + # arg - pi/2 exists already (index1), but not cos(arg - pi/2). Change + # index0 to be cos of index1.lhs instead of sin of arg + @eval $factored[$index0] = $factored[$index0].lhs ~ cos($factored[$index1].lhs) + end + else + # (arg - pi/2) doesn't exist, so we need to create it + newsym = gensym(:aux) + newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) + newvar = genvar(newsym) + insert!(factored, index0, Equation(Symbolics.value(newvar), arguments(factored[index0].rhs)[] - pi/2)) + @eval $factored[$index0+1] = $factored[$index0+1].lhs ~ cos($newvar) + end + break + end + end end end @@ -1511,6 +1558,10 @@ function write_operation(file::IOStream, RHS::BasicSymbolic{Real}, inputs::Vecto write(file, SCMC_sigmoid_kernel(inputs..., gradlist, sparsity)) elseif RHS.f==sqrt write(file, SCMC_float_power_kernel(inputs..., 0.5, gradlist, sparsity)) + elseif RHS.f==cos + write(file, SCMC_cos_kernel(inputs..., gradlist, sparsity)) + elseif RHS.f==abs + write(file, SCMC_abs_kernel(inputs..., gradlist, sparsity)) else close(file) error("Some function was used that we can't handle yet ($RHS)") @@ -1845,6 +1896,10 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star else total_lines += 190 end + new_ID = findfirst(x -> isequal(x.lhs, RHS.base), factorized) + if !isnothing(new_ID) + total_lines += _complexity(complexity, factorized, new_ID) + end elseif exprtype(RHS) == TERM if RHS.f==exp total_lines += 212 # Ranges from 212--310 @@ -1866,8 +1921,24 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star end elseif RHS.f==sqrt total_lines += 190 + new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized) + if !isnothing(new_ID) + total_lines += _complexity(complexity, factorized, new_ID) + end + elseif RHS.f==cos || RHS.f==sin + total_lines += 300 + new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized) + if !isnothing(new_ID) + total_lines += _complexity(complexity, factorized, new_ID) + end + elseif RHS.f==abs + total_lines += 280 + new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized) + if !isnothing(new_ID) + total_lines += _complexity(complexity, factorized, new_ID) + end else - error("Unknown function") + error("Some function was used that we can't handle yet ($RHS)") end elseif exprtype(RHS) == SYM nothing diff --git a/src/kernel_writer/math_kernels.jl b/src/kernel_writer/math_kernels.jl index 103952c..76932c4 100644 --- a/src/kernel_writer/math_kernels.jl +++ b/src/kernel_writer/math_kernels.jl @@ -6,6 +6,12 @@ # these same functions, but in buffer/string form for the purposes of writing # new kernels. +# NOTE: These kernels might all be faster if we flip the ordering of indices. +# I.e., instead of having each row be a unique point to evaluate, make +# each column a unique point to evaluate. Preliminary checking on my +# workstation says this could be ~25% faster (tried for multiplication, +# 100000 unique points) + #= Unitary Rules =# @@ -1144,6 +1150,162 @@ function SCMC_inv_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix) return nothing end + +# Absolute value +# max threads: ??? +function SCMC_abs_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix) + idx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x + stride = blockDim().x * gridDim().x + colmax = Int32((size(OUT,2)-4)/2) + + while idx <= Int32(size(OUT,1)) + # Reset the column counter + col = Int32(1) + + # Get interval extension + if x[idx,4] >= 0.0 && x[idx,3] <= 0.0 + OUT[idx,3] = 0.0 + else + OUT[idx,3] = min(abs(x[idx,3]), abs(x[idx,4])) + end + OUT[idx,4] = max(abs(x[idx,3]), abs(x[idx,4])) + + # Calculate eps_min and eps_max + if x[idx,3] >= 0.0 + eps_min = x[idx,3] + elseif x[idx,4] <= 0.0 + eps_min = x[idx,4] + else + eps_min = 0.0 + end + if abs(x[idx,4]) >= abs(x[idx,3]) + eps_max = x[idx,4] + else + eps_max = x[idx,3] + end + + # Get midcv and midcc by finding the middle values of (cv, cc, eps_min), (cv, cc, eps_max) + midcv, cv_id, midcc, cc_id = midvals(x[idx,1], x[idx,2], eps_min, eps_max) + + # Get derivative values + if x[idx,4] - x[idx,3] == 0.0 + OUT[idx,2] = abs(midcc) + if midcc > 0.0 + dcc = 1.0 + elseif midcc < 0.0 + dcc = -1.0 + else + dcc = 0.0 + end + else + OUT[idx,2] = (abs(x[idx,3])*(x[idx,4] - midcc) + abs(x[idx,4])*(midcc - x[idx,3]))/(x[idx,4]-x[idx,3]) + dcc = (abs(x[idx,4]) - abs(x[idx,3]))/(x[idx,4]-x[idx,3]) + end + OUT[idx,1] = abs(midcv) + if midcv > 0.0 + dcv = 1.0 + elseif midcc < 0.0 + dcv = -1.0 + else + dcv = 0.0 + end + + # Calculate subgradients + if cv_id==1 + if cc_id==1 + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc + col += Int32(1) + end + elseif cc_id==2 + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc + col += Int32(1) + end + else + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + elseif cv_id==2 + if cc_id==1 + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc + col += Int32(1) + end + elseif cc_id==2 + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc + col += Int32(1) + end + else + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + else + if cc_id==1 + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc + col += Int32(1) + end + elseif cc_id==2 + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc + col += Int32(1) + end + else + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + end + + # Perform the cut operation + if OUT[idx,1] < OUT[idx,3] + OUT[idx,1] = OUT[idx,3] + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + col += Int32(1) + end + end + if OUT[idx,2] > OUT[idx,4] + OUT[idx,2] = OUT[idx,4] + col = Int32(1) + while col <= colmax + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + + idx += stride + end + return nothing +end + + # Multiplication by a constant # max threads: 640 function SCMC_cmul_kernel(OUT::CuDeviceMatrix, CONST::Real, x::CuDeviceMatrix) @@ -4455,6 +4617,182 @@ function SCMC_large_float_power_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix, c return nothing end +# Cosine (argument should be in radians) +# NOTE: Sine can be cos(x - pi/2) +function SCMC_cos_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix) + idx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x + stride = blockDim().x * gridDim().x + colmax = Int32((size(OUT,2)-4)/2) + + while idx <= Int32(size(OUT,1)) + # Reset the column counter + col = Int32(1) + + # Get lower and upper bounds from the interval + if (x[idx,4] - x[idx,3]) >= 2.0*pi + OUT[idx,3] = -1.0 + OUT[idx,4] = 1.0 + else + lo_quadrant, lo = quadrant(x[idx,3]) + hi_quadrant, hi = quadrant(x[idx,4]) + + if lo_quadrant == hi_quadrant + if x[idx,4] - x[idx,3] > 3.141592653589793 + OUT[idx,3] = -1.0 + OUT[idx,4] = 1.0 + elseif lo_quadrant==2 || lo_quadrant==3 + OUT[idx,3] = cos(lo) + OUT[idx,4] = cos(hi) + else + OUT[idx,3] = cos(hi) + OUT[idx,4] = cos(lo) + end + elseif lo_quadrant==2 && hi_quadrant==3 + OUT[idx,3] = cos(lo) + OUT[idx,4] = cos(hi) + elseif lo_quadrant==0 && hi_quadrant==1 + OUT[idx,3] = cos(hi) + OUT[idx,4] = cos(lo) + elseif (lo_quadrant==2 || lo_quadrant==3) && (hi_quadrant==0 || hi_quadrant==1) + OUT[idx,3] = min(cos(lo), cos(hi)) + OUT[idx,4] = 1.0 + elseif (lo_quadrant==0 || lo_quadrant==1) && (hi_quadrant==2 || hi_quadrant==3) + OUT[idx,3] = -1.0 + OUT[idx,4] = max(cos(lo), cos(hi)) + else + OUT[idx,3] = -1.0 + OUT[idx,4] = 1.0 + end + end + + + # get eps_min and eps_max + kL = Base.ceil(-0.5 - x[idx,3]/(2.0*pi)) + xL1 = x[idx,3] + 2.0*pi*kL + xU1 = x[idx,4] + 2.0*pi*kL + if (xL1 < -pi) || (xL1 > pi) + eps_min = NaN + eps_max = NaN + elseif xL1 <= 0.0 + if xU1 <= 0.0 + eps_min = x[idx,3] + eps_max = x[idx,4] + elseif xU1 >= pi + eps_min = pi - 2.0*pi*kL + eps_max = -2.0*pi*kL + else + eps_min = (cos(xL1) <= cos(xU1)) ? x[idx,3] : x[idx,4] + eps_max = -2.0*pi*kL + end + elseif xU1 <= pi + eps_min = x[idx,4] + eps_max = x[idx,3] + elseif xU1 >= 2.0*pi + eps_min = pi - 2.0*pi*kL + eps_max = 2.0*pi - 2.0*pi*kL + else + eps_min = pi - 2.0*pi*kL + eps_max = (cos(xL1) >= cos(xU1)) ? x[idx,3] : x[idx,4] + end + + midcv, cv_id, midcc, cc_id = midvals(x[idx,1], x[idx,2], eps_min, eps_max) + + # Call cv normally + cv, dcv = SCMC_cv_cos(midcv, x[idx,3], x[idx,4]) + OUT[idx,1] = cv + + # Call cc by shifting and negating the cv path + neg_cc, neg_dcc = SCMC_cv_cos(midcc - pi, x[idx,3] - pi, x[idx,4] - pi) + OUT[idx,2] = -neg_cc + dcc = -neg_dcc + + # Now we need mid_grad things... + if cv_id==1 + if cc_id==1 + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc + col += Int32(1) + end + elseif cc_id==2 + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc + col += Int32(1) + end + else + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + elseif cv_id==2 + if cc_id==1 + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc + col += Int32(1) + end + elseif cc_id==2 + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc + col += Int32(1) + end + else + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + else + if cc_id==1 + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc + col += Int32(1) + end + elseif cc_id==2 + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc + col += Int32(1) + end + else + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + end + + # Perform the cut operation + if OUT[idx,1] < OUT[idx,3] + OUT[idx,1] = OUT[idx,3] + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + col += Int32(1) + end + end + if OUT[idx,2] > OUT[idx,4] + OUT[idx,2] = OUT[idx,4] + col = Int32(1) + while col <= colmax + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + + idx += stride + end + return nothing +end + + #= Binary Rules =# @@ -4862,6 +5200,269 @@ function SCMC_add_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix, y::CuDeviceMatr end +################## +# Helper functions for some kernels to use +function midvals(xcv::Float64, xcc::Float64, eps_min::Float64, eps_max::Float64) + if xcc >= xcv + if xcv == xcc + midcc = xcv + cc_id = Int32(2) + midcv = xcv + cv_id = Int32(2) + elseif xcv >= eps_max + if xcv >= eps_min + midcc = xcv + cc_id = Int32(2) + midcv = xcv + cv_id = Int32(2) + elseif eps_min >= xcc + midcc = xcv + cc_id = Int32(2) + midcv = xcc + cv_id = Int32(1) + else + midcc = xcv + cc_id = Int32(2) + midcv = eps_min + cv_id = Int32(3) + end + elseif eps_max >= xcc + if xcv >= eps_min + midcc = xcc + cc_id = Int32(1) + midcv = xcv + cv_id = Int32(2) + elseif eps_min >= xcc + midcc = xcc + cc_id = Int32(1) + midcv = xcc + cv_id = Int32(1) + else + midcc = xcc + cc_id = Int32(1) + midcv = eps_min + cv_id = Int32(3) + end + else + if xcv >= eps_min + midcc = eps_max + cc_id = Int32(3) + midcv = xcv + cv_id = Int32(2) + elseif eps_min >= xcc + midcc = eps_max + cc_id = Int32(3) + midcv = xcc + cv_id = Int32(1) + else + midcc = eps_max + cc_id = Int32(3) + midcv = eps_min + cv_id = Int32(3) + end + end + elseif eps_max >= xcv + if eps_min >= xcv + midcc = xcv + cc_id = Int32(2) + midcv = xcv + cv_id = Int32(2) + elseif xcc >= eps_min + midcc = xcv + cc_id = Int32(2) + midcv = xcc + cv_id = Int32(1) + else + midcc = xcv + cc_id = Int32(2) + midcv = eps_min + cv_id = Int32(3) + end + elseif xcc >= eps_max + if eps_min >= xcv + midcc = xcc + cc_id = Int32(1) + midcv = xcv + cv_id = Int32(2) + elseif xcc >= eps_min + midcc = xcc + cc_id = Int32(1) + midcv = xcc + cv_id = Int32(1) + else + midcc = xcc + cc_id = Int32(1) + midcv = eps_min + cv_id = Int32(3) + end + else + if eps_min >= xcv + midcc = eps_max + cc_id = Int32(3) + midcv = xcv + cv_id = Int32(2) + elseif xcc >= eps_min + midcc = eps_max + cc_id = Int32(3) + midcv = xcc + cv_id = Int32(1) + else + midcc = eps_max + cc_id = Int32(3) + midcv = eps_min + cv_id = Int32(3) + end + end + return midcv, cv_id, midcc, cc_id +end + +@inline function SCMC_cv_cos(x::Float64, xL::Float64, xU::Float64) + kL = Base.ceil(-0.5 - xL/(2.0*pi)) + if x <= (pi - 2.0*pi*kL) + xL1 = xL + 2.0*pi*kL + if xL1 >= pi/2.0 + return cos(x), -sin(x) + end + xU1 = min(xU + 2.0*pi*kL, pi) + if (xL1 >= -pi/2) && (xU1 <= pi/2) + if abs(xU - xL) < 1E-10 + return cos(xL), 0.0 + else + return cos(xL) + (x - xL)*(cos(xU) - cos(xL))/(xU - xL), + (cos(xU) - cos(xL))/(xU - xL) + end + end + return SCMC_cv_cosin(x + 2.0*pi*kL, xL1, xU1) + end + kU = Base.floor(0.5 - xU/(2.0*pi)) + if (x >= -pi - 2.0*pi*kU) + xU2 = xU + 2.0*pi*kU + if xU2 <= -pi/2.0 + return cos(x), -sin(x) + end + return SCMC_cv_cosin(x + 2.0*pi*kU, max(xL + 2.0*pi*kU, -pi), xU2) + end + return -1.0, 0.0 +end + +# Needs to return only two things (inlining to make comparisons with x) +@inline function SCMC_cv_cosin(x::Float64, xL::Float64, xU::Float64) + if abs(xL) <= abs(xU) + left = false + x0 = xU + xm = xL + else + left = true + x0 = xL + xm = xU + end + xj = cos_newton_or_golden_section(x0, xL, xU, xm) + if (left && (x <= xj)) || (~left && (x >= xj)) + return cos(x), -sin(x) + else + if abs(xm - xj) < 1e-10 + return cos(xm), 0.0 + else + return cos(xm) + (x - xm)*(cos(xm) - cos(xj))/(xm - xj), (cos(xm) - cos(xj))/(xm - xj) + end + end +end + +function cos_newton_or_golden_section(x0::Float64, xL::Float64, xU::Float64, envp::Float64) + dfk = 0.0 + xk = max(xL, min(x0, xU)) + fk = (xk - envp)*sin(xk) + cos(xk) - cos(envp) + iter = Int32(1) + while iter <= Int32(100) + dfk = (xk - envp)*cos(xk) + if abs(fk) < 1e-10 + return xk + end + if iszero(dfk) + xk = 0.0 + break # Need to do golden section + end + if (xk == xL) && (fk/dfk > 0.0) + return xk + elseif (xk == xU) && (fk/dfk < 0.0) + return xk + end + xk = max(xL, min(xU, xk - fk/dfk)) + fk = (xk - envp)*sin(xk) + cos(xk) - cos(envp) + iter += Int32(1) + end + + # If flag, we need to do golden section instead + a_golden = xL + fa_golden = (a_golden - envp)*sin(a_golden) + cos(a_golden) - cos(envp) + c_golden = xU + fc_golden = (c_golden - envp)*sin(c_golden) + cos(c_golden) - cos(envp) + + if fa_golden*fc_golden > 0 + xk = NaN + return xk + end + + b_golden = xU - (2.0 - Base.MathConstants.golden)*(xU - xL) + fb_golden = (b_golden - envp)*sin(b_golden) + cos(b_golden) - cos(envp) + + iter = Int32(1) + while iter <= Int32(100) + if (c_golden - b_golden > b_golden - a_golden) + x_golden = b_golden + (2.0 - Base.MathConstants.golden)*(c_golden - b_golden) + if abs(c_golden-a_golden) < 1.0e-10*(abs(b_golden) + abs(x_golden)) || iter == Int32(100) + xk = (c_golden + a_golden)/2.0 + return xk + end + iter += Int32(1) + fx_golden = (x_golden - envp)*sin(x_golden) + cos(x_golden) - cos(envp) + if fa_golden*fx_golden < 0.0 + c_golden = x_golden + fc_golden = fx_golden + else + a_golden = b_golden + fa_golden = fb_golden + b_golden = x_golden + fb_golden = fx_golden + end + else + x_golden = b_golden - (2.0 - Base.MathConstants.golden)*(b_golden - a_golden) + if abs(c_golden-a_golden) < 1.0e-10*(abs(b_golden) + abs(x_golden)) || iter == Int32(100) + xk = (c_golden + a_golden)/2.0 + return xk + end + iter += Int32(1) + fx_golden = (x_golden - envp)*sin(x_golden) + cos(x_golden) - cos(envp) + if fa_golden*fb_golden < 0.0 + c_golden = b_golden + fc_golden = fb_golden + b_golden = x_golden + fb_golden = fx_golden + else + a_golden = x_golden + fa_golden = fx_golden + end + end + end + + # Should never get to this point, but for completeness... + return xk +end + +# Directly from IntervalArithmetic.jl +function quadrant(x::Float64) + x_mod2pi = rem2pi(x, RoundNearest) + + x_mod2pi < -(pi/2.0) && return (Int32(2), x_mod2pi) + x_mod2pi < 0 && return (Int32(3), x_mod2pi) + x_mod2pi <= (pi/2.0) && return (Int32(0), x_mod2pi) + + return Int32(1), x_mod2pi +end + + + + ################## # Some templates that are useful for writing new kernels. diff --git a/src/kernel_writer/string_math_kernels.jl b/src/kernel_writer/string_math_kernels.jl index 869b6e1..dfecb34 100644 --- a/src/kernel_writer/string_math_kernels.jl +++ b/src/kernel_writer/string_math_kernels.jl @@ -2096,7 +2096,7 @@ function SCMC_log_kernel(OUT::String, v1::String, varlist::Vector{String}, spars return String(take!(buffer)) end -# Inversion (DONE) +# Inversion # max threads: 768 function SCMC_inv_kernel(OUT::String, v1::String, varlist::Vector{String}, sparsity::Vector{Int}; sum_output::Bool=false) if sum_output @@ -3255,15 +3255,9 @@ function SCMC_inv_kernel(OUT::String, v1::String, varlist::Vector{String}, spars return String(take!(buffer)) end -# Multiplication by a constant -# max threads: 640 -function SCMC_cmul_kernel(OUT::String, v1::String, CONST::Real, varlist::Vector{String}, sparsity::Vector{Int}; sum_output::Bool=false) - if sum_output - eq = "+=" - else - eq = "=" - end - +# Absolute value +# max threads: ??? +function SCMC_abs_kernel(OUT::String, v1::String, varlist::Vector{String}, sparsity::Vector{Int}) if startswith(v1, "temp") v1_cv = "$(v1)_cv" v1_cc = "$(v1)_cc" @@ -3302,7 +3296,7 @@ function SCMC_cmul_kernel(OUT::String, v1::String, CONST::Real, varlist::Vector{ # Get the anti-sparsity list (elements NOT being used) antisparsity = collect(1:length(varlist)) - antisparsity = antisparsity[antisparsity .∉ Ref(sparsity)] + antisparsity = antisparsity[antisparsity .∉ Ref(sparsity)] # Determine the sparsity case: # 1) Use sparsity list @@ -3322,216 +3316,419 @@ function SCMC_cmul_kernel(OUT::String, v1::String, CONST::Real, varlist::Vector{ buffer = Base.IOBuffer() # Write all the lines to the buffer - - if CONST >= 0.0 - if startswith(v1, r"aux|temp") - write(buffer, " ###########################################\n") - write(buffer, " ## Multiplication by a Positive Constant ##\n") - write(buffer, " ###########################################\n") - write(buffer, "\n") - write(buffer, " # Reset the column counter\n") - write(buffer, " col = Int32(1)\n") - write(buffer, "\n") - write(buffer, " # Begin rule\n") - write(buffer, " $OUT_lo $eq $CONST*$v1_lo\n") - write(buffer, " $OUT_hi $eq $CONST*$v1_hi\n") - write(buffer, " $OUT_cv $eq $CONST*$v1_cv\n") - write(buffer, " $OUT_cc $eq $CONST*$v1_cc\n") - write(buffer, " while col <= colmax\n") - if sparsity_case == 1 - write(buffer, " if $sparsity_string\n") - write(buffer, " $OUT_cvgrad $eq $CONST*$v1_cvgrad\n") - write(buffer, " $OUT_ccgrad $eq $CONST*$v1_ccgrad\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - elseif sparsity_case == 2 - write(buffer, " if $antisparsity_string\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq $CONST*$v1_cvgrad\n") - write(buffer, " $OUT_ccgrad $eq $CONST*$v1_ccgrad\n") - write(buffer, " end\n") - else - write(buffer, " $OUT_cvgrad $eq $CONST*$v1_cvgrad\n") - write(buffer, " $OUT_ccgrad $eq $CONST*$v1_ccgrad\n") - end - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, "\n") - write(buffer, " # Cut\n") - write(buffer, " if $OUT_cv < $OUT_lo\n") - write(buffer, " $OUT_cv = $OUT_lo\n") - write(buffer, " col = Int32(1)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_cvgrad = 0.0\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " end\n") - write(buffer, " if $OUT_cc > $OUT_hi\n") - write(buffer, " $OUT_cc = $OUT_hi\n") - write(buffer, " col = Int32(1)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_ccgrad = 0.0\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " end\n") - write(buffer, "\n") + if startswith(v1, r"aux|temp") + write(buffer, " ####################\n") + write(buffer, " ## Absolute value ##\n") + write(buffer, " ####################\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " if $v1_hi >= 0.0 && $v1_lo <= 0.0\n") + write(buffer, " $OUT_lo = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = min(abs($v1_lo), abs($v1_hi))\n") + write(buffer, " end\n") + write(buffer, " $OUT_hi = max(abs($v1_lo), abs($v1_hi))\n") + write(buffer, "\n") + write(buffer, " # Get eps_min and eps_max\n") + write(buffer, " if $v1_lo >= 0.0\n") + write(buffer, " eps_min = $v1_lo\n") + write(buffer, " elseif $v1_hi <= 0.0\n") + write(buffer, " eps_min = $v1_hi\n") + write(buffer, " else\n") + write(buffer, " eps_min = 0.0\n") + write(buffer, " end\n") + write(buffer, " if abs($v1_hi) >= abs($v1_lo)\n") + write(buffer, " eps_max = $v1_hi\n") + write(buffer, " else\n") + write(buffer, " eps_max = $v1_lo\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get midcv and midcc by finding the middle values of (cv, cc, eps_min), (cv, cc, eps_max)\n") + write(buffer, " midcv, cv_id, midcc, cc_id = SourceCodeMcCormick.midvals($v1_cv, $v1_cc, eps_min, eps_max)\n") + write(buffer, "\n") + write(buffer, " # Get derivative values\n") + write(buffer, " if $v1_hi - $v1_lo == 0.0\n") + write(buffer, " $OUT_cc = abs(midcc)\n") + write(buffer, " if midcc > 0.0\n") + write(buffer, " dcc = 1.0\n") + write(buffer, " elseif midcc < 0.0\n") + write(buffer, " dcc = -1.0\n") + write(buffer, " else\n") + write(buffer, " dcc = 0.0\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " $OUT_cc = (abs($v1_lo)*($v1_hi - midcc) + abs($v1_hi)*(midcc - $v1_lo))/($v1_hi-$v1_lo)\n") + write(buffer, " dcc = (abs($v1_hi) - abs($v1_lo))/($v1_hi-$v1_lo)\n") + write(buffer, " end\n") + write(buffer, " $OUT_cv = abs(midcv)\n") + write(buffer, " if midcv > 0.0\n") + write(buffer, " dcv = 1.0\n") + write(buffer, " elseif midcc < 0.0\n") + write(buffer, " dcv = -1.0\n") + write(buffer, " else\n") + write(buffer, " dcv = 0.0\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Calculate subgradients\n") + write(buffer, " if cv_id==1\n") + write(buffer, " if cc_id==1\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " end\n") else - ID = findfirst(==(v1), varlist) - isnothing(ID) && error("Empty varlist") - write(buffer, " ###########################################\n") - write(buffer, " ## Multiplication by a Positive Constant ##\n") - write(buffer, " ###########################################\n") - write(buffer, "\n") - write(buffer, " # Reset the column counter\n") - write(buffer, " col = Int32(1)\n") - write(buffer, "\n") - write(buffer, " # Begin rule\n") - write(buffer, " $OUT_lo $eq $CONST*$v1_lo\n") - write(buffer, " $OUT_hi $eq $CONST*$v1_hi\n") - write(buffer, " $OUT_cv $eq $CONST*$v1_cv\n") - write(buffer, " $OUT_cc $eq $CONST*$v1_cc\n") - write(buffer, " while col <= colmax\n") - write(buffer, " if col == Int32($ID)\n") - write(buffer, " $OUT_cvgrad $eq $CONST\n") - write(buffer, " $OUT_ccgrad $eq $CONST\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, "\n") - write(buffer, " # Cut\n") - write(buffer, " if $OUT_cv < $OUT_lo\n") - write(buffer, " $OUT_cv = $OUT_lo\n") - write(buffer, " col = Int32(1)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_cvgrad = 0.0\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " end\n") - write(buffer, " if $OUT_cc > $OUT_hi\n") - write(buffer, " $OUT_cc = $OUT_hi\n") - write(buffer, " col = Int32(1)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_ccgrad = 0.0\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " end\n") - write(buffer, "\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") end - else - if startswith(v1, r"aux|temp") - write(buffer, " ###########################################\n") - write(buffer, " ## Multiplication by a Negative Constant ##\n") - write(buffer, " ###########################################\n") - write(buffer, "\n") - write(buffer, " # Reset the column counter\n") - write(buffer, " col = Int32(1)\n") - write(buffer, "\n") - write(buffer, " # Begin rule\n") - write(buffer, " $OUT_lo $eq $CONST*$v1_hi\n") - write(buffer, " $OUT_hi $eq $CONST*$v1_lo\n") - write(buffer, " $OUT_cv $eq $CONST*$v1_cc\n") - write(buffer, " $OUT_cc $eq $CONST*$v1_cv\n") - write(buffer, " while col <= colmax\n") - if sparsity_case == 1 - write(buffer, " if $sparsity_string\n") - write(buffer, " $OUT_cvgrad $eq $CONST*$v1_ccgrad\n") - write(buffer, " $OUT_ccgrad $eq $CONST*$v1_cvgrad\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - elseif sparsity_case == 2 - write(buffer, " if $antisparsity_string\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq $CONST*$v1_ccgrad\n") - write(buffer, " $OUT_ccgrad $eq $CONST*$v1_cvgrad\n") - write(buffer, " end\n") - else - write(buffer, " $OUT_cvgrad $eq $CONST*$v1_ccgrad\n") - write(buffer, " $OUT_ccgrad $eq $CONST*$v1_cvgrad\n") - end - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, "\n") - write(buffer, " # Cut\n") - write(buffer, " if $OUT_cv < $OUT_lo\n") - write(buffer, " $OUT_cv = $OUT_lo\n") - write(buffer, " col = Int32(1)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_cvgrad = 0.0\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " end\n") - write(buffer, " if $OUT_cc > $OUT_hi\n") - write(buffer, " $OUT_cc = $OUT_hi\n") - write(buffer, " col = Int32(1)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_ccgrad = 0.0\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " end\n") - write(buffer, "\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif cc_id==2\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " end\n") else - ID = findfirst(==(v1), varlist) - isnothing(ID) && error("Empty varlist") - write(buffer, " ###########################################\n") - write(buffer, " ## Multiplication by a Negative Constant ##\n") - write(buffer, " ###########################################\n") - write(buffer, "\n") - write(buffer, " # Reset the column counter\n") - write(buffer, " col = Int32(1)\n") - write(buffer, "\n") - write(buffer, " # Begin rule\n") - write(buffer, " $OUT_lo $eq $CONST*$v1_hi\n") - write(buffer, " $OUT_hi $eq $CONST*$v1_lo\n") - write(buffer, " $OUT_cv $eq $CONST*$v1_cc\n") - write(buffer, " $OUT_cc $eq $CONST*$v1_cv\n") - write(buffer, " while col <= colmax\n") - write(buffer, " if col == Int32($ID)\n") - write(buffer, " $OUT_cvgrad $eq $CONST\n") - write(buffer, " $OUT_ccgrad $eq $CONST\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, "\n") - write(buffer, " # Cut\n") - write(buffer, " if $OUT_cv < $OUT_lo\n") - write(buffer, " $OUT_cv = $OUT_lo\n") - write(buffer, " col = Int32(1)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_cvgrad = 0.0\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " end\n") - write(buffer, " if $OUT_cc > $OUT_hi\n") - write(buffer, " $OUT_cc = $OUT_hi\n") - write(buffer, " col = Int32(1)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_ccgrad = 0.0\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " end\n") - write(buffer, "\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " elseif cv_id==2\n") + write(buffer, " if cc_id==1\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif cc_id==2\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " if cc_id==1\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif cc_id==2\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + else + ID = findfirst(==(v1), varlist) + isnothing(ID) && error("Empty varlist") + write(buffer, " ####################\n") + write(buffer, " ## Absolute value ##\n") + write(buffer, " ####################\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " if $v1_hi >= 0.0 && $v1_lo <= 0.0\n") + write(buffer, " $OUT_lo = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = min(abs($v1_lo), abs($v1_hi))\n") + write(buffer, " end\n") + write(buffer, " $OUT_hi = max(abs($v1_lo), abs($v1_hi))\n") + write(buffer, "\n") + write(buffer, " # Get eps_min and eps_max\n") + write(buffer, " if $v1_lo >= 0.0\n") + write(buffer, " eps_min = $v1_lo\n") + write(buffer, " elseif $v1_hi <= 0.0\n") + write(buffer, " eps_min = $v1_hi\n") + write(buffer, " else\n") + write(buffer, " eps_min = 0.0\n") + write(buffer, " end\n") + write(buffer, " if abs($v1_hi) >= abs($v1_lo)\n") + write(buffer, " eps_max = $v1_hi\n") + write(buffer, " else\n") + write(buffer, " eps_max = $v1_lo\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get midcv and midcc by finding the middle values of (cv, cc, eps_min), (cv, cc, eps_max)\n") + write(buffer, " midcv, cv_id, midcc, cc_id = SourceCodeMcCormick.midvals($v1_cv, $v1_cc, eps_min, eps_max)\n") + write(buffer, "\n") + write(buffer, " # Get derivative values\n") + write(buffer, " if $v1_hi - $v1_lo == 0.0\n") + write(buffer, " $OUT_cc = abs(midcc)\n") + write(buffer, " if midcc > 0.0\n") + write(buffer, " dcc = 1.0\n") + write(buffer, " elseif midcc < 0.0\n") + write(buffer, " dcc = -1.0\n") + write(buffer, " else\n") + write(buffer, " dcc = 0.0\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " $OUT_cc = (abs($v1_lo)*($v1_hi - midcc) + abs($v1_hi)*(midcc - $v1_lo))/($v1_hi-$v1_lo)\n") + write(buffer, " dcc = (abs($v1_hi) - abs($v1_lo))/($v1_hi-$v1_lo)\n") + write(buffer, " end\n") + write(buffer, " $OUT_cv = abs(midcv)\n") + write(buffer, " if midcv > 0.0\n") + write(buffer, " dcv = 1.0\n") + write(buffer, " elseif midcc < 0.0\n") + write(buffer, " dcv = -1.0\n") + write(buffer, " else\n") + write(buffer, " dcv = 0.0\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Calculate subgradients\n") + write(buffer, " if cv_id==1 || cv_id==2\n") + write(buffer, " if cc_id==1 || cc_id==2\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = dcv\n") + write(buffer, " $OUT_ccgrad = dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " if cc_id==1 || cc_id==2\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") end return String(take!(buffer)) end -# Sigmoid function +# Multiplication by a constant # max threads: 640 -function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, sparsity::Vector{Int}; sum_output::Bool=false) +function SCMC_cmul_kernel(OUT::String, v1::String, CONST::Real, varlist::Vector{String}, sparsity::Vector{Int}; sum_output::Bool=false) if sum_output eq = "+=" else @@ -3596,79 +3793,353 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s buffer = Base.IOBuffer() # Write all the lines to the buffer - - if startswith(v1, r"aux|temp") - write(buffer, " #############\n") - write(buffer, " ## Sigmoid ##\n") - write(buffer, " #############\n") - write(buffer, "\n") - write(buffer, " # Reset the column counter\n") - write(buffer, " col = Int32(1)\n") - write(buffer, "\n") - write(buffer, " # Calculate the interval separately\n") - write(buffer, " $OUT_lo $eq 1.0/(1.0+exp(-$v1_lo))\n") - write(buffer, " $OUT_hi $eq 1.0/(1.0+exp(-$v1_hi))\n") - write(buffer, "\n") - write(buffer, " # Begin rule\n") - write(buffer, " if $v1_lo >= 0.0\n") - write(buffer, " if $v1_cc >= $v1_cv\n") - write(buffer, " if $v1_cv >= $v1_hi\n") - write(buffer, " if $v1_cv >= $v1_lo\n") - write(buffer, " if $v1_lo == $v1_hi\n") - write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_cv))\n") - write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cv))\n") - write(buffer, " while col <= colmax\n") - if sparsity_case == 1 - write(buffer, " if $sparsity_string\n") - write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") - write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - elseif sparsity_case == 2 - write(buffer, " if $antisparsity_string\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") - write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") - write(buffer, " end\n") + + if CONST >= 0.0 + if startswith(v1, r"aux|temp") + write(buffer, " ###########################################\n") + write(buffer, " ## Multiplication by a Positive Constant ##\n") + write(buffer, " ###########################################\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " $OUT_lo $eq $CONST*$v1_lo\n") + write(buffer, " $OUT_hi $eq $CONST*$v1_hi\n") + write(buffer, " $OUT_cv $eq $CONST*$v1_cv\n") + write(buffer, " $OUT_cc $eq $CONST*$v1_cc\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad $eq $CONST*$v1_cvgrad\n") + write(buffer, " $OUT_ccgrad $eq $CONST*$v1_ccgrad\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq $CONST*$v1_cvgrad\n") + write(buffer, " $OUT_ccgrad $eq $CONST*$v1_ccgrad\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad $eq $CONST*$v1_cvgrad\n") + write(buffer, " $OUT_ccgrad $eq $CONST*$v1_ccgrad\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") else - write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") - write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + ID = findfirst(==(v1), varlist) + isnothing(ID) && error("Empty varlist") + write(buffer, " ###########################################\n") + write(buffer, " ## Multiplication by a Positive Constant ##\n") + write(buffer, " ###########################################\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " $OUT_lo $eq $CONST*$v1_lo\n") + write(buffer, " $OUT_hi $eq $CONST*$v1_hi\n") + write(buffer, " $OUT_cv $eq $CONST*$v1_cv\n") + write(buffer, " $OUT_cc $eq $CONST*$v1_cc\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad $eq $CONST\n") + write(buffer, " $OUT_ccgrad $eq $CONST\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") end - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_cv) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_cv - $v1_lo))/($v1_hi - $v1_lo)\n") - write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cv))\n") - write(buffer, " while col <= colmax\n") - if sparsity_case == 1 - write(buffer, " if $sparsity_string\n") - write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * (1.0/(1.0 + exp(-$v1_hi)) - 1.0/(1.0 + exp(-$v1_lo)))/($v1_hi - $v1_lo)\n") - write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - elseif sparsity_case == 2 - write(buffer, " if $antisparsity_string\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * (1.0/(1.0 + exp(-$v1_hi)) - 1.0/(1.0 + exp(-$v1_lo)))/($v1_hi - $v1_lo)\n") - write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") - write(buffer, " end\n") + else + if startswith(v1, r"aux|temp") + write(buffer, " ###########################################\n") + write(buffer, " ## Multiplication by a Negative Constant ##\n") + write(buffer, " ###########################################\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " $OUT_lo $eq $CONST*$v1_hi\n") + write(buffer, " $OUT_hi $eq $CONST*$v1_lo\n") + write(buffer, " $OUT_cv $eq $CONST*$v1_cc\n") + write(buffer, " $OUT_cc $eq $CONST*$v1_cv\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad $eq $CONST*$v1_ccgrad\n") + write(buffer, " $OUT_ccgrad $eq $CONST*$v1_cvgrad\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq $CONST*$v1_ccgrad\n") + write(buffer, " $OUT_ccgrad $eq $CONST*$v1_cvgrad\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad $eq $CONST*$v1_ccgrad\n") + write(buffer, " $OUT_ccgrad $eq $CONST*$v1_cvgrad\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") else - write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * (1.0/(1.0 + exp(-$v1_hi)) - 1.0/(1.0 + exp(-$v1_lo)))/($v1_hi - $v1_lo)\n") - write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + ID = findfirst(==(v1), varlist) + isnothing(ID) && error("Empty varlist") + write(buffer, " ###########################################\n") + write(buffer, " ## Multiplication by a Negative Constant ##\n") + write(buffer, " ###########################################\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " $OUT_lo $eq $CONST*$v1_hi\n") + write(buffer, " $OUT_hi $eq $CONST*$v1_lo\n") + write(buffer, " $OUT_cv $eq $CONST*$v1_cc\n") + write(buffer, " $OUT_cc $eq $CONST*$v1_cv\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad $eq $CONST\n") + write(buffer, " $OUT_ccgrad $eq $CONST\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") end - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - write(buffer, " end\n") - write(buffer, " elseif $v1_cc == $v1_cv\n") - write(buffer, " if $v1_lo == $v1_hi\n") + end + return String(take!(buffer)) +end + +# Sigmoid function +# max threads: 640 +function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, sparsity::Vector{Int}; sum_output::Bool=false) + if sum_output + eq = "+=" + else + eq = "=" + end + + if startswith(v1, "temp") + v1_cv = "$(v1)_cv" + v1_cc = "$(v1)_cc" + v1_lo = "$(v1)_lo" + v1_hi = "$(v1)_hi" + v1_cvgrad = "$(v1)_cvgrad[col]" + v1_ccgrad = "$(v1)_ccgrad[col]" + elseif startswith(v1, "aux") + v1_cv = "$(v1)[idx,1]" + v1_cc = "$(v1)[idx,2]" + v1_lo = "$(v1)[idx,3]" + v1_hi = "$(v1)[idx,4]" + v1_cvgrad = "$(v1)[idx,end-2*colmax+col]" + v1_ccgrad = "$(v1)[idx,end-1*colmax+col]" + else + v1_cv = "$(v1)[idx,1]" + v1_cc = "$(v1)[idx,1]" + v1_lo = "$(v1)[idx,2]" + v1_hi = "$(v1)[idx,3]" + end + if startswith(OUT, "temp") + OUT_cv = "$(OUT)_cv" + OUT_cc = "$(OUT)_cc" + OUT_lo = "$(OUT)_lo" + OUT_hi = "$(OUT)_hi" + OUT_cvgrad = "$(OUT)_cvgrad[col]" + OUT_ccgrad = "$(OUT)_ccgrad[col]" + else + OUT_cv = "$(OUT)[idx,1]" + OUT_cc = "$(OUT)[idx,2]" + OUT_lo = "$(OUT)[idx,3]" + OUT_hi = "$(OUT)[idx,4]" + OUT_cvgrad = "$(OUT)[idx,end-2*colmax+col]" + OUT_ccgrad = "$(OUT)[idx,end-1*colmax+col]" + end + + # Get the anti-sparsity list (elements NOT being used) + antisparsity = collect(1:length(varlist)) + antisparsity = antisparsity[antisparsity .∉ Ref(sparsity)] + + # Determine the sparsity case: + # 1) Use sparsity list + # 2) Use antisparsity list (because it's shorter than the sparsity list) + # 3) Don't use either, simply calculate all elements + if length(sparsity) <= length(antisparsity) + sparsity_case = 1 + sparsity_string = join(["col == Int32($(x))" for x in sparsity], " || ") + elseif length(antisparsity) > 0 + antisparsity_string = join(["col == Int32($(x))" for x in antisparsity], " || ") + sparsity_case = 2 + else + sparsity_case = 3 + end + + # Create the buffer that we will write to + buffer = Base.IOBuffer() + + # Write all the lines to the buffer + + if startswith(v1, r"aux|temp") + write(buffer, " #############\n") + write(buffer, " ## Sigmoid ##\n") + write(buffer, " #############\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Calculate the interval separately\n") + write(buffer, " $OUT_lo $eq 1.0/(1.0+exp(-$v1_lo))\n") + write(buffer, " $OUT_hi $eq 1.0/(1.0+exp(-$v1_hi))\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " if $v1_lo >= 0.0\n") + write(buffer, " if $v1_cc >= $v1_cv\n") + write(buffer, " if $v1_cv >= $v1_hi\n") + write(buffer, " if $v1_cv >= $v1_lo\n") + write(buffer, " if $v1_lo == $v1_hi\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_cv))\n") + write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cv))\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_cv) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_cv - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cv))\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * (1.0/(1.0 + exp(-$v1_hi)) - 1.0/(1.0 + exp(-$v1_lo)))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * (1.0/(1.0 + exp(-$v1_hi)) - 1.0/(1.0 + exp(-$v1_lo)))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad $eq $v1_cvgrad * (1.0/(1.0 + exp(-$v1_hi)) - 1.0/(1.0 + exp(-$v1_lo)))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_ccgrad $eq $v1_cvgrad * exp(-$v1_cv)/(1.0 + exp(-$v1_cv))^2\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " elseif $v1_cc == $v1_cv\n") + write(buffer, " if $v1_lo == $v1_hi\n") write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_cv))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cv))\n") write(buffer, " while col <= colmax\n") @@ -3803,7 +4274,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cv))\n") write(buffer, " while col <= colmax\n") if sparsity_case == 1 @@ -4021,7 +4492,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cc))\n") write(buffer, " while col <= colmax\n") if sparsity_case == 1 @@ -4167,7 +4638,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_hi))\n") write(buffer, " while col <= colmax\n") write(buffer, " $OUT_cvgrad $eq 0.0\n") @@ -4314,7 +4785,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cv))\n") write(buffer, " while col <= colmax\n") if sparsity_case == 1 @@ -4478,7 +4949,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cc))\n") write(buffer, " while col <= colmax\n") if sparsity_case == 1 @@ -4624,7 +5095,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_hi))\n") write(buffer, " while col <= colmax\n") write(buffer, " $OUT_cvgrad $eq 0.0\n") @@ -9550,149 +10021,671 @@ function SCMC_float_power_kernel(OUT::String, v1::String, POW::T, varlist::Vecto write(buffer, "\n") write(buffer, " # Cut\n") write(buffer, " if $OUT_cv < $OUT_lo\n") - write(buffer, " $OUT_cv $eq $OUT_lo\n") + write(buffer, " $OUT_cv $eq $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc $eq $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + else + ID = findfirst(==(v1), varlist) + isnothing(ID) && error("Empty varlist") + if POW==0.5 + write(buffer, " #################\n") + write(buffer, " ## Square Root ##\n") + write(buffer, " #################\n") + write(buffer, "\n") + else + L = length(string(POW)) + write(buffer, " ###################$("#"^L)####\n") + write(buffer, " ## Floating Power ($POW) ##\n") + write(buffer, " ###################$("#"^L)####\n") + write(buffer, "\n") + end + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " if $v1_lo < 0.0\n") + write(buffer, " $OUT_lo = NaN\n") + write(buffer, " $OUT_hi = NaN\n") + write(buffer, " $OUT_cv = NaN\n") + write(buffer, " $OUT_cc = NaN\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = NaN\n") + write(buffer, " $OUT_ccgrad = NaN\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif $v1_lo == $v1_hi\n") + if POW==0.5 + write(buffer, " $OUT_lo = sqrt($v1_lo)\n") + write(buffer, " $OUT_hi = sqrt($v1_hi)\n") + write(buffer, " $OUT_cv = sqrt($v1_cv)\n") + write(buffer, " $OUT_cc = sqrt($v1_cc)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = 0.5 / sqrt($v1_cv)\n") + write(buffer, " $OUT_ccgrad = 0.5 / sqrt($v1_cc)\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_lo = $v1_lo^$POW\n") + write(buffer, " $OUT_hi = $v1_hi^$POW\n") + write(buffer, " $OUT_cv = $v1_cv^$POW\n") + write(buffer, " $OUT_cc = $v1_cc^$POW\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = $POW*$v1_cv^$(POW-1)\n") + write(buffer, " $OUT_ccgrad = $POW*$v1_cc^$(POW-1)\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + end + write(buffer, " else\n") + if POW==0.5 + write(buffer, " $OUT_lo = sqrt($v1_lo)\n") + write(buffer, " $OUT_hi = sqrt($v1_hi)\n") + write(buffer, " $OUT_cv = (sqrt($v1_lo)*($v1_hi - $v1_cv) + sqrt($v1_hi)*($v1_cv - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cc = sqrt($v1_cv)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = (sqrt($v1_hi) - sqrt($v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_ccgrad = 0.5 / sqrt($v1_cv)\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + elseif POW < 1.0 + write(buffer, " $OUT_lo = $v1_lo^$POW\n") + write(buffer, " $OUT_hi = $v1_hi^$POW\n") + write(buffer, " $OUT_cv = ($v1_lo^$POW*($v1_hi - $v1_cv) + $v1_hi^$POW*($v1_cv - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cc = $v1_cv^$POW\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = ($v1_hi^$POW - $v1_lo^$POW)/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_ccgrad = $POW*$v1_cv^$(POW-1)\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_lo = $v1_lo^$POW\n") + write(buffer, " $OUT_hi = $v1_hi^$POW\n") + write(buffer, " $OUT_cv = $v1_cv^$POW\n") + write(buffer, " $OUT_cc = ($v1_lo^$POW*($v1_hi - $v1_cv) + $v1_hi^$POW*($v1_cv - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = $POW*$v1_cv^$(POW-1)\n") + write(buffer, " $OUT_ccgrad = ($v1_hi^$POW - $v1_lo^$POW)/($v1_hi - $v1_lo)\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + end + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv $eq $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc $eq $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + end + return String(take!(buffer)) +end + +# Cos +# max threads: ??? +function SCMC_cos_kernel(OUT::String, v1::String, varlist::Vector{String}, sparsity::Vector{Int}) + if startswith(v1, "temp") + v1_cv = "$(v1)_cv" + v1_cc = "$(v1)_cc" + v1_lo = "$(v1)_lo" + v1_hi = "$(v1)_hi" + v1_cvgrad = "$(v1)_cvgrad[col]" + v1_ccgrad = "$(v1)_ccgrad[col]" + elseif startswith(v1, "aux") + v1_cv = "$(v1)[idx,1]" + v1_cc = "$(v1)[idx,2]" + v1_lo = "$(v1)[idx,3]" + v1_hi = "$(v1)[idx,4]" + v1_cvgrad = "$(v1)[idx,end-2*colmax+col]" + v1_ccgrad = "$(v1)[idx,end-1*colmax+col]" + else + v1_cv = "$(v1)[idx,1]" + v1_cc = "$(v1)[idx,1]" + v1_lo = "$(v1)[idx,2]" + v1_hi = "$(v1)[idx,3]" + end + if startswith(OUT, "temp") + OUT_cv = "$(OUT)_cv" + OUT_cc = "$(OUT)_cc" + OUT_lo = "$(OUT)_lo" + OUT_hi = "$(OUT)_hi" + OUT_cvgrad = "$(OUT)_cvgrad[col]" + OUT_ccgrad = "$(OUT)_ccgrad[col]" + else + OUT_cv = "$(OUT)[idx,1]" + OUT_cc = "$(OUT)[idx,2]" + OUT_lo = "$(OUT)[idx,3]" + OUT_hi = "$(OUT)[idx,4]" + OUT_cvgrad = "$(OUT)[idx,end-2*colmax+col]" + OUT_ccgrad = "$(OUT)[idx,end-1*colmax+col]" + end + + # Get the anti-sparsity list (elements NOT being used) + antisparsity = collect(1:length(varlist)) + antisparsity = antisparsity[antisparsity .∉ Ref(sparsity)] + + # Determine the sparsity case: + # 1) Use sparsity list + # 2) Use antisparsity list (because it's shorter than the sparsity list) + # 3) Don't use either, simply calculate all elements + if length(sparsity) <= length(antisparsity) + sparsity_case = 1 + sparsity_string = join(["col == Int32($(x))" for x in sparsity], " || ") + elseif length(antisparsity) > 0 + antisparsity_string = join(["col == Int32($(x))" for x in antisparsity], " || ") + sparsity_case = 2 + else + sparsity_case = 3 + end + + # Create the buffer that we will write to + buffer = Base.IOBuffer() + + # Write all the lines to the buffer + if startswith(v1, r"aux|temp") + write(buffer, " ##############################\n") + write(buffer, " ## Cosine (Or Shifted Sine) ##\n") + write(buffer, " ##############################\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " if ($v1_hi - $v1_lo) >= 2.0*pi\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " else\n") + write(buffer, " lo_quadrant, lo = SourceCodeMcCormick.quadrant($v1_lo)\n") + write(buffer, " hi_quadrant, hi = SourceCodeMcCormick.quadrant($v1_hi)\n") + write(buffer, "\n") + write(buffer, " if lo_quadrant == hi_quadrant\n") + write(buffer, " if $v1_hi - $v1_lo > 3.141592653589793\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " elseif lo_quadrant==2 || lo_quadrant==3\n") + write(buffer, " $OUT_lo = cos(lo)\n") + write(buffer, " $OUT_hi = cos(hi)\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = cos(hi)\n") + write(buffer, " $OUT_hi = cos(lo)\n") + write(buffer, " end\n") + write(buffer, " elseif lo_quadrant==2 && hi_quadrant==3\n") + write(buffer, " $OUT_lo = cos(lo)\n") + write(buffer, " $OUT_hi = cos(hi)\n") + write(buffer, " elseif lo_quadrant==0 && hi_quadrant==1\n") + write(buffer, " $OUT_lo = cos(hi)\n") + write(buffer, " $OUT_hi = cos(lo)\n") + write(buffer, " elseif (lo_quadrant==2 || lo_quadrant==3) && (hi_quadrant==0 || hi_quadrant==1)\n") + write(buffer, " $OUT_lo = min(cos(lo), cos(hi))\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " elseif (lo_quadrant==0 || lo_quadrant==1) && (hi_quadrant==2 || hi_quadrant==3)\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = max(cos(lo), cos(hi))\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get eps_min and eps_max\n") + write(buffer, " kL = Base.ceil(-0.5 - $v1_lo/(2.0*pi))\n") + write(buffer, " xL1 = $v1_lo + 2.0*pi*kL\n") + write(buffer, " xU1 = $v1_hi + 2.0*pi*kL\n") + write(buffer, " if (xL1 < -pi) || (xL1 > pi)\n") + write(buffer, " eps_min = NaN\n") + write(buffer, " eps_max = NaN\n") + write(buffer, " elseif xL1 <= 0.0\n") + write(buffer, " if xU1 <= 0.0\n") + write(buffer, " eps_min = $v1_lo\n") + write(buffer, " eps_max = $v1_hi\n") + write(buffer, " elseif xU1 >= pi\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = -2.0*pi*kL\n") + write(buffer, " else\n") + write(buffer, " eps_min = (cos(xL1) <= cos(xU1)) ? $v1_lo : $v1_hi\n") + write(buffer, " eps_max = -2.0*pi*kL\n") + write(buffer, " end\n") + write(buffer, " elseif xU1 <= pi\n") + write(buffer, " eps_min = $v1_hi\n") + write(buffer, " eps_max = $v1_lo\n") + write(buffer, " elseif xU1 >= 2.0*pi\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = 2.0*pi - 2.0*pi*kL\n") + write(buffer, " else\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = (cos(xL1) >= cos(xU1)) ? $v1_lo : $v1_hi\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get midcv and midcc by finding the middle values of (cv, cc, eps_min), (cv, cc, eps_max)\n") + write(buffer, " midcv, cv_id, midcc, cc_id = SourceCodeMcCormick.midvals($v1_cv, $v1_cc, eps_min, eps_max)\n") + write(buffer, "\n") + write(buffer, " # Call the SCMC_cv_cos function for both cv and cc\n") + write(buffer, " cv, dcv = SourceCodeMcCormick.SCMC_cv_cos(midcv, $v1_lo, $v1_hi)\n") + write(buffer, " neg_cc, neg_dcc = SourceCodeMcCormick.SCMC_cv_cos(midcc - pi, $v1_lo - pi, $v1_hi - pi)\n") + write(buffer, " $OUT_cv = cv\n") + write(buffer, " $OUT_cc = -neg_cc\n") + write(buffer, " dcc = -neg_dcc\n") + write(buffer, "\n") + write(buffer, " # Calculate subgradients\n") + write(buffer, " if cv_id==1\n") + write(buffer, " if cc_id==1\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif cc_id==2\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " elseif cv_id==2\n") + write(buffer, " if cc_id==1\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif cc_id==2\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " if cc_id==1\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif cc_id==2\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") write(buffer, " col = Int32(1)\n") write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_cvgrad = 0.0\n") write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " end\n") write(buffer, " if $OUT_cc > $OUT_hi\n") - write(buffer, " $OUT_cc $eq $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") write(buffer, " col = Int32(1)\n") write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " end\n") - write(buffer, "\n") else ID = findfirst(==(v1), varlist) isnothing(ID) && error("Empty varlist") - if POW==0.5 - write(buffer, " #################\n") - write(buffer, " ## Square Root ##\n") - write(buffer, " #################\n") - write(buffer, "\n") - else - L = length(string(POW)) - write(buffer, " ###################$("#"^L)####\n") - write(buffer, " ## Floating Power ($POW) ##\n") - write(buffer, " ###################$("#"^L)####\n") - write(buffer, "\n") - end + write(buffer, " ##############################\n") + write(buffer, " ## Cosine (Or Shifted Sine) ##\n") + write(buffer, " ##############################\n") + write(buffer, "\n") write(buffer, " # Reset the column counter\n") write(buffer, " col = Int32(1)\n") write(buffer, "\n") - write(buffer, " if $v1_lo < 0.0\n") - write(buffer, " $OUT_lo = NaN\n") - write(buffer, " $OUT_hi = NaN\n") - write(buffer, " $OUT_cv = NaN\n") - write(buffer, " $OUT_cc = NaN\n") - write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_cvgrad = NaN\n") - write(buffer, " $OUT_ccgrad = NaN\n") - write(buffer, " col += Int32(1)\n") + write(buffer, " # Begin rule\n") + write(buffer, " if ($v1_hi - $v1_lo) >= 2.0*pi\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " else\n") + write(buffer, " lo_quadrant, lo = SourceCodeMcCormick.quadrant($v1_lo)\n") + write(buffer, " hi_quadrant, hi = SourceCodeMcCormick.quadrant($v1_hi)\n") + write(buffer, "\n") + write(buffer, " if lo_quadrant == hi_quadrant\n") + write(buffer, " if $v1_hi - $v1_lo > 3.141592653589793\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " elseif lo_quadrant==2 && hi_quadrant==3\n") + write(buffer, " $OUT_lo = cos(lo)\n") + write(buffer, " $OUT_hi = cos(hi)\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = cos(hi)\n") + write(buffer, " $OUT_hi = cos(lo)\n") + write(buffer, " end\n") + write(buffer, " elseif lo_quadrant==2 && hi_quadrant==3\n") + write(buffer, " $OUT_lo = cos(lo)\n") + write(buffer, " $OUT_hi = cos(hi)\n") + write(buffer, " elseif lo_quadrant==0 && hi_quadrant==1\n") + write(buffer, " $OUT_lo = cos(hi)\n") + write(buffer, " $OUT_hi = cos(lo)\n") + write(buffer, " elseif (lo_quadrant==2 || lo_quadrant==3) && (hi_quadrant==0 || hi_quadrant==1)\n") + write(buffer, " $OUT_lo = min(cos(lo), cos(hi))\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " elseif (lo_quadrant==0 || lo_quadrant==1) && (hi_quadrant==2 || hi_quadrant==3)\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = max(cos(lo), cos(hi))\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") write(buffer, " end\n") - write(buffer, " elseif $v1_lo == $v1_hi\n") - if POW==0.5 - write(buffer, " $OUT_lo = sqrt($v1_lo)\n") - write(buffer, " $OUT_hi = sqrt($v1_hi)\n") - write(buffer, " $OUT_cv = sqrt($v1_cv)\n") - write(buffer, " $OUT_cc = sqrt($v1_cc)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " if col == Int32($ID)\n") - write(buffer, " $OUT_cvgrad = 0.5 / sqrt($v1_cv)\n") - write(buffer, " $OUT_ccgrad = 0.5 / sqrt($v1_cc)\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - else - write(buffer, " $OUT_lo = $v1_lo^$POW\n") - write(buffer, " $OUT_hi = $v1_hi^$POW\n") - write(buffer, " $OUT_cv = $v1_cv^$POW\n") - write(buffer, " $OUT_cc = $v1_cc^$POW\n") - write(buffer, " while col <= colmax\n") - write(buffer, " if col == Int32($ID)\n") - write(buffer, " $OUT_cvgrad = $POW*$v1_cv^$(POW-1)\n") - write(buffer, " $OUT_ccgrad = $POW*$v1_cc^$(POW-1)\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - end + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get eps_min and eps_max\n") + write(buffer, " kL = Base.ceil(-0.5 - $v1_lo/(2.0*pi))\n") + write(buffer, " xL1 = $v1_lo + 2.0*pi*kL\n") + write(buffer, " xU1 = $v1_hi + 2.0*pi*kL\n") + write(buffer, " if (xL1 < -pi) || (xL1 > pi)\n") + write(buffer, " eps_min = NaN\n") + write(buffer, " eps_max = NaN\n") + write(buffer, " elseif xL1 <= 0.0\n") + write(buffer, " if xU1 <= 0.0\n") + write(buffer, " eps_min = $v1_lo\n") + write(buffer, " eps_max = $v1_hi\n") + write(buffer, " elseif xU1 >= pi\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = -2.0*pi*kL\n") + write(buffer, " else\n") + write(buffer, " eps_min = (cos(xL1) <= cos(xU1)) ? $v1_lo : $v1_hi\n") + write(buffer, " eps_max = -2.0*pi*kL\n") + write(buffer, " end\n") + write(buffer, " elseif xU1 <= pi\n") + write(buffer, " eps_min = $v1_hi\n") + write(buffer, " eps_max = $v1_lo\n") + write(buffer, " elseif xU1 >= 2.0*pi\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = 2.0*pi - 2.0*pi*kL\n") write(buffer, " else\n") - if POW==0.5 - write(buffer, " $OUT_lo = sqrt($v1_lo)\n") - write(buffer, " $OUT_hi = sqrt($v1_hi)\n") - write(buffer, " $OUT_cv = (sqrt($v1_lo)*($v1_hi - $v1_cv) + sqrt($v1_hi)*($v1_cv - $v1_lo))/($v1_hi - $v1_lo)\n") - write(buffer, " $OUT_cc = sqrt($v1_cv)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " if col == Int32($ID)\n") - write(buffer, " $OUT_cvgrad = (sqrt($v1_hi) - sqrt($v1_lo))/($v1_hi - $v1_lo)\n") - write(buffer, " $OUT_ccgrad = 0.5 / sqrt($v1_cv)\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - elseif POW < 1.0 - write(buffer, " $OUT_lo = $v1_lo^$POW\n") - write(buffer, " $OUT_hi = $v1_hi^$POW\n") - write(buffer, " $OUT_cv = ($v1_lo^$POW*($v1_hi - $v1_cv) + $v1_hi^$POW*($v1_cv - $v1_lo))/($v1_hi - $v1_lo)\n") - write(buffer, " $OUT_cc = $v1_cv^$POW\n") - write(buffer, " while col <= colmax\n") - write(buffer, " if col == Int32($ID)\n") - write(buffer, " $OUT_cvgrad = ($v1_hi^$POW - $v1_lo^$POW)/($v1_hi - $v1_lo)\n") - write(buffer, " $OUT_ccgrad = $POW*$v1_cv^$(POW-1)\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - else - write(buffer, " $OUT_lo = $v1_lo^$POW\n") - write(buffer, " $OUT_hi = $v1_hi^$POW\n") - write(buffer, " $OUT_cv = $v1_cv^$POW\n") - write(buffer, " $OUT_cc = ($v1_lo^$POW*($v1_hi - $v1_cv) + $v1_hi^$POW*($v1_cv - $v1_lo))/($v1_hi - $v1_lo)\n") - write(buffer, " while col <= colmax\n") - write(buffer, " if col == Int32($ID)\n") - write(buffer, " $OUT_cvgrad = $POW*$v1_cv^$(POW-1)\n") - write(buffer, " $OUT_ccgrad = ($v1_hi^$POW - $v1_lo^$POW)/($v1_hi - $v1_lo)\n") - write(buffer, " else\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") - write(buffer, " end\n") - write(buffer, " col += Int32(1)\n") - write(buffer, " end\n") - end + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = (cos(xL1) >= cos(xU1)) ? $v1_lo : $v1_hi\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get midcv and midcc by finding the middle values of (cv, cc, eps_min), (cv, cc, eps_max)\n") + write(buffer, " midcv, cv_id, midcc, cc_id = SourceCodeMcCormick.midvals($v1_cv, $v1_cc, eps_min, eps_max)\n") + write(buffer, "\n") + write(buffer, " # Call the SCMC_cv_cos function for both cv and cc\n") + write(buffer, " cv, dcv = SourceCodeMcCormick.SCMC_cv_cos(midcv, $v1_lo, $v1_hi)\n") + write(buffer, " neg_cc, neg_dcc = SourceCodeMcCormick.SCMC_cv_cos(midcc - pi, $v1_lo - pi, $v1_hi - pi)\n") + write(buffer, " $OUT_cv = cv\n") + write(buffer, " $OUT_cc = -neg_cc\n") + write(buffer, " dcc = -neg_dcc\n") + write(buffer, "\n") + write(buffer, " # Calculate subgradients\n") + write(buffer, " if cv_id==1 || cv_id==2\n") + write(buffer, " if cc_id==1 || cc_id==2\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = dcv\n") + write(buffer, " $OUT_ccgrad = dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " if cc_id==1 || cc_id==2\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") write(buffer, " end\n") write(buffer, "\n") write(buffer, " # Cut\n") write(buffer, " if $OUT_cv < $OUT_lo\n") - write(buffer, " $OUT_cv $eq $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") write(buffer, " col = Int32(1)\n") write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_cvgrad $eq 0.0\n") + write(buffer, " $OUT_cvgrad = 0.0\n") write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " end\n") write(buffer, " if $OUT_cc > $OUT_hi\n") - write(buffer, " $OUT_cc $eq $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") write(buffer, " col = Int32(1)\n") write(buffer, " while col <= colmax\n") - write(buffer, " $OUT_ccgrad $eq 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " end\n") - write(buffer, "\n") end return String(take!(buffer)) end @@ -14107,6 +15100,23 @@ function SCMC_quadaff_initialize(CONST::Real) # Create the buffer that we will write to buffer = Base.IOBuffer() + # Reset subgradients to 0, since they only get added to in quadaff expressions + write(buffer, " ##################################\n") + write(buffer, " ## Reset Terms and Subgradients ##\n") + write(buffer, " ##################################\n") + write(buffer, "\n") + write(buffer, " temp1_cv = 0.0\n") + write(buffer, " temp1_cc = 0.0\n") + write(buffer, " temp1_lo = 0.0\n") + write(buffer, " temp1_hi = 0.0\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " temp1_cvgrad[col] = 0.0\n") + write(buffer, " temp1_ccgrad[col] = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, "\n") + # Write the initialization of the quadratic constants # to the buffer write(buffer, " #############################\n") @@ -14116,6 +15126,7 @@ function SCMC_quadaff_initialize(CONST::Real) write(buffer, " intercept_cv = $(Float64(CONST))\n") write(buffer, " intercept_cc = $(Float64(CONST))\n") write(buffer, "\n") + return String(take!(buffer)) end diff --git a/src/transform/utilities.jl b/src/transform/utilities.jl index c7c9d25..ebbc545 100644 --- a/src/transform/utilities.jl +++ b/src/transform/utilities.jl @@ -284,6 +284,9 @@ function pull_vars(eqns::Vector{Equation}) end return vars end +function pull_vars(eqn::T) where T<:Real + return Num[] +end # Sorts variables in a more logical ordering, to be consistent # with McCormick.jl organization.