Skip to content

Commit defd87a

Browse files
committed
Merge remote-tracking branch 'origin/master' into migrate-to-Expronicon
2 parents 79f524d + b3ce057 commit defd87a

20 files changed

+393
-59
lines changed

.typos.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[default.extend-words]
22
numer = "numer"
33
Commun = "Commun"
4-
nd = "nd"
4+
nd = "nd"
5+
assum = "assum"

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Symbolics"
22
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
33
authors = ["Shashi Gowda <gowda@mit.edu>"]
4-
version = "6.11.0"
4+
version = "6.13.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -93,7 +93,7 @@ StaticArraysCore = "1.4"
9393
SymPy = "2.2"
9494
SymbolicIndexingInterface = "0.3.14"
9595
SymbolicLimits = "0.2.2"
96-
SymbolicUtils = "2, 3"
96+
SymbolicUtils = "3.7"
9797
TermInterface = "2"
9898
julia = "1.10"
9999

ext/SymbolicsGroebnerExt.jl

-4
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,9 @@ end
320320
# Helps with precompilation time
321321
PrecompileTools.@setup_workload begin
322322
@variables a b c x y z
323-
equation1 = a*log(x)^b + c ~ 0
324-
equation_actually_polynomial = sin(x^2 +1)^2 + sin(x^2 + 1) + 3
325323
simple_linear_equations = [x - y, y + 2z]
326324
equations_intersect_sphere_line = [x^2 + y^2 + z^2 - 9, x - 2y + 3, y - z]
327325
PrecompileTools.@compile_workload begin
328-
symbolic_solve(equation1, x)
329-
symbolic_solve(equation_actually_polynomial)
330326
symbolic_solve(simple_linear_equations, [x, y], warns=false)
331327
symbolic_solve(equations_intersect_sphere_line, [x, y, z], warns=false)
332328
end

ext/SymbolicsNemoExt.jl

+6
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ end
6161
PrecompileTools.@setup_workload begin
6262
@variables a b c x y z
6363
expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a))
64+
equation1 = a*log(x)^b + c ~ 0
65+
equation_polynomial = 9^x + 3^x + 2
66+
exp_eq = 5*2^(x+1) + 7^(x+3)
6467
PrecompileTools.@compile_workload begin
68+
symbolic_solve(equation1, x)
69+
symbolic_solve(equation_polynomial, x)
70+
symbolic_solve(exp_eq)
6571
symbolic_solve(expr_with_params, x, dropmultiplicity=false)
6672
symbolic_solve(x^10 - a^10, x, dropmultiplicity=false)
6773
end

src/Symbolics.jl

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ export Inequality, ≲, ≳
9292
include("inequality.jl")
9393

9494
import Bijections, DynamicPolynomials
95+
export tosymbol
9596
include("utils.jl")
9697

9798
using ConstructionBase

src/arrays.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ end
6363
ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T}
6464

6565
function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, m)
66+
args = map(args) do arg
67+
if iscall(arg) && operation(arg) == Ref && symbolic_type(only(arguments(arg))) == NotSymbolic()
68+
return Ref(only(arguments(arg)))
69+
else
70+
return arg
71+
end
72+
end
73+
6674
t = f(args...)
6775
t isa Symbolic && !isnothing(m) ?
6876
metadata(t, m) : t
@@ -968,7 +976,7 @@ end
968976
### Codegen
969977

970978
function SymbolicUtils.Code.toexpr(x::ArrayOp, st)
971-
haskey(st.symbolify, x) && return st.symbolify[x]
979+
haskey(st.rewrites, x) && return st.rewrites[x]
972980

973981
if iscall(x.term)
974982
toexpr(x.term, st)

src/solver/attract.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,8 @@ function attract_trig(lhs, var)
197197
r_trig = [@acrule(sin(~x::(contains_var))^2 + cos(~x::(contains_var))^2=>one(~x))
198198
@acrule(sin(~x::(contains_var))^2 + -1=>-1 * cos(~x)^2)
199199
@acrule(cos(~x::(contains_var))^2 + -1=>-1 * sin(~x)^2)
200-
@acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2 *
201-
~x))
202-
@acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2 *
203-
~x))
200+
@acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2*~x))
201+
@acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2*~x))
204202
@acrule(cos(~x::(contains_var)) * sin(~x::(contains_var))=>sin(2 * ~x) / 2)
205203
@acrule(tan(~x::(contains_var))^2 + -1 * sec(~x::(contains_var))^2=>one(~x))
206204
@acrule(-1 * tan(~x::(contains_var))^2 + sec(~x::(contains_var))^2=>one(~x))

src/solver/ia_main.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function isolate(lhs, var; warns=true, conditions=[])
123123
new_var = (@variables $new_var)[1]
124124
rhs = map(
125125
sol -> term(rev_oper[oper], sol) +
126-
term(*, Base.MathConstants.pi, 2 * new_var),
126+
term(*, Base.MathConstants.pi, new_var),
127127
rhs)
128128
@info string(new_var) * " ϵ" * " Ζ"
129129

src/solver/main.jl

