-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] add formula prior #285
Comments
Check out #276 - I think this supports the use case? |
For posterity, here's how you would do it: objective = """
function my_custom_objective(tree, dataset::Dataset{T}, options) where {T<:Real}
# Require root node to be binary, so we can split it,
# otherwise return a large loss:
tree.degree != 2 && return T(10000)
f1 = tree.l
f2 = tree.r
# Evaluate f1:
f1_value, flag = eval_tree_array(f1, dataset.X, options)
!flag && return T(10000)
# Evaluate f2:
f2_value, r_flag = eval_tree_array(f2, dataset.X, options)
!flag && return T(10000)
# Impose functional form:
prediction = f1_value .+ f2_value
# See if x2 or x3 in an expression:
function contains_x2_x3(t)
if t.degree == 0
return !t.constant && t.feature in (2, 3)
elseif t.degree == 1
return contains_x2_x3(t.l)
else
return contains_x2_x3(t.l) || contains_x2_x3(t.r)
end
end
# See if x1 or x4 in an expression:
function contains_x1_x4(t)
if t.degree == 0
return !t.constant && t.feature in (1, 4)
elseif t.degree == 1
return contains_x1_x4(t.l)
else
return contains_x1_x4(t.l) || contains_x1_x4(t.r)
end
end
f1_violating = contains_x2_x3(f1)
f2_violating = contains_x1_x4(f2)
regularization = T(100) * f1_violating + T(100) * f2_violating
prediction_loss = sum((prediction .- dataset.y) .^ 2) / dataset.n
return prediction_loss + regularization
end
"""
model = PySRRegressor(
binary_operators=["*", "+", "-"],
full_objective=objective
) It won't completely constrain the expression to be of that form, because it can be good if the genetic algorithm can explore violating expressions. However the final expressions should be of that desired form, because the regularization should punish it. However you should also note that the returned expression and printed format will not have the form you specified. You will have to manually parse them into that form. In the future perhaps I could look at adding this but it will be a bit tricky to write it generally. Also, perhaps there's a way to make this easier in the future. e.g., you could just write out your desired functional form... will have to think about how to implement that. For now just write a custom objective like this. |
Thank you very much! The code and explanation are very clear! |
Awesome. Also note I just pushed a quick fix to that PR. If you tried it and it didn't work, if you try again it should hopefully work now. In the future I'll make it so you can specify a custom function for printing too. |
Let me know if there are any other issues. Cheers, |
For a given input variable
x1, x2, x3, x4
and labely
, I have a priori that the target formula has the form ofy=f1 (x1, x4)+f2(x2, x3)
, that is, the input variables are grouped, and the formula has a tree form constraint. Can pysr consider supporting this priori? Thanks.The text was updated successfully, but these errors were encountered: