1
-
2
1
# Alex: make sure `Num`s are not processed here as they'd break it.
3
2
_postprocess_root (x) = x
4
3
@@ -32,30 +31,30 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
32
31
! iscall (x) && return x
33
32
34
33
x = Symbolics. term (operation (x), map (_postprocess_root, arguments (x))... )
34
+ oper = operation (x)
35
35
36
36
# sqrt(0), cbrt(0) => 0
37
37
# sqrt(1), cbrt(1) => 1
38
- if iscall (x) &&
39
- (operation (x) === sqrt || operation (x) === cbrt || operation (x) === ssqrt ||
40
- operation (x) === scbrt)
38
+ if (oper === sqrt || oper === cbrt || oper === ssqrt ||
39
+ oper === scbrt)
41
40
arg = arguments (x)[1 ]
42
41
if isequal (arg, 0 ) || isequal (arg, 1 )
43
42
return arg
44
43
end
45
44
end
46
45
47
46
# (X)^0 => 1
48
- if iscall (x) && operation (x) === (^ ) && isequal (arguments (x)[2 ], 0 )
47
+ if oper === (^ ) && isequal (arguments (x)[2 ], 0 )
49
48
return 1
50
49
end
51
50
52
51
# (X)^1 => X
53
- if iscall (x) && operation (x) === (^ ) && isequal (arguments (x)[2 ], 1 )
52
+ if oper === (^ ) && isequal (arguments (x)[2 ], 1 )
54
53
return arguments (x)[1 ]
55
54
end
56
55
57
56
# sqrt((N / D)^2 * M) => N / D * sqrt(M)
58
- if iscall (x) && ( operation (x) === sqrt || operation (x) === ssqrt)
57
+ if (oper === sqrt || oper === ssqrt)
59
58
function squarefree_decomp (x:: Integer )
60
59
square, squarefree = big (1 ), big (1 )
61
60
for (p, d) in collect (Primes. factor (abs (x)))
@@ -90,7 +89,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
90
89
end
91
90
92
91
# (sqrt(N))^M => N^div(M, 2)*sqrt(N)^(mod(M, 2))
93
- if iscall (x) && operation (x) === (^ )
92
+ if oper === (^ )
94
93
arg1, arg2 = arguments (x)
95
94
if iscall (arg1) && (operation (arg1) === sqrt || operation (arg1) === ssqrt)
96
95
if arg2 isa Integer
@@ -105,6 +104,19 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
105
104
end
106
105
end
107
106
107
+ x = convert_consts (x)
108
+
109
+ if oper === (+ )
110
+ args = arguments (x)
111
+ for arg in args
112
+ if isequal (arg, 0 )
113
+ after_removing = setdiff (args, arg)
114
+ isone (length (after_removing)) && return after_removing[1 ]
115
+ return Symbolics. term (+ , after_removing)
116
+ end
117
+ end
118
+ end
119
+
108
120
return x
109
121
end
110
122
@@ -122,3 +134,54 @@ function postprocess_root(x)
122
134
end
123
135
x # unreachable
124
136
end
137
+
138
+
139
+ inv_exacts = [0 , Symbolics. term (* , pi ),
140
+ Symbolics. term (/ ,pi ,3 ),
141
+ Symbolics. term (/ , pi , 2 ),
142
+ Symbolics. term (/ , Symbolics. term (* , 2 , pi ), 3 ),
143
+ Symbolics. term (/ , pi , 6 ),
144
+ Symbolics. term (/ , Symbolics. term (* , 5 , pi ), 6 ),
145
+ Symbolics. term (/ , pi , 4 )
146
+ ]
147
+ inv_evald = Symbolics. symbolic_to_float .(inv_exacts)
148
+
149
+ const inv_pairs = collect (zip (inv_exacts, inv_evald))
150
+ """
151
+ function convert_consts(x)
152
+ This function takes BasicSymbolic terms as input (x) and attempts
153
+ to simplify these basic symbolic terms using known values.
154
+ Currently, this function only supports inverse trigonometric functions.
155
+
156
+ ## Examples
157
+ ```jldoctest
158
+ julia> Symbolics.convert_consts(Symbolics.term(acos, 0))
159
+ π / 2
160
+
161
+ julia> Symbolics.convert_consts(Symbolics.term(atan, 0))
162
+ 0
163
+
164
+ julia> Symbolics.convert_consts(Symbolics.term(atan, 1))
165
+ π / 4
166
+ ```
167
+ """
168
+ function convert_consts (x)
169
+ ! iscall (x) && return x
170
+
171
+ oper = operation (x)
172
+ inv_opers = [asin, acos, atan]
173
+
174
+ if any (isequal (oper, o) for o in inv_opers) && isempty (Symbolics. get_variables (x))
175
+ val = Symbolics. symbolic_to_float (x)
176
+ for (exact, evald) in inv_pairs
177
+ if isapprox (evald, val)
178
+ return exact
179
+ elseif isapprox (- evald, val)
180
+ return - exact
181
+ end
182
+ end
183
+ end
184
+
185
+ # add [sin, cos, tan] simplifications in the future?
186
+ return x
187
+ end
0 commit comments