78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from ctranslate2.specs import common_spec, model_spec, transformer_spec
|
|
|
|
|
|
class WhisperConfig(model_spec.ModelConfig):
|
|
"""Configuration for the Whisper model."""
|
|
|
|
def __init__(
|
|
self,
|
|
suppress_ids: Optional[List[int]] = None,
|
|
suppress_ids_begin: Optional[List[int]] = None,
|
|
lang_ids: Optional[List[int]] = None,
|
|
alignment_heads: Optional[List[Tuple[int, int]]] = None,
|
|
):
|
|
super().__init__(
|
|
suppress_ids=suppress_ids,
|
|
suppress_ids_begin=suppress_ids_begin,
|
|
lang_ids=lang_ids,
|
|
alignment_heads=alignment_heads,
|
|
)
|
|
|
|
|
|
class WhisperSpec(model_spec.LanguageModelSpec):
|
|
"""Describes a Whisper model."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_encoder_layers,
|
|
num_encoder_heads,
|
|
num_decoder_layers,
|
|
num_decoder_heads,
|
|
):
|
|
"""Initializes the model specification.
|
|
|
|
Args:
|
|
num_encoder_layers: The number of encoder layers.
|
|
num_encoder_heads: The number of encoder attention heads.
|
|
num_decoder_layers: The number of decoder layers.
|
|
num_decoder_heads: The number of decoder attention heads.
|
|
"""
|
|
super().__init__()
|
|
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads)
|
|
self.decoder = transformer_spec.TransformerDecoderSpec(
|
|
num_decoder_layers,
|
|
num_decoder_heads,
|
|
activation=common_spec.Activation.GELU,
|
|
)
|
|
self.decoder.scale_embeddings = False
|
|
|
|
@property
|
|
def name(self):
|
|
return "WhisperSpec"
|
|
|
|
@property
|
|
def revision(self):
|
|
return 3
|
|
|
|
def get_default_config(self):
|
|
return WhisperConfig()
|
|
|
|
def get_vocabulary_size(self):
|
|
return self.decoder.embeddings.weight.shape[0]
|
|
|
|
|
|
class WhisperEncoderSpec(model_spec.LayerSpec):
|
|
def __init__(self, num_layers, num_heads):
|
|
self.num_heads = np.dtype("int16").type(num_heads)
|
|
self.conv1 = common_spec.Conv1DSpec()
|
|
self.conv2 = common_spec.Conv1DSpec()
|
|
self.position_encodings = transformer_spec.PositionEncoderSpec()
|
|
self.layer_norm = common_spec.LayerNormSpec()
|
|
self.layer = [
|
|
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
|
|
]
|