Skip to content

Commit

Permalink
Adjust GPU functions according to the change JuliaGNSS#31
Browse files Browse the repository at this point in the history
  • Loading branch information
coezmaden committed Nov 14, 2021
1 parent 504afbb commit e5a8398
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/tracking_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ end

# CUDA dispatch
function TrackingState(
prn::Integer,
system::S,
carrier_doppler,
code_phase;
Expand Down Expand Up @@ -236,6 +237,7 @@ function TrackingState(
carrier = CarrierReplicaGPU(num_samples) #nothing
code = nothing
TrackingState{S, C, CALF, COLF, CN, typeof(downconverted_signal), typeof(carrier), Nothing}(
prn,
system,
carrier_doppler,
code_doppler,
Expand Down
5 changes: 2 additions & 3 deletions test/cuda/tracking_loop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ end
prn = 1
range = 0:3999
start_carrier_phase = π / 2
state = TrackingState(gpsl1, carrier_doppler - 20Hz, start_code_phase, num_samples = 4000)
state = TrackingState(prn, gpsl1, carrier_doppler - 20Hz, start_code_phase, num_samples = 4000)

carrier_phases = cu(2π .* carrier_doppler .* range ./ sampling_frequency .+ start_carrier_phase)
code_phases = get_code_frequency(gpsl1) / sampling_frequency .* range .+ start_code_phase
upsampled_code = gpsl1.codes[1 .+ mod.(floor.(Int, code_phases), 1023), prn]
signal = StructArray{ComplexF32}((cos.(carrier_phases), sin.(carrier_phases)))
@. signal *= upsampled_code

track_result = @inferred track(signal, state, prn, sampling_frequency)
track_result = @inferred track(signal, state, sampling_frequency)

iterations = 2000
code_phases = zeros(iterations)
Expand Down Expand Up @@ -147,7 +147,6 @@ end
track_result = @inferred track(
signal,
get_state(track_result),
prn,
sampling_frequency
)
comp_carrier_phase = mod2pi(2π * carrier_doppler * 4000 * (i + 1) /
Expand Down
4 changes: 2 additions & 2 deletions test/cuda/tracking_results.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "Tracking results GPU" begin
gpsl1 = GPSL1(use_gpu = Val(true))
results = Tracking.TrackingResults(
TrackingState(gpsl1, 100Hz, 100, num_samples = 2500),
TrackingState(1, gpsl1, 100Hz, 100, num_samples = 2500),
EarlyPromptLateCorrelator(NumAnts(2)),
SVector(-1, 0, 1),
1,
Expand Down Expand Up @@ -31,7 +31,7 @@
@test @inferred(get_cn0(results)) == 45dBHz

results = Tracking.TrackingResults(
TrackingState(gpsl1, 100Hz, 100, num_samples = 2500),
TrackingState(1, gpsl1, 100Hz, 100, num_samples = 2500),
EarlyPromptLateCorrelator(NumAnts(2)),
SVector(-1, 0, 1),
1,
Expand Down
9 changes: 6 additions & 3 deletions test/cuda/tracking_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
code_phase = 100
gpsl1 = GPSL1(use_gpu = Val(true))
num_samples = 2500
state = TrackingState(gpsl1, carrier_doppler, code_phase, num_samples = num_samples)
state = TrackingState(1, gpsl1, carrier_doppler, code_phase, num_samples = num_samples)

@test_throws UndefKeywordError TrackingState(gpsl1, carrier_doppler, code_phase)
@test_throws UndefKeywordError TrackingState(1, gpsl1, carrier_doppler, code_phase)

@test @inferred(Tracking.get_prn(state)) == 1
@test @inferred(Tracking.get_code_phase(state)) == 100
@test @inferred(Tracking.get_carrier_phase(state)) == 0.0
@test @inferred(Tracking.get_init_code_doppler(state)) == 100Hz / 1540
Expand All @@ -22,6 +23,7 @@
@test @inferred(Tracking.get_integrated_samples(state)) == 0

state = TrackingState(
1,
gpsl1,
carrier_doppler,
code_phase;
Expand All @@ -36,6 +38,7 @@
prompt_accumulator = zero(ComplexF64)
)

@test @inferred(Tracking.get_prn(state)) == 1
@test @inferred(Tracking.get_code_phase(state)) == 100
@test @inferred(Tracking.get_carrier_phase(state)) == 0.0
@test @inferred(Tracking.get_init_code_doppler(state)) == 100Hz / 1540
Expand All @@ -49,7 +52,7 @@
@test @inferred(Tracking.get_prompt_accumulator(state)) == 0.0
@test @inferred(Tracking.get_integrated_samples(state)) == 0

state = TrackingState(gpsl1, carrier_doppler, code_phase, num_samples = num_samples, num_ants = NumAnts(2))
state = TrackingState(1, gpsl1, carrier_doppler, code_phase, num_samples = num_samples, num_ants = NumAnts(2))
@test @inferred(Tracking.get_correlator(state)) == EarlyPromptLateCorrelator(NumAnts(2))

end

0 comments on commit e5a8398

Please sign in to comment.