Source code for nnmnkwii.datasets.ljspeech

from __future__ import absolute_import, print_function, with_statement

from os.path import exists, join

import numpy as np
from nnmnkwii.datasets import FileDataSource


class LJSpeechDataSource(FileDataSource):
    def __init__(self, data_root):
        self.data_root = data_root
        metadata_path = join(data_root, "metadata.csv")
        if not exists(metadata_path):
            raise RuntimeError(
                'metadata.csv doesn\'t exists at "{}"'.format(metadata_path)
            )

        with open(metadata_path, "rb") as f:
            metadata = []
            for line in f:
                parts = line.decode("utf-8").strip().split("|")
                assert len(parts) == 3
                metadata.append(parts)
        self.metadata = np.asarray(metadata)


[docs]class TranscriptionDataSource(LJSpeechDataSource): """Transcription data source for LJSpeech dataset. The data source collects text transcriptions from LJSpeech. Users are expected to inherit the class and implement ``collect_features`` method, which defines how features are computed given a transcription. Args: data_root (str): Data root. normalized (bool): Collect normalized transcriptions or not. Attributes: metadata (numpy.ndarray): Metadata, shapeo (``num_files x 3``). """ def __init__(self, data_root, normalized=False): super(TranscriptionDataSource, self).__init__(data_root) self.normalized = normalized
[docs] def collect_files(self): """Collect text transcriptions. .. warning:: Note that it returns list of transcriptions (str), not file paths. Returns: list: List of text transcription. """ idx = 2 if self.normalized else 1 return list(self.metadata[:, idx])
class NormalizedTranscriptionDataSource(TranscriptionDataSource): """Normalized transcription data source for LJSpeech dataset. .. warn:: Deprecated. Use TranscriptionDataSource with ``normalized=True`` instead. Similar to ``LJSpeechTranscriptionDataSource``, but this collect normalized transcriptions instead of raw ones. Args: data_root (str): Data root. Attributes: metadata (numpy.ndarray): Metadata, shape (``num_files x 3``). """ def __init__(self, data_root): super(NormalizedTranscriptionDataSource, self).__init__( data_root, normalized=True )
[docs]class WavFileDataSource(LJSpeechDataSource): """Wav file data source for LJSpeech dataset. The data source collects wav files from LJSpeech. Users are expected to inherit the class and implement ``collect_features`` method, which defines how features are computed given a wav file path. Args: data_root (str): Data root. Attributes: metadata (numpy.ndarray): Metadata, shape (``num_files x 3``). """ def __init__(self, data_root): super(WavFileDataSource, self).__init__(data_root)
[docs] def collect_files(self): """Collect wav files. Returns: list: List of wav files. """ files = list( map( lambda x: join(self.data_root, "wavs", x + ".wav"), list(self.metadata[:, 0]), ) ) return files