456 lines
15 KiB
Python
456 lines
15 KiB
Python
import argparse
|
|
import copy
|
|
import os
|
|
|
|
from typing import Optional, Union
|
|
|
|
from ctranslate2.converters import utils
|
|
from ctranslate2.converters.converter import Converter
|
|
from ctranslate2.specs import common_spec, transformer_spec
|
|
|
|
_SUPPORTED_ACTIVATIONS = {
|
|
"gelu": common_spec.Activation.GELUTanh,
|
|
"relu": common_spec.Activation.RELU,
|
|
"swish": common_spec.Activation.SWISH,
|
|
}
|
|
|
|
|
|
class OpenNMTTFConverter(Converter):
|
|
"""Converts OpenNMT-tf models."""
|
|
|
|
@classmethod
|
|
def from_config(
|
|
cls,
|
|
config: Union[str, dict],
|
|
auto_config: bool = False,
|
|
checkpoint_path: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
):
|
|
"""Creates the converter from the configuration.
|
|
|
|
Arguments:
|
|
config: Path to the YAML configuration, or a dictionary with the loaded configuration.
|
|
auto_config: Whether the model automatic configuration values should be used.
|
|
checkpoint_path: Path to the checkpoint or checkpoint directory to load. If not set,
|
|
the latest checkpoint from the model directory is loaded.
|
|
model: If the model instance cannot be resolved from the model directory, this argument
|
|
can be set to either the name of the model in the catalog or the path to the model
|
|
configuration.
|
|
|
|
Returns:
|
|
A :class:`ctranslate2.converters.OpenNMTTFConverter` instance.
|
|
"""
|
|
from opennmt import config as config_util
|
|
from opennmt.utils.checkpoint import Checkpoint
|
|
|
|
if isinstance(config, str):
|
|
config = config_util.load_config([config])
|
|
else:
|
|
config = copy.deepcopy(config)
|
|
|
|
if model is None:
|
|
model = config_util.load_model(config["model_dir"])
|
|
elif os.path.exists(model):
|
|
model = config_util.load_model_from_file(model)
|
|
else:
|
|
model = config_util.load_model_from_catalog(model)
|
|
|
|
if auto_config:
|
|
config_util.merge_config(config, model.auto_config())
|
|
|
|
data_config = config_util.try_prefix_paths(config["model_dir"], config["data"])
|
|
model.initialize(data_config)
|
|
|
|
checkpoint = Checkpoint.from_config(config, model)
|
|
checkpoint_path = checkpoint.restore(checkpoint_path=checkpoint_path)
|
|
if checkpoint_path is None:
|
|
raise RuntimeError("No checkpoint was restored")
|
|
|
|
model.create_variables()
|
|
return cls(model)
|
|
|
|
def __init__(self, model):
|
|
"""Initializes the converter.
|
|
|
|
Arguments:
|
|
model: An initialized and fully-built ``opennmt.models.Model`` instance.
|
|
"""
|
|
self._model = model
|
|
|
|
def _load(self):
|
|
import opennmt
|
|
|
|
if isinstance(self._model, opennmt.models.LanguageModel):
|
|
spec_builder = TransformerDecoderSpecBuilder()
|
|
else:
|
|
spec_builder = TransformerSpecBuilder()
|
|
|
|
return spec_builder(self._model)
|
|
|
|
|
|
class TransformerSpecBuilder:
|
|
def __call__(self, model):
|
|
import opennmt
|
|
|
|
check = utils.ConfigurationChecker()
|
|
check(
|
|
isinstance(model, opennmt.models.Transformer),
|
|
"Only Transformer models are supported",
|
|
)
|
|
check.validate()
|
|
|
|
check(
|
|
isinstance(model.encoder, opennmt.encoders.SelfAttentionEncoder),
|
|
"Parallel encoders are not supported",
|
|
)
|
|
check(
|
|
isinstance(
|
|
model.features_inputter,
|
|
(opennmt.inputters.WordEmbedder, opennmt.inputters.ParallelInputter),
|
|
),
|
|
"Source inputter must be a WordEmbedder or a ParallelInputter",
|
|
)
|
|
check.validate()
|
|
|
|
mha = model.encoder.layers[0].self_attention.layer
|
|
ffn = model.encoder.layers[0].ffn.layer
|
|
with_relative_position = mha.maximum_relative_position is not None
|
|
activation_name = ffn.inner.activation.__name__
|
|
|
|
check(
|
|
activation_name in _SUPPORTED_ACTIVATIONS,
|
|
"Activation %s is not supported (supported activations are: %s)"
|
|
% (activation_name, ", ".join(_SUPPORTED_ACTIVATIONS.keys())),
|
|
)
|
|
check(
|
|
with_relative_position != bool(model.encoder.position_encoder),
|
|
"Relative position representation and position encoding cannot be both enabled "
|
|
"or both disabled",
|
|
)
|
|
check(
|
|
model.decoder.attention_reduction
|
|
!= opennmt.layers.MultiHeadAttentionReduction.AVERAGE_ALL_LAYERS,
|
|
"Averaging all multi-head attention matrices is not supported",
|
|
)
|
|
|
|
source_inputters = _get_inputters(model.features_inputter)
|
|
target_inputters = _get_inputters(model.labels_inputter)
|
|
num_source_embeddings = len(source_inputters)
|
|
if num_source_embeddings == 1:
|
|
embeddings_merge = common_spec.EmbeddingsMerge.CONCAT
|
|
else:
|
|
reducer = model.features_inputter.reducer
|
|
embeddings_merge = None
|
|
if reducer is not None:
|
|
if isinstance(reducer, opennmt.layers.ConcatReducer):
|
|
embeddings_merge = common_spec.EmbeddingsMerge.CONCAT
|
|
elif isinstance(reducer, opennmt.layers.SumReducer):
|
|
embeddings_merge = common_spec.EmbeddingsMerge.ADD
|
|
|
|
check(
|
|
all(
|
|
isinstance(inputter, opennmt.inputters.WordEmbedder)
|
|
for inputter in source_inputters
|
|
),
|
|
"All source inputters must WordEmbedders",
|
|
)
|
|
check(
|
|
embeddings_merge is not None,
|
|
"Unsupported embeddings reducer %s" % reducer,
|
|
)
|
|
|
|
alignment_layer = -1
|
|
alignment_heads = 1
|
|
if (
|
|
model.decoder.attention_reduction
|
|
== opennmt.layers.MultiHeadAttentionReduction.AVERAGE_LAST_LAYER
|
|
):
|
|
alignment_heads = 0
|
|
|
|
check.validate()
|
|
|
|
encoder_spec = transformer_spec.TransformerEncoderSpec(
|
|
len(model.encoder.layers),
|
|
model.encoder.layers[0].self_attention.layer.num_heads,
|
|
pre_norm=model.encoder.layer_norm is not None,
|
|
activation=_SUPPORTED_ACTIVATIONS[activation_name],
|
|
num_source_embeddings=num_source_embeddings,
|
|
embeddings_merge=embeddings_merge,
|
|
relative_position=with_relative_position,
|
|
)
|
|
|
|
decoder_spec = transformer_spec.TransformerDecoderSpec(
|
|
len(model.decoder.layers),
|
|
model.decoder.layers[0].self_attention.layer.num_heads,
|
|
pre_norm=model.decoder.layer_norm is not None,
|
|
activation=_SUPPORTED_ACTIVATIONS[activation_name],
|
|
relative_position=with_relative_position,
|
|
alignment_layer=alignment_layer,
|
|
alignment_heads=alignment_heads,
|
|
)
|
|
|
|
spec = transformer_spec.TransformerSpec(encoder_spec, decoder_spec)
|
|
|
|
spec.config.add_source_bos = bool(source_inputters[0].mark_start)
|
|
spec.config.add_source_eos = bool(source_inputters[0].mark_end)
|
|
for inputter in source_inputters:
|
|
spec.register_source_vocabulary(_load_vocab(inputter.vocabulary_file))
|
|
for inputter in target_inputters:
|
|
spec.register_target_vocabulary(_load_vocab(inputter.vocabulary_file))
|
|
|
|
self.set_transformer_encoder(
|
|
spec.encoder,
|
|
model.encoder,
|
|
model.features_inputter,
|
|
)
|
|
self.set_transformer_decoder(
|
|
spec.decoder,
|
|
model.decoder,
|
|
model.labels_inputter,
|
|
)
|
|
|
|
return spec
|
|
|
|
def set_transformer_encoder(self, spec, module, inputter):
|
|
for embedding_spec, inputter in zip(spec.embeddings, _get_inputters(inputter)):
|
|
self.set_embeddings(embedding_spec, inputter)
|
|
if module.position_encoder is not None:
|
|
self.set_position_encodings(
|
|
spec.position_encodings,
|
|
module.position_encoder,
|
|
)
|
|
|
|
for layer_spec, layer in zip(spec.layer, module.layers):
|
|
self.set_multi_head_attention(
|
|
layer_spec.self_attention,
|
|
layer.self_attention,
|
|
self_attention=True,
|
|
)
|
|
|
|
self.set_ffn(layer_spec.ffn, layer.ffn)
|
|
|
|
if module.layer_norm is not None:
|
|
self.set_layer_norm(spec.layer_norm, module.layer_norm)
|
|
|
|
def set_transformer_decoder(self, spec, module, inputter):
|
|
self.set_embeddings(spec.embeddings, inputter)
|
|
if module.position_encoder is not None:
|
|
self.set_position_encodings(
|
|
spec.position_encodings,
|
|
module.position_encoder,
|
|
)
|
|
|
|
for layer_spec, layer in zip(spec.layer, module.layers):
|
|
self.set_multi_head_attention(
|
|
layer_spec.self_attention,
|
|
layer.self_attention,
|
|
self_attention=True,
|
|
)
|
|
|
|
if layer.attention:
|
|
self.set_multi_head_attention(
|
|
layer_spec.attention,
|
|
layer.attention[0],
|
|
self_attention=False,
|
|
)
|
|
|
|
self.set_ffn(layer_spec.ffn, layer.ffn)
|
|
|
|
if module.layer_norm is not None:
|
|
self.set_layer_norm(spec.layer_norm, module.layer_norm)
|
|
|
|
self.set_linear(spec.projection, module.output_layer)
|
|
|
|
def set_ffn(self, spec, module):
|
|
self.set_linear(spec.linear_0, module.layer.inner)
|
|
self.set_linear(spec.linear_1, module.layer.outer)
|
|
self.set_layer_norm_from_wrapper(spec.layer_norm, module)
|
|
|
|
def set_multi_head_attention(self, spec, module, self_attention=False):
|
|
split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
self.set_linear(split_layers[0], module.layer.linear_queries)
|
|
self.set_linear(split_layers[1], module.layer.linear_keys)
|
|
self.set_linear(split_layers[2], module.layer.linear_values)
|
|
|
|
if self_attention:
|
|
utils.fuse_linear(spec.linear[0], split_layers)
|
|
if module.layer.maximum_relative_position is not None:
|
|
spec.relative_position_keys = (
|
|
module.layer.relative_position_keys.numpy()
|
|
)
|
|
spec.relative_position_values = (
|
|
module.layer.relative_position_values.numpy()
|
|
)
|
|
else:
|
|
utils.fuse_linear(spec.linear[0], split_layers[:1])
|
|
utils.fuse_linear(spec.linear[1], split_layers[1:])
|
|
|
|
self.set_linear(spec.linear[-1], module.layer.linear_output)
|
|
self.set_layer_norm_from_wrapper(spec.layer_norm, module)
|
|
|
|
def set_layer_norm_from_wrapper(self, spec, module):
|
|
self.set_layer_norm(
|
|
spec,
|
|
(
|
|
module.output_layer_norm
|
|
if module.input_layer_norm is None
|
|
else module.input_layer_norm
|
|
),
|
|
)
|
|
|
|
def set_layer_norm(self, spec, module):
|
|
spec.gamma = module.gamma.numpy()
|
|
spec.beta = module.beta.numpy()
|
|
|
|
def set_linear(self, spec, module):
|
|
spec.weight = module.kernel.numpy()
|
|
if not module.transpose:
|
|
spec.weight = spec.weight.transpose()
|
|
if module.bias is not None:
|
|
spec.bias = module.bias.numpy()
|
|
|
|
def set_embeddings(self, spec, module):
|
|
spec.weight = module.embedding.numpy()
|
|
|
|
def set_position_encodings(self, spec, module):
|
|
import opennmt
|
|
|
|
if isinstance(module, opennmt.layers.PositionEmbedder):
|
|
spec.encodings = module.embedding.numpy()[1:]
|
|
|
|
|
|
class TransformerDecoderSpecBuilder(TransformerSpecBuilder):
|
|
def __call__(self, model):
|
|
import opennmt
|
|
|
|
check = utils.ConfigurationChecker()
|
|
check(
|
|
isinstance(model.decoder, opennmt.decoders.SelfAttentionDecoder),
|
|
"Only self-attention decoders are supported",
|
|
)
|
|
check.validate()
|
|
|
|
mha = model.decoder.layers[0].self_attention.layer
|
|
ffn = model.decoder.layers[0].ffn.layer
|
|
activation_name = ffn.inner.activation.__name__
|
|
|
|
check(
|
|
activation_name in _SUPPORTED_ACTIVATIONS,
|
|
"Activation %s is not supported (supported activations are: %s)"
|
|
% (activation_name, ", ".join(_SUPPORTED_ACTIVATIONS.keys())),
|
|
)
|
|
check.validate()
|
|
|
|
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
len(model.decoder.layers),
|
|
mha.num_heads,
|
|
pre_norm=model.decoder.layer_norm is not None,
|
|
activation=_SUPPORTED_ACTIVATIONS[activation_name],
|
|
)
|
|
|
|
spec.register_vocabulary(_load_vocab(model.features_inputter.vocabulary_file))
|
|
self.set_transformer_decoder(
|
|
spec.decoder,
|
|
model.decoder,
|
|
model.features_inputter,
|
|
)
|
|
|
|
return spec
|
|
|
|
|
|
def _get_inputters(inputter):
|
|
import opennmt
|
|
|
|
return (
|
|
inputter.inputters
|
|
if isinstance(inputter, opennmt.inputters.MultiInputter)
|
|
else [inputter]
|
|
)
|
|
|
|
|
|
def _load_vocab(vocab, unk_token="<unk>"):
|
|
import opennmt
|
|
|
|
if isinstance(vocab, opennmt.data.Vocab):
|
|
tokens = list(vocab.words)
|
|
elif isinstance(vocab, list):
|
|
tokens = list(vocab)
|
|
elif isinstance(vocab, str):
|
|
tokens = opennmt.data.Vocab.from_file(vocab).words
|
|
else:
|
|
raise TypeError("Invalid vocabulary type")
|
|
|
|
if unk_token not in tokens:
|
|
tokens.append(unk_token)
|
|
return tokens
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
parser.add_argument("--config", help="Path to the YAML configuration.")
|
|
parser.add_argument(
|
|
"--auto_config",
|
|
action="store_true",
|
|
help="Use the model automatic configuration values.",
|
|
)
|
|
parser.add_argument(
|
|
"--model_path",
|
|
help=(
|
|
"Path to the checkpoint or checkpoint directory to load. If not set, "
|
|
"the latest checkpoint from the model directory is loaded."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--model_type",
|
|
help=(
|
|
"If the model instance cannot be resolved from the model directory, "
|
|
"this argument can be set to either the name of the model in the catalog "
|
|
"or the path to the model configuration."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--src_vocab",
|
|
help="Path to the source vocabulary (required if no configuration is set).",
|
|
)
|
|
parser.add_argument(
|
|
"--tgt_vocab",
|
|
help="Path to the target vocabulary (required if no configuration is set).",
|
|
)
|
|
Converter.declare_arguments(parser)
|
|
args = parser.parse_args()
|
|
|
|
config = args.config
|
|
if not config:
|
|
if not args.model_path or not args.src_vocab or not args.tgt_vocab:
|
|
raise ValueError(
|
|
"Options --model_path, --src_vocab, --tgt_vocab are required "
|
|
"when a configuration is not set"
|
|
)
|
|
|
|
model_dir = (
|
|
args.model_path
|
|
if os.path.isdir(args.model_path)
|
|
else os.path.dirname(args.model_path)
|
|
)
|
|
config = {
|
|
"model_dir": model_dir,
|
|
"data": {
|
|
"source_vocabulary": args.src_vocab,
|
|
"target_vocabulary": args.tgt_vocab,
|
|
},
|
|
}
|
|
|
|
converter = OpenNMTTFConverter.from_config(
|
|
config,
|
|
auto_config=args.auto_config,
|
|
checkpoint_path=args.model_path,
|
|
model=args.model_type,
|
|
)
|
|
converter.convert_from_args(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|