-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_from_unconditional_model.py
172 lines (132 loc) · 5.65 KB
/
generate_from_unconditional_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import argparse
import os
import time
import note_seq
import numpy as np
from magenta.models.score2perf import score2perf
from midi2audio import FluidSynth
from tensor2tensor.utils import decoding, trainer_lib
from commons import SAMPLE_RATE, SF2_PATH, decode
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_dir", type=str, required=True, help="Midi output directory.")
parser.add_argument(
"-p",
"--primer",
type=str,
help="Midi file path for priming if not provided model will generate sample without priming.",
)
parser.add_argument(
"--max_primer_seconds",
type=int,
default=20,
help="Maximum number of time in seconds for priming.",
)
args = parser.parse_args()
class PianoPerformanceLanguageModelProblem(score2perf.Score2PerfProblem):
@property
def add_eos_symbol(self):
return True
# Create input generator (so we can adjust priming and
# decode length on the fly).
def input_generator(targets, decode_length):
while True:
yield {
"targets": np.array([targets], dtype=np.int32),
"decode_length": np.array(decode_length, dtype=np.int32),
}
def generate_ns_from_scratch(estimator, ckpt_path, unconditional_encoders):
targets = []
decode_length = 1024
# Start the Estimator, loading from the specified checkpoint.
input_fn = decoding.make_input_fn_from_generator(input_generator(targets, decode_length))
unconditional_samples = estimator.predict(input_fn, checkpoint_path=ckpt_path)
# Generate sample events.
sample_ids = next(unconditional_samples)["outputs"]
# Decode to NoteSequence.
midi_filename = decode(sample_ids, encoder=unconditional_encoders["targets"])
unconditional_ns = note_seq.midi_file_to_note_sequence(midi_filename)
return unconditional_ns
def generate_ns_continuation(primer_ns, estimator, ckpt_path, unconditional_encoders):
targets = unconditional_encoders["targets"].encode_note_sequence(primer_ns)
# Remove the end token from the encoded primer.
targets = targets[:-1]
decode_length = max(0, 4096 - len(targets))
if len(targets) >= 4096:
print("Primer has more events than maximum sequence length; nothing will be generated.")
# Start the Estimator, loading from the specified checkpoint.
input_fn = decoding.make_input_fn_from_generator(input_generator(targets, decode_length))
unconditional_samples = estimator.predict(input_fn, checkpoint_path=ckpt_path)
# Generate sample events.
sample_ids = next(unconditional_samples)["outputs"]
# Decode to NoteSequence.
midi_filename = decode(sample_ids, encoder=unconditional_encoders["targets"])
ns = note_seq.midi_file_to_note_sequence(midi_filename)
# Append continuation to primer.
continuation_ns = note_seq.concatenate_sequences([primer_ns, ns])
return continuation_ns
def main():
if args.primer is not None:
if not os.path.isfile(args.primer):
raise ValueError(f"'{args.primer}' is not a file path.")
if args.max_primer_seconds <= 0:
raise ValueError("Max primer seconds must be > 0.")
model_name = "transformer"
hparams_set = "transformer_tpu"
ckpt_path = "checkpoints/unconditional_model_16.ckpt"
problem = PianoPerformanceLanguageModelProblem()
unconditional_encoders = problem.get_feature_encoders()
# Set up HParams.
hparams = trainer_lib.create_hparams(hparams_set=hparams_set)
trainer_lib.add_problem_hparams(hparams, problem)
hparams.num_hidden_layers = 16
hparams.sampling_method = "random"
# Set up decoding HParams.
decode_hparams = decoding.decode_hparams()
decode_hparams.alpha = 0.0
decode_hparams.beam_size = 1
# Create Estimator.
run_config = trainer_lib.create_run_config(hparams)
estimator = trainer_lib.create_estimator(model_name, hparams, run_config, decode_hparams=decode_hparams)
if args.primer is None:
generated_ns = generate_ns_from_scratch(
estimator,
ckpt_path,
unconditional_encoders,
)
else:
# Use one of the provided primers.
primer_ns = note_seq.midi_file_to_note_sequence(args.primer)
# Handle sustain pedal in the primer.
primer_ns = note_seq.apply_sustain_control_changes(primer_ns)
# Trim to desired number of seconds.
if primer_ns.total_time > args.max_primer_seconds:
print("Primer is longer than %d seconds, truncating." % args.max_primer_seconds)
primer_ns = note_seq.extract_subsequence(primer_ns, 0, args.max_primer_seconds)
# Remove drums from primer if present.
if any(note.is_drum for note in primer_ns.notes):
print("Primer contains drums; they will be removed.")
notes = [note for note in primer_ns.notes if not note.is_drum]
del primer_ns.notes[:]
primer_ns.notes.extend(notes)
# Set primer instrument and program.
for note in primer_ns.notes:
note.instrument = 1
note.program = 0
generated_ns = generate_ns_continuation(
primer_ns,
estimator,
ckpt_path,
unconditional_encoders,
)
# Write to output_dir.
stem = f"unconditional_{time.strftime('%Y-%m-%d_%H%M%S')}"
os.makedirs(args.output_dir, exist_ok=True)
note_seq.note_sequence_to_midi_file(generated_ns, os.path.join(args.output_dir, f"{stem}.mid"))
# Convert midi file to wave file.
fs = FluidSynth(SF2_PATH, SAMPLE_RATE)
fs.midi_to_audio(
os.path.join(args.output_dir, f"{stem}.mid"),
os.path.join(args.output_dir, f"{stem}.wav"),
)
if __name__ == "__main__":
main()