add read me

This commit is contained in:
2026-01-09 10:28:44 +11:00
commit edaf914b73
13417 changed files with 2952119 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,434 @@
# coding=utf-8
# Copyright 2023-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities used by both the sync and async inference clients."""
import base64
import io
import json
import logging
import mimetypes
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, AsyncIterable, BinaryIO, Iterable, Literal, NoReturn, Optional, Union, overload
import httpx
from huggingface_hub.errors import (
GenerationError,
HfHubHTTPError,
IncompleteGenerationError,
OverloadedError,
TextGenerationError,
UnknownError,
ValidationError,
)
from ..utils import get_session, is_numpy_available, is_pillow_available
from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput
if TYPE_CHECKING:
from PIL.Image import Image
# TYPES
UrlT = str
PathT = Union[str, Path]
ContentT = Union[bytes, BinaryIO, PathT, UrlT, "Image", bytearray, memoryview]
# Use to set an Accept: image/png header
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
logger = logging.getLogger(__name__)
@dataclass
class RequestParameters:
url: str
task: str
model: Optional[str]
json: Optional[Union[str, dict, list]]
data: Optional[bytes]
headers: dict[str, Any]
class MimeBytes(bytes):
"""
A bytes object with a mime type.
To be returned by `_prepare_payload_open_as_mime_bytes` in subclasses.
Example:
```python
>>> b = MimeBytes(b"hello", "text/plain")
>>> isinstance(b, bytes)
True
>>> b.mime_type
'text/plain'
```
"""
mime_type: Optional[str]
def __new__(cls, data: bytes, mime_type: Optional[str] = None):
obj = super().__new__(cls, data)
obj.mime_type = mime_type
if isinstance(data, MimeBytes) and mime_type is None:
obj.mime_type = data.mime_type
return obj
## IMPORT UTILS
def _import_numpy():
"""Make sure `numpy` is installed on the machine."""
if not is_numpy_available():
raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).")
import numpy
return numpy
def _import_pil_image():
"""Make sure `PIL` is installed on the machine."""
if not is_pillow_available():
raise ImportError(
"Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be"
" post-processed, use `client.post(...)` and get the raw response from the server."
)
from PIL import Image
return Image
## ENCODING / DECODING UTILS
@overload
def _open_as_mime_bytes(content: ContentT) -> MimeBytes: ... # means "if input is not None, output is not None"
@overload
def _open_as_mime_bytes(content: Literal[None]) -> Literal[None]: ... # means "if input is None, output is None"
def _open_as_mime_bytes(content: Optional[ContentT]) -> Optional[MimeBytes]:
"""Open `content` as a binary file, either from a URL, a local path, raw bytes, or a PIL Image.
Do nothing if `content` is None.
"""
# If content is None, yield None
if content is None:
return None
# If content is bytes, return it
if isinstance(content, bytes):
return MimeBytes(content)
# If content is raw binary data (bytearray, memoryview)
if isinstance(content, (bytearray, memoryview)):
return MimeBytes(bytes(content))
# If content is a binary file-like object
if hasattr(content, "read"): # duck-typing instead of isinstance(content, BinaryIO)
logger.debug("Reading content from BinaryIO")
data = content.read()
mime_type = mimetypes.guess_type(str(content.name))[0] if hasattr(content, "name") else None
if isinstance(data, str):
raise TypeError("Expected binary stream (bytes), but got text stream")
return MimeBytes(data, mime_type=mime_type)
# If content is a string => must be either a URL or a path
if isinstance(content, str):
if content.startswith("https://") or content.startswith("http://"):
logger.debug(f"Downloading content from {content}")
response = get_session().get(content)
mime_type = response.headers.get("Content-Type")
if mime_type is None:
mime_type = mimetypes.guess_type(content)[0]
return MimeBytes(response.content, mime_type=mime_type)
content = Path(content)
if not content.exists():
raise FileNotFoundError(
f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local"
" file. To pass raw content, please encode it as bytes first."
)
# If content is a Path => open it
if isinstance(content, Path):
logger.debug(f"Opening content from {content}")
return MimeBytes(content.read_bytes(), mime_type=mimetypes.guess_type(content)[0])
# If content is a PIL Image => convert to bytes
if is_pillow_available():
from PIL import Image
if isinstance(content, Image.Image):
logger.debug("Converting PIL Image to bytes")
buffer = io.BytesIO()
format = content.format or "PNG"
content.save(buffer, format=format)
return MimeBytes(buffer.getvalue(), mime_type=f"image/{format.lower()}")
# If nothing matched, raise error
raise TypeError(
f"Unsupported content type: {type(content)}. "
"Expected one of: bytes, bytearray, BinaryIO, memoryview, Path, str (URL or file path), or PIL.Image.Image."
)
def _b64_encode(content: ContentT) -> str:
"""Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL."""
raw_bytes = _open_as_mime_bytes(content)
return base64.b64encode(raw_bytes).decode()
def _as_url(content: ContentT, default_mime_type: str) -> str:
if isinstance(content, str) and content.startswith(("http://", "https://", "data:")):
return content
# Convert content to bytes
raw_bytes = _open_as_mime_bytes(content)
# Get MIME type
mime_type = raw_bytes.mime_type or default_mime_type
# Encode content to base64
encoded_data = base64.b64encode(raw_bytes).decode()
# Build data URL
return f"data:{mime_type};base64,{encoded_data}"
def _b64_to_image(encoded_image: str) -> "Image":
"""Parse a base64-encoded string into a PIL Image."""
Image = _import_pil_image()
return Image.open(io.BytesIO(base64.b64decode(encoded_image)))
def _bytes_to_list(content: bytes) -> list:
"""Parse bytes from a Response object into a Python list.
Expects the response body to be JSON-encoded data.
NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a
dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
"""
return json.loads(content.decode())
def _bytes_to_dict(content: bytes) -> dict:
"""Parse bytes from a Response object into a Python dictionary.
Expects the response body to be JSON-encoded data.
NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a
list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
"""
return json.loads(content.decode())
def _bytes_to_image(content: bytes) -> "Image":
"""Parse bytes from a Response object into a PIL Image.
Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead.
"""
Image = _import_pil_image()
return Image.open(io.BytesIO(content))
def _as_dict(response: Union[bytes, dict]) -> dict:
return json.loads(response) if isinstance(response, bytes) else response
## STREAMING UTILS
def _stream_text_generation_response(
output_lines: Iterable[str], details: bool
) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]:
"""Used in `InferenceClient.text_generation`."""
# Parse ServerSentEvents
for line in output_lines:
try:
output = _format_text_generation_stream_output(line, details)
except StopIteration:
break
if output is not None:
yield output
async def _async_stream_text_generation_response(
output_lines: AsyncIterable[str], details: bool
) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
"""Used in `AsyncInferenceClient.text_generation`."""
# Parse ServerSentEvents
async for line in output_lines:
try:
output = _format_text_generation_stream_output(line, details)
except StopIteration:
break
if output is not None:
yield output
def _format_text_generation_stream_output(
line: str, details: bool
) -> Optional[Union[str, TextGenerationStreamOutput]]:
if not line.startswith("data:"):
return None # empty line
if line.strip() == "data: [DONE]":
raise StopIteration("[DONE] signal received.")
# Decode payload
payload = line.lstrip("data:").rstrip("/n")
json_payload = json.loads(payload)
# Either an error as being returned
if json_payload.get("error") is not None:
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
# Or parse token payload
output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload)
return output.token.text if not details else output
def _stream_chat_completion_response(
lines: Iterable[str],
) -> Iterable[ChatCompletionStreamOutput]:
"""Used in `InferenceClient.chat_completion` if model is served with TGI."""
for line in lines:
try:
output = _format_chat_completion_stream_output(line)
except StopIteration:
break
if output is not None:
yield output
async def _async_stream_chat_completion_response(
lines: AsyncIterable[str],
) -> AsyncIterable[ChatCompletionStreamOutput]:
"""Used in `AsyncInferenceClient.chat_completion`."""
async for line in lines:
try:
output = _format_chat_completion_stream_output(line)
except StopIteration:
break
if output is not None:
yield output
def _format_chat_completion_stream_output(
line: str,
) -> Optional[ChatCompletionStreamOutput]:
if not line.startswith("data:"):
return None # empty line
if line.strip() == "data: [DONE]":
raise StopIteration("[DONE] signal received.")
# Decode payload
json_payload = json.loads(line.lstrip("data:").strip())
# Either an error as being returned
if json_payload.get("error") is not None:
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
# Or parse token payload
return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)
async def _async_yield_from(client: httpx.AsyncClient, response: httpx.Response) -> AsyncIterable[str]:
async for line in response.aiter_lines():
yield line.strip()
# "TGI servers" are servers running with the `text-generation-inference` backend.
# This backend is the go-to solution to run large language models at scale. However,
# for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference`
# solution is still in use.
#
# Both approaches have very similar APIs, but not exactly the same. What we do first in
# the `text_generation` method is to assume the model is served via TGI. If we realize
# it's not the case (i.e. we receive an HTTP 400 Bad Request), we fall back to the
# default API with a warning message. When that's the case, We remember the unsupported
# attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable.
#
# In addition, TGI servers have a built-in API route for chat-completion, which is not
# available on the default API. We use this route to provide a more consistent behavior
# when available.
#
# For more details, see https://github.com/huggingface/text-generation-inference and
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
_UNSUPPORTED_TEXT_GENERATION_KWARGS: dict[Optional[str], list[str]] = {}
def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: list[str]) -> None:
_UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs)
def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> list[str]:
return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, [])
# TEXT GENERATION ERRORS
# ----------------------
# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
# inference project (https://github.com/huggingface/text-generation-inference).
# ----------------------
def raise_text_generation_error(http_error: HfHubHTTPError) -> NoReturn:
"""
Try to parse text-generation-inference error message and raise HTTPError in any case.
Args:
error (`HTTPError`):
The HTTPError that have been raised.
"""
# Try to parse a Text Generation Inference error
if http_error.response is None:
raise http_error
try:
# Hacky way to retrieve payload in case of aiohttp error
payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
error = payload.get("error")
error_type = payload.get("error_type")
except Exception: # no payload
raise http_error
# If error_type => more information than `hf_raise_for_status`
if error_type is not None:
exception = _parse_text_generation_error(error, error_type)
raise exception from http_error
# Otherwise, fallback to default error
raise http_error
def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
if error_type == "generation":
return GenerationError(error) # type: ignore
if error_type == "incomplete_generation":
return IncompleteGenerationError(error) # type: ignore
if error_type == "overloaded":
return OverloadedError(error) # type: ignore
if error_type == "validation":
return ValidationError(error) # type: ignore
return UnknownError(error) # type: ignore

View File

@@ -0,0 +1,192 @@
# This file is auto-generated by `utils/generate_inference_types.py`.
# Do not modify it manually.
#
# ruff: noqa: F401
from .audio_classification import (
AudioClassificationInput,
AudioClassificationOutputElement,
AudioClassificationOutputTransform,
AudioClassificationParameters,
)
from .audio_to_audio import AudioToAudioInput, AudioToAudioOutputElement
from .automatic_speech_recognition import (
AutomaticSpeechRecognitionEarlyStoppingEnum,
AutomaticSpeechRecognitionGenerationParameters,
AutomaticSpeechRecognitionInput,
AutomaticSpeechRecognitionOutput,
AutomaticSpeechRecognitionOutputChunk,
AutomaticSpeechRecognitionParameters,
)
from .base import BaseInferenceType
from .chat_completion import (
ChatCompletionInput,
ChatCompletionInputFunctionDefinition,
ChatCompletionInputFunctionName,
ChatCompletionInputGrammarType,
ChatCompletionInputJSONSchema,
ChatCompletionInputMessage,
ChatCompletionInputMessageChunk,
ChatCompletionInputMessageChunkType,
ChatCompletionInputResponseFormatJSONObject,
ChatCompletionInputResponseFormatJSONSchema,
ChatCompletionInputResponseFormatText,
ChatCompletionInputStreamOptions,
ChatCompletionInputTool,
ChatCompletionInputToolCall,
ChatCompletionInputToolChoiceClass,
ChatCompletionInputToolChoiceEnum,
ChatCompletionInputURL,
ChatCompletionOutput,
ChatCompletionOutputComplete,
ChatCompletionOutputFunctionDefinition,
ChatCompletionOutputLogprob,
ChatCompletionOutputLogprobs,
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputTopLogprob,
ChatCompletionOutputUsage,
ChatCompletionStreamOutput,
ChatCompletionStreamOutputChoice,
ChatCompletionStreamOutputDelta,
ChatCompletionStreamOutputDeltaToolCall,
ChatCompletionStreamOutputFunction,
ChatCompletionStreamOutputLogprob,
ChatCompletionStreamOutputLogprobs,
ChatCompletionStreamOutputTopLogprob,
ChatCompletionStreamOutputUsage,
)
from .depth_estimation import DepthEstimationInput, DepthEstimationOutput
from .document_question_answering import (
DocumentQuestionAnsweringInput,
DocumentQuestionAnsweringInputData,
DocumentQuestionAnsweringOutputElement,
DocumentQuestionAnsweringParameters,
)
from .feature_extraction import FeatureExtractionInput, FeatureExtractionInputTruncationDirection
from .fill_mask import FillMaskInput, FillMaskOutputElement, FillMaskParameters
from .image_classification import (
ImageClassificationInput,
ImageClassificationOutputElement,
ImageClassificationOutputTransform,
ImageClassificationParameters,
)
from .image_segmentation import (
ImageSegmentationInput,
ImageSegmentationOutputElement,
ImageSegmentationParameters,
ImageSegmentationSubtask,
)
from .image_to_image import ImageToImageInput, ImageToImageOutput, ImageToImageParameters, ImageToImageTargetSize
from .image_to_text import (
ImageToTextEarlyStoppingEnum,
ImageToTextGenerationParameters,
ImageToTextInput,
ImageToTextOutput,
ImageToTextParameters,
)
from .image_to_video import ImageToVideoInput, ImageToVideoOutput, ImageToVideoParameters, ImageToVideoTargetSize
from .object_detection import (
ObjectDetectionBoundingBox,
ObjectDetectionInput,
ObjectDetectionOutputElement,
ObjectDetectionParameters,
)
from .question_answering import (
QuestionAnsweringInput,
QuestionAnsweringInputData,
QuestionAnsweringOutputElement,
QuestionAnsweringParameters,
)
from .sentence_similarity import SentenceSimilarityInput, SentenceSimilarityInputData
from .summarization import (
SummarizationInput,
SummarizationOutput,
SummarizationParameters,
SummarizationTruncationStrategy,
)
from .table_question_answering import (
Padding,
TableQuestionAnsweringInput,
TableQuestionAnsweringInputData,
TableQuestionAnsweringOutputElement,
TableQuestionAnsweringParameters,
)
from .text2text_generation import (
Text2TextGenerationInput,
Text2TextGenerationOutput,
Text2TextGenerationParameters,
Text2TextGenerationTruncationStrategy,
)
from .text_classification import (
TextClassificationInput,
TextClassificationOutputElement,
TextClassificationOutputTransform,
TextClassificationParameters,
)
from .text_generation import (
TextGenerationInput,
TextGenerationInputGenerateParameters,
TextGenerationInputGrammarType,
TextGenerationOutput,
TextGenerationOutputBestOfSequence,
TextGenerationOutputDetails,
TextGenerationOutputFinishReason,
TextGenerationOutputPrefillToken,
TextGenerationOutputToken,
TextGenerationStreamOutput,
TextGenerationStreamOutputStreamDetails,
TextGenerationStreamOutputToken,
TypeEnum,
)
from .text_to_audio import (
TextToAudioEarlyStoppingEnum,
TextToAudioGenerationParameters,
TextToAudioInput,
TextToAudioOutput,
TextToAudioParameters,
)
from .text_to_image import TextToImageInput, TextToImageOutput, TextToImageParameters
from .text_to_speech import (
TextToSpeechEarlyStoppingEnum,
TextToSpeechGenerationParameters,
TextToSpeechInput,
TextToSpeechOutput,
TextToSpeechParameters,
)
from .text_to_video import TextToVideoInput, TextToVideoOutput, TextToVideoParameters
from .token_classification import (
TokenClassificationAggregationStrategy,
TokenClassificationInput,
TokenClassificationOutputElement,
TokenClassificationParameters,
)
from .translation import TranslationInput, TranslationOutput, TranslationParameters, TranslationTruncationStrategy
from .video_classification import (
VideoClassificationInput,
VideoClassificationOutputElement,
VideoClassificationOutputTransform,
VideoClassificationParameters,
)
from .visual_question_answering import (
VisualQuestionAnsweringInput,
VisualQuestionAnsweringInputData,
VisualQuestionAnsweringOutputElement,
VisualQuestionAnsweringParameters,
)
from .zero_shot_classification import (
ZeroShotClassificationInput,
ZeroShotClassificationOutputElement,
ZeroShotClassificationParameters,
)
from .zero_shot_image_classification import (
ZeroShotImageClassificationInput,
ZeroShotImageClassificationOutputElement,
ZeroShotImageClassificationParameters,
)
from .zero_shot_object_detection import (
ZeroShotObjectDetectionBoundingBox,
ZeroShotObjectDetectionInput,
ZeroShotObjectDetectionOutputElement,
ZeroShotObjectDetectionParameters,
)

