Skip to content

Commit

Permalink
Merge heh and zhehuai's initial version of frozen am+llm (#5)
Browse files Browse the repository at this point in the history
* Merge heh and zhehuai's initial version of frozen am+llm

The previous differences are summarized here:
https://docs.google.com/document/d/1zNI4hC6vJtUfcHbrUSPaMuYWRBQdN_36H0P2NiBiuPY/edit

This PR includes
1. Finish merging the model, dataset, and config code
2. Previous tests are still enabled and passed (prepare_llm_input, training_step,
    validation_step)
3. the example training script with LS960 has been run to make sure the training
pipeline works

The major remaining works are listed here
https://docs.google.com/document/d/1o0AM7v4gcTQkPZjE0Vl9TTX4vYnGTrbXEFGWh0UhGlk/edit#bookmark=id.pzvdadt5oxyw

---------

Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
  • Loading branch information
zhehuaichen and stevehuang52 committed Oct 4, 2023
1 parent c7e5774 commit 11eb478
Show file tree
Hide file tree
Showing 19 changed files with 2,625 additions and 603 deletions.
294 changes: 233 additions & 61 deletions examples/multimodel/conf/speechllm/modularized_speech_gpt_config.yaml

Large diffs are not rendered by default.

194 changes: 194 additions & 0 deletions nemo/collections/common/parts/preprocessing/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,200 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs):
)


class AudioQuestAns(_Collection):
"""List of audio-transcript text correspondence with preprocessing."""

OUTPUT_TYPE = collections.namedtuple(
typename='AudioQAEntity', field_names='id audio_file duration question answer offset speaker orig_sr lang',
)

def __init__(
self,
ids: List[int],
audio_files: List[str],
durations: List[float],
questions: List[str],
answers: List[str],
offsets: List[str],
speakers: List[Optional[int]],
orig_sampling_rates: List[Optional[int]],
langs: List[Optional[str]],
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
max_number: Optional[int] = None,
do_sort_by_duration: bool = False,
index_by_file_id: bool = False,
):
"""Instantiates audio-question-answer manifest with filters and preprocessing.
Args:
ids: List of examples positions.
audio_files: List of audio files.
durations: List of float durations.
questions: List of raw text transcripts.
answers: List of raw text transcripts.
offsets: List of duration offsets or None.
speakers: List of optional speakers ids.
orig_sampling_rates: List of original sampling rates of audio files.
langs: List of language ids, one for eadh sample, or None.
min_duration: Minimum duration to keep entry with (default: None).
max_duration: Maximum duration to keep entry with (default: None).
max_number: Maximum number of samples to collect.
do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id.
index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
"""

output_type = self.OUTPUT_TYPE
data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0
if index_by_file_id:
self.mapping = {}

for id_, audio_file, duration, offset, question, answer, speaker, orig_sr, lang in zip(
ids, audio_files, durations, offsets, questions, answers, speakers, orig_sampling_rates, langs
):
# Duration filters.
if min_duration is not None and duration < min_duration:
duration_filtered += duration
num_filtered += 1
continue

if max_duration is not None and duration > max_duration:
duration_filtered += duration
num_filtered += 1
continue

if answer is None:
duration_filtered += duration
num_filtered += 1
continue

total_duration += duration

data.append(output_type(id_, audio_file, duration, question, answer, offset, speaker, orig_sr, lang))
if index_by_file_id:
file_id, _ = os.path.splitext(os.path.basename(audio_file))
if file_id not in self.mapping:
self.mapping[file_id] = []
self.mapping[file_id].append(len(data) - 1)

# Max number of entities filter.
if len(data) == max_number:
break

if do_sort_by_duration:
if index_by_file_id:
logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.")
else:
data.sort(key=lambda entity: entity.duration)

logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600)
logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600)

super().__init__(data)


class ALMAudioQA(AudioQuestAns):
"""`AudioQuestAns` collector from audio-LM json files."""

def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs):
"""Parse lists of audio files, durations and transcripts texts.
Args:
manifests_files: Either single string file or list of such -
manifests to yield items from.
*args: Args to pass to `AudioText` constructor.
**kwargs: Kwargs to pass to `AudioText` constructor.
"""

ids, audio_files, durations, questions, answers, offsets, = (
[],
[],
[],
[],
[],
[],
)
speakers, orig_srs, langs = (
[],
[],
[],
)
for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item):
ids.append(item['id'])
audio_files.append(item['audio_file'])
durations.append(item['duration'])
questions.append(item['question'])
answers.append(item['answer'])
offsets.append(item['offset'])
speakers.append(item['speaker'])
orig_srs.append(item['orig_sr'])
langs.append(item['lang'])
super().__init__(
ids, audio_files, durations, questions, answers, offsets, speakers, orig_srs, langs, *args, **kwargs
)

def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]:
item = json.loads(line)

# Audio file
if 'audio_filename' in item:
item['audio_file'] = item.pop('audio_filename')
elif 'audio_filepath' in item:
item['audio_file'] = item.pop('audio_filepath')
elif 'audio_file' not in item:
raise ValueError(
f"Manifest file {manifest_file} has invalid json line structure: {line} without proper audio file key."
)

# If the audio path is a relative path and does not exist,
# try to attach the parent directory of manifest to the audio path.
# Revert to the original path if the new path still doesn't exist.
# Assume that the audio path is like "wavs/xxxxxx.wav".
item['audio_file'] = manifest.get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file)

# Duration.
if 'duration' not in item:
raise ValueError(
f"Manifest file {manifest_file} has invalid json line structure: {line} without proper duration key."
)

# Question.
if 'question' in item:
pass
elif 'question_filepath' in item:
with open(item.pop('text_filepath'), 'r') as f:
item['question'] = f.read().replace('\n', '')
elif 'normalized_text' in item:
item['question'] = item['normalized_text']
else:
item['question'] = "what does this audio mean"

# Answer.
if 'answer' in item:
pass
elif 'text' in item:
item['answer'] = item.pop('text')
elif 'text_filepath' in item:
with open(item.pop('text_filepath'), 'r') as f:
item['answer'] = f.read().replace('\n', '')
elif 'normalized_text' in item:
item['answer'] = item['normalized_text']
else:
item['answer'] = ""

item = dict(
audio_file=item['audio_file'],
duration=item['duration'],
question=item['question'],
answer=item['answer'],
offset=item.get('offset', None),
speaker=item.get('speaker', None),
orig_sr=item.get('orig_sample_rate', None),
lang=item.get('lang', None),
)
return item


class SpeechLabel(_Collection):
"""List of audio-label correspondence with preprocessing."""

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/multimodal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.multimodal import models, modules
from nemo.collections.multimodal import models, modules
13 changes: 13 additions & 0 deletions nemo/collections/multimodal/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 11eb478

Please sign in to comment.