+16-6
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
195195
for e in expr
196196
for var in x
197197
if !check_poly_inunivar(e, var)
198-
warns && @warn("This system can not be currently solved by solve.")
198+
warns && @warn("This system can not be currently solved by `symbolic_solve`.")
199199
return nothing
200200
end
201201
end
@@ -276,7 +276,7 @@ function solve_univar(expression, x; dropmultiplicity=true)
276276
end
277277
end
278278

279-
subs, filtered_expr = filter_poly(expression, x)
279+
subs, filtered_expr, assumptions = filter_poly(expression, x, assumptions=true)
280280
coeffs, constant = polynomial_coeffs(filtered_expr, [x])
281281
degree = sdegree(coeffs, x)
282282

@@ -296,18 +296,28 @@ function solve_univar(expression, x; dropmultiplicity=true)
296296
append!(arr_roots, og_arr_roots)
297297
end
298298
end
299-
300-
return arr_roots
301299
end
302300

303301
if length(factors) != 1
304-
for factor in factors_subbed
305-
roots = solve_univar(factor, x, dropmultiplicity = dropmultiplicity)
302+
for i in eachindex(factors_subbed)
303+
if !any(isequal(x, var) for var in get_variables(factors[i]))
304+
continue
305+
end
306+
roots = solve_univar(factors_subbed[i], x, dropmultiplicity = dropmultiplicity)
306307
append!(arr_roots, roots)
307308
end
308309
end
309310

311+
for i in reverse(eachindex(arr_roots))
312+
for j in eachindex(assumptions)
313+
if isequal(substitute(assumptions[j], Dict(x=>arr_roots[i])), 0)
314+
deleteat!(arr_roots, i)
315+
end
316+
end
317+
end
318+
310319
if isequal(arr_roots, [])
320+
@assert check_polynomial(expression) "This expression could not be solved by `symbolic_solve`."
311321
return [RootsOf(wrap(expression), wrap(x))]
312322
end
313323

src/solver/postprocess.jl

+71-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Alex: make sure `Num`s are not processed here as they'd break it.
32
_postprocess_root(x) = x
43

@@ -32,30 +31,30 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
3231
!iscall(x) && return x
3332

3433
x = Symbolics.term(operation(x), map(_postprocess_root, arguments(x))...)
34+
oper = operation(x)
3535

3636
# sqrt(0), cbrt(0) => 0
3737
# sqrt(1), cbrt(1) => 1
38-
if iscall(x) &&
39-
(operation(x) === sqrt || operation(x) === cbrt || operation(x) === ssqrt ||
40-
operation(x) === scbrt)
38+
if (oper === sqrt || oper === cbrt || oper === ssqrt ||
39+
oper === scbrt)
4140
arg = arguments(x)[1]
4241
if isequal(arg, 0) || isequal(arg, 1)
4342
return arg
4443
end
4544
end
4645

4746
# (X)^0 => 1
48-
if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 0)
47+
if oper === (^) && isequal(arguments(x)[2], 0)
4948
return 1
5049
end
5150

5251
# (X)^1 => X
53-
if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 1)
52+
if oper === (^) && isequal(arguments(x)[2], 1)
5453
return arguments(x)[1]
5554
end
5655

5756
# sqrt((N / D)^2 * M) => N / D * sqrt(M)
58-
if iscall(x) && (operation(x) === sqrt || operation(x) === ssqrt)
57+
if (oper === sqrt || oper === ssqrt)
5958
function squarefree_decomp(x::Integer)
6059
square, squarefree = big(1), big(1)
6160
for (p, d) in collect(Primes.factor(abs(x)))
@@ -90,7 +89,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
9089
end
9190

9291
# (sqrt(N))^M => N^div(M, 2)*sqrt(N)^(mod(M, 2))
93-
if iscall(x) && operation(x) === (^)
92+
if oper === (^)
9493
arg1, arg2 = arguments(x)
9594
if iscall(arg1) && (operation(arg1) === sqrt || operation(arg1) === ssqrt)
9695
if arg2 isa Integer
@@ -105,6 +104,19 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
105104
end
106105
end
107106

107+
x = convert_consts(x)
108+
109+
if oper === (+)
110+
args = arguments(x)
111+
for arg in args
112+
if isequal(arg, 0)
113+
after_removing = setdiff(args, arg)
114+
isone(length(after_removing)) && return after_removing[1]
115+
return Symbolics.term(+, after_removing)
116+
end
117+
end
118+
end
119+
108120
return x
109121
end
110122

