128 lines
3.6 KiB
Python
128 lines
3.6 KiB
Python
import numpy as np
|
|
|
|
|
|
def fuse_linear(spec, layers):
|
|
if not layers:
|
|
raise ValueError("Cannot fuse linear layers: at least one layer is required")
|
|
|
|
if isinstance(layers[0].weight, np.ndarray):
|
|
concatenate = np.concatenate
|
|
zeros = np.zeros
|
|
else:
|
|
import torch
|
|
|
|
concatenate = torch.cat
|
|
zeros = torch.zeros
|
|
|
|
spec.weight = concatenate([layer.weight for layer in layers])
|
|
|
|
bias_dtype = None
|
|
for layer in layers:
|
|
if layer.has_bias():
|
|
bias_dtype = layer.bias.dtype
|
|
break
|
|
|
|
if bias_dtype is not None:
|
|
spec.bias = concatenate(
|
|
[
|
|
(
|
|
layer.bias
|
|
if layer.has_bias()
|
|
else zeros([layer.weight.shape[0]], dtype=bias_dtype)
|
|
)
|
|
for layer in layers
|
|
]
|
|
)
|
|
|
|
|
|
def fuse_linear_prequant(spec, layers, axis):
|
|
if not layers:
|
|
raise ValueError("Cannot fuse linear layers: at least one layer is required")
|
|
params = ["weight", "weight_scale", "weight_zero"]
|
|
if isinstance(layers[0].weight, np.ndarray):
|
|
concatenate = np.concatenate
|
|
else:
|
|
import torch
|
|
|
|
concatenate = torch.cat
|
|
|
|
for param in params:
|
|
setattr(
|
|
spec,
|
|
param,
|
|
concatenate([getattr(layer, param) for layer in layers], axis=axis),
|
|
)
|
|
|
|
|
|
def permute_for_sliced_rotary(weight, num_heads, rotary_dim=None):
|
|
"""Permutes the weight to use the sliced rotary implementation."""
|
|
if rotary_dim is not None:
|
|
weight = weight.reshape(num_heads, weight.shape[0] // num_heads, -1)
|
|
|
|
rotary_weight = weight[:, :rotary_dim]
|
|
rotary_weight = permute_for_sliced_rotary(
|
|
rotary_weight.reshape(num_heads * rotary_dim, -1), num_heads
|
|
).reshape(num_heads, rotary_dim, -1)
|
|
|
|
weight[:, :rotary_dim] = rotary_weight
|
|
|
|
return weight.reshape(-1, weight.shape[-1])
|
|
|
|
return (
|
|
weight.reshape(num_heads, weight.shape[0] // num_heads // 2, 2, weight.shape[1])
|
|
.swapaxes(1, 2)
|
|
.reshape(weight.shape[0], weight.shape[1])
|
|
)
|
|
|
|
|
|
def smooth_activation(layer_norm, linear, activation_scales):
|
|
"""Applies the activation smoothing technique described in
|
|
https://github.com/mit-han-lab/smoothquant.
|
|
"""
|
|
if not isinstance(linear.weight, np.ndarray):
|
|
linear_weight = linear.weight.numpy()
|
|
activation_scales = activation_scales.numpy()
|
|
else:
|
|
linear_weight = linear.weight
|
|
|
|
weight_scales = np.amax(np.absolute(linear_weight), axis=0)
|
|
weight_scales = np.maximum(weight_scales, 1e-5)
|
|
|
|
activation_scales = activation_scales.astype(weight_scales.dtype)
|
|
|
|
scales = np.sqrt(activation_scales / weight_scales)
|
|
scales = np.maximum(scales, 1e-5)
|
|
|
|
if not isinstance(linear.weight, np.ndarray):
|
|
import torch
|
|
|
|
scales = torch.from_numpy(scales)
|
|
|
|
layer_norm.gamma /= scales
|
|
layer_norm.beta /= scales
|
|
|
|
linear.weight *= scales.reshape(1, -1)
|
|
|
|
|
|
def raise_unsupported(reasons):
|
|
message = (
|
|
"The model you are trying to convert is not supported by CTranslate2. "
|
|
"We identified the following reasons:\n"
|
|
)
|
|
for reason in reasons:
|
|
message += "\n- " + reason
|
|
raise ValueError(message)
|
|
|
|
|
|
class ConfigurationChecker:
|
|
def __init__(self):
|
|
self._unsupported_reasons = []
|
|
|
|
def __call__(self, assert_condition, error_message):
|
|
if not assert_condition:
|
|
self._unsupported_reasons.append(error_message)
|
|
|
|
def validate(self):
|
|
if self._unsupported_reasons:
|
|
raise_unsupported(self._unsupported_reasons)
|