73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from ctranslate2.specs import common_spec, model_spec, transformer_spec
|
|
|
|
|
|
class Wav2Vec2Config(model_spec.ModelConfig):
|
|
"""Configuration for the Wav2Vec2 model."""
|
|
|
|
def __init__(self):
|
|
return
|
|
|
|
|
|
class Wav2Vec2Spec(model_spec.LanguageModelSpec):
|
|
def __init__(
|
|
self,
|
|
feat_layers,
|
|
num_layers,
|
|
num_heads,
|
|
vocab_size,
|
|
return_hidden,
|
|
):
|
|
super().__init__()
|
|
self.vocab_size = np.dtype("int16").type(vocab_size)
|
|
self.encoder = Wav2Vec2EncoderSpec(
|
|
feat_layers,
|
|
num_layers,
|
|
num_heads,
|
|
return_hidden,
|
|
)
|
|
|
|
@property
|
|
def name(self):
|
|
return "Wav2Vec2Spec"
|
|
|
|
@property
|
|
def revision(self):
|
|
return 3
|
|
|
|
def get_default_config(self):
|
|
return Wav2Vec2Config()
|
|
|
|
def get_vocabulary_size(self):
|
|
return int(self.vocab_size.numpy())
|
|
|
|
|
|
class Wav2Vec2LayerNormConvLayer(model_spec.LayerSpec):
|
|
def __init__(self):
|
|
self.conv = common_spec.Conv1DSpec()
|
|
self.layer_norm = common_spec.LayerNormSpec()
|
|
|
|
|
|
class Wav2Vec2PosEmbedConvLayer(model_spec.LayerSpec):
|
|
def __init__(self):
|
|
self.conv = common_spec.Conv1DSpec()
|
|
|
|
|
|
class Wav2Vec2EncoderSpec(model_spec.LayerSpec):
|
|
def __init__(self, feat_layers, num_layers, num_heads, return_hidden):
|
|
self.num_heads = np.dtype("int16").type(num_heads)
|
|
self.feat_layer0 = Wav2Vec2LayerNormConvLayer()
|
|
self.feat_layer = [Wav2Vec2LayerNormConvLayer() for i in range(feat_layers - 1)]
|
|
self.fp_layer_norm = common_spec.LayerNormSpec()
|
|
self.fp_projection = common_spec.LinearSpec()
|
|
self.pos_conv_embed = Wav2Vec2PosEmbedConvLayer()
|
|
self.layer_norm = common_spec.LayerNormSpec()
|
|
self.layer = [
|
|
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
|
|
]
|
|
if not return_hidden:
|
|
self.lm_head = common_spec.LinearSpec()
|