Skip to content

Commit

Permalink
GSS over frequency not time
Browse files Browse the repository at this point in the history
  • Loading branch information
popcornell committed Feb 4, 2024
1 parent 5274343 commit 0b6d6b2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
32 changes: 15 additions & 17 deletions scripts/chime7/pipeline/gss_process/chime7_enhancers.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ def enhance_cuts(
break # succesfully processed the batch
except torch.cuda.OutOfMemoryError as e:
# try again with more chunks
logger.warning('OOM exception: %s', e)
#logger.warning('OOM exception: %s', e)
num_chunks = num_chunks + 1
logging.warning(
f'Out of memory error while processing the batch. Trying again with {num_chunks} chunks.'
)
except Exception as e:
logging.error(f'Error enhancing batch: {e}')
logging.error(f'Error enhancing batch: {e}, using channel 0 instead of enhanced signal.')
num_errors += 1
# Keep the original signal (only load channel 0)
x_hat = batch.audio[0].cpu().numpy()
Expand Down Expand Up @@ -369,29 +369,29 @@ def enhance_batch(self, audio, activity, speaker_id, num_chunks=1, left_context=
a_enc = activity_time_to_timefreq(activity, win_length=self.fft_length, hop_length=self.hop_length)

# processing is running in chunks
T = x_enc.size(-1)
chunk_size = int(math.ceil(T / num_chunks))
F = x_enc.size(-2)
chunk_size = int(math.ceil(F / num_chunks))

# run dereverb and gss on chunks
mask = []
for n in range(num_chunks):
n_start = n * chunk_size
n_end = min(T, (n + 1) * chunk_size)
n_end = min(F, (n + 1) * chunk_size)

x_enc_n = x_enc[..., n_start:n_end]
x_enc_n = x_enc[..., n_start:n_end, :]

# dereverb
x_enc_n, _ = self.dereverb(input=x_enc_n)
x_enc[..., n_start:n_end] = x_enc_n
x_enc[..., n_start:n_end, :] = x_enc_n

# mask estimator
mask_n = self.gss(x_enc_n, a_enc[..., n_start:n_end])
mask_n = self.gss(x_enc_n, a_enc)

# append mask to the list
mask.append(mask_n)

# concatenate estimated masks
mask = torch.concatenate(mask, dim=-1)
mask = torch.concatenate(mask, dim=-2)

# drop context
mask[..., :left_context_frames] = 0
Expand All @@ -406,26 +406,24 @@ def enhance_batch(self, audio, activity, speaker_id, num_chunks=1, left_context=
target_enc = []
for n in range(num_chunks):
n_start = n * chunk_size
n_end = min(T, (n + 1) * chunk_size)
n_end = min(F, (n + 1) * chunk_size)

# multichannel filter
target_enc_n, _ = self.mc(
input=x_enc[..., n_start:n_end],
mask=mask_target[..., n_start:n_end],
mask_undesired=mask_undesired[..., n_start:n_end],
input=x_enc[..., n_start:n_end, :],
mask=mask_target[..., n_start:n_end, :],
mask_undesired=mask_undesired[..., n_start:n_end, :],
)

# append target to the list
target_enc.append(target_enc_n)

# concatenate estimates
target_enc = torch.concatenate(target_enc, axis=-1)
target_enc = torch.concatenate(target_enc, axis=-2)
target, _ = self.synthesis(input=target_enc)

# drop context from the estimated audio
target = target[0].detach().cpu().numpy().squeeze()
target = target[0, 0].detach().cpu().numpy().squeeze()
target = target[left_context:]
if right_context > 0:
target = target[:-right_context]

return target
1 change: 1 addition & 0 deletions scripts/chime7/pipeline/gss_process/run_gss_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def convert_diar_results_to_falign(scenarios: list, diarization_dir: str, output
# Output of diarization is organized in 3 subdirectories, with each subdirectory corresponding to one scenario (chime6, dipco, mixer6)
diar_json_dir = os.path.join(diarization_dir, "pred_jsons_T0.55")


# assert len(scenario_dirs) == 3, f'Expected 3 subdirectories, found {len(scenario_dirs)}'
none_useful_fields = ['audio_filepath', 'words', 'text', 'duration', 'offset']
for scenario in scenarios:
Expand Down

0 comments on commit 0b6d6b2

Please sign in to comment.