Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Support fitting arbitrary StatisticalModels with DataFrames #571

Merged
merged 3 commits into from
Mar 31, 2014
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/DataFrames.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ include(joinpath("dataframe", "show.jl"))
include(joinpath("dataframe", "join.jl"))
include(joinpath("groupeddataframe", "grouping.jl"))
include(joinpath("dataframe", "reshape.jl"))
include(joinpath("formula", "formula.jl"))
include(joinpath("statsmodels", "formula.jl"))
include(joinpath("statsmodels", "statsmodel.jl"))
include(joinpath("dataframe", "io.jl"))
include("RDA.jl")
include("deprecated.jl")
Expand Down
File renamed without changes.
91 changes: 91 additions & 0 deletions src/statsmodels/statsmodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
##############################################################################
#
# A macro for doing delegation
#
# This macro call
#
# @delegate MyContainer.elems [:size, :length, :ndims, :endof]
#
# produces this block of expressions
#
# size(a::MyContainer) = size(a.elems)
# length(a::MyContainer) = length(a.elems)
# ndims(a::MyContainer) = ndims(a.elems)
# endof(a::MyContainer) = endof(a.elems)
#
##############################################################################

macro delegate(source, targets)
typename = esc(source.args[1])
fieldname = esc(Expr(:quote, source.args[2].args[1]))
funcnames = targets.args
n = length(funcnames)
result = quote begin end end
for i in 1:n
funcname = esc(funcnames[i])
f = quote
($funcname)(a::($typename), args...) = ($funcname)(a.($fieldname), args...)
end
push!(result.args[2].args, f)
end
return result
end

# Wrappers for DataFrameStatisticalModel and DataFrameRegressionModel
immutable DataFrameStatisticalModel{M,T} <: StatisticalModel
model::M
mf::ModelFrame
mm::ModelMatrix{T}
end

immutable DataFrameRegressionModel{M,T} <: RegressionModel
model::M
mf::ModelFrame
mm::ModelMatrix{T}
end

for (modeltype, dfmodeltype) in ((:StatisticalModel, DataFrameStatisticalModel),
(:RegressionModel, DataFrameRegressionModel))
@eval begin
function StatsBase.fit{T<:$modeltype}(::Type{T}, f::Formula, df::AbstractDataFrame,
args...; kwargs...)
mf = ModelFrame(f, df)
mm = ModelMatrix(mf)
y = model_response(mf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This raises the question of how we should handle unsupervised methods that won't take a y input. R does this often with formulas that have a . on the left-hand side. Not sure we need that, but seems worth thinking about.

$dfmodeltype(fit(T, mm.m, y, args...; kwargs...), mf, mm)
end
end
end

# Delegate functions from StatsBase that use our new types
typealias DataFrameModels Union(DataFrameStatisticalModel, DataFrameRegressionModel)
@delegate DataFrameModels.model [StatsBase.coef, StatsBase.confint, StatsBase.deviance,
StatsBase.loglikelihood, StatsBase.nobs, StatsBase.stderr,
StatsBase.vcov]
@delegate DataFrameRegressionModel.model [StatsBase.residuals, StatsBase.model_response,
StatsBase.predict, StatsBase.predict!]

# coeftable implementation
function StatsBase.coeftable(model::DataFrameModels)
ct = coeftable(model.model)
cfnames = coefnames(model.mf)
if length(ct.rownms) == length(cfnames)
ct.rownms = cfnames
end
ct
end

# show function that delegates to coeftable
function Base.show(io::IO, model::DataFrameModels)
try
ct = coeftable(model)
println(io, "$(typeof(model)):\n\nCoefficients:")
show(io, ct)
catch e
if isa(e, String) && beginswith(e, "coeftable is not defined")
show(io, model.model)
else
rethrow(e)
end
end
end