View File

@@ -0,0 +1,43 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
AudioClassificationOutputTransform = Literal["sigmoid", "softmax", "none"]
@dataclass_with_extra
class AudioClassificationParameters(BaseInferenceType):
"""Additional inference parameters for Audio Classification"""
function_to_apply: Optional["AudioClassificationOutputTransform"] = None
"""The function to apply to the model outputs in order to retrieve the scores."""
top_k: Optional[int] = None
"""When specified, limits the output to the top K most probable classes."""
@dataclass_with_extra
class AudioClassificationInput(BaseInferenceType):
"""Inputs for Audio Classification inference"""
inputs: str
"""The input audio data as a base64-encoded string. If no `parameters` are provided, you can
also provide the audio data as a raw bytes payload.
"""
parameters: Optional[AudioClassificationParameters] = None
"""Additional inference parameters for Audio Classification"""
@dataclass_with_extra
class AudioClassificationOutputElement(BaseInferenceType):
"""Outputs for Audio Classification inference"""
label: str
"""The predicted class label."""
score: float
"""The corresponding probability."""

View File

@@ -0,0 +1,30 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class AudioToAudioInput(BaseInferenceType):
"""Inputs for Audio to Audio inference"""
inputs: Any
"""The input audio data"""
@dataclass_with_extra
class AudioToAudioOutputElement(BaseInferenceType):
"""Outputs of inference for the Audio To Audio task
A generated audio file with its label.
"""
blob: Any
"""The generated audio file."""
content_type: str
"""The content type of audio file."""
label: str
"""The label of the audio file."""

View File

@@ -0,0 +1,113 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Literal, Optional, Union
from .base import BaseInferenceType, dataclass_with_extra
AutomaticSpeechRecognitionEarlyStoppingEnum = Literal["never"]
@dataclass_with_extra
class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType):
"""Parametrization of the text generation process"""
do_sample: Optional[bool] = None
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
early_stopping: Optional[Union[bool, "AutomaticSpeechRecognitionEarlyStoppingEnum"]] = None
"""Controls the stopping condition for beam-based methods."""
epsilon_cutoff: Optional[float] = None
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
"""
eta_cutoff: Optional[float] = None
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
float strictly between 0 and 1, a token is only considered if it is greater than either
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
for more details.
"""
max_length: Optional[int] = None
"""The maximum length (in tokens) of the generated text, including the input."""
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate. Takes precedence over max_length."""
min_length: Optional[int] = None
"""The minimum length (in tokens) of the generated text, including the input."""
min_new_tokens: Optional[int] = None
"""The minimum number of tokens to generate. Takes precedence over min_length."""
num_beam_groups: Optional[int] = None
"""Number of groups to divide num_beams into in order to ensure diversity among different
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
"""
num_beams: Optional[int] = None
"""Number of beams to use for beam search."""
penalty_alpha: Optional[float] = None
"""The value balances the model confidence and the degeneration penalty in contrastive
search decoding.
"""
temperature: Optional[float] = None
"""The value used to modulate the next token probabilities."""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
top_p: Optional[float] = None
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
that add up to top_p or higher are kept for generation.
"""
typical_p: Optional[float] = None
"""Local typicality measures how similar the conditional probability of predicting a target
token next is to the expected conditional probability of predicting a random token next,
given the partial text already generated. If set to float < 1, the smallest set of the
most locally typical tokens with probabilities that add up to typical_p or higher are
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
"""
use_cache: Optional[bool] = None
"""Whether the model should use the past last key/values attentions to speed up decoding"""
@dataclass_with_extra
class AutomaticSpeechRecognitionParameters(BaseInferenceType):
"""Additional inference parameters for Automatic Speech Recognition"""
generation_parameters: Optional[AutomaticSpeechRecognitionGenerationParameters] = None
"""Parametrization of the text generation process"""
return_timestamps: Optional[bool] = None
"""Whether to output corresponding timestamps with the generated text"""
@dataclass_with_extra
class AutomaticSpeechRecognitionInput(BaseInferenceType):
"""Inputs for Automatic Speech Recognition inference"""
inputs: str
"""The input audio data as a base64-encoded string. If no `parameters` are provided, you can
also provide the audio data as a raw bytes payload.
"""
parameters: Optional[AutomaticSpeechRecognitionParameters] = None
"""Additional inference parameters for Automatic Speech Recognition"""
@dataclass_with_extra
class AutomaticSpeechRecognitionOutputChunk(BaseInferenceType):
text: str
"""A chunk of text identified by the model"""
timestamp: list[float]
"""The start and end timestamps corresponding with the text"""
@dataclass_with_extra
class AutomaticSpeechRecognitionOutput(BaseInferenceType):
"""Outputs of inference for the Automatic Speech Recognition task"""
text: str
"""The recognized text."""
chunks: Optional[list[AutomaticSpeechRecognitionOutputChunk]] = None
"""When returnTimestamps is enabled, chunks contains a list of audio chunks identified by
the model.
"""

View File