@@ -122,3 +134,54 @@ function postprocess_root(x)
122134
end
123135
x # unreachable
124136
end
137+
138+
139+
inv_exacts = [0, Symbolics.term(*, pi),
140+
Symbolics.term(/,pi,3),
141+
Symbolics.term(/, pi, 2),
142+
Symbolics.term(/, Symbolics.term(*, 2, pi), 3),
143+
Symbolics.term(/, pi, 6),
144+
Symbolics.term(/, Symbolics.term(*, 5, pi), 6),
145+
Symbolics.term(/, pi, 4)
146+
]
147+
inv_evald = Symbolics.symbolic_to_float.(inv_exacts)
148+
149+
const inv_pairs = collect(zip(inv_exacts, inv_evald))
150+
"""
151+
function convert_consts(x)
152+
This function takes BasicSymbolic terms as input (x) and attempts
153+
to simplify these basic symbolic terms using known values.
154+
Currently, this function only supports inverse trigonometric functions.
155+
156+
## Examples
157+
```jldoctest
158+
julia> Symbolics.convert_consts(Symbolics.term(acos, 0))
159+
π / 2
160+
161+
julia> Symbolics.convert_consts(Symbolics.term(atan, 0))
162+
0
163+
164+
julia> Symbolics.convert_consts(Symbolics.term(atan, 1))
165+
π / 4
166+
```
167+
"""
168+
function convert_consts(x)
169+
!iscall(x) && return x
170+
171+
oper = operation(x)
172+
inv_opers = [asin, acos, atan]
173+
174+
if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x))
175+
val = Symbolics.symbolic_to_float(x)
176+
for (exact, evald) in inv_pairs
177+
if isapprox(evald, val)
178+
return exact
179+
elseif isapprox(-evald, val)
180+
return -exact
181+
end
182+
end
183+
end
184+
185+
# add [sin, cos, tan] simplifications in the future?
186+
return x
187+
end

src/solver/preprocess.jl

+13-6
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,19 @@ function clean_f(filtered_expr, var, subs)
4040
unwrapped_f = unwrap(filtered_expr)
4141
!iscall(unwrapped_f) && return filtered_expr
4242
oper = operation(unwrapped_f)
43+
assumptions = []
4344

4445
if oper === (/)
4546
args = arguments(unwrapped_f)
4647
if any(isequal(var, x) for x in get_variables(args[2]))
47-
return filtered_expr
48+
filtered_expr = expand(args[1] * args[2])
49+
push!(assumptions, substitute(args[2], subs, fold=false))
50+
return filtered_expr, assumptions
4851
end
4952
filtered_expr = args[1]
5053
@info "Assuming $(substitute(args[2], subs, fold=false) != 0)"
5154
end
52-
return filtered_expr
55+
return filtered_expr, assumptions
5356
end
5457

5558
"""
@@ -238,15 +241,17 @@ julia> filter_poly((x+1)*term(log, 3), x)
238241
(Dict{Any, Any}(var"##247" => log(3)), var"##247"*(1 + x))
239242
```
240243
"""
241-
function filter_poly(og_expr, var)
244+
function filter_poly(og_expr, var; assumptions=false)
242245
expr = deepcopy(og_expr)
243246
expr = unwrap(expr)
244247
vars = get_variables(expr)
245248

246249
# handle edge cases
247250
if !isequal(vars, []) && isequal(vars[1], expr)
251+
assumptions && return Dict{Any, Any}(), expr, []
248252
return (Dict{Any, Any}(), expr)
249253
elseif isequal(vars, [])
254+
assumptions && return filter_stuff(expr), []
250255
return filter_stuff(expr)
251256
end
252257

@@ -256,14 +261,16 @@ function filter_poly(og_expr, var)
256261
# reassemble expr to avoid variables remembering original values issue and clean
257262
args = arguments(expr)
258263
oper = operation(expr)
259-
new_expr = clean_f(term(oper, args...), var, subs)
264+
new_expr, assum_array = clean_f(term(oper, args...), var, subs)
260265

266+
assumptions && return subs, new_expr, assum_array
261267
return subs, new_expr
262268
end
263-
function filter_poly(og_expr)
269+
270+
function filter_poly(og_expr; assumptions=false)
264271
new_var = gensym()
265272
new_var = (@variables $(new_var))[1]
266-
return filter_poly(og_expr, new_var)
273+
return filter_poly(og_expr, new_var; assumptions=assumptions)
267274
end
268275

269276

src/solver/solve_helpers.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function check_expr_validity(expr)
7878
valid_type = false
7979

8080
if type_expr <: Number || type_expr == Num || type_expr == SymbolicUtils.BasicSymbolic{Real} ||
81-
type_expr == Complex{Num} || type_expr == ComplexTerm{Real}
81+
type_expr == Complex{Num} || type_expr == ComplexTerm{Real} || type_expr == SymbolicUtils.BasicSymbolic{Complex{Real}}
8282
valid_type = true
8383
end
8484
iscall(unwrap(expr)) && @assert !hasderiv(unwrap(expr)) "Differential equations are not currently supported"

src/utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi
131131
string(nameof(arguments(oldop)[1]))
132132
elseif oldop == getindex
133133
args = arguments(O)
134-
opname = string(tosymbol(args[1]), "[", map(tosymbol, args[2:end])..., "]")
135-
return _Sym(symtype(O), Symbol(opname, d_separator, ds))
134+
opname = string(tosymbol(args[1]))
135+
return metadata(_Sym(symtype(args[1]), Symbol(opname, d_separator, ds)), metadata(args[1]))[args[2:end]...]
136136
elseif oldop isa Function
137137
return nothing
138138
else

0 commit comments

Comments
 (0)