Source code for voicebox.tts.tts

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Sequence, Type, Tuple, Optional

from voicebox.audio import Audio
from voicebox.tts.utils import get_audio_from_mp3, get_audio_from_wav_file
from voicebox.types import StrOrSSML

log = logging.getLogger(__name__)


[docs] class TTS(ABC): """Base class for text-to-speech engines.""" def __call__(self, text: StrOrSSML) -> Audio: return self.get_speech(text)
[docs] @abstractmethod def get_speech(self, text: StrOrSSML) -> Audio: """Returns audio of the given text.""" ... # pragma: no cover
[docs] class AudioFileTTS(TTS, ABC): """Base class for text-to-speech engines that generate audio files.""" temp_file_dir: Optional[str] temp_file_prefix: str def __init__( self, temp_file_dir: Optional[str], temp_file_prefix: str, ): self.temp_file_dir = temp_file_dir self.temp_file_prefix = temp_file_prefix
[docs] def get_speech(self, text: StrOrSSML) -> Audio: with NamedTemporaryFile( prefix=self.temp_file_prefix, suffix="." + self.get_audio_file_type(), dir=self.temp_file_dir, ) as audio_file: audio_file.close() audio_file_path = Path(audio_file.name) self.generate_speech_audio_file(text, audio_file_path) return self.get_audio_from_file(audio_file_path)
[docs] @abstractmethod def get_audio_file_type(self) -> str: """Returns the file type of the audio files generated by this TTS.""" ... # pragma: no cover
[docs] @abstractmethod def generate_speech_audio_file( self, text: StrOrSSML, audio_file_path: Path ) -> None: """Generates a speech audio file from the given text.""" ... # pragma: no cover
[docs] @abstractmethod def get_audio_from_file(self, file_path: Path) -> Audio: """Returns an Audio instance from the given file path.""" ... # pragma: no cover
[docs] class Mp3FileTTS(AudioFileTTS, ABC):
[docs] def get_audio_file_type(self) -> str: return "mp3"
[docs] def get_audio_from_file(self, file_path: Path) -> Audio: return get_audio_from_mp3(file_path)
[docs] class WavFileTTS(AudioFileTTS, ABC):
[docs] def get_audio_file_type(self) -> str: return "wav"
[docs] def get_audio_from_file(self, file_path: Path) -> Audio: return get_audio_from_wav_file(file_path)
[docs] @dataclass class FallbackTTS(TTS): """ Attempts to call the TTSs in order, returning results from the first TTS that does not raise an exception. Useful if you have e.g. an online TTS that you want to use primarily, and want to fall back to an offline TTS in case something goes wrong. Args: ttss: The TTSs to try, in order. exceptions_to_catch: The exceptions to catch and log when calling the TTSs. If an exception is raised that is not in this tuple, then it will not be caught. log: The logger to use for logging exceptions. """ ttss: Sequence[TTS] exceptions_to_catch: Tuple[Type[BaseException]] = (Exception,) log: Logger = log
[docs] def get_speech(self, text: StrOrSSML) -> Audio: for i, tts in enumerate(self.ttss): try: return tts.get_speech(text) except BaseException as e: self.handle_exception(e, tts, i) is_last = i + 1 >= len(self.ttss) should_catch = isinstance(e, self.exceptions_to_catch) if is_last or not should_catch: raise raise ValueError("self.ttss is empty")
[docs] def handle_exception(self, e: BaseException, tts: TTS, tts_index: int) -> None: message = f"Exception occurred calling TTS={tts} (index {tts_index})" self.log.exception(message, exc_info=e)
[docs] @dataclass class RetryTTS(TTS): """ If an exception occurs while getting speech from the given TTS, retry until ``max_attempts`` is reached. Args: tts: The TTS to call. max_attempts: The maximum number of attempts to make. exceptions_to_catch: The exceptions to catch and log when calling the TTS. If an exception is raised that is not in this tuple, then it will not be caught. log: The logger to use for logging exceptions. """ tts: TTS max_attempts: int = 3 exceptions_to_catch: Tuple[Type[BaseException]] = (Exception,) log: Logger = log
[docs] def get_speech(self, text: StrOrSSML) -> Audio: for attempt in range(1, self.max_attempts + 1): try: return self.tts.get_speech(text) except BaseException as e: self.handle_exception(e, attempt) is_last_attempt = attempt >= self.max_attempts should_catch = isinstance(e, self.exceptions_to_catch) if is_last_attempt or not should_catch: raise raise ValueError( f"self.max_attempts must be > 0; " f"max_attempts={self.max_attempts}" )
[docs] def handle_exception(self, e: BaseException, attempt: int) -> None: message = f"TTS attempt {attempt}/{self.max_attempts} failed" self.log.exception(message, exc_info=e)