Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace the call to sort by select in stdlib_stats_median #584

Merged
merged 7 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions doc/specs/stdlib_stats.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ and if `n` is an odd number, the median is:
median(array) = mean( array_sorted( floor( (n + 1) / 2.):floor( (n + 1) / 2.) + 1 ) )
```

The current implementation is a quite naive implementation that relies on sorting
the whole array, using the subroutine `[[stdlib_sorting(module):ord_sort(interface)]]`
provided by the `[[stdlib_sorting(module)]]` module.
The current implementation relies on a selection algorithm applied on a copy of
the whole array, using the subroutine `[[stdlib_selection(module):select(interface)]]`
provided by the `[[stdlib_selection(module)]]` module.

### Syntax

Expand All @@ -220,11 +220,11 @@ Generic subroutine

### Arguments

`array`: Shall be an array of type `integer` or `real`.
`array`: Shall be an array of type `integer` or `real`. It is an `intent(in)` argument.

`dim`: Shall be a scalar of type `integer` with a value in the range from 1 to `n`, where `n` is the rank of `array`.
`dim`: Shall be a scalar of type `integer` with a value in the range from 1 to `n`, where `n` is the rank of `array`. It is an `intent(in)` argument.

`mask` (optional): Shall be of type `logical` and either a scalar or an array of the same shape as `array`.
`mask` (optional): Shall be of type `logical` and either a scalar or an array of the same shape as `array`. It is an `intent(in)` argument.

### Return value

Expand Down
2 changes: 1 addition & 1 deletion src/Makefile.manual
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ stdlib_stats_mean.o: \
stdlib_stats_median.o: \
stdlib_optval.o \
stdlib_kinds.o \
stdlib_sorting.o \
stdlib_selection.o \
stdlib_stats.o
stdlib_stats_moment.o: \
stdlib_optval.o \
Expand Down
66 changes: 40 additions & 26 deletions src/stdlib_stats_median.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ submodule (stdlib_stats) stdlib_stats_median
use, intrinsic:: ieee_arithmetic, only: ieee_value, ieee_quiet_nan, ieee_is_nan
use stdlib_error, only: error_stop
use stdlib_optval, only: optval
! Use "ord_sort" rather than "sort" because the former can be much faster for arrays
! that are already partly sorted. While it is slightly slower for random arrays,
! ord_sort seems a better overall choice.
use stdlib_sorting, only: sort => ord_sort
use stdlib_selection, only: select
implicit none

contains
Expand All @@ -24,6 +21,7 @@ contains
real(${o1}$) :: res

integer(kind = int64) :: c, n
${t1}$ :: val, val1
${t1}$, allocatable :: x_tmp(:)

if (.not.optval(mask, .true.) .or. size(x) == 0) then
Expand All @@ -43,16 +41,18 @@ contains

x_tmp = reshape(x, [n])

call sort(x_tmp)
call select(x_tmp, c, val)

if (mod(n, 2_int64) == 0) then
val1 = minval(x_tmp(c+1:n)) !instead of call select(x_tmp, c+1, val1, left = c)
#:if t1[0] == 'r'
res = sum(x_tmp(c:c+1)) / 2._${o1}$
res = (val + val1) / 2._${o1}$
#:else
res = sum( real(x_tmp(c:c+1), kind=${o1}$) ) / 2._${o1}$
res = (real(val, kind=${o1}$) + &
real(val1, kind=${o1}$)) / 2._${o1}$
#:endif
else
res = x_tmp(c)
res = val
end if

end function ${name}$
Expand All @@ -74,6 +74,7 @@ contains
integer :: j${fj}$
#:endfor
#:endif
${t1}$ :: val, val1
${t1}$, allocatable :: x_tmp(:)

if (.not.optval(mask, .true.) .or. size(x) == 0) then
Expand Down Expand Up @@ -107,17 +108,18 @@ contains
end if
#:endif

call sort(x_tmp)
call select(x_tmp, c, val)

if (mod(n, 2) == 0) then
val1 = minval(x_tmp(c+1:n))
res${reduce_subvector('j', rank, fi)}$ = &
#:if t1[0] == 'r'
sum(x_tmp(c:c+1)) / 2._${o1}$
(val + val1) / 2._${o1}$
#:else
sum(real(x_tmp(c:c+1), kind=${o1}$) ) / 2._${o1}$
(real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$
#:endif
else
res${reduce_subvector('j', rank, fi)}$ = x_tmp(c)
res${reduce_subvector('j', rank, fi)}$ = val
end if
#:for fj in range(1, rank)
end do
Expand All @@ -141,6 +143,7 @@ contains
real(${o1}$) :: res

integer(kind = int64) :: c, n
${t1}$ :: val, val1
${t1}$, allocatable :: x_tmp(:)

if (any(shape(x) .ne. shape(mask))) then
Expand All @@ -156,21 +159,26 @@ contains

x_tmp = pack(x, mask)

call sort(x_tmp)

n = size(x_tmp, kind=int64)
c = floor( (n + 1) / 2._${o1}$, kind=int64)

if (n == 0) then
res = ieee_value(1._${o1}$, ieee_quiet_nan)
else if (mod(n, 2_int64) == 0) then
return
end if

c = floor( (n + 1) / 2._${o1}$, kind=int64)

call select(x_tmp, c, val)

if (mod(n, 2_int64) == 0) then
val1 = minval(x_tmp(c+1:n))
#:if t1[0] == 'r'
res = sum(x_tmp(c:c+1)) / 2._${o1}$
res = (val + val1) / 2._${o1}$
#:else
res = sum(real(x_tmp(c:c+1), kind=${o1}$)) / 2._${o1}$
res = (real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$
#:endif
else if (mod(n, 2_int64) == 1) then
res = x_tmp(c)
res = val
end if

end function ${name}$
Expand All @@ -192,6 +200,7 @@ contains
integer :: j${fj}$
#:endfor
#:endif
${t1}$ :: val, val1
${t1}$, allocatable :: x_tmp(:)

if (any(shape(x) .ne. shape(mask))) then
Expand Down Expand Up @@ -220,23 +229,28 @@ contains
end if
#:endif

call sort(x_tmp)

n = size(x_tmp, kind=int64)
c = floor( (n + 1) / 2._${o1}$, kind=int64 )

if (n == 0) then
res${reduce_subvector('j', rank, fi)}$ = &
ieee_value(1._${o1}$, ieee_quiet_nan)
else if (mod(n, 2_int64) == 0) then
return
end if

c = floor( (n + 1) / 2._${o1}$, kind=int64 )

call select(x_tmp, c, val)

if (mod(n, 2_int64) == 0) then
val1 = minval(x_tmp(c+1:n))
res${reduce_subvector('j', rank, fi)}$ = &
#:if t1[0] == 'r'
sum(x_tmp(c:c+1)) / 2._${o1}$
(val + val1) / 2._${o1}$
#:else
sum(real(x_tmp(c:c+1), kind=${o1}$)) / 2._${o1}$
(real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$
#:endif
else if (mod(n, 2_int64) == 1) then
res${reduce_subvector('j', rank, fi)}$ = x_tmp(c)
res${reduce_subvector('j', rank, fi)}$ = val
end if

deallocate(x_tmp)
Expand Down
6 changes: 0 additions & 6 deletions src/tests/stats/test_median.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,6 @@ contains
call check(error, median(d2odd_${k1}$), 1._dp&
, 'median(d2odd_${k1}$): uncorrect answer'&
, thr = dptol)
if (allocated(error)) return

call check(error, median(d2odd_${k1}$), 1._dp&
, 'median(d2odd_${k1}$): uncorrect answer'&
, thr = dptol)
if (allocated(error)) return

end subroutine

Expand Down