Skip to content

Commit

Permalink
srnn wip
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinedaurat committed Jan 9, 2025
1 parent 932083c commit eb61251
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions mimikit/networks/sample_rnn_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def from_config(cls, config: "SampleRNN.Config") -> "SampleRNN":
# only one input module supported
spec_input_module = config.io_spec.inputs[0].module
for i, fs in enumerate(config.frame_sizes[:-1]):
if isinstance(spec_input_module, FramedIO) and i == 0: # only the top-tier has no proj of the input
if i == 0: # the top-tier never has proj of the input
input_module = FramedIO() \
.set(class_size=spec_input_module.class_size, frame_size=fs, hop_length=fs).module()
in_dim = fs
Expand Down Expand Up @@ -230,9 +230,16 @@ def from_config(cls, config: "SampleRNN.Config") -> "SampleRNN":
else 1)
)]

modules = [spec_input_module.copy()
.set(frame_size=config.frame_sizes[-1],
hop_length=1, out_dim=h_dim, h_dim=config.embedding_dim).module()]
if isinstance(config.frame_sizes[-1], tuple):
# TODO: would be nice! needs support in batch_items, generate...
modules = [spec_input_module.copy()
.set(frame_size=fs,
hop_length=1, out_dim=h_dim, h_dim=config.embedding_dim).module()
for fs in config.frame_sizes[-1]]
else:
modules = [spec_input_module.copy()
.set(frame_size=config.frame_sizes[-1],
hop_length=1, out_dim=h_dim, h_dim=config.embedding_dim).module()]
input_module = ZipReduceVariables(mode=config.inputs_mode, modules=modules)
tiers += [
SampleRNNTier(
Expand Down

0 comments on commit eb61251

Please sign in to comment.