98 lines
3.4 KiB
Python
98 lines
3.4 KiB
Python
import numpy as np
|
|
|
|
from ctranslate2.specs import attention_spec, common_spec, model_spec
|
|
|
|
|
|
class Wav2Vec2BertConfig(model_spec.ModelConfig):
|
|
"""Configuration for the Wav2Vec2Bert model."""
|
|
|
|
def __init__(self):
|
|
return
|
|
|
|
|
|
class Wav2Vec2BertSpec(model_spec.LanguageModelSpec):
|
|
def __init__(
|
|
self,
|
|
num_hidden_layers,
|
|
num_adapter_layers,
|
|
vocab_size,
|
|
return_hidden,
|
|
):
|
|
super().__init__()
|
|
self.vocab_size = np.dtype("int16").type(vocab_size)
|
|
self.encoder = Wav2Vec2BertEncoderSpec(
|
|
num_adapter_layers,
|
|
num_hidden_layers,
|
|
return_hidden,
|
|
)
|
|
|
|
@property
|
|
def name(self):
|
|
return "Wav2Vec2BertSpec"
|
|
|
|
@property
|
|
def revision(self):
|
|
return 1
|
|
|
|
def get_default_config(self):
|
|
return Wav2Vec2BertConfig()
|
|
|
|
def get_vocabulary_size(self):
|
|
return int(self.vocab_size.numpy())
|
|
|
|
|
|
class Wav2Vec2BertFeedForwardSpec(model_spec.LayerSpec):
|
|
def __init__(self, glu=False, rms_norm=False):
|
|
self.linear_0 = common_spec.LinearSpec()
|
|
self.linear_1 = common_spec.LinearSpec()
|
|
if glu:
|
|
self.linear_0_noact = common_spec.LinearSpec()
|
|
|
|
|
|
class EncoderSpec(model_spec.LayerSpec):
|
|
def __init__(self):
|
|
self.enc_ffn1_layer_norm = common_spec.LayerNormSpec()
|
|
self.enc_ffn1 = Wav2Vec2BertFeedForwardSpec()
|
|
self.enc_attn_layer_norm = common_spec.LayerNormSpec()
|
|
self.enc_attn = attention_spec.MultiHeadAttentionSpec(
|
|
self_attention=True,
|
|
relative_asymmetric_position=True,
|
|
)
|
|
del self.enc_attn.layer_norm
|
|
self.enc_conv_layer_norm = common_spec.LayerNormSpec()
|
|
self.enc_conv_pointwise_conv1 = common_spec.Conv1DSpec()
|
|
del self.enc_conv_pointwise_conv1.bias
|
|
self.enc_conv_depthwise_conv = common_spec.Conv1DSpec()
|
|
del self.enc_conv_depthwise_conv.bias
|
|
self.enc_conv_depthwise_layer_norm = common_spec.LayerNormSpec()
|
|
self.enc_conv_pointwise_conv2 = common_spec.Conv1DSpec()
|
|
del self.enc_conv_pointwise_conv2.bias
|
|
self.enc_ffn2_layer_norm = common_spec.LayerNormSpec()
|
|
self.enc_ffn2 = Wav2Vec2BertFeedForwardSpec()
|
|
self.enc_final_layer_norm = common_spec.LayerNormSpec()
|
|
|
|
|
|
class AdapterSpec(model_spec.LayerSpec):
|
|
def __init__(self):
|
|
self.adpt_residual_layer_norm = common_spec.LayerNormSpec()
|
|
self.adpt_residual_conv = common_spec.Conv1DSpec()
|
|
self.adpt_attn_layer_norm = common_spec.LayerNormSpec()
|
|
self.adpt_attn_conv = common_spec.Conv1DSpec()
|
|
self.adpt_attn_layer = attention_spec.MultiHeadAttentionSpec(
|
|
self_attention=True,
|
|
relative_asymmetric_position=False,
|
|
)
|
|
del self.adpt_attn_layer.layer_norm
|
|
self.adpt_ffn_layer_norm = common_spec.LayerNormSpec()
|
|
self.adpt_ffn = Wav2Vec2BertFeedForwardSpec()
|
|
|
|
|
|
class Wav2Vec2BertEncoderSpec(model_spec.LayerSpec):
|
|
def __init__(self, num_hidden_layers, num_adapter_layers, return_hidden):
|
|
self.fp_layer_norm = common_spec.LayerNormSpec()
|
|
self.fp_projection = common_spec.LinearSpec()
|
|
self.encoder_layers = [EncoderSpec() for _ in range(num_hidden_layers)]
|
|
self.adapter_layers = [AdapterSpec() for _ in range(num_adapter_layers)]
|
|
if not return_hidden:
|
|
self.lm_head = common_spec.LinearSpec()
|