-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add active sampling capability within desirable feature and target re…
…gion
- Loading branch information
1 parent
ca7bb99
commit 2b32633
Showing
7 changed files
with
618 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
# default readout-column distance functionals | ||
""" | ||
QuadraticStandardizedDistance(; λ=1) | ||
Return an anonymous function `(x, col; prior) -> λ * (x .- col).^2 / (2*σ2)`, where `σ2` is the variance of `col` calculated with respect to `prior`. | ||
""" | ||
QuadraticStandardizedDistance(; λ = 1) = function (x, col; prior = ones(length(col))) | ||
σ2 = var(col, Weights(prior); corrected = false) | ||
|
||
return λ * (x .- col) .^ 2 / (2 * σ2) | ||
end | ||
|
||
""" | ||
DiscreteMetric(; λ=1) | ||
Return an anonymous function `(x, col) -> λ * (x .== col)`. | ||
""" | ||
DiscreteMetric(; λ = 1) = function (x, col; _...) | ||
return map(y -> y == x ? λ : 0.0, col) | ||
end | ||
|
||
# default similarity functional | ||
""" | ||
Exponential(; λ=1) | ||
Return an anonymous function `x -> exp(-λ * sum(x; init=0))`. | ||
""" | ||
Exponential(; λ = 1) = x -> exp(-λ * sum(x; init = 0)) | ||
|
||
# default uncertainty functionals | ||
compute_variance(data::AbstractVector; weights) = var(data, Weights(weights)) | ||
|
||
compute_variance(data; weights) = sum(var(Matrix(data), Weights(weights), 1)) | ||
|
||
""" | ||
Variance(data; prior) | ||
Return a function of `weights` that computes the percentage of variance in the data, compared to the variance calculated with respect to a specified `prior`. | ||
""" | ||
function Variance(data; prior) | ||
initial = compute_variance(data; weights = prior) | ||
return weights -> (compute_variance(data; weights) / initial) | ||
end | ||
|
||
function compute_entropy(labels; weights) | ||
aggregate_weights = collect(values(countmap(labels, Weights(weights)))) | ||
return entropy(aggregate_weights ./ sum(aggregate_weights)) | ||
end | ||
|
||
""" | ||
Entropy(labels; prior) | ||
Return a function of `weights` that computes the percentage of information entropy, compared to the entropy calculated with respect to a specified `prior`. | ||
""" | ||
function Entropy(labels; prior) | ||
@assert elscitype(labels) <: Multiclass "labels must be of `Multiclass` scitype, but `elscitype(labels)=$(elscitype(labels))`" | ||
initial = compute_entropy(labels; weights = prior) | ||
return (weights -> compute_entropy(labels; weights) / initial) | ||
end | ||
|
||
""" | ||
DistanceBased(data, target, uncertainty, similarity=Exponential(), distances=Dict(); prior=ones(nrow(data))) | ||
Compute distances between experimental evidence and historical readouts, and apply a 'similarity' functional to obtain probability mass for each row. | ||
# Return Value | ||
A named tuple with the following fields: | ||
- `sampler`: a function of `(evidence, features, rng)`, in which `evidence` denotes the current experimental evidence, `features` represent the set of features we want to sample from, and `rng` is a random number generator; it returns a dictionary mapping the features to outcomes. | ||
- `uncertainty`: a function of `evidence`; it returns the measure of variance or uncertainty about the target variable, conditioned on the experimental evidence acquired so far. | ||
- `weights`: a function of `evidence`; it returns probabilities (posterior) acrss the rows in `data`. | ||
# Arguments | ||
- `data`: a dataframe with historical data. | ||
- `target`: target column name or a vector of target columns names. | ||
- `uncertainty`: a function that takes the subdataframe containing columns in targets along with prior, and returns an anonymous function taking a single argument (a probability vector over observations) and returns an uncertainty measure over targets. | ||
- `similarity`: a function that, for each row, takes distances between `row[col]` and `readout[col]`, and returns a non-negative probability mass for the row. | ||
- `distances`: a dictionary of pairs `colname => similarity functional`, where a similarity functional must implement the signature `(readout, col; prior)`. Defaults to [`QuadraticStandardizedDistance`](@ref) and [`DiscreteMetric`](@ref) for `Continuous` and `Multiclass` scitypes, respectively. | ||
# Keyword Argumets | ||
- `prior`: prior across rows, uniform by default. | ||
# Example | ||
```julia | ||
(; sampler, uncertainty, weights) = | ||
DistanceBased(data, "HeartDisease", Entropy, Exponential(; λ = 5)); | ||
``` | ||
""" | ||
function DistanceBased_active( | ||
data::DataFrame, | ||
target, | ||
uncertainty, | ||
similarity = Exponential(), | ||
distances = Dict(); | ||
prior = ones(nrow(data)), | ||
desirable_range = Dict(), | ||
importance_sampling = false, | ||
target_constraints = Dict(), | ||
) | ||
distances = Dict( | ||
try | ||
if haskey(distances, colname) | ||
string(colname) => distances[colname] | ||
elseif elscitype(data[!, colname]) <: Continuous | ||
string(colname) => QuadraticStandardizedDistance() | ||
elseif elscitype(data[!, colname]) <: Multiclass | ||
string(colname) => DiscreteMetric() | ||
else | ||
error() | ||
end | ||
catch | ||
error( | ||
"""column $colname has scitype $(elscitype(data[!, colname])), which is not supported by default. | ||
Please provide a custom readout-column distances functional of the signature `(x, col; prior)`.""", | ||
) | ||
end for colname in names(data[!, Not(target)]) | ||
) | ||
|
||
prior = Weights(prior) | ||
targets = target isa AbstractVector ? target : [target] | ||
# Update compute_weights to consider desirable ranges and importance sampling | ||
compute_weights = function (evidence::Evidence) | ||
if isempty(evidence) | ||
return prior | ||
else | ||
array_distances = zeros((nrow(data), length(evidence))) | ||
for (i, colname) in enumerate(keys(evidence)) | ||
if colname ∈ targets | ||
continue | ||
else | ||
array_distances[:, i] .= | ||
distances[colname](evidence[colname], data[!, colname]; prior) | ||
end | ||
end | ||
|
||
similarities = | ||
prior .* | ||
map(i -> similarity(array_distances[i, :]), 1:size(array_distances, 1)) | ||
|
||
# hard match on target columns | ||
for colname in collect(keys(evidence)) ∩ targets | ||
similarities .*= data[!, colname] .== evidence[colname] | ||
end | ||
|
||
# Apply desirable range constraints | ||
for (colname, range) in desirable_range | ||
if colname in keys(evidence) | ||
within_range = | ||
(data[!, colname] .>= range[1]) .& (data[!, colname] .<= range[2]) | ||
similarities .*= within_range | ||
end | ||
end | ||
|
||
# Apply importance sampling if enabled | ||
if importance_sampling | ||
# Compute importance weights based on target constraints | ||
importance_weights = ones(nrow(data)) | ||
for (colname, constraint) in target_constraints | ||
if colname in keys(evidence) | ||
importance_weights .*= constraint(data[!, colname]) | ||
end | ||
end | ||
similarities .*= importance_weights | ||
end | ||
|
||
return Weights(similarities ./ sum(similarities)) | ||
end | ||
end | ||
|
||
sampler = function (evidence::Evidence, columns, rng = default_rng()) | ||
observed = data[sample(rng, compute_weights(evidence)), :] | ||
|
||
return Dict(c => observed[c] for c in columns) | ||
end | ||
|
||
f_uncertainty = uncertainty(data[!, target]; prior) | ||
compute_uncertainty = function (evidence::Evidence) | ||
return f_uncertainty(compute_weights(evidence)) | ||
end | ||
|
||
return (; sampler, uncertainty = compute_uncertainty, weights = compute_weights) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
using Test | ||
using DataFrames | ||
using CEED.GenerativeDesigns: DistanceBased_active, Evidence | ||
using ScientificTypes | ||
using CEED, CEED.GenerativeDesigns | ||
|
||
# Define the types for each column | ||
types = | ||
Dict(:A => Continuous, :B => Continuous, :Target1 => Multiclass, :Target2 => Continuous) | ||
|
||
# Sample data for testing with all numerical values | ||
data = DataFrame(; | ||
A = 1:10, | ||
B = 11:20, | ||
Target1 = rand(1:2, 10), | ||
Target2 = rand(1:10, 10), | ||
) | ||
|
||
# Coerce the data to the correct types | ||
data = coerce(data, types) | ||
|
||
# Define a dummy uncertainty function | ||
dummy_uncertainty(data; prior) = weights -> sum(weights) | ||
|
||
# Define a dummy similarity function | ||
dummy_similarity = x -> exp(-sum(x)) | ||
|
||
# Define target constraints for importance sampling with numerical conditions | ||
target_constraints = | ||
Dict("Target1" => x -> x .== 1 ? 2.0 : 1.0, | ||
"Target2" => x -> x .> 5 ? 1.5 : 1.0) | ||
|
||
# Define desirable ranges for each dimension | ||
desirable_range = Dict("A" => (3, 7), "B" => (15, 18)) | ||
# Create the DistanceBased function with the new features | ||
distance_based_result = DistanceBased_active( | ||
data, | ||
["Target1", "Target2"], | ||
dummy_uncertainty, | ||
dummy_similarity, | ||
Dict(); | ||
prior = ones(nrow(data)), | ||
desirable_range = desirable_range, | ||
importance_sampling = true, | ||
target_constraints = target_constraints, | ||
) | ||
|
||
# Test the weights computation considering the desirable range and importance sampling | ||
@testset "DistanceBased Function Tests" begin | ||
evidence = Evidence("A" => 5, "B" => 16) | ||
weights = distance_based_result.weights(evidence) | ||
|
||
# Check if weights are zero outside the desirable range | ||
@test all(weights[1:2] .== 0.0) | ||
@test all(weights[8:10] .== 0.0) | ||
|
||
# Check if weights are adjusted according to the target constraints | ||
@test weights[3] >= weights[4] # Assuming "Good" is more frequent in the first half | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.