@@ -0,0 +1,164 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a base class for all inference types."""
import inspect
import json
import types
from dataclasses import asdict, dataclass
from typing import Any, TypeVar, Union, get_args
T = TypeVar("T", bound="BaseInferenceType")
def _repr_with_extra(self):
fields = list(self.__dataclass_fields__.keys())
other_fields = list(k for k in self.__dict__ if k not in fields)
return f"{self.__class__.__name__}({', '.join(f'{k}={self.__dict__[k]!r}' for k in fields + other_fields)})"
def dataclass_with_extra(cls: type[T]) -> type[T]:
"""Decorator to add a custom __repr__ method to a dataclass, showing all fields, including extra ones.
This decorator only works with dataclasses that inherit from `BaseInferenceType`.
"""
cls = dataclass(cls)
cls.__repr__ = _repr_with_extra # type: ignore[method-assign]
return cls
@dataclass
class BaseInferenceType(dict):
"""Base class for all inference types.
Object is a dataclass and a dict for backward compatibility but plan is to remove the dict part in the future.
Handle parsing from dict, list and json strings in a permissive way to ensure future-compatibility (e.g. all fields
are made optional, and non-expected fields are added as dict attributes).
"""
@classmethod
def parse_obj_as_list(cls: type[T], data: Union[bytes, str, list, dict]) -> list[T]:
"""Alias to parse server response and return a single instance.
See `parse_obj` for more details.
"""
output = cls.parse_obj(data)
if not isinstance(output, list):
raise ValueError(f"Invalid input data for {cls}. Expected a list, but got {type(output)}.")
return output
@classmethod
def parse_obj_as_instance(cls: type[T], data: Union[bytes, str, list, dict]) -> T:
"""Alias to parse server response and return a single instance.
See `parse_obj` for more details.
"""
output = cls.parse_obj(data)
if isinstance(output, list):
raise ValueError(f"Invalid input data for {cls}. Expected a single instance, but got a list.")
return output
@classmethod
def parse_obj(cls: type[T], data: Union[bytes, str, list, dict]) -> Union[list[T], T]:
"""Parse server response as a dataclass or list of dataclasses.
To enable future-compatibility, we want to handle cases where the server return more fields than expected.
In such cases, we don't want to raise an error but still create the dataclass object. Remaining fields are
added as dict attributes.
"""
# Parse server response (from bytes)
if isinstance(data, bytes):
data = data.decode()
if isinstance(data, str):
data = json.loads(data)
# If a list, parse each item individually
if isinstance(data, list):
return [cls.parse_obj(d) for d in data] # type: ignore [misc]
# At this point, we expect a dict
if not isinstance(data, dict):
raise ValueError(f"Invalid data type: {type(data)}")
init_values = {}
other_values = {}
for key, value in data.items():
key = normalize_key(key)
if key in cls.__dataclass_fields__ and cls.__dataclass_fields__[key].init:
if isinstance(value, dict) or isinstance(value, list):
field_type = cls.__dataclass_fields__[key].type
# if `field_type` is a `BaseInferenceType`, parse it
if inspect.isclass(field_type) and issubclass(field_type, BaseInferenceType):
value = field_type.parse_obj(value)
# otherwise, recursively parse nested dataclasses (if possible)
# `get_args` returns handle Union and Optional for us
else:
expected_types = get_args(field_type)
for expected_type in expected_types:
if (
isinstance(expected_type, types.GenericAlias) and expected_type.__origin__ is list
) or getattr(expected_type, "_name", None) == "List":
expected_type = get_args(expected_type)[
0
] # assume same type for all items in the list
if inspect.isclass(expected_type) and issubclass(expected_type, BaseInferenceType):
value = expected_type.parse_obj(value)
break
init_values[key] = value
else:
other_values[key] = value
# Make all missing fields default to None
# => ensure that dataclass initialization will never fail even if the server does not return all fields.
for key in cls.__dataclass_fields__:
if key not in init_values:
init_values[key] = None
# Initialize dataclass with expected values
item = cls(**init_values)
# Add remaining fields as dict attributes
item.update(other_values)
# Add remaining fields as extra dataclass fields.
# They won't be part of the dataclass fields but will be accessible as attributes.
# Use @dataclass_with_extra to show them in __repr__.
item.__dict__.update(other_values)
return item
def __post_init__(self):
self.update(asdict(self))
def __setitem__(self, __key: Any, __value: Any) -> None:
# Hacky way to keep dataclass values in sync when dict is updated
super().__setitem__(__key, __value)
if __key in self.__dataclass_fields__ and getattr(self, __key, None) != __value:
self.__setattr__(__key, __value)
return
def __setattr__(self, __name: str, __value: Any) -> None:
# Hacky way to keep dict values is sync when dataclass is updated
super().__setattr__(__name, __value)
if self.get(__name) != __value:
self[__name] = __value
return
def normalize_key(key: str) -> str:
# e.g "content-type" -> "content_type", "Accept" -> "accept"
return key.replace("-", "_").replace(" ", "_").lower()

View File

@@ -0,0 +1,347 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Literal, Optional, Union
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class ChatCompletionInputURL(BaseInferenceType):
url: str
ChatCompletionInputMessageChunkType = Literal["text", "image_url"]
@dataclass_with_extra
class ChatCompletionInputMessageChunk(BaseInferenceType):
type: "ChatCompletionInputMessageChunkType"
image_url: Optional[ChatCompletionInputURL] = None
text: Optional[str] = None
@dataclass_with_extra
class ChatCompletionInputFunctionDefinition(BaseInferenceType):
name: str
parameters: Any
description: Optional[str] = None
@dataclass_with_extra
class ChatCompletionInputToolCall(BaseInferenceType):
function: ChatCompletionInputFunctionDefinition
id: str
type: str
@dataclass_with_extra
class ChatCompletionInputMessage(BaseInferenceType):
role: str
content: Optional[Union[list[ChatCompletionInputMessageChunk], str]] = None
name: Optional[str] = None
tool_calls: Optional[list[ChatCompletionInputToolCall]] = None
@dataclass_with_extra
class ChatCompletionInputJSONSchema(BaseInferenceType):
name: str
"""
The name of the response format.
"""
description: Optional[str] = None
"""
A description of what the response format is for, used by the model to determine
how to respond in the format.
"""
schema: Optional[dict[str, object]] = None
"""
The schema for the response format, described as a JSON Schema object. Learn how
to build JSON schemas [here](https://json-schema.org/).
"""
strict: Optional[bool] = None
"""
Whether to enable strict schema adherence when generating the output. If set to
true, the model will always follow the exact schema defined in the `schema`
field.
"""
@dataclass_with_extra
class ChatCompletionInputResponseFormatText(BaseInferenceType):
type: Literal["text"]
@dataclass_with_extra
class ChatCompletionInputResponseFormatJSONSchema(BaseInferenceType):
type: Literal["json_schema"]
json_schema: ChatCompletionInputJSONSchema
@dataclass_with_extra
class ChatCompletionInputResponseFormatJSONObject(BaseInferenceType):
type: Literal["json_object"]
ChatCompletionInputGrammarType = Union[
ChatCompletionInputResponseFormatText,
ChatCompletionInputResponseFormatJSONSchema,
ChatCompletionInputResponseFormatJSONObject,
]
@dataclass_with_extra
class ChatCompletionInputStreamOptions(BaseInferenceType):
include_usage: Optional[bool] = None
"""If set, an additional chunk will be streamed before the data: [DONE] message. The usage
field on this chunk shows the token usage statistics for the entire request, and the
choices field will always be an empty array. All other chunks will also include a usage
field, but with a null value.
"""
@dataclass_with_extra
class ChatCompletionInputFunctionName(BaseInferenceType):
name: str
@dataclass_with_extra
class ChatCompletionInputToolChoiceClass(BaseInferenceType):
function: ChatCompletionInputFunctionName
ChatCompletionInputToolChoiceEnum = Literal["auto", "none", "required"]
@dataclass_with_extra
class ChatCompletionInputTool(BaseInferenceType):
function: ChatCompletionInputFunctionDefinition
type: str
@dataclass_with_extra
class ChatCompletionInput(BaseInferenceType):
"""Chat Completion Input.
Auto-generated from TGI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
"""
messages: list[ChatCompletionInputMessage]
"""A list of messages comprising the conversation so far."""
frequency_penalty: Optional[float] = None
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing
frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
"""
logit_bias: Optional[list[float]] = None
"""UNUSED
Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON
object that maps tokens
(specified by their token ID in the tokenizer) to an associated bias value from -100 to
100. Mathematically,
the bias is added to the logits generated by the model prior to sampling. The exact
effect will vary per model,
but values between -1 and 1 should decrease or increase likelihood of selection; values
like -100 or 100 should
result in a ban or exclusive selection of the relevant token.
"""
logprobs: Optional[bool] = None
"""Whether to return log probabilities of the output tokens or not. If true, returns the log
probabilities of each
output token returned in the content of message.
"""
max_tokens: Optional[int] = None
"""The maximum number of tokens that can be generated in the chat completion."""
model: Optional[str] = None
"""[UNUSED] ID of the model to use. See the model endpoint compatibility table for details
on which models work with the Chat API.
"""
n: Optional[int] = None
"""UNUSED
How many chat completion choices to generate for each input message. Note that you will
be charged based on the
number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
"""
presence_penalty: Optional[float] = None
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they
appear in the text so far,
increasing the model's likelihood to talk about new topics
"""
response_format: Optional[ChatCompletionInputGrammarType] = None
seed: Optional[int] = None
stop: Optional[list[str]] = None
"""Up to 4 sequences where the API will stop generating further tokens."""
stream: Optional[bool] = None
stream_options: Optional[ChatCompletionInputStreamOptions] = None
temperature: Optional[float] = None
"""What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the
output more random, while
lower values like 0.2 will make it more focused and deterministic.
We generally recommend altering this or `top_p` but not both.
"""
tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None
tool_prompt: Optional[str] = None
"""A prompt to be appended before the tools"""
tools: Optional[list[ChatCompletionInputTool]] = None
"""A list of tools the model may call. Currently, only functions are supported as a tool.
Use this to provide a list of
functions the model may generate JSON inputs for.
"""
top_logprobs: Optional[int] = None
"""An integer between 0 and 5 specifying the number of most likely tokens to return at each
token position, each with
an associated log probability. logprobs must be set to true if this parameter is used.
"""
top_p: Optional[float] = None
"""An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the
tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
probability mass are considered.
"""
@dataclass_with_extra
class ChatCompletionOutputTopLogprob(BaseInferenceType):
logprob: float
token: str
@dataclass_with_extra
class ChatCompletionOutputLogprob(BaseInferenceType):
logprob: float
token: str
top_logprobs: list[ChatCompletionOutputTopLogprob]
@dataclass_with_extra
class ChatCompletionOutputLogprobs(BaseInferenceType):
content: list[ChatCompletionOutputLogprob]
@dataclass_with_extra
class ChatCompletionOutputFunctionDefinition(BaseInferenceType):
arguments: str
name: str
description: Optional[str] = None
@dataclass_with_extra
class ChatCompletionOutputToolCall(BaseInferenceType):
function: ChatCompletionOutputFunctionDefinition
id: str
type: str
@dataclass_with_extra
class ChatCompletionOutputMessage(BaseInferenceType):
role: str
content: Optional[str] = None
reasoning: Optional[str] = None
tool_call_id: Optional[str] = None
tool_calls: Optional[list[ChatCompletionOutputToolCall]] = None
@dataclass_with_extra
class ChatCompletionOutputComplete(BaseInferenceType):
finish_reason: str
index: int
message: ChatCompletionOutputMessage
logprobs: Optional[ChatCompletionOutputLogprobs] = None
@dataclass_with_extra
class ChatCompletionOutputUsage(BaseInferenceType):
completion_tokens: int
prompt_tokens: int
total_tokens: int
@dataclass_with_extra
class ChatCompletionOutput(BaseInferenceType):
"""Chat Completion Output.
Auto-generated from TGI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
"""
choices: list[ChatCompletionOutputComplete]
created: int
id: str
model: str
system_fingerprint: str
usage: ChatCompletionOutputUsage
@dataclass_with_extra
class ChatCompletionStreamOutputFunction(BaseInferenceType):
arguments: str
name: Optional[str] = None
@dataclass_with_extra
class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType):
function: ChatCompletionStreamOutputFunction
id: str
index: int
type: str
@dataclass_with_extra
class ChatCompletionStreamOutputDelta(BaseInferenceType):
role: str
content: Optional[str] = None
reasoning: Optional[str] = None
tool_call_id: Optional[str] = None
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None
@dataclass_with_extra
class ChatCompletionStreamOutputTopLogprob(BaseInferenceType):
logprob: float
token: str
@dataclass_with_extra
class ChatCompletionStreamOutputLogprob(BaseInferenceType):
logprob: float
token: str
top_logprobs: list[ChatCompletionStreamOutputTopLogprob]
@dataclass_with_extra
class ChatCompletionStreamOutputLogprobs(BaseInferenceType):
content: list[ChatCompletionStreamOutputLogprob]
@dataclass_with_extra
class ChatCompletionStreamOutputChoice(BaseInferenceType):
delta: ChatCompletionStreamOutputDelta
index: int
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionStreamOutputLogprobs] = None
@dataclass_with_extra
class ChatCompletionStreamOutputUsage(BaseInferenceType):
completion_tokens: int
prompt_tokens: int
total_tokens: int
@dataclass_with_extra
class ChatCompletionStreamOutput(BaseInferenceType):
"""Chat Completion Stream Output.
Auto-generated from TGI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
"""
choices: list[ChatCompletionStreamOutputChoice]
created: int
id: str
model: str
system_fingerprint: str
usage: Optional[ChatCompletionStreamOutputUsage] = None

View File

@@ -0,0 +1,28 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class DepthEstimationInput(BaseInferenceType):
"""Inputs for Depth Estimation inference"""
inputs: Any
"""The input image data"""
parameters: Optional[dict[str, Any]] = None
"""Additional inference parameters for Depth Estimation"""
@dataclass_with_extra
class DepthEstimationOutput(BaseInferenceType):
"""Outputs of inference for the Depth Estimation task"""
depth: Any
"""The predicted depth as an image"""
predicted_depth: Any
"""The predicted depth as a tensor"""

View File

@@ -0,0 +1,80 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Optional, Union
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class DocumentQuestionAnsweringInputData(BaseInferenceType):
"""One (document, question) pair to answer"""
image: Any
"""The image on which the question is asked"""
question: str
"""A question to ask of the document"""
@dataclass_with_extra
class DocumentQuestionAnsweringParameters(BaseInferenceType):
"""Additional inference parameters for Document Question Answering"""
doc_stride: Optional[int] = None
"""If the words in the document are too long to fit with the question for the model, it will
be split in several chunks with some overlap. This argument controls the size of that
overlap.
"""
handle_impossible_answer: Optional[bool] = None
"""Whether to accept impossible as an answer"""
lang: Optional[str] = None
"""Language to use while running OCR. Defaults to english."""
max_answer_len: Optional[int] = None
"""The maximum length of predicted answers (e.g., only answers with a shorter length are
considered).
"""
max_question_len: Optional[int] = None
"""The maximum length of the question after tokenization. It will be truncated if needed."""
max_seq_len: Optional[int] = None
"""The maximum length of the total sentence (context + question) in tokens of each chunk
passed to the model. The context will be split in several chunks (using doc_stride as
overlap) if needed.
"""
top_k: Optional[int] = None
"""The number of answers to return (will be chosen by order of likelihood). Can return less
than top_k answers if there are not enough options available within the context.
"""
word_boxes: Optional[list[Union[list[float], str]]] = None
"""A list of words and bounding boxes (normalized 0->1000). If provided, the inference will
skip the OCR step and use the provided bounding boxes instead.
"""
@dataclass_with_extra
class DocumentQuestionAnsweringInput(BaseInferenceType):
"""Inputs for Document Question Answering inference"""
inputs: DocumentQuestionAnsweringInputData
"""One (document, question) pair to answer"""
parameters: Optional[DocumentQuestionAnsweringParameters] = None
"""Additional inference parameters for Document Question Answering"""
@dataclass_with_extra
class DocumentQuestionAnsweringOutputElement(BaseInferenceType):
"""Outputs of inference for the Document Question Answering task"""
answer: str
"""The answer to the question."""
end: int
"""The end word index of the answer (in the OCRd version of the input or provided word
boxes).
"""
score: float
"""The probability associated to the answer."""
start: int
"""The start word index of the answer (in the OCRd version of the input or provided word
boxes).
"""

View File

@@ -0,0 +1,36 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Literal, Optional, Union
from .base import BaseInferenceType, dataclass_with_extra
FeatureExtractionInputTruncationDirection = Literal["left", "right"]
@dataclass_with_extra
class FeatureExtractionInput(BaseInferenceType):
"""Feature Extraction Input.
Auto-generated from TEI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts.
"""
inputs: Union[list[str], str]
"""The text or list of texts to embed."""
normalize: Optional[bool] = None
prompt_name: Optional[str] = None
"""The name of the prompt that should be used by for encoding. If not set, no prompt
will be applied.
Must be a key in the `sentence-transformers` configuration `prompts` dictionary.
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",
...},
then the sentence "What is the capital of France?" will be encoded as
"query: What is the capital of France?" because the prompt text will be prepended before
any text to encode.
"""
truncate: Optional[bool] = None
truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None

View File

@@ -0,0 +1,47 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class FillMaskParameters(BaseInferenceType):
"""Additional inference parameters for Fill Mask"""
targets: Optional[list[str]] = None
"""When passed, the model will limit the scores to the passed targets instead of looking up
in the whole vocabulary. If the provided targets are not in the model vocab, they will be
tokenized and the first resulting token will be used (with a warning, and that might be
slower).
"""
top_k: Optional[int] = None
"""When passed, overrides the number of predictions to return."""
@dataclass_with_extra
class FillMaskInput(BaseInferenceType):
"""Inputs for Fill Mask inference"""
inputs: str
"""The text with masked tokens"""
parameters: Optional[FillMaskParameters] = None
"""Additional inference parameters for Fill Mask"""
@dataclass_with_extra
class FillMaskOutputElement(BaseInferenceType):
"""Outputs of inference for the Fill Mask task"""
score: float
"""The corresponding probability"""
sequence: str
"""The corresponding input with the mask token prediction."""
token: int
"""The predicted token id (to replace the masked one)."""
token_str: Any
fill_mask_output_token_str: Optional[str] = None
"""The predicted token (to replace the masked one)."""

View File

@@ -0,0 +1,43 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
ImageClassificationOutputTransform = Literal["sigmoid", "softmax", "none"]
@dataclass_with_extra
class ImageClassificationParameters(BaseInferenceType):
"""Additional inference parameters for Image Classification"""
function_to_apply: Optional["ImageClassificationOutputTransform"] = None
"""The function to apply to the model outputs in order to retrieve the scores."""
top_k: Optional[int] = None
"""When specified, limits the output to the top K most probable classes."""
@dataclass_with_extra
class ImageClassificationInput(BaseInferenceType):
"""Inputs for Image Classification inference"""
inputs: str
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
also provide the image data as a raw bytes payload.
"""
parameters: Optional[ImageClassificationParameters] = None
"""Additional inference parameters for Image Classification"""
@dataclass_with_extra
class ImageClassificationOutputElement(BaseInferenceType):
"""Outputs of inference for the Image Classification task"""
label: str
"""The predicted class label."""
score: float
"""The corresponding probability."""

View File

@@ -0,0 +1,51 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
ImageSegmentationSubtask = Literal["instance", "panoptic", "semantic"]
@dataclass_with_extra
class ImageSegmentationParameters(BaseInferenceType):
"""Additional inference parameters for Image Segmentation"""
mask_threshold: Optional[float] = None
"""Threshold to use when turning the predicted masks into binary values."""
overlap_mask_area_threshold: Optional[float] = None
"""Mask overlap threshold to eliminate small, disconnected segments."""
subtask: Optional["ImageSegmentationSubtask"] = None
"""Segmentation task to be performed, depending on model capabilities."""
threshold: Optional[float] = None
"""Probability threshold to filter out predicted masks."""
@dataclass_with_extra
class ImageSegmentationInput(BaseInferenceType):
"""Inputs for Image Segmentation inference"""
inputs: str
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
also provide the image data as a raw bytes payload.
"""
parameters: Optional[ImageSegmentationParameters] = None
"""Additional inference parameters for Image Segmentation"""
@dataclass_with_extra
class ImageSegmentationOutputElement(BaseInferenceType):
"""Outputs of inference for the Image Segmentation task
A predicted mask / segment
"""
label: str
"""The label of the predicted segment."""
mask: str
"""The corresponding mask as a black-and-white image (base64-encoded)."""
score: Optional[float] = None
"""The score or confidence degree the model has."""

View File

@@ -0,0 +1,60 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class ImageToImageTargetSize(BaseInferenceType):
"""The size in pixels of the output image. This parameter is only supported by some
providers and for specific models. It will be ignored when unsupported.
"""
height: int
width: int
@dataclass_with_extra
class ImageToImageParameters(BaseInferenceType):
"""Additional inference parameters for Image To Image"""
guidance_scale: Optional[float] = None
"""For diffusion models. A higher guidance scale value encourages the model to generate
images closely linked to the text prompt at the expense of lower image quality.
"""
negative_prompt: Optional[str] = None
"""One prompt to guide what NOT to include in image generation."""
num_inference_steps: Optional[int] = None
"""For diffusion models. The number of denoising steps. More denoising steps usually lead to
a higher quality image at the expense of slower inference.
"""
prompt: Optional[str] = None
"""The text prompt to guide the image generation."""
target_size: Optional[ImageToImageTargetSize] = None
"""The size in pixels of the output image. This parameter is only supported by some
providers and for specific models. It will be ignored when unsupported.
"""
@dataclass_with_extra
class ImageToImageInput(BaseInferenceType):
"""Inputs for Image To Image inference"""
inputs: str
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
also provide the image data as a raw bytes payload.
"""
parameters: Optional[ImageToImageParameters] = None
"""Additional inference parameters for Image To Image"""
@dataclass_with_extra
class ImageToImageOutput(BaseInferenceType):
"""Outputs of inference for the Image To Image task"""
image: Any
"""The output image returned as raw bytes in the payload."""

