316 lines
11 KiB
Python
316 lines
11 KiB
Python
import argparse
|
|
import re
|
|
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import yaml
|
|
|
|
from ctranslate2.converters import utils
|
|
from ctranslate2.converters.converter import Converter
|
|
from ctranslate2.specs import common_spec, transformer_spec
|
|
|
|
_SUPPORTED_ACTIVATIONS = {
|
|
"gelu": common_spec.Activation.GELUSigmoid,
|
|
"relu": common_spec.Activation.RELU,
|
|
"swish": common_spec.Activation.SWISH,
|
|
}
|
|
|
|
_SUPPORTED_POSTPROCESS_EMB = {"", "d", "n", "nd"}
|
|
|
|
|
|
class MarianConverter(Converter):
|
|
"""Converts models trained with Marian."""
|
|
|
|
def __init__(self, model_path: str, vocab_paths: List[str]):
|
|
"""Initializes the Marian converter.
|
|
|
|
Arguments:
|
|
model_path: Path to the Marian model (.npz file).
|
|
vocab_paths: Paths to the vocabularies (.yml files).
|
|
"""
|
|
self._model_path = model_path
|
|
self._vocab_paths = vocab_paths
|
|
|
|
def _load(self):
|
|
model = np.load(self._model_path)
|
|
config = _get_model_config(model)
|
|
vocabs = list(map(load_vocab, self._vocab_paths))
|
|
|
|
activation = config["transformer-ffn-activation"]
|
|
pre_norm = "n" in config["transformer-preprocess"]
|
|
postprocess_emb = config["transformer-postprocess-emb"]
|
|
|
|
check = utils.ConfigurationChecker()
|
|
check(config["type"] == "transformer", "Option --type must be 'transformer'")
|
|
check(
|
|
config["transformer-decoder-autoreg"] == "self-attention",
|
|
"Option --transformer-decoder-autoreg must be 'self-attention'",
|
|
)
|
|
check(
|
|
not config["transformer-no-projection"],
|
|
"Option --transformer-no-projection is not supported",
|
|
)
|
|
check(
|
|
activation in _SUPPORTED_ACTIVATIONS,
|
|
"Option --transformer-ffn-activation %s is not supported "
|
|
"(supported activations are: %s)"
|
|
% (activation, ", ".join(_SUPPORTED_ACTIVATIONS.keys())),
|
|
)
|
|
check(
|
|
postprocess_emb in _SUPPORTED_POSTPROCESS_EMB,
|
|
"Option --transformer-postprocess-emb %s is not supported (supported values are: %s)"
|
|
% (postprocess_emb, ", ".join(_SUPPORTED_POSTPROCESS_EMB)),
|
|
)
|
|
|
|
if pre_norm:
|
|
check(
|
|
config["transformer-preprocess"] == "n"
|
|
and config["transformer-postprocess"] == "da"
|
|
and config.get("transformer-postprocess-top", "") == "n",
|
|
"Unsupported pre-norm Transformer architecture, expected the following "
|
|
"combination of options: "
|
|
"--transformer-preprocess n "
|
|
"--transformer-postprocess da "
|
|
"--transformer-postprocess-top n",
|
|
)
|
|
else:
|
|
check(
|
|
config["transformer-preprocess"] == ""
|
|
and config["transformer-postprocess"] == "dan"
|
|
and config.get("transformer-postprocess-top", "") == "",
|
|
"Unsupported post-norm Transformer architecture, excepted the following "
|
|
"combination of options: "
|
|
"--transformer-preprocess '' "
|
|
"--transformer-postprocess dan "
|
|
"--transformer-postprocess-top ''",
|
|
)
|
|
|
|
check.validate()
|
|
|
|
alignment_layer = config["transformer-guided-alignment-layer"]
|
|
alignment_layer = -1 if alignment_layer == "last" else int(alignment_layer) - 1
|
|
layernorm_embedding = "n" in postprocess_emb
|
|
|
|
model_spec = transformer_spec.TransformerSpec.from_config(
|
|
(config["enc-depth"], config["dec-depth"]),
|
|
config["transformer-heads"],
|
|
pre_norm=pre_norm,
|
|
activation=_SUPPORTED_ACTIVATIONS[activation],
|
|
alignment_layer=alignment_layer,
|
|
alignment_heads=1,
|
|
layernorm_embedding=layernorm_embedding,
|
|
)
|
|
set_transformer_spec(model_spec, model)
|
|
model_spec.register_source_vocabulary(vocabs[0])
|
|
model_spec.register_target_vocabulary(vocabs[-1])
|
|
model_spec.config.add_source_eos = True
|
|
return model_spec
|
|
|
|
|
|
def _get_model_config(model):
|
|
config = model["special:model.yml"]
|
|
config = config[:-1].tobytes()
|
|
config = yaml.safe_load(config)
|
|
return config
|
|
|
|
|
|
def load_vocab(path):
|
|
# pyyaml skips some entries so we manually parse the vocabulary file.
|
|
with open(path, encoding="utf-8") as vocab:
|
|
tokens = []
|
|
token = None
|
|
idx = None
|
|
for i, line in enumerate(vocab):
|
|
line = line.rstrip("\n\r")
|
|
if not line:
|
|
continue
|
|
|
|
if line.startswith("? "): # Complex key mapping (key)
|
|
token = line[2:]
|
|
elif token is not None: # Complex key mapping (value)
|
|
idx = line[2:]
|
|
else:
|
|
token, idx = line.rsplit(":", 1)
|
|
|
|
if token is not None:
|
|
if token.startswith('"') and token.endswith('"'):
|
|
# Unescape characters and remove quotes.
|
|
token = re.sub(r"\\([^x])", r"\1", token)
|
|
token = token[1:-1]
|
|
if token.startswith("\\x"):
|
|
# Convert the digraph \x to the actual escaped sequence.
|
|
token = chr(int(token[2:], base=16))
|
|
elif token.startswith("'") and token.endswith("'"):
|
|
token = token[1:-1]
|
|
token = token.replace("''", "'")
|
|
|
|
if idx is not None:
|
|
try:
|
|
idx = int(idx.strip())
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
"Unexpected format at line %d: '%s'" % (i + 1, line)
|
|
) from e
|
|
|
|
tokens.append((idx, token))
|
|
|
|
token = None
|
|
idx = None
|
|
|
|
return [token for _, token in sorted(tokens, key=lambda item: item[0])]
|
|
|
|
|
|
def set_transformer_spec(spec, weights):
|
|
set_transformer_encoder(spec.encoder, weights, "encoder")
|
|
set_transformer_decoder(spec.decoder, weights, "decoder")
|
|
|
|
|
|
def set_transformer_encoder(spec, weights, scope):
|
|
set_common_layers(spec, weights, scope)
|
|
for i, layer_spec in enumerate(spec.layer):
|
|
set_transformer_encoder_layer(layer_spec, weights, "%s_l%d" % (scope, i + 1))
|
|
|
|
|
|
def set_transformer_decoder(spec, weights, scope):
|
|
spec.start_from_zero_embedding = True
|
|
set_common_layers(spec, weights, scope)
|
|
for i, layer_spec in enumerate(spec.layer):
|
|
set_transformer_decoder_layer(layer_spec, weights, "%s_l%d" % (scope, i + 1))
|
|
|
|
set_linear(
|
|
spec.projection,
|
|
weights,
|
|
"%s_ff_logit_out" % scope,
|
|
reuse_weight=spec.embeddings.weight,
|
|
)
|
|
|
|
|
|
def set_common_layers(spec, weights, scope):
|
|
embeddings_specs = spec.embeddings
|
|
if not isinstance(embeddings_specs, list):
|
|
embeddings_specs = [embeddings_specs]
|
|
|
|
set_embeddings(embeddings_specs[0], weights, scope)
|
|
set_position_encodings(
|
|
spec.position_encodings, weights, dim=embeddings_specs[0].weight.shape[1]
|
|
)
|
|
if hasattr(spec, "layernorm_embedding"):
|
|
set_layer_norm(
|
|
spec.layernorm_embedding,
|
|
weights,
|
|
"%s_emb" % scope,
|
|
pre_norm=True,
|
|
)
|
|
if hasattr(spec, "layer_norm"):
|
|
set_layer_norm(spec.layer_norm, weights, "%s_top" % scope)
|
|
|
|
|
|
def set_transformer_encoder_layer(spec, weights, scope):
|
|
set_ffn(spec.ffn, weights, "%s_ffn" % scope)
|
|
set_multi_head_attention(
|
|
spec.self_attention, weights, "%s_self" % scope, self_attention=True
|
|
)
|
|
|
|
|
|
def set_transformer_decoder_layer(spec, weights, scope):
|
|
set_ffn(spec.ffn, weights, "%s_ffn" % scope)
|
|
set_multi_head_attention(
|
|
spec.self_attention, weights, "%s_self" % scope, self_attention=True
|
|
)
|
|
set_multi_head_attention(spec.attention, weights, "%s_context" % scope)
|
|
|
|
|
|
def set_multi_head_attention(spec, weights, scope, self_attention=False):
|
|
split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
set_linear(split_layers[0], weights, scope, "q")
|
|
set_linear(split_layers[1], weights, scope, "k")
|
|
set_linear(split_layers[2], weights, scope, "v")
|
|
|
|
if self_attention:
|
|
utils.fuse_linear(spec.linear[0], split_layers)
|
|
else:
|
|
spec.linear[0].weight = split_layers[0].weight
|
|
spec.linear[0].bias = split_layers[0].bias
|
|
utils.fuse_linear(spec.linear[1], split_layers[1:])
|
|
|
|
set_linear(spec.linear[-1], weights, scope, "o")
|
|
set_layer_norm_auto(spec.layer_norm, weights, "%s_Wo" % scope)
|
|
|
|
|
|
def set_ffn(spec, weights, scope):
|
|
set_layer_norm_auto(spec.layer_norm, weights, "%s_ffn" % scope)
|
|
set_linear(spec.linear_0, weights, scope, "1")
|
|
set_linear(spec.linear_1, weights, scope, "2")
|
|
|
|
|
|
def set_layer_norm_auto(spec, weights, scope):
|
|
try:
|
|
set_layer_norm(spec, weights, scope, pre_norm=True)
|
|
except KeyError:
|
|
set_layer_norm(spec, weights, scope)
|
|
|
|
|
|
def set_layer_norm(spec, weights, scope, pre_norm=False):
|
|
suffix = "_pre" if pre_norm else ""
|
|
spec.gamma = weights["%s_ln_scale%s" % (scope, suffix)].squeeze()
|
|
spec.beta = weights["%s_ln_bias%s" % (scope, suffix)].squeeze()
|
|
|
|
|
|
def set_linear(spec, weights, scope, suffix="", reuse_weight=None):
|
|
weight = weights.get("%s_W%s" % (scope, suffix))
|
|
|
|
if weight is None:
|
|
weight = weights.get("%s_Wt%s" % (scope, suffix), reuse_weight)
|
|
else:
|
|
weight = weight.transpose()
|
|
|
|
spec.weight = weight
|
|
|
|
bias = weights.get("%s_b%s" % (scope, suffix))
|
|
if bias is not None:
|
|
spec.bias = bias.squeeze()
|
|
|
|
|
|
def set_embeddings(spec, weights, scope):
|
|
spec.weight = weights.get("%s_Wemb" % scope)
|
|
if spec.weight is None:
|
|
spec.weight = weights.get("Wemb")
|
|
|
|
|
|
def set_position_encodings(spec, weights, dim=None):
|
|
spec.encodings = weights.get("Wpos", _make_sinusoidal_position_encodings(dim))
|
|
|
|
|
|
def _make_sinusoidal_position_encodings(dim, num_positions=2048):
|
|
positions = np.arange(num_positions)
|
|
timescales = np.power(10000, 2 * (np.arange(dim) // 2) / dim)
|
|
position_enc = np.expand_dims(positions, 1) / np.expand_dims(timescales, 0)
|
|
table = np.zeros_like(position_enc)
|
|
table[:, : dim // 2] = np.sin(position_enc[:, 0::2])
|
|
table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
|
|
return table
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
parser.add_argument(
|
|
"--model_path", required=True, help="Path to the model .npz file."
|
|
)
|
|
parser.add_argument(
|
|
"--vocab_paths",
|
|
required=True,
|
|
nargs="+",
|
|
help="List of paths to the YAML vocabularies.",
|
|
)
|
|
Converter.declare_arguments(parser)
|
|
args = parser.parse_args()
|
|
converter = MarianConverter(args.model_path, args.vocab_paths)
|
|
converter.convert_from_args(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|