-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add sherlock scripts * Allow setting hyperparams for synthetic comparison via command-line args. * Add birdsong run scripts for sherlock * Add scripts for speech runs on sherlock * fix speech.jl script for sherlock
- Loading branch information
Showing
6 changed files
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
|
||
using Revise | ||
import JLD | ||
import PyCall | ||
import CMF | ||
|
||
scipy_io = PyCall.pyimport("scipy.io") | ||
matfile = scipy_io.loadmat("../../cmf_data/MackeviciusData.mat") | ||
song = matfile["SONG"] | ||
|
||
L = 50 | ||
K = 3 | ||
|
||
algorithms = Dict( | ||
"HALS" => :hals, | ||
"Mult" => :mult, | ||
"ANLS" => :anls | ||
) | ||
max_time = 60 | ||
seed = sum([Int(c) for c in "INITIALIZE"]); | ||
|
||
res = Dict() | ||
|
||
# Warmstart algorithms | ||
for alg in keys(algorithms) | ||
CMF.fit_cnmf( | ||
song, K=K, L=L, | ||
alg=algorithms[alg], max_itr=1, max_time=Inf | ||
) | ||
end | ||
println("Finished warmstart") | ||
|
||
for alg in keys(algorithms) | ||
res[alg] = CMF.fit_cnmf( | ||
song, K=K, L=L, | ||
alg=algorithms[alg], max_itr=Inf, max_time=max_time, seed=seed | ||
) | ||
|
||
println("Finished ", alg) | ||
end | ||
|
||
JLD.save("songbird.jld", res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
# | ||
#SBATCH --job-name=test | ||
# | ||
#SBATCH --time=200:00 | ||
#SBATCH --ntasks=1 | ||
#SBATCH --cpus-per-task=2 | ||
#SBATCH --mem-per-cpu=16G | ||
|
||
ml julia/1.0.0 | ||
ml viz | ||
ml py-matplotlib | ||
srun julia birdsong.jl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
|
||
using CMF | ||
using PyPlot | ||
using DSP | ||
using WAV | ||
using JLD | ||
using ArgParse | ||
|
||
path = "../../cmf_data/ira_glass.wav" | ||
s, fs = WAV.wavread(path); | ||
|
||
# Downsample to 8 Khz | ||
fs_new = 8e3 | ||
p = Int(round(fs / fs_new)) | ||
s = s[1:p:end]; | ||
|
||
|
||
# Log transform spectrogram | ||
start_idx = 400 | ||
end_idx = 800 | ||
S = spectrogram(s[:,1], 512, 384; window=hanning) | ||
t = time(S) | ||
f = freq(S) | ||
data = log10.(S.power) | ||
data = data .+ abs(minimum(data)) | ||
|
||
# Parse commandline arguments to allow specifying model size | ||
s = ArgParseSettings() | ||
@add_arg_table s begin | ||
"--alg" | ||
help = "Alg to use for speech fit" | ||
arg_type = String | ||
default = "mult" | ||
end | ||
parsed_args = parse_args(ARGS, s) | ||
|
||
alg = Symbol(parsed_args["alg"]) | ||
|
||
# fit once to compile | ||
CMF.fit_cnmf(data; L=12, K=20, | ||
alg=alg, max_itr=1, max_time=6000, | ||
check_convergence=false | ||
) | ||
|
||
results = CMF.fit_cnmf(data; L=12, K=20, | ||
alg=alg, max_itr=Inf, max_time=6000, | ||
check_convergence=false | ||
) | ||
|
||
|
||
alg_str = parsed_args["alg"] | ||
JLD.save("speech_run_$alg_str.jld", "results", results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/bin/bash | ||
# | ||
#SBATCH --job-name=test | ||
# | ||
#SBATCH --time=200:00 | ||
#SBATCH --ntasks=1 | ||
#SBATCH --cpus-per-task=2 | ||
#SBATCH --mem-per-cpu=16G | ||
|
||
ml julia/1.0.0 | ||
ml viz | ||
ml py-matplotlib | ||
sbatch julia speech.jl --alg "mult" | ||
sbatch julia speech.jl --alg "anls" | ||
sbatch julia speech.jl --alg "hals" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
|
||
import CMF | ||
import PyPlot; plt = PyPlot; | ||
import Random; Random.seed!(0); | ||
import JLD | ||
using ArgParse | ||
using Dates | ||
using Random | ||
include("../datasets/synthetic.jl") | ||
|
||
|
||
# Parse commandline arguments to allow specifying model size | ||
s = ArgParseSettings() | ||
@add_arg_table s begin | ||
"--N" | ||
help = "Number of features" | ||
arg_type = Int | ||
default = 250 | ||
"--K" | ||
help = "Number of components" | ||
arg_type = Int | ||
default = 5 | ||
"--L" | ||
help = "Lag" | ||
arg_type = Int | ||
default = 20 | ||
"--alpha" | ||
help = "Dirichlet param alpha for generating data" | ||
arg_type = Float64 | ||
default = 0.1 | ||
"--sigma" | ||
help = "Dirichlet param sigma for generating data" | ||
arg_type = Float64 | ||
default = 0.2 | ||
"--p_h" | ||
help = "Probability of nonzero h" | ||
arg_type = Float64 | ||
default = 0.1 | ||
"--noise_scale" | ||
help = "Noise scale for synthetic data" | ||
arg_type = Float64 | ||
default = 0.1 | ||
end | ||
parsed_args = parse_args(ARGS, s) | ||
|
||
|
||
# Model params | ||
N = parsed_args["N"] | ||
K = parsed_args["K"] | ||
L = parsed_args["L"] | ||
|
||
# Synthetic data params | ||
alpha = parsed_args["alpha"] | ||
p_h = parsed_args["p_h"] | ||
sigma = parsed_args["sigma"] | ||
noise_scale = parsed_args["noise_scale"] | ||
|
||
T_list = [500, 2500, 10_000, 50_000] | ||
runtimes = Dict( | ||
500 => 60, | ||
2500 => 120, | ||
10_000 => 400, | ||
50_000 => 1000 | ||
) | ||
|
||
alg_list = [:hals, :mult, :anls] | ||
labels = Dict(:hals => "HALS", :mult => "MULT", :anls => "ANLS") | ||
|
||
results = Dict() | ||
results["args"] = parsed_args | ||
|
||
data_list = Dict() | ||
|
||
for T in T_list | ||
data, trueW, trueH = synthetic_sequences( | ||
K=K, N=N, L=L, T=T, | ||
alpha=alpha, | ||
p_h=p_h, | ||
sigma=sigma, | ||
noise_scale=noise_scale | ||
) | ||
data_list[T] = (data, trueW, trueH) | ||
end | ||
|
||
# Run once to warmup algorithms | ||
for T in [500] | ||
for alg in alg_list | ||
CMF.fit_cnmf(data_list[T][1], alg=alg, max_itr=1) | ||
end | ||
end | ||
|
||
for T in T_list | ||
results[T] = Dict() | ||
|
||
for alg in alg_list | ||
# use same initialization for all algs | ||
Random.seed!(0) | ||
|
||
results[T][alg] = CMF.fit_cnmf( | ||
data_list[T][1], alg=alg, | ||
K=K, L=L, max_time=runtimes[T], max_itr=Inf | ||
) | ||
end | ||
end | ||
|
||
time_now = string(now()) | ||
JLD.save("./synthetic_comparison_$time_now.jld", "results", results) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
# | ||
#SBATCH --job-name=test | ||
# | ||
#SBATCH --time=300:00 | ||
#SBATCH --ntasks=1 | ||
#SBATCH --cpus-per-task=2 | ||
#SBATCH --mem-per-cpu=16G | ||
|
||
ml julia/1.0.0 | ||
ml viz | ||
ml py-matplotlib | ||
srun julia synthetic_comparison.jl |