View File

@@ -0,0 +1,100 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Literal, Optional, Union
from .base import BaseInferenceType, dataclass_with_extra
ImageToTextEarlyStoppingEnum = Literal["never"]
@dataclass_with_extra
class ImageToTextGenerationParameters(BaseInferenceType):
"""Parametrization of the text generation process"""
do_sample: Optional[bool] = None
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
early_stopping: Optional[Union[bool, "ImageToTextEarlyStoppingEnum"]] = None
"""Controls the stopping condition for beam-based methods."""
epsilon_cutoff: Optional[float] = None
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
"""
eta_cutoff: Optional[float] = None
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
float strictly between 0 and 1, a token is only considered if it is greater than either
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
for more details.
"""
max_length: Optional[int] = None
"""The maximum length (in tokens) of the generated text, including the input."""
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate. Takes precedence over max_length."""
min_length: Optional[int] = None
"""The minimum length (in tokens) of the generated text, including the input."""
min_new_tokens: Optional[int] = None
"""The minimum number of tokens to generate. Takes precedence over min_length."""
num_beam_groups: Optional[int] = None
"""Number of groups to divide num_beams into in order to ensure diversity among different
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
"""
num_beams: Optional[int] = None
"""Number of beams to use for beam search."""
penalty_alpha: Optional[float] = None
"""The value balances the model confidence and the degeneration penalty in contrastive
search decoding.
"""
temperature: Optional[float] = None
"""The value used to modulate the next token probabilities."""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
top_p: Optional[float] = None
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
that add up to top_p or higher are kept for generation.
"""
typical_p: Optional[float] = None
"""Local typicality measures how similar the conditional probability of predicting a target
token next is to the expected conditional probability of predicting a random token next,
given the partial text already generated. If set to float < 1, the smallest set of the
most locally typical tokens with probabilities that add up to typical_p or higher are
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
"""
use_cache: Optional[bool] = None
"""Whether the model should use the past last key/values attentions to speed up decoding"""
@dataclass_with_extra
class ImageToTextParameters(BaseInferenceType):
"""Additional inference parameters for Image To Text"""
generation_parameters: Optional[ImageToTextGenerationParameters] = None
"""Parametrization of the text generation process"""
max_new_tokens: Optional[int] = None
"""The amount of maximum tokens to generate."""
@dataclass_with_extra
class ImageToTextInput(BaseInferenceType):
"""Inputs for Image To Text inference"""
inputs: Any
"""The input image data"""
parameters: Optional[ImageToTextParameters] = None
"""Additional inference parameters for Image To Text"""
@dataclass_with_extra
class ImageToTextOutput(BaseInferenceType):
"""Outputs of inference for the Image To Text task"""
generated_text: Any
image_to_text_output_generated_text: Optional[str] = None
"""The generated text."""

View File

