Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add formula prior #285

Closed
GCaptainNemo opened this issue Mar 24, 2023 · 5 comments
Closed

[Feature] add formula prior #285

GCaptainNemo opened this issue Mar 24, 2023 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@GCaptainNemo
Copy link
Contributor

For a given input variable x1, x2, x3, x4 and label y, I have a priori that the target formula has the form of y=f1 (x1, x4)+f2(x2, x3), that is, the input variables are grouped, and the formula has a tree form constraint. Can pysr consider supporting this priori? Thanks.

@GCaptainNemo GCaptainNemo added the enhancement New feature or request label Mar 24, 2023
@MilesCranmer
Copy link
Owner

Check out #276 - I think this supports the use case?

@MilesCranmer
Copy link
Owner

For posterity, here's how you would do it:

objective = """
function my_custom_objective(tree, dataset::Dataset{T}, options) where {T<:Real}
    # Require root node to be binary, so we can split it,
    # otherwise return a large loss:
    tree.degree != 2 && return T(10000)

    f1 = tree.l
    f2 = tree.r

    # Evaluate f1:
    f1_value, flag = eval_tree_array(f1, dataset.X, options)
    !flag && return T(10000)

    # Evaluate f2:
    f2_value, r_flag = eval_tree_array(f2, dataset.X, options)
    !flag && return T(10000)

    # Impose functional form:
    prediction = f1_value .+ f2_value

    # See if x2 or x3 in an expression:
    function contains_x2_x3(t)
        if t.degree == 0
            return !t.constant && t.feature in (2, 3)
        elseif t.degree == 1
            return contains_x2_x3(t.l)
        else
            return contains_x2_x3(t.l) || contains_x2_x3(t.r)
        end
    end

    # See if x1 or x4 in an expression:
    function contains_x1_x4(t)
        if t.degree == 0
            return !t.constant && t.feature in (1, 4)
        elseif t.degree == 1
            return contains_x1_x4(t.l)
        else
            return contains_x1_x4(t.l) || contains_x1_x4(t.r)
        end
    end

    f1_violating = contains_x2_x3(f1)
    f2_violating = contains_x1_x4(f2)

    regularization = T(100) * f1_violating + T(100) * f2_violating

    prediction_loss = sum((prediction .- dataset.y) .^ 2) / dataset.n
    return prediction_loss + regularization
end
"""

model = PySRRegressor(
    binary_operators=["*", "+", "-"],
    full_objective=objective
)

It won't completely constrain the expression to be of that form, because it can be good if the genetic algorithm can explore violating expressions. However the final expressions should be of that desired form, because the regularization should punish it.

However you should also note that the returned expression and printed format will not have the form you specified. You will have to manually parse them into that form. In the future perhaps I could look at adding this but it will be a bit tricky to write it generally.

Also, perhaps there's a way to make this easier in the future. e.g., you could just write out your desired functional form... will have to think about how to implement that. For now just write a custom objective like this.

@GCaptainNemo
Copy link
Contributor Author

Thank you very much! The code and explanation are very clear!

@MilesCranmer
Copy link
Owner

MilesCranmer commented Mar 25, 2023

Awesome. Also note I just pushed a quick fix to that PR. If you tried it and it didn't work, if you try again it should hopefully work now.

In the future I'll make it so you can specify a custom function for printing too.

@MilesCranmer
Copy link
Owner

Let me know if there are any other issues. Cheers,
Miles

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants