Source code for voicebox.tts.cache

from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Iterable,
    Literal,
    Mapping,
    MutableMapping,
    Type,
    Union,
)

from cachetools import Cache, LRUCache

from voicebox.audio import Audio
from voicebox.tts import TTS
from voicebox.tts.utils import get_audio_from_wav_file
from voicebox.types import StrOrSSML, FileOrPath

Size = Union[int, float]

SizeFunc = Callable[[Any], Size]
"""Returns the size of the given item."""


[docs] @dataclass class CachedTTS(TTS): """Wraps a ``TTS`` instance in a cache to reduce calls to the ``TTS``.""" tts: TTS cache: MutableMapping
[docs] @classmethod def build( cls, tts: TTS, max_size: Size = 60, size_func: Union[Literal["bytes", "count", "seconds"], SizeFunc] = "seconds", cache_class: Type[Cache] = LRUCache, ) -> "CachedTTS": """ Constructs a cache that by default will keep the most recently used 60 seconds of audio, and wraps the given ``TTS`` instance in the cache so the ``TTS`` is only called for text not contained in the cache. Args: tts: The TTS instance to wrap. Will be called for text not contained in the cache. max_size: The maximum size of the cache, as determined by ``size_func``. Defaults to 60 (seconds). size_func: The function that measures the size of each item. If set to 'seconds' (default), then ``max_size`` will be in units of audio seconds. If set to 'bytes', then ``max_size`` will be in units of audio bytes. If set to 'count', then ``max_size`` will be simply the number of audio clips to cache. Alternatively, any function that takes an ``Audio`` instance as input and returns a size value can be passed in. cache_class: The ``Cache`` class used to construct the cache. Defaults to ``cachetools.LRUCache``, a Least Recently Used cache. Returns: An instance of ``CachedTTS``. """ if size_func == "bytes": size_func = lambda audio: audio.len_bytes elif size_func == "count": size_func = lambda audio: 1 elif size_func == "seconds": size_func = lambda audio: audio.len_seconds cache = cache_class(maxsize=max_size, getsizeof=size_func) return cls(tts, cache)
[docs] def get_speech(self, text: StrOrSSML) -> Audio: try: return self.cache[text] except KeyError: audio = self.tts.get_speech(text) return self._add_to_cache(text, audio)
def _add_to_cache(self, text: StrOrSSML, audio: Audio) -> Audio: try: self.cache[text] = audio except ValueError as e: if str(e) != "value too large": raise return audio
[docs] @dataclass class PrerecordedTTS(TTS): """ Returns audio from a map of message texts to ``Audio`` instances. Useful for playing back pre-recorded messages. Also supports an optional fallback ``TTS`` instance for messages not in the map. Args: texts_to_audios: Mapping of message texts to ``Audio`` instances. fallback_tts: Optional fallback ``TTS`` instance that will be used if a text is not found in ``messages``. """ texts_to_audios: Mapping[StrOrSSML, Audio] fallback_tts: TTS = None
[docs] @classmethod def from_tts( cls, tts: TTS, texts: Iterable[StrOrSSML], use_as_fallback: bool = True, ) -> "PrerecordedTTS": """ Returns a ``PrerecordedTTS`` instance using audio generated by the given ``TTS`` instance. Args: tts: The ``TTS`` instance to use. texts: The texts to generate audio of. use_as_fallback: If ``True``, then the given ``TTS`` instance will be used as a fallback for texts not in the map. Example: >>> tts = ...some tts... >>> tts = PrerecordedTTS.from_tts( >>> tts, >>> texts=[ >>> 'I say this all the time.', >>> 'Hello there!', >>> ], >>> ) """ texts = {text: tts.get_speech(text) for text in texts} return cls(texts, fallback_tts=tts if use_as_fallback else None)
[docs] @classmethod def from_wav_files( cls, texts_to_files: Mapping[StrOrSSML, FileOrPath], fallback_tts: TTS = None ) -> "PrerecordedTTS": """ Returns a ``PrerecordedTTS`` instance using audio from the specified wav files. Args: texts_to_files: Mapping of texts to wav files. fallback_tts: Optional fallback ``TTS`` instance. Example: >>> tts = PrerecordedTTS.from_wav_files({ >>> 'startup': 'audio/startup.wav', >>> 'no_internet': 'audio/no_internet.wav', >>> }) """ messages = { text: get_audio_from_wav_file(file) for text, file in texts_to_files.items() } return cls(messages, fallback_tts=fallback_tts)
[docs] def get_speech(self, text: StrOrSSML) -> Audio: try: return self.texts_to_audios[text] except KeyError: if self.fallback_tts is not None: return self.fallback_tts.get_speech(text) else: raise