@@ -0,0 +1,60 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class ImageToVideoTargetSize(BaseInferenceType):
"""The size in pixel of the output video frames."""
height: int
width: int
@dataclass_with_extra
class ImageToVideoParameters(BaseInferenceType):
"""Additional inference parameters for Image To Video"""
guidance_scale: Optional[float] = None
"""For diffusion models. A higher guidance scale value encourages the model to generate
videos closely linked to the text prompt at the expense of lower image quality.
"""
negative_prompt: Optional[str] = None
"""One prompt to guide what NOT to include in video generation."""
num_frames: Optional[float] = None
"""The num_frames parameter determines how many video frames are generated."""
num_inference_steps: Optional[int] = None
"""The number of denoising steps. More denoising steps usually lead to a higher quality
video at the expense of slower inference.
"""
prompt: Optional[str] = None
"""The text prompt to guide the video generation."""
seed: Optional[int] = None
"""Seed for the random number generator."""
target_size: Optional[ImageToVideoTargetSize] = None
"""The size in pixel of the output video frames."""
@dataclass_with_extra
class ImageToVideoInput(BaseInferenceType):
"""Inputs for Image To Video inference"""
inputs: str
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
also provide the image data as a raw bytes payload.
"""
parameters: Optional[ImageToVideoParameters] = None
"""Additional inference parameters for Image To Video"""
@dataclass_with_extra
class ImageToVideoOutput(BaseInferenceType):
"""Outputs of inference for the Image To Video task"""
video: Any
"""The generated video returned as raw bytes in the payload."""

View File

@@ -0,0 +1,58 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class ObjectDetectionParameters(BaseInferenceType):
"""Additional inference parameters for Object Detection"""
threshold: Optional[float] = None
"""The probability necessary to make a prediction."""
@dataclass_with_extra
class ObjectDetectionInput(BaseInferenceType):
"""Inputs for Object Detection inference"""
inputs: str
"""The input image data as a base64-encoded string. If no `parameters` are provided, you can
also provide the image data as a raw bytes payload.
"""
parameters: Optional[ObjectDetectionParameters] = None
"""Additional inference parameters for Object Detection"""
@dataclass_with_extra
class ObjectDetectionBoundingBox(BaseInferenceType):
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
image.
"""
xmax: int
"""The x-coordinate of the bottom-right corner of the bounding box."""
xmin: int
"""The x-coordinate of the top-left corner of the bounding box."""
ymax: int
"""The y-coordinate of the bottom-right corner of the bounding box."""
ymin: int
"""The y-coordinate of the top-left corner of the bounding box."""
@dataclass_with_extra
class ObjectDetectionOutputElement(BaseInferenceType):
"""Outputs of inference for the Object Detection task"""
box: ObjectDetectionBoundingBox
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
image.
"""
label: str
"""The predicted label for the bounding box."""
score: float
"""The associated score / probability."""

View File

@@ -0,0 +1,74 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class QuestionAnsweringInputData(BaseInferenceType):
"""One (context, question) pair to answer"""
context: str
"""The context to be used for answering the question"""
question: str
"""The question to be answered"""
@dataclass_with_extra
class QuestionAnsweringParameters(BaseInferenceType):
"""Additional inference parameters for Question Answering"""
align_to_words: Optional[bool] = None
"""Attempts to align the answer to real words. Improves quality on space separated
languages. Might hurt on non-space-separated languages (like Japanese or Chinese)
"""
doc_stride: Optional[int] = None
"""If the context is too long to fit with the question for the model, it will be split in
several chunks with some overlap. This argument controls the size of that overlap.
"""
handle_impossible_answer: Optional[bool] = None
"""Whether to accept impossible as an answer."""
max_answer_len: Optional[int] = None
"""The maximum length of predicted answers (e.g., only answers with a shorter length are
considered).
"""
max_question_len: Optional[int] = None
"""The maximum length of the question after tokenization. It will be truncated if needed."""
max_seq_len: Optional[int] = None
"""The maximum length of the total sentence (context + question) in tokens of each chunk
passed to the model. The context will be split in several chunks (using docStride as
overlap) if needed.
"""
top_k: Optional[int] = None
"""The number of answers to return (will be chosen by order of likelihood). Note that we
return less than topk answers if there are not enough options available within the
context.
"""
@dataclass_with_extra
class QuestionAnsweringInput(BaseInferenceType):
"""Inputs for Question Answering inference"""
inputs: QuestionAnsweringInputData
"""One (context, question) pair to answer"""
parameters: Optional[QuestionAnsweringParameters] = None
"""Additional inference parameters for Question Answering"""
@dataclass_with_extra
class QuestionAnsweringOutputElement(BaseInferenceType):
"""Outputs of inference for the Question Answering task"""
answer: str
"""The answer to the question."""
end: int
"""The character position in the input where the answer ends."""
score: float
"""The probability associated to the answer."""
start: int
"""The character position in the input where the answer begins."""

View File

@@ -0,0 +1,27 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class SentenceSimilarityInputData(BaseInferenceType):
sentences: list[str]
"""A list of strings which will be compared against the source_sentence."""
source_sentence: str
"""The string that you wish to compare the other strings with. This can be a phrase,
sentence, or longer passage, depending on the model being used.
"""
@dataclass_with_extra
class SentenceSimilarityInput(BaseInferenceType):
"""Inputs for Sentence similarity inference"""
inputs: SentenceSimilarityInputData
parameters: Optional[dict[str, Any]] = None
"""Additional inference parameters for Sentence Similarity"""

View File

@@ -0,0 +1,41 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
SummarizationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"]
@dataclass_with_extra
class SummarizationParameters(BaseInferenceType):
"""Additional inference parameters for summarization."""
clean_up_tokenization_spaces: Optional[bool] = None
"""Whether to clean up the potential extra spaces in the text output."""
generate_parameters: Optional[dict[str, Any]] = None
"""Additional parametrization of the text generation algorithm."""
truncation: Optional["SummarizationTruncationStrategy"] = None
"""The truncation strategy to use."""
@dataclass_with_extra
class SummarizationInput(BaseInferenceType):
"""Inputs for Summarization inference"""
inputs: str
"""The input text to summarize."""
parameters: Optional[SummarizationParameters] = None
"""Additional inference parameters for summarization."""
@dataclass_with_extra
class SummarizationOutput(BaseInferenceType):
"""Outputs of inference for the Summarization task"""
summary_text: str
"""The summarized text."""

View File

@@ -0,0 +1,62 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class TableQuestionAnsweringInputData(BaseInferenceType):
"""One (table, question) pair to answer"""
question: str
"""The question to be answered about the table"""
table: dict[str, list[str]]
"""The table to serve as context for the questions"""
Padding = Literal["do_not_pad", "longest", "max_length"]
@dataclass_with_extra
class TableQuestionAnsweringParameters(BaseInferenceType):
"""Additional inference parameters for Table Question Answering"""
padding: Optional["Padding"] = None
"""Activates and controls padding."""
sequential: Optional[bool] = None
"""Whether to do inference sequentially or as a batch. Batching is faster, but models like
SQA require the inference to be done sequentially to extract relations within sequences,
given their conversational nature.
"""
truncation: Optional[bool] = None
"""Activates and controls truncation."""
@dataclass_with_extra
class TableQuestionAnsweringInput(BaseInferenceType):
"""Inputs for Table Question Answering inference"""
inputs: TableQuestionAnsweringInputData
"""One (table, question) pair to answer"""
parameters: Optional[TableQuestionAnsweringParameters] = None
"""Additional inference parameters for Table Question Answering"""
@dataclass_with_extra
class TableQuestionAnsweringOutputElement(BaseInferenceType):
"""Outputs of inference for the Table Question Answering task"""
answer: str
"""The answer of the question given the table. If there is an aggregator, the answer will be
preceded by `AGGREGATOR >`.
"""
cells: list[str]
"""list of strings made up of the answer cell values."""
coordinates: list[list[int]]
"""Coordinates of the cells of the answers."""
aggregator: Optional[str] = None
"""If the model has an aggregator, this returns the aggregator."""

View File

@@ -0,0 +1,42 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
Text2TextGenerationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"]
@dataclass_with_extra
class Text2TextGenerationParameters(BaseInferenceType):
"""Additional inference parameters for Text2text Generation"""
clean_up_tokenization_spaces: Optional[bool] = None
"""Whether to clean up the potential extra spaces in the text output."""
generate_parameters: Optional[dict[str, Any]] = None
"""Additional parametrization of the text generation algorithm"""
truncation: Optional["Text2TextGenerationTruncationStrategy"] = None
"""The truncation strategy to use"""
@dataclass_with_extra
class Text2TextGenerationInput(BaseInferenceType):
"""Inputs for Text2text Generation inference"""
inputs: str
"""The input text data"""
parameters: Optional[Text2TextGenerationParameters] = None
"""Additional inference parameters for Text2text Generation"""
@dataclass_with_extra
class Text2TextGenerationOutput(BaseInferenceType):
"""Outputs of inference for the Text2text Generation task"""
generated_text: Any
text2_text_generation_output_generated_text: Optional[str] = None
"""The generated text."""

View File

@@ -0,0 +1,41 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
TextClassificationOutputTransform = Literal["sigmoid", "softmax", "none"]
@dataclass_with_extra
class TextClassificationParameters(BaseInferenceType):
"""Additional inference parameters for Text Classification"""
function_to_apply: Optional["TextClassificationOutputTransform"] = None
"""The function to apply to the model outputs in order to retrieve the scores."""
top_k: Optional[int] = None
"""When specified, limits the output to the top K most probable classes."""
@dataclass_with_extra
class TextClassificationInput(BaseInferenceType):
"""Inputs for Text Classification inference"""
inputs: str
"""The text to classify"""
parameters: Optional[TextClassificationParameters] = None
"""Additional inference parameters for Text Classification"""
@dataclass_with_extra
class TextClassificationOutputElement(BaseInferenceType):
"""Outputs of inference for the Text Classification task"""
label: str
"""The predicted class label."""
score: float
"""The corresponding probability."""

View File

@@ -0,0 +1,168 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
TypeEnum = Literal["json", "regex", "json_schema"]
@dataclass_with_extra
class TextGenerationInputGrammarType(BaseInferenceType):
type: "TypeEnum"
value: Any
"""A string that represents a [JSON Schema](https://json-schema.org/).
JSON Schema is a declarative language that allows to annotate JSON documents
with types and descriptions.
"""
@dataclass_with_extra
class TextGenerationInputGenerateParameters(BaseInferenceType):
adapter_id: Optional[str] = None
"""Lora adapter id"""
best_of: Optional[int] = None
"""Generate best_of sequences and return the one if the highest token logprobs."""
decoder_input_details: Optional[bool] = None
"""Whether to return decoder input token logprobs and ids."""
details: Optional[bool] = None
"""Whether to return generation details."""
do_sample: Optional[bool] = None
"""Activate logits sampling."""
frequency_penalty: Optional[float] = None
"""The parameter for frequency penalty. 1.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
"""
grammar: Optional[TextGenerationInputGrammarType] = None
max_new_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
repetition_penalty: Optional[float] = None
"""The parameter for repetition penalty. 1.0 means no penalty.
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
return_full_text: Optional[bool] = None
"""Whether to prepend the prompt to the generated text"""
seed: Optional[int] = None
"""Random sampling seed."""
stop: Optional[list[str]] = None
"""Stop generating tokens if a member of `stop` is generated."""
temperature: Optional[float] = None
"""The value used to module the logits distribution."""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
top_n_tokens: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for top-n-filtering."""
top_p: Optional[float] = None
"""Top-p value for nucleus sampling."""
truncate: Optional[int] = None
"""Truncate inputs tokens to the given size."""
typical_p: Optional[float] = None
"""Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666)
for more information.
"""
watermark: Optional[bool] = None
"""Watermarking with [A Watermark for Large Language
Models](https://arxiv.org/abs/2301.10226).
"""
@dataclass_with_extra
class TextGenerationInput(BaseInferenceType):
"""Text Generation Input.
Auto-generated from TGI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
"""
inputs: str
parameters: Optional[TextGenerationInputGenerateParameters] = None
stream: Optional[bool] = None
TextGenerationOutputFinishReason = Literal["length", "eos_token", "stop_sequence"]
@dataclass_with_extra
class TextGenerationOutputPrefillToken(BaseInferenceType):
id: int
logprob: float
text: str
@dataclass_with_extra
class TextGenerationOutputToken(BaseInferenceType):
id: int
logprob: float
special: bool
text: str
@dataclass_with_extra
class TextGenerationOutputBestOfSequence(BaseInferenceType):
finish_reason: "TextGenerationOutputFinishReason"
generated_text: str
generated_tokens: int
prefill: list[TextGenerationOutputPrefillToken]
tokens: list[TextGenerationOutputToken]
seed: Optional[int] = None
top_tokens: Optional[list[list[TextGenerationOutputToken]]] = None
@dataclass_with_extra
class TextGenerationOutputDetails(BaseInferenceType):
finish_reason: "TextGenerationOutputFinishReason"
generated_tokens: int
prefill: list[TextGenerationOutputPrefillToken]
tokens: list[TextGenerationOutputToken]
best_of_sequences: Optional[list[TextGenerationOutputBestOfSequence]] = None
seed: Optional[int] = None
top_tokens: Optional[list[list[TextGenerationOutputToken]]] = None
@dataclass_with_extra
class TextGenerationOutput(BaseInferenceType):
"""Text Generation Output.
Auto-generated from TGI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
"""
generated_text: str
details: Optional[TextGenerationOutputDetails] = None
@dataclass_with_extra
class TextGenerationStreamOutputStreamDetails(BaseInferenceType):
finish_reason: "TextGenerationOutputFinishReason"
generated_tokens: int
input_length: int
seed: Optional[int] = None
@dataclass_with_extra
class TextGenerationStreamOutputToken(BaseInferenceType):
id: int
logprob: float
special: bool
text: str
@dataclass_with_extra
class TextGenerationStreamOutput(BaseInferenceType):
"""Text Generation Stream Output.
Auto-generated from TGI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
"""
index: int
token: TextGenerationStreamOutputToken
details: Optional[TextGenerationStreamOutputStreamDetails] = None
generated_text: Optional[str] = None
top_tokens: Optional[list[TextGenerationStreamOutputToken]] = None

View File

@@ -0,0 +1,99 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Literal, Optional, Union
from .base import BaseInferenceType, dataclass_with_extra
TextToAudioEarlyStoppingEnum = Literal["never"]
@dataclass_with_extra
class TextToAudioGenerationParameters(BaseInferenceType):
"""Parametrization of the text generation process"""
do_sample: Optional[bool] = None
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
early_stopping: Optional[Union[bool, "TextToAudioEarlyStoppingEnum"]] = None
"""Controls the stopping condition for beam-based methods."""
epsilon_cutoff: Optional[float] = None
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
"""
eta_cutoff: Optional[float] = None
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
float strictly between 0 and 1, a token is only considered if it is greater than either
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
for more details.
"""
max_length: Optional[int] = None
"""The maximum length (in tokens) of the generated text, including the input."""
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate. Takes precedence over max_length."""
min_length: Optional[int] = None
"""The minimum length (in tokens) of the generated text, including the input."""
min_new_tokens: Optional[int] = None
"""The minimum number of tokens to generate. Takes precedence over min_length."""
num_beam_groups: Optional[int] = None
"""Number of groups to divide num_beams into in order to ensure diversity among different
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
"""
num_beams: Optional[int] = None
"""Number of beams to use for beam search."""
penalty_alpha: Optional[float] = None
"""The value balances the model confidence and the degeneration penalty in contrastive
search decoding.
"""
temperature: Optional[float] = None
"""The value used to modulate the next token probabilities."""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
top_p: Optional[float] = None
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
that add up to top_p or higher are kept for generation.
"""
typical_p: Optional[float] = None
"""Local typicality measures how similar the conditional probability of predicting a target
token next is to the expected conditional probability of predicting a random token next,
given the partial text already generated. If set to float < 1, the smallest set of the
most locally typical tokens with probabilities that add up to typical_p or higher are
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
"""
use_cache: Optional[bool] = None
"""Whether the model should use the past last key/values attentions to speed up decoding"""
@dataclass_with_extra
class TextToAudioParameters(BaseInferenceType):
"""Additional inference parameters for Text To Audio"""
generation_parameters: Optional[TextToAudioGenerationParameters] = None
"""Parametrization of the text generation process"""
@dataclass_with_extra
class TextToAudioInput(BaseInferenceType):
"""Inputs for Text To Audio inference"""
inputs: str
"""The input text data"""
parameters: Optional[TextToAudioParameters] = None
"""Additional inference parameters for Text To Audio"""
@dataclass_with_extra
class TextToAudioOutput(BaseInferenceType):
"""Outputs of inference for the Text To Audio task"""
audio: Any
"""The generated audio waveform."""
sampling_rate: float
"""The sampling rate of the generated audio waveform."""

View File

@@ -0,0 +1,50 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class TextToImageParameters(BaseInferenceType):
"""Additional inference parameters for Text To Image"""
guidance_scale: Optional[float] = None
"""A higher guidance scale value encourages the model to generate images closely linked to
the text prompt, but values too high may cause saturation and other artifacts.
"""
height: Optional[int] = None
"""The height in pixels of the output image"""
negative_prompt: Optional[str] = None
"""One prompt to guide what NOT to include in image generation."""
num_inference_steps: Optional[int] = None
"""The number of denoising steps. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
"""
scheduler: Optional[str] = None
"""Override the scheduler with a compatible one."""
seed: Optional[int] = None
"""Seed for the random number generator."""
width: Optional[int] = None
"""The width in pixels of the output image"""
@dataclass_with_extra
class TextToImageInput(BaseInferenceType):
"""Inputs for Text To Image inference"""
inputs: str
"""The input text data (sometimes called "prompt")"""
parameters: Optional[TextToImageParameters] = None
"""Additional inference parameters for Text To Image"""
@dataclass_with_extra
class TextToImageOutput(BaseInferenceType):
"""Outputs of inference for the Text To Image task"""
image: Any
"""The generated image returned as raw bytes in the payload."""

View File

@@ -0,0 +1,99 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Literal, Optional, Union
from .base import BaseInferenceType, dataclass_with_extra
TextToSpeechEarlyStoppingEnum = Literal["never"]
@dataclass_with_extra
class TextToSpeechGenerationParameters(BaseInferenceType):
"""Parametrization of the text generation process"""
do_sample: Optional[bool] = None
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None
"""Controls the stopping condition for beam-based methods."""
epsilon_cutoff: Optional[float] = None
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
"""
eta_cutoff: Optional[float] = None
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
float strictly between 0 and 1, a token is only considered if it is greater than either
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
for more details.
"""
max_length: Optional[int] = None
"""The maximum length (in tokens) of the generated text, including the input."""
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate. Takes precedence over max_length."""
min_length: Optional[int] = None
"""The minimum length (in tokens) of the generated text, including the input."""
min_new_tokens: Optional[int] = None
"""The minimum number of tokens to generate. Takes precedence over min_length."""
num_beam_groups: Optional[int] = None
"""Number of groups to divide num_beams into in order to ensure diversity among different
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
"""
num_beams: Optional[int] = None
"""Number of beams to use for beam search."""
penalty_alpha: Optional[float] = None
"""The value balances the model confidence and the degeneration penalty in contrastive
search decoding.
"""
temperature: Optional[float] = None
"""The value used to modulate the next token probabilities."""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
top_p: Optional[float] = None
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
that add up to top_p or higher are kept for generation.
"""
typical_p: Optional[float] = None
"""Local typicality measures how similar the conditional probability of predicting a target
token next is to the expected conditional probability of predicting a random token next,
given the partial text already generated. If set to float < 1, the smallest set of the
most locally typical tokens with probabilities that add up to typical_p or higher are
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
"""
use_cache: Optional[bool] = None
"""Whether the model should use the past last key/values attentions to speed up decoding"""
@dataclass_with_extra
class TextToSpeechParameters(BaseInferenceType):
"""Additional inference parameters for Text To Speech"""
generation_parameters: Optional[TextToSpeechGenerationParameters] = None
"""Parametrization of the text generation process"""
@dataclass_with_extra
class TextToSpeechInput(BaseInferenceType):
"""Inputs for Text To Speech inference"""
inputs: str
"""The input text data"""
parameters: Optional[TextToSpeechParameters] = None
"""Additional inference parameters for Text To Speech"""
@dataclass_with_extra
class TextToSpeechOutput(BaseInferenceType):
"""Outputs of inference for the Text To Speech task"""
audio: Any
"""The generated audio"""
sampling_rate: Optional[float] = None
"""The sampling rate of the generated audio waveform."""

View File

@@ -0,0 +1,46 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class TextToVideoParameters(BaseInferenceType):
"""Additional inference parameters for Text To Video"""
guidance_scale: Optional[float] = None
"""A higher guidance scale value encourages the model to generate videos closely linked to
the text prompt, but values too high may cause saturation and other artifacts.
"""
negative_prompt: Optional[list[str]] = None
"""One or several prompt to guide what NOT to include in video generation."""
num_frames: Optional[float] = None
"""The num_frames parameter determines how many video frames are generated."""
num_inference_steps: Optional[int] = None
"""The number of denoising steps. More denoising steps usually lead to a higher quality
video at the expense of slower inference.
"""
seed: Optional[int] = None
"""Seed for the random number generator."""
@dataclass_with_extra
class TextToVideoInput(BaseInferenceType):
"""Inputs for Text To Video inference"""
inputs: str
"""The input text data (sometimes called "prompt")"""
parameters: Optional[TextToVideoParameters] = None
"""Additional inference parameters for Text To Video"""
@dataclass_with_extra
class TextToVideoOutput(BaseInferenceType):
"""Outputs of inference for the Text To Video task"""
video: Any
"""The generated video returned as raw bytes in the payload."""

View File

@@ -0,0 +1,51 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
TokenClassificationAggregationStrategy = Literal["none", "simple", "first", "average", "max"]
@dataclass_with_extra
class TokenClassificationParameters(BaseInferenceType):
"""Additional inference parameters for Token Classification"""
aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None
"""The strategy used to fuse tokens based on model predictions"""
ignore_labels: Optional[list[str]] = None
"""A list of labels to ignore"""
stride: Optional[int] = None
"""The number of overlapping tokens between chunks when splitting the input text."""
@dataclass_with_extra
class TokenClassificationInput(BaseInferenceType):
"""Inputs for Token Classification inference"""
inputs: str
"""The input text data"""
parameters: Optional[TokenClassificationParameters] = None
"""Additional inference parameters for Token Classification"""
@dataclass_with_extra
class TokenClassificationOutputElement(BaseInferenceType):
"""Outputs of inference for the Token Classification task"""
end: int
"""The character position in the input where this group ends."""
score: float
"""The associated score / probability"""
start: int
"""The character position in the input where this group begins."""
word: str
"""The corresponding text"""
entity: Optional[str] = None
"""The predicted label for a single token"""
entity_group: Optional[str] = None
"""The predicted label for a group of one or more tokens"""

View File

@@ -0,0 +1,49 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
TranslationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"]
@dataclass_with_extra
class TranslationParameters(BaseInferenceType):
"""Additional inference parameters for Translation"""
clean_up_tokenization_spaces: Optional[bool] = None
"""Whether to clean up the potential extra spaces in the text output."""
generate_parameters: Optional[dict[str, Any]] = None
"""Additional parametrization of the text generation algorithm."""
src_lang: Optional[str] = None
"""The source language of the text. Required for models that can translate from multiple
languages.
"""
tgt_lang: Optional[str] = None
"""Target language to translate to. Required for models that can translate to multiple
languages.
"""
truncation: Optional["TranslationTruncationStrategy"] = None
"""The truncation strategy to use."""
@dataclass_with_extra
class TranslationInput(BaseInferenceType):
"""Inputs for Translation inference"""
inputs: str
"""The text to translate."""
parameters: Optional[TranslationParameters] = None
"""Additional inference parameters for Translation"""
@dataclass_with_extra
class TranslationOutput(BaseInferenceType):
"""Outputs of inference for the Translation task"""
translation_text: str
"""The translated text."""

View File

@@ -0,0 +1,45 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Literal, Optional
from .base import BaseInferenceType, dataclass_with_extra
VideoClassificationOutputTransform = Literal["sigmoid", "softmax", "none"]
@dataclass_with_extra
class VideoClassificationParameters(BaseInferenceType):
"""Additional inference parameters for Video Classification"""
frame_sampling_rate: Optional[int] = None
"""The sampling rate used to select frames from the video."""
function_to_apply: Optional["VideoClassificationOutputTransform"] = None
"""The function to apply to the model outputs in order to retrieve the scores."""
num_frames: Optional[int] = None
"""The number of sampled frames to consider for classification."""
top_k: Optional[int] = None
"""When specified, limits the output to the top K most probable classes."""
@dataclass_with_extra
class VideoClassificationInput(BaseInferenceType):
"""Inputs for Video Classification inference"""
inputs: Any
"""The input video data"""
parameters: Optional[VideoClassificationParameters] = None
"""Additional inference parameters for Video Classification"""
@dataclass_with_extra
class VideoClassificationOutputElement(BaseInferenceType):
"""Outputs of inference for the Video Classification task"""
label: str
"""The predicted class label."""
score: float
"""The corresponding probability."""

View File

@@ -0,0 +1,49 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class VisualQuestionAnsweringInputData(BaseInferenceType):
"""One (image, question) pair to answer"""
image: Any
"""The image."""
question: str
"""The question to answer based on the image."""
@dataclass_with_extra
class VisualQuestionAnsweringParameters(BaseInferenceType):
"""Additional inference parameters for Visual Question Answering"""
top_k: Optional[int] = None
"""The number of answers to return (will be chosen by order of likelihood). Note that we
return less than topk answers if there are not enough options available within the
context.
"""
@dataclass_with_extra
class VisualQuestionAnsweringInput(BaseInferenceType):
"""Inputs for Visual Question Answering inference"""
inputs: VisualQuestionAnsweringInputData
"""One (image, question) pair to answer"""
parameters: Optional[VisualQuestionAnsweringParameters] = None
"""Additional inference parameters for Visual Question Answering"""
@dataclass_with_extra
class VisualQuestionAnsweringOutputElement(BaseInferenceType):
"""Outputs of inference for the Visual Question Answering task"""
score: float
"""The associated score / probability"""
answer: Optional[str] = None
"""The answer to the question"""

View File

@@ -0,0 +1,45 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class ZeroShotClassificationParameters(BaseInferenceType):
"""Additional inference parameters for Zero Shot Classification"""
candidate_labels: list[str]
"""The set of possible class labels to classify the text into."""
hypothesis_template: Optional[str] = None
"""The sentence used in conjunction with `candidate_labels` to attempt the text
classification by replacing the placeholder with the candidate labels.
"""
multi_label: Optional[bool] = None
"""Whether multiple candidate labels can be true. If false, the scores are normalized such
that the sum of the label likelihoods for each sequence is 1. If true, the labels are
considered independent and probabilities are normalized for each candidate.
"""
@dataclass_with_extra
class ZeroShotClassificationInput(BaseInferenceType):
"""Inputs for Zero Shot Classification inference"""
inputs: str
"""The text to classify"""
parameters: ZeroShotClassificationParameters
"""Additional inference parameters for Zero Shot Classification"""
@dataclass_with_extra
class ZeroShotClassificationOutputElement(BaseInferenceType):
"""Outputs of inference for the Zero Shot Classification task"""
label: str
"""The predicted class label."""
score: float
"""The corresponding probability."""

View File

@@ -0,0 +1,40 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Optional
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class ZeroShotImageClassificationParameters(BaseInferenceType):
"""Additional inference parameters for Zero Shot Image Classification"""
candidate_labels: list[str]
"""The candidate labels for this image"""
hypothesis_template: Optional[str] = None
"""The sentence used in conjunction with `candidate_labels` to attempt the image
classification by replacing the placeholder with the candidate labels.
"""
@dataclass_with_extra
class ZeroShotImageClassificationInput(BaseInferenceType):
"""Inputs for Zero Shot Image Classification inference"""
inputs: str
"""The input image data to classify as a base64-encoded string."""
parameters: ZeroShotImageClassificationParameters
"""Additional inference parameters for Zero Shot Image Classification"""
@dataclass_with_extra
class ZeroShotImageClassificationOutputElement(BaseInferenceType):
"""Outputs of inference for the Zero Shot Image Classification task"""
label: str
"""The predicted class label."""
score: float
"""The corresponding probability."""

View File

@@ -0,0 +1,50 @@
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from .base import BaseInferenceType, dataclass_with_extra
@dataclass_with_extra
class ZeroShotObjectDetectionParameters(BaseInferenceType):
"""Additional inference parameters for Zero Shot Object Detection"""
candidate_labels: list[str]
"""The candidate labels for this image"""
@dataclass_with_extra
class ZeroShotObjectDetectionInput(BaseInferenceType):
"""Inputs for Zero Shot Object Detection inference"""
inputs: str
"""The input image data as a base64-encoded string."""
parameters: ZeroShotObjectDetectionParameters
"""Additional inference parameters for Zero Shot Object Detection"""
@dataclass_with_extra
class ZeroShotObjectDetectionBoundingBox(BaseInferenceType):
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
image.
"""
xmax: int
xmin: int
ymax: int
ymin: int
@dataclass_with_extra
class ZeroShotObjectDetectionOutputElement(BaseInferenceType):
"""Outputs of inference for the Zero Shot Object Detection task"""
box: ZeroShotObjectDetectionBoundingBox
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
image.
"""
label: str
"""A candidate label"""
score: float
"""The associated score / probability"""

View File

@@ -0,0 +1,88 @@
import asyncio
import sys
from functools import partial
import typer
def _patch_anyio_open_process():
"""
Patch anyio.open_process to allow detached processes on Windows and Unix-like systems.
This is necessary to prevent the MCP client from being interrupted by Ctrl+C when running in the CLI.
"""
import subprocess
import anyio
if getattr(anyio, "_tiny_agents_patched", False):
return
anyio._tiny_agents_patched = True # ty: ignore[invalid-assignment]
original_open_process = anyio.open_process
if sys.platform == "win32":
# On Windows, we need to set the creation flags to create a new process group
async def open_process_in_new_group(*args, **kwargs):
"""
Wrapper for open_process to handle Windows-specific process creation flags.
"""
# Ensure we pass the creation flags for Windows
kwargs.setdefault("creationflags", subprocess.CREATE_NEW_PROCESS_GROUP)
return await original_open_process(*args, **kwargs)
anyio.open_process = open_process_in_new_group # ty: ignore[invalid-assignment]
else:
# For Unix-like systems, we can use setsid to create a new session
async def open_process_in_new_group(*args, **kwargs):
"""
Wrapper for open_process to handle Unix-like systems with start_new_session=True.
"""
kwargs.setdefault("start_new_session", True)
return await original_open_process(*args, **kwargs)
anyio.open_process = open_process_in_new_group # ty: ignore[invalid-assignment]
async def _async_prompt(exit_event: asyncio.Event, prompt: str = "» ") -> str:
"""
Asynchronous prompt function that reads input from stdin without blocking.
This function is designed to work in an asynchronous context, allowing the event loop to gracefully stop it (e.g. on Ctrl+C).
Alternatively, we could use https://github.com/vxgmichel/aioconsole but that would be an additional dependency.
"""
loop = asyncio.get_event_loop()
if sys.platform == "win32":
# Windows: Use run_in_executor to avoid blocking the event loop
# Degraded solution: this is not ideal as user will have to CTRL+C once more to stop the prompt (and it'll not be graceful)
return await loop.run_in_executor(None, partial(typer.prompt, prompt, prompt_suffix=" "))
else:
# UNIX-like: Use loop.add_reader for non-blocking stdin read
future = loop.create_future()
def on_input():
line = sys.stdin.readline()
loop.remove_reader(sys.stdin)
future.set_result(line)
print(prompt, end=" ", flush=True)
loop.add_reader(sys.stdin, on_input) # not supported on Windows
# Wait for user input or exit event
# Wait until either the user hits enter or exit_event is set
exit_task = asyncio.create_task(exit_event.wait())
await asyncio.wait(
[future, exit_task],
return_when=asyncio.FIRST_COMPLETED,
)
# Check which one has been triggered
if exit_event.is_set():
future.cancel()
return ""
line = await future
return line.strip()

View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import asyncio
from typing import AsyncGenerator, Iterable, Optional, Union
from huggingface_hub import ChatCompletionInputMessage, ChatCompletionStreamOutput, MCPClient
from .._providers import PROVIDER_OR_POLICY_T
from .constants import DEFAULT_SYSTEM_PROMPT, EXIT_LOOP_TOOLS, MAX_NUM_TURNS
from .types import ServerConfig
class Agent(MCPClient):
"""
Implementation of a Simple Agent, which is a simple while loop built right on top of an [`MCPClient`].
> [!WARNING]
> This class is experimental and might be subject to breaking changes in the future without prior notice.
Args:
model (`str`, *optional*):
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
or a URL to a deployed Inference Endpoint or other local or remote endpoint.
servers (`Iterable[dict]`):
MCP servers to connect to. Each server is a dictionary containing a `type` key and a `config` key. The `type` key can be `"stdio"` or `"sse"`, and the `config` key is a dictionary of arguments for the server.
provider (`str`, *optional*):
Name of the provider to use for inference. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
base_url (`str`, *optional*):
The base URL to run inference. Defaults to None.
api_key (`str`, *optional*):
Token to use for authentication. Will default to the locally Hugging Face saved token if not provided. You can also use your own provider API key to interact directly with the provider's service.
prompt (`str`, *optional*):
The system prompt to use for the agent. Defaults to the default system prompt in `constants.py`.
"""
def __init__(
self,
*,
model: Optional[str] = None,
servers: Iterable[ServerConfig],
provider: Optional[PROVIDER_OR_POLICY_T] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
prompt: Optional[str] = None,
):
super().__init__(model=model, provider=provider, base_url=base_url, api_key=api_key)
self._servers_cfg = list(servers)
self.messages: list[Union[dict, ChatCompletionInputMessage]] = [
{"role": "system", "content": prompt or DEFAULT_SYSTEM_PROMPT}
]
async def load_tools(self) -> None:
for cfg in self._servers_cfg:
await self.add_mcp_server(**cfg)
async def run(
self,
user_input: str,
*,
abort_event: Optional[asyncio.Event] = None,
) -> AsyncGenerator[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage], None]:
"""
Run the agent with the given user input.
Args:
user_input (`str`):
The user input to run the agent with.
abort_event (`asyncio.Event`, *optional*):
An event that can be used to abort the agent. If the event is set, the agent will stop running.
"""
self.messages.append({"role": "user", "content": user_input})
num_turns: int = 0
next_turn_should_call_tools = True
while True:
if abort_event and abort_event.is_set():
return
async for item in self.process_single_turn_with_tools(
self.messages,
exit_loop_tools=EXIT_LOOP_TOOLS,
exit_if_first_chunk_no_tool=(num_turns > 0 and next_turn_should_call_tools),
):
yield item
num_turns += 1
last = self.messages[-1]
if last.get("role") == "tool" and last.get("name") in {t.function.name for t in EXIT_LOOP_TOOLS}:
return
if last.get("role") != "tool" and num_turns > MAX_NUM_TURNS:
return
if last.get("role") != "tool" and next_turn_should_call_tools:
return
next_turn_should_call_tools = last.get("role") != "tool"

View File

@@ -0,0 +1,255 @@
import asyncio
import os
import signal
import traceback
from typing import Optional
import typer
from ...utils import ANSI
from ._cli_hacks import _async_prompt, _patch_anyio_open_process
from .agent import Agent
from .utils import _load_agent_config
app = typer.Typer(
rich_markup_mode="rich",
help="A squad of lightweight composable AI applications built on Hugging Face's Inference Client and MCP stack.",
)
run_cli = typer.Typer(
name="run",
help="Run the Agent in the CLI",
invoke_without_command=True,
)
app.add_typer(run_cli, name="run")
async def run_agent(
agent_path: Optional[str],
) -> None:
"""
Tiny Agent loop.
Args:
agent_path (`str`, *optional*):
Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` or `AGENTS.md` file or a built-in agent stored in a Hugging Face dataset.
"""
_patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C
config, prompt = _load_agent_config(agent_path)
inputs = config.get("inputs", [])
servers = config.get("servers", [])
abort_event = asyncio.Event()
exit_event = asyncio.Event()
first_sigint = True
loop = asyncio.get_running_loop()
original_sigint_handler = signal.getsignal(signal.SIGINT)
def _sigint_handler() -> None:
nonlocal first_sigint
if first_sigint:
first_sigint = False
abort_event.set()
print(ANSI.red("\nInterrupted. Press Ctrl+C again to quit."), flush=True)
return
print(ANSI.red("\nExiting..."), flush=True)
exit_event.set()
try:
sigint_registered_in_loop = False
try:
loop.add_signal_handler(signal.SIGINT, _sigint_handler)
sigint_registered_in_loop = True
except (AttributeError, NotImplementedError):
# Windows (or any loop that doesn't support it) : fall back to sync
signal.signal(signal.SIGINT, lambda *_: _sigint_handler())
# Handle inputs (i.e. env variables injection)
resolved_inputs: dict[str, str] = {}
if len(inputs) > 0:
print(
ANSI.bold(
ANSI.blue(
"Some initial inputs are required by the agent. "
"Please provide a value or leave empty to load from env."
)
)
)
for input_item in inputs:
input_id = input_item["id"]
description = input_item["description"]
env_special_value = f"${{input:{input_id}}}"
# Check if the input is used by any server or as an apiKey
input_usages = set()
for server in servers:
# Check stdio's "env" and http/sse's "headers" mappings
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
for key, value in env_or_headers.items():
if env_special_value in value:
input_usages.add(key)
raw_api_key = config.get("apiKey")
if isinstance(raw_api_key, str) and env_special_value in raw_api_key:
input_usages.add("apiKey")
if not input_usages:
print(
ANSI.yellow(
f"Input '{input_id}' defined in config but not used by any server or as an API key."
" Skipping."
)
)
continue
# Prompt user for input
env_variable_key = input_id.replace("-", "_").upper()
print(
ANSI.blue(f"{input_id}") + f": {description}. (default: load from {env_variable_key}).",
end=" ",
)
user_input = (await _async_prompt(exit_event=exit_event)).strip()
if exit_event.is_set():
return
# Fallback to environment variable when user left blank
final_value = user_input
if not final_value:
final_value = os.getenv(env_variable_key, "")
if final_value:
print(ANSI.green(f"Value successfully loaded from '{env_variable_key}'"))
else:
print(
ANSI.yellow(
f"No value found for '{env_variable_key}' in environment variables. Continuing."
)
)
resolved_inputs[input_id] = final_value
# Inject resolved value (can be empty) into stdio's env or http/sse's headers
for server in servers:
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
for key, value in env_or_headers.items():
if env_special_value in value:
env_or_headers[key] = env_or_headers[key].replace(env_special_value, final_value)
print()
raw_api_key = config.get("apiKey")
if isinstance(raw_api_key, str):
substituted_api_key = raw_api_key
for input_id, val in resolved_inputs.items():
substituted_api_key = substituted_api_key.replace(f"${{input:{input_id}}}", val)
config["apiKey"] = substituted_api_key
# Main agent loop
async with Agent(
provider=config.get("provider"), # type: ignore[arg-type]
model=config.get("model"),
base_url=config.get("endpointUrl"), # type: ignore[arg-type]
api_key=config.get("apiKey"),
servers=servers, # type: ignore[arg-type]
prompt=prompt,
) as agent:
await agent.load_tools()
print(ANSI.bold(ANSI.blue("Agent loaded with {} tools:".format(len(agent.available_tools)))))
for t in agent.available_tools:
print(ANSI.blue(f"{t.function.name}"))
while True:
abort_event.clear()
# Check if we should exit
if exit_event.is_set():
return
try:
user_input = await _async_prompt(exit_event=exit_event)
first_sigint = True
except EOFError:
print(ANSI.red("\nEOF received, exiting."), flush=True)
break
except KeyboardInterrupt:
if not first_sigint and abort_event.is_set():
continue
else:
print(ANSI.red("\nKeyboard interrupt during input processing."), flush=True)
break
try:
async for chunk in agent.run(user_input, abort_event=abort_event):
if abort_event.is_set() and not first_sigint:
break
if exit_event.is_set():
return
if hasattr(chunk, "choices"):
delta = chunk.choices[0].delta
if delta.content:
print(delta.content, end="", flush=True)
if delta.tool_calls:
for call in delta.tool_calls:
if call.id:
print(f"<Tool {call.id}>", end="")
if call.function.name:
print(f"{call.function.name}", end=" ")
if call.function.arguments:
print(f"{call.function.arguments}", end="")
else:
print(
ANSI.green(f"\n\nTool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}\n"),
flush=True,
)
print()
except Exception as e:
tb_str = traceback.format_exc()
print(ANSI.red(f"\nError during agent run: {e}\n{tb_str}"), flush=True)
first_sigint = True # Allow graceful interrupt for the next command
except Exception as e:
tb_str = traceback.format_exc()
print(ANSI.red(f"\nAn unexpected error occurred: {e}\n{tb_str}"), flush=True)
raise e
finally:
if sigint_registered_in_loop:
try:
loop.remove_signal_handler(signal.SIGINT)
except (AttributeError, NotImplementedError):
pass
else:
signal.signal(signal.SIGINT, original_sigint_handler)
@run_cli.callback()
def run(
path: Optional[str] = typer.Argument(
None,
help=(
"Path to a local folder containing an agent.json file or a built-in agent "
"stored in the 'tiny-agents/tiny-agents' Hugging Face dataset "
"(https://huggingface.co/datasets/tiny-agents/tiny-agents)"
),
show_default=False,
),
):
try:
asyncio.run(run_agent(path))
except KeyboardInterrupt:
print(ANSI.red("\nApplication terminated by KeyboardInterrupt."), flush=True)
raise typer.Exit(code=130)
except Exception as e:
print(ANSI.red(f"\nAn unexpected error occurred: {e}"), flush=True)
raise e
if __name__ == "__main__":
app()

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
import sys
from pathlib import Path
from huggingface_hub import ChatCompletionInputTool
FILENAME_CONFIG = "agent.json"
PROMPT_FILENAMES = ("PROMPT.md", "AGENTS.md")
DEFAULT_AGENT = {
"model": "Qwen/Qwen2.5-72B-Instruct",
"provider": "nebius",
"servers": [
{
"type": "stdio",
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-filesystem",
str(Path.home() / ("Desktop" if sys.platform == "darwin" else "")),
],
},
{
"type": "stdio",
"command": "npx",
"args": ["@playwright/mcp@latest"],
},
],
}
DEFAULT_SYSTEM_PROMPT = """
You are an agent - please keep going until the users query is completely
resolved, before ending your turn and yielding back to the user. Only terminate
your turn when you are sure that the problem is solved, or if you need more
info from the user to solve the problem.
If you are not sure about anything pertaining to the users request, use your
tools to read files and gather the relevant information: do NOT guess or make
up an answer.
You MUST plan extensively before each function call, and reflect extensively
on the outcomes of the previous function calls. DO NOT do this entire process
by making function calls only, as this can impair your ability to solve the
problem and think insightfully.
""".strip()
MAX_NUM_TURNS = 10
TASK_COMPLETE_TOOL: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj( # type: ignore[assignment]
{
"type": "function",
"function": {
"name": "task_complete",
"description": "Call this tool when the task given by the user is complete",
"parameters": {
"type": "object",
"properties": {},
},
},
}
)
ASK_QUESTION_TOOL: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj( # type: ignore[assignment]
{
"type": "function",
"function": {
"name": "ask_question",
"description": "Ask the user for more info required to solve or clarify their problem.",
"parameters": {
"type": "object",
"properties": {},
},
},
}
)
EXIT_LOOP_TOOLS: list[ChatCompletionInputTool] = [TASK_COMPLETE_TOOL, ASK_QUESTION_TOOL]
DEFAULT_REPO_ID = "tiny-agents/tiny-agents"

View File

@@ -0,0 +1,395 @@
import json
import logging
from contextlib import AsyncExitStack
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, AsyncIterable, Literal, Optional, TypedDict, Union, overload
from typing_extensions import NotRequired, TypeAlias, Unpack
from ...utils._runtime import get_hf_hub_version
from .._generated._async_client import AsyncInferenceClient
from .._generated.types import (
ChatCompletionInputMessage,
ChatCompletionInputTool,
ChatCompletionStreamOutput,
ChatCompletionStreamOutputDeltaToolCall,
)
from .._providers import PROVIDER_OR_POLICY_T
from .utils import format_result
if TYPE_CHECKING:
from mcp import ClientSession
logger = logging.getLogger(__name__)
# Type alias for tool names
ToolName: TypeAlias = str
ServerType: TypeAlias = Literal["stdio", "sse", "http"]
class StdioServerParameters_T(TypedDict):
command: str
args: NotRequired[list[str]]
env: NotRequired[dict[str, str]]
cwd: NotRequired[Union[str, Path, None]]
class SSEServerParameters_T(TypedDict):
url: str
headers: NotRequired[dict[str, Any]]
timeout: NotRequired[float]
sse_read_timeout: NotRequired[float]
class StreamableHTTPParameters_T(TypedDict):
url: str
headers: NotRequired[dict[str, Any]]
timeout: NotRequired[timedelta]
sse_read_timeout: NotRequired[timedelta]
terminate_on_close: NotRequired[bool]
class MCPClient:
"""
Client for connecting to one or more MCP servers and processing chat completions with tools.
> [!WARNING]
> This class is experimental and might be subject to breaking changes in the future without prior notice.
Args:
model (`str`, `optional`):
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
or a URL to a deployed Inference Endpoint or other local or remote endpoint.
provider (`str`, *optional*):
Name of the provider to use for inference. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
base_url (`str`, *optional*):
The base URL to run inference. Defaults to None.
api_key (`str`, `optional`):
Token to use for authentication. Will default to the locally Hugging Face saved token if not provided. You can also use your own provider API key to interact directly with the provider's service.
"""
def __init__(
self,
*,
model: Optional[str] = None,
provider: Optional[PROVIDER_OR_POLICY_T] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
):
# Initialize MCP sessions as a dictionary of ClientSession objects
self.sessions: dict[ToolName, "ClientSession"] = {}
self.exit_stack = AsyncExitStack()
self.available_tools: list[ChatCompletionInputTool] = []
# To be able to send the model in the payload if `base_url` is provided
if model is None and base_url is None:
raise ValueError("At least one of `model` or `base_url` should be set in `MCPClient`.")
self.payload_model = model
self.client = AsyncInferenceClient(
model=None if base_url is not None else model,
provider=provider,
api_key=api_key,
base_url=base_url,
)
async def __aenter__(self):
"""Enter the context manager"""
await self.client.__aenter__()
await self.exit_stack.__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit the context manager"""
await self.client.__aexit__(exc_type, exc_val, exc_tb)
await self.cleanup()
async def cleanup(self):
"""Clean up resources"""
await self.client.close()
await self.exit_stack.aclose()
@overload
async def add_mcp_server(self, type: Literal["stdio"], **params: Unpack[StdioServerParameters_T]): ...
@overload
async def add_mcp_server(self, type: Literal["sse"], **params: Unpack[SSEServerParameters_T]): ...
@overload
async def add_mcp_server(self, type: Literal["http"], **params: Unpack[StreamableHTTPParameters_T]): ...
async def add_mcp_server(self, type: ServerType, **params: Any):
"""Connect to an MCP server
Args:
type (`str`):
Type of the server to connect to. Can be one of:
- "stdio": Standard input/output server (local)
- "sse": Server-sent events (SSE) server
- "http": StreamableHTTP server
**params (`dict[str, Any]`):
Server parameters that can be either:
- For stdio servers:
- command (str): The command to run the MCP server
- args (list[str], optional): Arguments for the command
- env (dict[str, str], optional): Environment variables for the command
- cwd (Union[str, Path, None], optional): Working directory for the command
- allowed_tools (list[str], optional): List of tool names to allow from this server
- For SSE servers:
- url (str): The URL of the SSE server
- headers (dict[str, Any], optional): Headers for the SSE connection
- timeout (float, optional): Connection timeout
- sse_read_timeout (float, optional): SSE read timeout
- allowed_tools (list[str], optional): List of tool names to allow from this server
- For StreamableHTTP servers:
- url (str): The URL of the StreamableHTTP server
- headers (dict[str, Any], optional): Headers for the StreamableHTTP connection
- timeout (timedelta, optional): Connection timeout
- sse_read_timeout (timedelta, optional): SSE read timeout
- terminate_on_close (bool, optional): Whether to terminate on close
- allowed_tools (list[str], optional): List of tool names to allow from this server
"""
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types
# Extract allowed_tools configuration if provided
allowed_tools = params.pop("allowed_tools", None)
# Determine server type and create appropriate parameters
if type == "stdio":
# Handle stdio server
from mcp.client.stdio import stdio_client
logger.info(f"Connecting to stdio MCP server with command: {params['command']} {params.get('args', [])}")
client_kwargs = {"command": params["command"]}
for key in ["args", "env", "cwd"]:
if params.get(key) is not None:
client_kwargs[key] = params[key]
server_params = StdioServerParameters(**client_kwargs)
read, write = await self.exit_stack.enter_async_context(stdio_client(server_params))
elif type == "sse":
# Handle SSE server
from mcp.client.sse import sse_client
logger.info(f"Connecting to SSE MCP server at: {params['url']}")
client_kwargs = {"url": params["url"]}
for key in ["headers", "timeout", "sse_read_timeout"]:
if params.get(key) is not None:
client_kwargs[key] = params[key]
read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs))
elif type == "http":
# Handle StreamableHTTP server
from mcp.client.streamable_http import streamablehttp_client
logger.info(f"Connecting to StreamableHTTP MCP server at: {params['url']}")
client_kwargs = {"url": params["url"]}
for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]:
if params.get(key) is not None:
client_kwargs[key] = params[key]
read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs))
# ^ TODO: should be handle `get_session_id_callback`? (function to retrieve the current session ID)
else:
raise ValueError(f"Unsupported server type: {type}")
session = await self.exit_stack.enter_async_context(
ClientSession(
read_stream=read,
write_stream=write,
client_info=mcp_types.Implementation(
name="huggingface_hub.MCPClient",
version=get_hf_hub_version(),
),
)
)
logger.debug("Initializing session...")
await session.initialize()
# List available tools
response = await session.list_tools()
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
# Filter tools based on allowed_tools configuration
filtered_tools = response.tools
if allowed_tools is not None:
filtered_tools = [tool for tool in response.tools if tool.name in allowed_tools]
logger.debug(
f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
)
for tool in filtered_tools:
if tool.name in self.sessions:
logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
continue
# Map tool names to their server for later lookup
self.sessions[tool.name] = session
# Add tool to the list of available tools (for use in chat completions)
self.available_tools.append(
ChatCompletionInputTool.parse_obj_as_instance(
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema,
},
}
)
)
async def process_single_turn_with_tools(
self,
messages: list[Union[dict, ChatCompletionInputMessage]],
exit_loop_tools: Optional[list[ChatCompletionInputTool]] = None,
exit_if_first_chunk_no_tool: bool = False,
) -> AsyncIterable[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage]]:
"""Process a query using `self.model` and available tools, yielding chunks and tool outputs.
Args:
messages (`list[dict]`):
List of message objects representing the conversation history
exit_loop_tools (`list[ChatCompletionInputTool]`, *optional*):
List of tools that should exit the generator when called
exit_if_first_chunk_no_tool (`bool`, *optional*):
Exit if no tool is present in the first chunks. Default to False.
Yields:
[`ChatCompletionStreamOutput`] chunks or [`ChatCompletionInputMessage`] objects
"""
# Prepare tools list based on options
tools = self.available_tools
if exit_loop_tools is not None:
tools = [*exit_loop_tools, *self.available_tools]
# Create the streaming request
response = await self.client.chat.completions.create(
model=self.payload_model,
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
)
message: dict[str, Any] = {"role": "unknown", "content": ""}
final_tool_calls: dict[int, ChatCompletionStreamOutputDeltaToolCall] = {}
num_of_chunks = 0
# Read from stream
async for chunk in response:
num_of_chunks += 1
delta = chunk.choices[0].delta if chunk.choices and len(chunk.choices) > 0 else None
if not delta:
continue
# Process message
if delta.role:
message["role"] = delta.role
if delta.content:
message["content"] += delta.content
# Process tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
idx = tool_call.index
# first chunk for this tool call
if idx not in final_tool_calls:
final_tool_calls[idx] = tool_call
if final_tool_calls[idx].function.arguments is None:
final_tool_calls[idx].function.arguments = ""
continue
# safety before concatenating text to .function.arguments
if final_tool_calls[idx].function.arguments is None:
final_tool_calls[idx].function.arguments = ""
if tool_call.function.arguments:
final_tool_calls[idx].function.arguments += tool_call.function.arguments
# Optionally exit early if no tools in first chunks
if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0:
return
# Yield each chunk to caller
yield chunk
# Add the assistant message with tool calls (if any) to messages
if message["content"] or final_tool_calls:
# if the role is unknown, set it to assistant
if message.get("role") == "unknown":
message["role"] = "assistant"
# Convert final_tool_calls to the format expected by OpenAI
if final_tool_calls:
tool_calls_list: list[dict[str, Any]] = []
for tc in final_tool_calls.values():
tool_calls_list.append(
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments or "{}",
},
}
)
message["tool_calls"] = tool_calls_list
messages.append(message)
# Process tool calls one by one
for tool_call in final_tool_calls.values():
function_name = tool_call.function.name
if function_name is None:
message = ChatCompletionInputMessage.parse_obj_as_instance(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "Invalid tool call with no function name.",
}
)
messages.append(message)
yield message
continue # move to next tool call
try:
function_args = json.loads(tool_call.function.arguments or "{}")
except json.JSONDecodeError as err:
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"name": function_name,
"content": f"Invalid JSON generated by the model: {err}",
}
tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message)
messages.append(tool_message_as_obj)
yield tool_message_as_obj
continue # move to next tool call
tool_message = {"role": "tool", "tool_call_id": tool_call.id, "content": "", "name": function_name}
# Check if this is an exit loop tool
if exit_loop_tools and function_name in [t.function.name for t in exit_loop_tools]:
tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message)
messages.append(tool_message_as_obj)
yield tool_message_as_obj
return
# Execute tool call with the appropriate session
session = self.sessions.get(function_name)
if session is not None:
try:
result = await session.call_tool(function_name, function_args)
tool_message["content"] = format_result(result)
except Exception as err:
tool_message["content"] = f"Error: MCP tool call failed with error message: {err}"
else:
tool_message["content"] = f"Error: No session found for tool: {function_name}"
# Yield tool message
tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message)
messages.append(tool_message_as_obj)
yield tool_message_as_obj

View File

@@ -0,0 +1,45 @@
from typing import Literal, TypedDict, Union
from typing_extensions import NotRequired
class InputConfig(TypedDict, total=False):
id: str
description: str
type: str
password: bool
class StdioServerConfig(TypedDict):
type: Literal["stdio"]
command: str
args: list[str]
env: dict[str, str]
cwd: str
allowed_tools: NotRequired[list[str]]
class HTTPServerConfig(TypedDict):
type: Literal["http"]
url: str
headers: dict[str, str]
allowed_tools: NotRequired[list[str]]
class SSEServerConfig(TypedDict):
type: Literal["sse"]
url: str
headers: dict[str, str]
allowed_tools: NotRequired[list[str]]
ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
# AgentConfig root object
class AgentConfig(TypedDict):
model: str
provider: str
apiKey: NotRequired[str]
inputs: list[InputConfig]
servers: list[ServerConfig]

View File

@@ -0,0 +1,128 @@
"""
Utility functions for MCPClient and Tiny Agents.
Formatting utilities taken from the JS SDK: https://github.com/huggingface/huggingface.js/blob/main/packages/mcp-client/src/ResultFormatter.ts.
"""
import json
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from huggingface_hub import snapshot_download
from huggingface_hub.errors import EntryNotFoundError
from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, PROMPT_FILENAMES
from .types import AgentConfig
if TYPE_CHECKING:
from mcp import types as mcp_types
def format_result(result: "mcp_types.CallToolResult") -> str:
"""
Formats a mcp.types.CallToolResult content into a human-readable string.
Args:
result (CallToolResult)
Object returned by mcp.ClientSession.call_tool.
Returns:
str
A formatted string representing the content of the result.
"""
content = result.content
if len(content) == 0:
return "[No content]"
formatted_parts: list[str] = []
for item in content:
if item.type == "text":
formatted_parts.append(item.text)
elif item.type == "image":
formatted_parts.append(
f"[Binary Content: Image {item.mimeType}, {_get_base64_size(item.data)} bytes]\n"
f"The task is complete and the content accessible to the User"
)
elif item.type == "audio":
formatted_parts.append(
f"[Binary Content: Audio {item.mimeType}, {_get_base64_size(item.data)} bytes]\n"
f"The task is complete and the content accessible to the User"
)
elif item.type == "resource":
resource = item.resource
if hasattr(resource, "text") and isinstance(resource.text, str):
formatted_parts.append(resource.text)
elif hasattr(resource, "blob") and isinstance(resource.blob, str):
formatted_parts.append(
f"[Binary Content ({resource.uri}): {resource.mimeType}, {_get_base64_size(resource.blob)} bytes]\n"
f"The task is complete and the content accessible to the User"
)
return "\n".join(formatted_parts)
def _get_base64_size(base64_str: str) -> int:
"""Estimate the byte size of a base64-encoded string."""
# Remove any prefix like "data:image/png;base64,"
if "," in base64_str:
base64_str = base64_str.split(",")[1]
padding = 0
if base64_str.endswith("=="):
padding = 2
elif base64_str.endswith("="):
padding = 1
return (len(base64_str) * 3) // 4 - padding
def _load_agent_config(agent_path: Optional[str]) -> tuple[AgentConfig, Optional[str]]:
"""Load server config and prompt."""
def _read_dir(directory: Path) -> tuple[AgentConfig, Optional[str]]:
cfg_file = directory / FILENAME_CONFIG
if not cfg_file.exists():
raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally")
config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8"))
prompt: Optional[str] = None
for filename in PROMPT_FILENAMES:
prompt_file = directory / filename
if prompt_file.exists():
prompt = prompt_file.read_text(encoding="utf-8")
break
return config, prompt
if agent_path is None:
return DEFAULT_AGENT, None # type: ignore[return-value]
path = Path(agent_path).expanduser()
if path.is_file():
return json.loads(path.read_text(encoding="utf-8")), None
if path.is_dir():
return _read_dir(path)
# fetch from the Hub
try:
repo_dir = Path(
snapshot_download(
repo_id=DEFAULT_REPO_ID,
allow_patterns=f"{agent_path}/*",
repo_type="dataset",
)
)
return _read_dir(repo_dir / agent_path)
except Exception as err:
raise EntryNotFoundError(
f" Agent {agent_path} not found in tiny-agents/tiny-agents! Please make sure it exists in https://huggingface.co/datasets/tiny-agents/tiny-agents."
) from err

View File

@@ -0,0 +1,266 @@
from typing import Literal, Optional, Union
from huggingface_hub.inference._providers.featherless_ai import (
FeatherlessConversationalTask,
FeatherlessTextGenerationTask,
)
from huggingface_hub.utils import logging
from ._common import AutoRouterConversationalTask, TaskProviderHelper, _fetch_inference_provider_mapping
from .black_forest_labs import BlackForestLabsTextToImageTask
from .cerebras import CerebrasConversationalTask
from .clarifai import ClarifaiConversationalTask
from .cohere import CohereConversationalTask
from .fal_ai import (
FalAIAutomaticSpeechRecognitionTask,
FalAIImageSegmentationTask,
FalAIImageToImageTask,
FalAIImageToVideoTask,
FalAITextToImageTask,
FalAITextToSpeechTask,
FalAITextToVideoTask,
)
from .fireworks_ai import FireworksAIConversationalTask
from .groq import GroqConversationalTask
from .hf_inference import (
HFInferenceBinaryInputTask,
HFInferenceConversational,
HFInferenceFeatureExtractionTask,
HFInferenceTask,
)
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
from .nebius import (
NebiusConversationalTask,
NebiusFeatureExtractionTask,
NebiusTextGenerationTask,
NebiusTextToImageTask,
)
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
from .openai import OpenAIConversationalTask
from .ovhcloud import OVHcloudConversationalTask
from .publicai import PublicAIConversationalTask
from .replicate import (
ReplicateAutomaticSpeechRecognitionTask,
ReplicateImageToImageTask,
ReplicateTask,
ReplicateTextToImageTask,
ReplicateTextToSpeechTask,
)
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
from .wavespeed import (
WavespeedAIImageToImageTask,
WavespeedAIImageToVideoTask,
WavespeedAITextToImageTask,
WavespeedAITextToVideoTask,
)
from .zai_org import ZaiConversationalTask
logger = logging.get_logger(__name__)
PROVIDER_T = Literal[
"black-forest-labs",
"cerebras",
"clarifai",
"cohere",
"fal-ai",
"featherless-ai",
"fireworks-ai",
"groq",
"hf-inference",
"hyperbolic",
"nebius",
"novita",
"nscale",
"openai",
"ovhcloud",
"publicai",
"replicate",
"sambanova",
"scaleway",
"together",
"wavespeed",
"zai-org",
]
PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]
CONVERSATIONAL_AUTO_ROUTER = AutoRouterConversationalTask()
PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = {
"black-forest-labs": {
"text-to-image": BlackForestLabsTextToImageTask(),
},
"cerebras": {
"conversational": CerebrasConversationalTask(),
},
"clarifai": {
"conversational": ClarifaiConversationalTask(),
},
"cohere": {
"conversational": CohereConversationalTask(),
},
"fal-ai": {
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
"text-to-image": FalAITextToImageTask(),
"text-to-speech": FalAITextToSpeechTask(),
"text-to-video": FalAITextToVideoTask(),
"image-to-video": FalAIImageToVideoTask(),
"image-to-image": FalAIImageToImageTask(),
"image-segmentation": FalAIImageSegmentationTask(),
},
"featherless-ai": {
"conversational": FeatherlessConversationalTask(),
"text-generation": FeatherlessTextGenerationTask(),
},
"fireworks-ai": {
"conversational": FireworksAIConversationalTask(),
},
"groq": {
"conversational": GroqConversationalTask(),
},
"hf-inference": {
"text-to-image": HFInferenceTask("text-to-image"),
"conversational": HFInferenceConversational(),
"text-generation": HFInferenceTask("text-generation"),
"text-classification": HFInferenceTask("text-classification"),
"question-answering": HFInferenceTask("question-answering"),
"audio-classification": HFInferenceBinaryInputTask("audio-classification"),
"automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"),
"fill-mask": HFInferenceTask("fill-mask"),
"feature-extraction": HFInferenceFeatureExtractionTask(),
"image-classification": HFInferenceBinaryInputTask("image-classification"),
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
"document-question-answering": HFInferenceTask("document-question-answering"),
"image-to-text": HFInferenceBinaryInputTask("image-to-text"),
"object-detection": HFInferenceBinaryInputTask("object-detection"),
"audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"),
"zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
"zero-shot-classification": HFInferenceTask("zero-shot-classification"),
"image-to-image": HFInferenceBinaryInputTask("image-to-image"),
"sentence-similarity": HFInferenceTask("sentence-similarity"),
"table-question-answering": HFInferenceTask("table-question-answering"),
"tabular-classification": HFInferenceTask("tabular-classification"),
"text-to-speech": HFInferenceTask("text-to-speech"),
"token-classification": HFInferenceTask("token-classification"),
"translation": HFInferenceTask("translation"),
"summarization": HFInferenceTask("summarization"),
"visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
},
"hyperbolic": {
"text-to-image": HyperbolicTextToImageTask(),
"conversational": HyperbolicTextGenerationTask("conversational"),
"text-generation": HyperbolicTextGenerationTask("text-generation"),
},
"nebius": {
"text-to-image": NebiusTextToImageTask(),
"conversational": NebiusConversationalTask(),
"text-generation": NebiusTextGenerationTask(),
"feature-extraction": NebiusFeatureExtractionTask(),
},
"novita": {
"text-generation": NovitaTextGenerationTask(),
"conversational": NovitaConversationalTask(),
"text-to-video": NovitaTextToVideoTask(),
},
"nscale": {
"conversational": NscaleConversationalTask(),
"text-to-image": NscaleTextToImageTask(),
},
"openai": {
"conversational": OpenAIConversationalTask(),
},
"ovhcloud": {
"conversational": OVHcloudConversationalTask(),
},
"publicai": {
"conversational": PublicAIConversationalTask(),
},
"replicate": {
"automatic-speech-recognition": ReplicateAutomaticSpeechRecognitionTask(),
"image-to-image": ReplicateImageToImageTask(),
"text-to-image": ReplicateTextToImageTask(),
"text-to-speech": ReplicateTextToSpeechTask(),
"text-to-video": ReplicateTask("text-to-video"),
},
"sambanova": {
"conversational": SambanovaConversationalTask(),
"feature-extraction": SambanovaFeatureExtractionTask(),
},
"scaleway": {
"conversational": ScalewayConversationalTask(),
"feature-extraction": ScalewayFeatureExtractionTask(),
},
"together": {
"text-to-image": TogetherTextToImageTask(),
"conversational": TogetherConversationalTask(),
"text-generation": TogetherTextGenerationTask(),
},
"wavespeed": {
"text-to-image": WavespeedAITextToImageTask(),
"text-to-video": WavespeedAITextToVideoTask(),
"image-to-image": WavespeedAIImageToImageTask(),
"image-to-video": WavespeedAIImageToVideoTask(),
},
"zai-org": {
"conversational": ZaiConversationalTask(),
},
}
def get_provider_helper(
provider: Optional[PROVIDER_OR_POLICY_T], task: str, model: Optional[str]
) -> TaskProviderHelper:
"""Get provider helper instance by name and task.
Args:
provider (`str`, *optional*): name of the provider, or "auto" to automatically select the provider for the model.
task (`str`): Name of the task
model (`str`, *optional*): Name of the model
Returns:
TaskProviderHelper: Helper instance for the specified provider and task
Raises:
ValueError: If provider or task is not supported
"""
if (model is None and provider in (None, "auto")) or (
model is not None and model.startswith(("http://", "https://"))
):
provider = "hf-inference"
if provider is None:
logger.info(
"No provider specified for task `conversational`. Defaulting to server-side auto routing."
if task == "conversational"
else "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
)
provider = "auto"
if provider == "auto":
if model is None:
raise ValueError("Specifying a model is required when provider is 'auto'")
if task == "conversational":
# Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping.
return CONVERSATIONAL_AUTO_ROUTER
provider_mapping = _fetch_inference_provider_mapping(model)
provider = next(iter(provider_mapping)).provider
provider_tasks = PROVIDERS.get(provider) # type: ignore
if provider_tasks is None:
raise ValueError(
f"Provider '{provider}' not supported. Available values: 'auto' or any provider from {list(PROVIDERS.keys())}."
"Passing 'auto' (default value) will automatically select the first provider available for the model, sorted "
"by the user's order in https://hf.co/settings/inference-providers."
)
if task not in provider_tasks:
raise ValueError(
f"Task '{task}' not supported for provider '{provider}'. Available tasks: {list(provider_tasks.keys())}"
)
return provider_tasks[task]

Some files were not shown because too many files have changed in this diff Show More