Skip to content

Commit

Permalink
mutate(rowwise_df) makes a rowwise_df. closes #463
Browse files Browse the repository at this point in the history
  • Loading branch information
romainfrancois committed Jun 18, 2014
1 parent b76bdad commit d8de133
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 6 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

* The `AsIs` class is white listed (#453).

* `mutate` makes a `rowwise_df` when given a `rowwise_df` (#463).

# dplyr 0.2

## Piping
Expand Down
9 changes: 8 additions & 1 deletion inst/include/dplyr/tbl_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@ namespace dplyr {
Rcpp::IntegerVector::get_na(), -n) ;
}

template <typename Data>
inline Rcpp::CharacterVector classes_grouped(){
return Rcpp::CharacterVector::create( "grouped_df", "tbl_df", "tbl", "data.frame") ;
}

template <>
inline Rcpp::CharacterVector classes_grouped<RowwiseDataFrame>(){
return Rcpp::CharacterVector::create( "rowwise_df", "tbl_df", "tbl", "data.frame") ;
}

inline Rcpp::CharacterVector classes_not_grouped(){
return Rcpp::CharacterVector::create( "tbl_df", "tbl", "data.frame") ;
}
Expand Down Expand Up @@ -39,7 +46,7 @@ namespace dplyr {
set_rownames(data, nr ) ;

if( source.nvars() > 1){
data.attr( "class" ) = classes_grouped() ;
data.attr( "class" ) = classes_grouped<Data>() ;
List vars = source.data().attr("vars") ;
vars.erase( source.nvars() - 1) ;
data.attr( "vars") = vars ;
Expand Down
4 changes: 2 additions & 2 deletions src/dplyr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ SEXP integer_filter_grouped(GroupedDataFrame gdf, const List& args, const DataDo

}

DataFrame res = subset( data, indx, names, classes_grouped() ) ;
DataFrame res = subset( data, indx, names, classes_grouped<GroupedDataFrame>() ) ;
res.attr( "vars") = data.attr("vars") ;

return res ;
Expand Down Expand Up @@ -1493,7 +1493,7 @@ SEXP mutate_grouped(const DataFrame& df, List args, const DataDots& dots){
accumulator.set( name, variable) ;
}

return structure_mutate(accumulator, df, classes_grouped() );
return structure_mutate(accumulator, df, classes_grouped<Data>() );
}

SEXP mutate_not_grouped(DataFrame df, List args, const DataDots& dots){
Expand Down
6 changes: 3 additions & 3 deletions src/filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ DataFrame filter_grouped_single_env( const GroupedDataFrame& gdf, const List& ar
}
}
}
DataFrame res = subset( data, test, names, classes_grouped() ) ;
DataFrame res = subset( data, test, names, classes_grouped<GroupedDataFrame>() ) ;
res.attr( "vars") = data.attr("vars") ;

return res ;
Expand Down Expand Up @@ -157,7 +157,7 @@ DataFrame filter_grouped_multiple_env( const GroupedDataFrame& gdf, const List&
}
}
}
DataFrame res = subset( data, test, names, classes_grouped() ) ;
DataFrame res = subset( data, test, names, classes_grouped<GroupedDataFrame>() ) ;
res.attr( "vars") = data.attr("vars") ;

return res ;
Expand Down Expand Up @@ -255,7 +255,7 @@ SEXP filter_impl( DataFrame df, List args, Environment env){
if( what[0] == TRUE ){
return df ;
} else {
return empty_subset( df, df.names(), is<GroupedDataFrame>(df) ? classes_grouped() : classes_not_grouped() ) ;
return empty_subset( df, df.names(), is<GroupedDataFrame>(df) ? classes_grouped<GroupedDataFrame>() : classes_not_grouped() ) ;
}
}
}
Expand Down
17 changes: 17 additions & 0 deletions tests/testthat/test-mutate.r
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,20 @@ test_that("mutate remove variables with = NULL syntax (#462)", {
expect_false( "cyl" %in% names(data) )
})

test_that("mutate(rowwise_df) makes a rowwise_df (#463)", {
one_mod <- data.frame(grp = "a", x = runif(5,0,1)) %>%
tbl_df %>%
mutate(y = rnorm(x,x*2,1)) %>%
group_by(grp) %>%
do(mod = lm(y~x,data = .))

out <- one_mod %>%
mutate(rsq = summary(mod)$r.squared) %>%
mutate(aic = AIC(mod))

expect_is(out, "rowwise_df")
expect_equal(nrow(out), 1L)
expect_is(out$mod, "list")
expect_is(out$mod[[1L]], "lm" )
})

0 comments on commit d8de133

Please sign in to comment.