from collections import OrderedDict
from glob import glob
from os.path import basename, exists, join, splitext
import numpy as np
from nnmnkwii.datasets import FileDataSource
# List of available speakers.
available_speakers = [
"225",
"226",
"227",
"228",
"229",
"230",
"231",
"232",
"233",
"234",
"236",
"237",
"238",
"239",
"240",
"241",
"243",
"244",
"245",
"246",
"247",
"248",
"249",
"250",
"251",
"252",
"253",
"254",
"255",
"256",
"257",
"258",
"259",
"260",
"261",
"262",
"263",
"264",
"265",
"266",
"267",
"268",
"269",
"270",
"271",
"272",
"273",
"274",
"275",
"276",
"277",
"278",
"279",
"280",
"281",
"282",
"283",
"284",
"285",
"286",
"287",
"288",
"292",
"293",
"294",
"295",
"297",
"298",
"299",
"300",
"301",
"302",
"303",
"304",
"305",
"306",
"307",
"308",
"310",
"311",
"312",
"313",
"314",
# "315", transcriptions are missing, so excludes it here
"316",
"317",
"318",
"323",
"326",
"329",
"330",
"333",
"334",
"335",
"336",
"339",
"340",
"341",
"343",
"345",
"347",
"351",
"360",
"361",
"362",
"363",
"364",
"374",
"376",
]
assert len(available_speakers) == 108
def _parse_speaker_info(data_root):
speaker_info_path = join(data_root, "speaker-info.txt")
if not exists(speaker_info_path):
raise RuntimeError(
'speaker-info.txt doesn\'t exist at "{}"'.format(speaker_info_path)
)
speaker_info = OrderedDict()
# filed_names = ["ID", "AGE", "GENDER", "ACCENTS", "REGION"]
with open(speaker_info_path, "rb") as f:
for line in f:
line = line.decode("utf-8")
fields = line.split()
if fields[0] == "ID":
continue
assert len(fields) == 4 or len(fields) == 5 or len(fields) == 6
ID = fields[0]
speaker_info[ID] = {}
speaker_info[ID]["AGE"] = int(fields[1])
speaker_info[ID]["GENDER"] = fields[2]
speaker_info[ID]["ACCENTS"] = fields[3]
if len(fields) > 4:
speaker_info[ID]["REGION"] = " ".join(fields[4:])
else:
speaker_info[ID]["REGION"] = ""
return speaker_info
class _VCTKBaseDataSource(FileDataSource):
def __init__(self, data_root, speakers, labelmap, max_files):
self.data_root = data_root
# accept both e.g., "225" and "p225"
for idx in range(len(speakers)):
if speakers[idx][0] == "p":
speakers[idx] = speakers[idx][1:]
if speakers == "all":
speakers = available_speakers
for speaker in speakers:
if speaker not in available_speakers:
raise ValueError(
"Unknown speaker '{}'. It should be one of {}".format(
speaker, available_speakers
)
)
self.speakers = speakers
if labelmap is None:
labelmap = {}
for idx, speaker in enumerate(speakers):
labelmap[speaker] = idx
self.labelmap = labelmap
self.labels = None
self.max_files = max_files
self.speaker_info = _parse_speaker_info(data_root)
self._validate()
def _validate(self):
# should have pair of transcription and wav files
for _, speaker in enumerate(self.speakers):
txt_files = sorted(
glob(
join(
self.data_root,
"txt",
"p" + speaker,
"p{}_*.txt".format(speaker),
)
)
)
wav_files = sorted(
glob(
join(
self.data_root,
"wav48",
"p" + speaker,
"p{}_*.wav".format(speaker),
)
)
)
assert len(txt_files) > 0
for txt_path, wav_path in zip(txt_files, wav_files):
assert (
splitext(basename(txt_path))[0] == splitext(basename(wav_path))[0]
)
def collect_files(self, is_wav):
if is_wav:
root = join(self.data_root, "wav48")
ext = ".wav"
else:
root = join(self.data_root, "txt")
ext = ".txt"
paths = []
labels = []
if self.max_files is None:
max_files_per_speaker = None
else:
max_files_per_speaker = self.max_files // len(self.speakers)
for idx, speaker in enumerate(self.speakers):
speaker_dir = join(root, "p" + speaker)
files = sorted(glob(join(speaker_dir, "p{}_*{}".format(speaker, ext))))
files = files[:max_files_per_speaker]
if not is_wav:
files = list(
map(lambda s: open(s, "rb").read().decode("utf-8")[:-1], files)
)
for f in files:
paths.append(f)
labels.append(self.labelmap[self.speakers[idx]])
self.labels = np.array(labels, dtype=np.int16)
return paths
[docs]class TranscriptionDataSource(_VCTKBaseDataSource):
"""Transcription data source for VCTK dataset.
The data source collects text transcriptions from VCTK.
Users are expected to inherit the class and implement ``collect_features``
method, which defines how features are computed given a transcription.
Args:
data_root (str): Data root.
speakers (list): List of speakers to find. Speaker id must be ``str``.
For supported names of speaker, please refer to ``available_speakers``
defined in the module.
labelmap (dict[optional]): Dict of speaker labels. If None,
it's assigned as incrementally (i.e., 0, 1, 2) for specified
speakers.
max_files (int): Total number of files to be collected.
Attributes:
speaker_info (dict): Dict of speaker information dict. Keyes are speaker
ids (str) and each value is speaker information consists of ``AGE``,
``GENDER`` and ``REGION``.
labels (numpy.ndarray): Speaker labels paired with collected files.
Stored in ``collect_files``. This is useful to build multi-speaker
models.
"""
def __init__(
self, data_root, speakers=available_speakers, labelmap=None, max_files=None
):
super(TranscriptionDataSource, self).__init__(
data_root, speakers, labelmap, max_files
)
[docs] def collect_files(self):
return super(TranscriptionDataSource, self).collect_files(False)
[docs]class WavFileDataSource(_VCTKBaseDataSource):
"""Transcription data source for VCTK dataset.
The data source collects text transcriptions from VCTK.
Users are expected to inherit the class and implement ``collect_features``
method, which defines how features are computed given a transcription.
Args:
data_root (str): Data root.
speakers (list): List of speakers to find. Speaker id must be ``str``.
For supported names of speaker, please refer to ``available_speakers``
defined in the module.
labelmap (dict[optional]): Dict of speaker labels. If None,
it's assigned as incrementally (i.e., 0, 1, 2) for specified
speakers.
max_files (int): Total number of files to be collected.
Attributes:
speaker_info (dict): Dict of speaker information dict. Keyes are speaker
ids (str) and each value is speaker information consists of ``AGE``,
``GENDER`` and ``REGION``.
labels (numpy.ndarray): Speaker labels paired with collected files.
Stored in ``collect_files``. This is useful to build multi-speaker
models.
"""
def __init__(
self, data_root, speakers=available_speakers, labelmap=None, max_files=None
):
super(WavFileDataSource, self).__init__(
data_root, speakers, labelmap, max_files
)
[docs] def collect_files(self):
return super(WavFileDataSource, self).collect_files(True)