fix:优化界面

This commit is contained in:
丹尼尔
2026-03-11 00:22:41 +08:00
parent 0655410134
commit 0e8639fde1
4268 changed files with 1224126 additions and 92 deletions

View File

@@ -0,0 +1,396 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import os as _os
import typing as _t
from typing_extensions import override
from . import types
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given
from ._utils import file_from_path
from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
from ._models import BaseModel
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
from ._exceptions import (
APIError,
OpenAIError,
ConflictError,
NotFoundError,
APIStatusError,
RateLimitError,
APITimeoutError,
BadRequestError,
APIConnectionError,
AuthenticationError,
InternalServerError,
PermissionDeniedError,
LengthFinishReasonError,
UnprocessableEntityError,
APIResponseValidationError,
InvalidWebhookSignatureError,
ContentFilterFinishReasonError,
)
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
from ._utils._logs import setup_logging as _setup_logging
from ._legacy_response import HttpxBinaryResponseContent as HttpxBinaryResponseContent
__all__ = [
"types",
"__version__",
"__title__",
"NoneType",
"Transport",
"ProxiesTypes",
"NotGiven",
"NOT_GIVEN",
"not_given",
"Omit",
"omit",
"OpenAIError",
"APIError",
"APIStatusError",
"APITimeoutError",
"APIConnectionError",
"APIResponseValidationError",
"BadRequestError",
"AuthenticationError",
"PermissionDeniedError",
"NotFoundError",
"ConflictError",
"UnprocessableEntityError",
"RateLimitError",
"InternalServerError",
"LengthFinishReasonError",
"ContentFilterFinishReasonError",
"InvalidWebhookSignatureError",
"Timeout",
"RequestOptions",
"Client",
"AsyncClient",
"Stream",
"AsyncStream",
"OpenAI",
"AsyncOpenAI",
"file_from_path",
"BaseModel",
"DEFAULT_TIMEOUT",
"DEFAULT_MAX_RETRIES",
"DEFAULT_CONNECTION_LIMITS",
"DefaultHttpxClient",
"DefaultAsyncHttpxClient",
"DefaultAioHttpClient",
]
if not _t.TYPE_CHECKING:
from ._utils._resources_proxy import resources as resources
from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool
from .version import VERSION as VERSION
from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI
from .lib._old_api import *
from .lib.streaming import (
AssistantEventHandler as AssistantEventHandler,
AsyncAssistantEventHandler as AsyncAssistantEventHandler,
)
_setup_logging()
# Update the __module__ attribute for exported symbols so that
# error messages point to this module instead of the module
# it was originally defined in, e.g.
# openai._exceptions.NotFoundError -> openai.NotFoundError
__locals = locals()
for __name in __all__:
if not __name.startswith("__"):
try:
__locals[__name].__module__ = "openai"
except (TypeError, AttributeError):
# Some of our exported symbols are builtins which we can't set attributes for.
pass
# ------ Module level client ------
import typing as _t
import typing_extensions as _te
import httpx as _httpx
from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
api_key: str | None = None
organization: str | None = None
project: str | None = None
webhook_secret: str | None = None
base_url: str | _httpx.URL | None = None
timeout: float | Timeout | None = DEFAULT_TIMEOUT
max_retries: int = DEFAULT_MAX_RETRIES
default_headers: _t.Mapping[str, str] | None = None
default_query: _t.Mapping[str, object] | None = None
http_client: _httpx.Client | None = None
_ApiType = _te.Literal["openai", "azure"]
api_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get("OPENAI_API_TYPE"))
api_version: str | None = _os.environ.get("OPENAI_API_VERSION")
azure_endpoint: str | None = _os.environ.get("AZURE_OPENAI_ENDPOINT")
azure_ad_token: str | None = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
azure_ad_token_provider: _azure.AzureADTokenProvider | None = None
class _ModuleClient(OpenAI):
# Note: we have to use type: ignores here as overriding class members
# with properties is technically unsafe but it is fine for our use case
@property # type: ignore
@override
def api_key(self) -> str | None:
return api_key
@api_key.setter # type: ignore
def api_key(self, value: str | None) -> None: # type: ignore
global api_key
api_key = value
@property # type: ignore
@override
def organization(self) -> str | None:
return organization
@organization.setter # type: ignore
def organization(self, value: str | None) -> None: # type: ignore
global organization
organization = value
@property # type: ignore
@override
def project(self) -> str | None:
return project
@project.setter # type: ignore
def project(self, value: str | None) -> None: # type: ignore
global project
project = value
@property # type: ignore
@override
def webhook_secret(self) -> str | None:
return webhook_secret
@webhook_secret.setter # type: ignore
def webhook_secret(self, value: str | None) -> None: # type: ignore
global webhook_secret
webhook_secret = value
@property
@override
def base_url(self) -> _httpx.URL:
if base_url is not None:
return _httpx.URL(base_url)
return super().base_url
@base_url.setter
def base_url(self, url: _httpx.URL | str) -> None:
super().base_url = url # type: ignore[misc]
@property # type: ignore
@override
def timeout(self) -> float | Timeout | None:
return timeout
@timeout.setter # type: ignore
def timeout(self, value: float | Timeout | None) -> None: # type: ignore
global timeout
timeout = value
@property # type: ignore
@override
def max_retries(self) -> int:
return max_retries
@max_retries.setter # type: ignore
def max_retries(self, value: int) -> None: # type: ignore
global max_retries
max_retries = value
@property # type: ignore
@override
def _custom_headers(self) -> _t.Mapping[str, str] | None:
return default_headers
@_custom_headers.setter # type: ignore
def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore
global default_headers
default_headers = value
@property # type: ignore
@override
def _custom_query(self) -> _t.Mapping[str, object] | None:
return default_query
@_custom_query.setter # type: ignore
def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore
global default_query
default_query = value
@property # type: ignore
@override
def _client(self) -> _httpx.Client:
return http_client or super()._client
@_client.setter # type: ignore
def _client(self, value: _httpx.Client) -> None: # type: ignore
global http_client
http_client = value
class _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore
...
class _AmbiguousModuleClientUsageError(OpenAIError):
def __init__(self) -> None:
super().__init__(
"Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`"
)
def _has_openai_credentials() -> bool:
return _os.environ.get("OPENAI_API_KEY") is not None
def _has_azure_credentials() -> bool:
return azure_endpoint is not None or _os.environ.get("AZURE_OPENAI_API_KEY") is not None
def _has_azure_ad_credentials() -> bool:
return (
_os.environ.get("AZURE_OPENAI_AD_TOKEN") is not None
or azure_ad_token is not None
or azure_ad_token_provider is not None
)
_client: OpenAI | None = None
def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]
global _client
if _client is None:
global api_type, azure_endpoint, azure_ad_token, api_version
if azure_endpoint is None:
azure_endpoint = _os.environ.get("AZURE_OPENAI_ENDPOINT")
if azure_ad_token is None:
azure_ad_token = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
if api_version is None:
api_version = _os.environ.get("OPENAI_API_VERSION")
if api_type is None:
has_openai = _has_openai_credentials()
has_azure = _has_azure_credentials()
has_azure_ad = _has_azure_ad_credentials()
if has_openai and (has_azure or has_azure_ad):
raise _AmbiguousModuleClientUsageError()
if (azure_ad_token is not None or azure_ad_token_provider is not None) and _os.environ.get(
"AZURE_OPENAI_API_KEY"
) is not None:
raise _AmbiguousModuleClientUsageError()
if has_azure or has_azure_ad:
api_type = "azure"
else:
api_type = "openai"
if api_type == "azure":
_client = _AzureModuleClient( # type: ignore
api_version=api_version,
azure_endpoint=azure_endpoint,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=http_client,
)
return _client
_client = _ModuleClient(
api_key=api_key,
organization=organization,
project=project,
webhook_secret=webhook_secret,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=http_client,
)
return _client
return _client
def _reset_client() -> None: # type: ignore[reportUnusedFunction]
global _client
_client = None
from ._module_client import (
beta as beta,
chat as chat,
audio as audio,
evals as evals,
files as files,
images as images,
models as models,
skills as skills,
videos as videos,
batches as batches,
uploads as uploads,
realtime as realtime,
webhooks as webhooks,
responses as responses,
containers as containers,
embeddings as embeddings,
completions as completions,
fine_tuning as fine_tuning,
moderations as moderations,
conversations as conversations,
vector_stores as vector_stores,
)

View File

@@ -0,0 +1,3 @@
from .cli import main
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,231 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime
from typing_extensions import Self, Literal
import pydantic
from pydantic.fields import FieldInfo
from ._types import IncEx, StrBytesIntFloat
_T = TypeVar("_T")
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
# --------------- Pydantic v2, v3 compatibility ---------------
# Pyright incorrectly reports some of our functions as overriding a method when they don't
# pyright: reportIncompatibleMethodOverride=false
PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
if TYPE_CHECKING:
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
...
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
...
def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
...
def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
...
def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
...
def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
...
def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
...
else:
# v1 re-exports
if PYDANTIC_V1:
from pydantic.typing import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
)
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
else:
from ._utils import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
parse_date as parse_date,
is_typeddict as is_typeddict,
parse_datetime as parse_datetime,
is_literal_type as is_literal_type,
)
# refactored config
if TYPE_CHECKING:
from pydantic import ConfigDict as ConfigDict
else:
if PYDANTIC_V1:
# TODO: provide an error message here?
ConfigDict = None
else:
from pydantic import ConfigDict as ConfigDict
# renamed methods / properties
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
if PYDANTIC_V1:
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
else:
return model.model_validate(value)
def field_is_required(field: FieldInfo) -> bool:
if PYDANTIC_V1:
return field.required # type: ignore
return field.is_required()
def field_get_default(field: FieldInfo) -> Any:
value = field.get_default()
if PYDANTIC_V1:
return value
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
def field_outer_type(field: FieldInfo) -> Any:
if PYDANTIC_V1:
return field.outer_type_ # type: ignore
return field.annotation
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
if PYDANTIC_V1:
return model.__config__ # type: ignore
return model.model_config
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
if PYDANTIC_V1:
return model.__fields__ # type: ignore
return model.model_fields
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
if PYDANTIC_V1:
return model.copy(deep=deep) # type: ignore
return model.model_copy(deep=deep)
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
if PYDANTIC_V1:
return model.json(indent=indent) # type: ignore
return model.model_dump_json(indent=indent)
def model_dump(
model: pydantic.BaseModel,
*,
exclude: IncEx | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
warnings: bool = True,
mode: Literal["json", "python"] = "python",
by_alias: bool | None = None,
) -> dict[str, Any]:
if (not PYDANTIC_V1) or hasattr(model, "model_dump"):
return model.model_dump(
mode=mode,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
# warnings are not supported in Pydantic v1
warnings=True if PYDANTIC_V1 else warnings,
by_alias=by_alias,
)
return cast(
"dict[str, Any]",
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, by_alias=bool(by_alias)
),
)
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
if PYDANTIC_V1:
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
return model.model_validate(data)
def model_parse_json(model: type[_ModelT], data: str | bytes) -> _ModelT:
if PYDANTIC_V1:
return model.parse_raw(data) # pyright: ignore[reportDeprecated]
return model.model_validate_json(data)
def model_json_schema(model: type[_ModelT]) -> dict[str, Any]:
if PYDANTIC_V1:
return model.schema() # pyright: ignore[reportDeprecated]
return model.model_json_schema()
# generic models
if TYPE_CHECKING:
class GenericModel(pydantic.BaseModel): ...
else:
if PYDANTIC_V1:
import pydantic.generics
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
else:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel): ...
# cached properties
if TYPE_CHECKING:
cached_property = property
# we define a separate type (copied from typeshed)
# that represents that `cached_property` is `set`able
# at runtime, which differs from `@property`.
#
# this is a separate type as editors likely special case
# `@property` and we don't want to cause issues just to have
# more helpful internal types.
class typed_cached_property(Generic[_T]):
func: Callable[[Any], _T]
attrname: str | None
def __init__(self, func: Callable[[Any], _T]) -> None: ...
@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()
def __set_name__(self, owner: type[Any], name: str) -> None: ...
# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None: ...
else:
from functools import cached_property as cached_property
typed_cached_property = cached_property

View File

@@ -0,0 +1,14 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
import httpx
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to"
# default timeout is 10 minutes
DEFAULT_TIMEOUT = httpx.Timeout(timeout=600, connect=5.0)
DEFAULT_MAX_RETRIES = 2
DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=1000, max_keepalive_connections=100)
INITIAL_RETRY_DELAY = 0.5
MAX_RETRY_DELAY = 8.0

View File

@@ -0,0 +1,161 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, cast
from typing_extensions import Literal
import httpx
from ._utils import is_dict
from ._models import construct_type
if TYPE_CHECKING:
from .types.chat import ChatCompletion
__all__ = [
"BadRequestError",
"AuthenticationError",
"PermissionDeniedError",
"NotFoundError",
"ConflictError",
"UnprocessableEntityError",
"RateLimitError",
"InternalServerError",
"LengthFinishReasonError",
"ContentFilterFinishReasonError",
"InvalidWebhookSignatureError",
]
class OpenAIError(Exception):
pass
class APIError(OpenAIError):
message: str
request: httpx.Request
body: object | None
"""The API response body.
If the API responded with a valid JSON structure then this property will be the
decoded result.
If it isn't a valid JSON structure then this will be the raw response.
If there was no response associated with this error then it will be `None`.
"""
code: Optional[str] = None
param: Optional[str] = None
type: Optional[str]
def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None:
super().__init__(message)
self.request = request
self.message = message
self.body = body
if is_dict(body):
self.code = cast(Any, construct_type(type_=Optional[str], value=body.get("code")))
self.param = cast(Any, construct_type(type_=Optional[str], value=body.get("param")))
self.type = cast(Any, construct_type(type_=str, value=body.get("type")))
else:
self.code = None
self.param = None
self.type = None
class APIResponseValidationError(APIError):
response: httpx.Response
status_code: int
def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None:
super().__init__(message or "Data returned by API invalid for expected schema.", response.request, body=body)
self.response = response
self.status_code = response.status_code
class APIStatusError(APIError):
"""Raised when an API response has a status code of 4xx or 5xx."""
response: httpx.Response
status_code: int
request_id: str | None
def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
super().__init__(message, response.request, body=body)
self.response = response
self.status_code = response.status_code
self.request_id = response.headers.get("x-request-id")
class APIConnectionError(APIError):
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
super().__init__(message, request, body=None)
class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)
class BadRequestError(APIStatusError):
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
class AuthenticationError(APIStatusError):
status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
class PermissionDeniedError(APIStatusError):
status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
class NotFoundError(APIStatusError):
status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
class ConflictError(APIStatusError):
status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
class UnprocessableEntityError(APIStatusError):
status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
class RateLimitError(APIStatusError):
status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
class InternalServerError(APIStatusError):
pass
class LengthFinishReasonError(OpenAIError):
completion: ChatCompletion
"""The completion that caused this error.
Note: this will *not* be a complete `ChatCompletion` object when streaming as `usage`
will not be included.
"""
def __init__(self, *, completion: ChatCompletion) -> None:
msg = "Could not parse response content as the length limit was reached"
if completion.usage:
msg += f" - {completion.usage}"
super().__init__(msg)
self.completion = completion
class ContentFilterFinishReasonError(OpenAIError):
def __init__(self) -> None:
super().__init__(
f"Could not parse response content as the request was rejected by the content filter",
)
class InvalidWebhookSignatureError(ValueError):
"""Raised when a webhook signature is invalid, meaning the computed signature does not match the expected signature."""

View File

@@ -0,0 +1,3 @@
from .numpy_proxy import numpy as numpy, has_numpy as has_numpy
from .pandas_proxy import pandas as pandas
from .sounddevice_proxy import sounddevice as sounddevice

View File

@@ -0,0 +1,21 @@
from .._exceptions import OpenAIError
INSTRUCTIONS = """
OpenAI error:
missing `{library}`
This feature requires additional dependencies:
$ pip install openai[{extra}]
"""
def format_instructions(*, library: str, extra: str) -> str:
return INSTRUCTIONS.format(library=library, extra=extra)
class MissingDependencyError(OpenAIError):
pass

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing_extensions import override
from .._utils import LazyProxy
from ._common import MissingDependencyError, format_instructions
if TYPE_CHECKING:
import numpy as numpy
NUMPY_INSTRUCTIONS = format_instructions(library="numpy", extra="voice_helpers")
class NumpyProxy(LazyProxy[Any]):
@override
def __load__(self) -> Any:
try:
import numpy
except ImportError as err:
raise MissingDependencyError(NUMPY_INSTRUCTIONS) from err
return numpy
if not TYPE_CHECKING:
numpy = NumpyProxy()
def has_numpy() -> bool:
try:
import numpy # noqa: F401 # pyright: ignore[reportUnusedImport]
except ImportError:
return False
return True

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing_extensions import override
from .._utils import LazyProxy
from ._common import MissingDependencyError, format_instructions
if TYPE_CHECKING:
import pandas as pandas
PANDAS_INSTRUCTIONS = format_instructions(library="pandas", extra="datalib")
class PandasProxy(LazyProxy[Any]):
@override
def __load__(self) -> Any:
try:
import pandas
except ImportError as err:
raise MissingDependencyError(PANDAS_INSTRUCTIONS) from err
return pandas
if not TYPE_CHECKING:
pandas = PandasProxy()

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing_extensions import override
from .._utils import LazyProxy
from ._common import MissingDependencyError, format_instructions
if TYPE_CHECKING:
import sounddevice as sounddevice # type: ignore
SOUNDDEVICE_INSTRUCTIONS = format_instructions(library="sounddevice", extra="voice_helpers")
class SounddeviceProxy(LazyProxy[Any]):
@override
def __load__(self) -> Any:
try:
import sounddevice # type: ignore
except ImportError as err:
raise MissingDependencyError(SOUNDDEVICE_INSTRUCTIONS) from err
return sounddevice
if not TYPE_CHECKING:
sounddevice = SounddeviceProxy()

View File

@@ -0,0 +1,123 @@
from __future__ import annotations
import io
import os
import pathlib
from typing import overload
from typing_extensions import TypeGuard
import anyio
from ._types import (
FileTypes,
FileContent,
RequestFiles,
HttpxFileTypes,
Base64FileInput,
HttpxFileContent,
HttpxRequestFiles,
)
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
)
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
) from None
@overload
def to_httpx_files(files: None) -> None: ...
@overload
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: _transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
return files
def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = pathlib.Path(file)
return (path.name, path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
def read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
return file
@overload
async def async_to_httpx_files(files: None) -> None: ...
@overload
async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: await _async_transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files]
else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
return files
async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = anyio.Path(file)
return (path.name, await path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], await async_read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
async def async_read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return await anyio.Path(file).read_bytes()
return file

View File

@@ -0,0 +1,491 @@
from __future__ import annotations
import os
import inspect
import logging
import datetime
import functools
from typing import (
TYPE_CHECKING,
Any,
Union,
Generic,
TypeVar,
Callable,
Iterator,
AsyncIterator,
cast,
overload,
)
from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin
import anyio
import httpx
import pydantic
from ._types import NoneType
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type
from ._models import BaseModel, is_basemodel, add_request_id
from ._constants import RAW_RESPONSE_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import APIResponseValidationError
if TYPE_CHECKING:
from ._models import FinalRequestOptions
from ._base_client import BaseClient
P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")
log: logging.Logger = logging.getLogger(__name__)
class LegacyAPIResponse(Generic[R]):
"""This is a legacy class as it will be replaced by `APIResponse`
and `AsyncAPIResponse` in the `_response.py` file in the next major
release.
For the sync client this will mostly be the same with the exception
of `content` & `text` will be methods instead of properties. In the
async client, all methods will be async.
A migration script will be provided & the migration in general should
be smooth.
"""
_cast_to: type[R]
_client: BaseClient[Any, Any]
_parsed_by_type: dict[type[Any], Any]
_stream: bool
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
_options: FinalRequestOptions
http_response: httpx.Response
retries_taken: int
"""The number of retries made. If no retries happened this will be `0`"""
def __init__(
self,
*,
raw: httpx.Response,
cast_to: type[R],
client: BaseClient[Any, Any],
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
options: FinalRequestOptions,
retries_taken: int = 0,
) -> None:
self._cast_to = cast_to
self._client = client
self._parsed_by_type = {}
self._stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
self.retries_taken = retries_taken
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
def parse(self, *, to: type[_T]) -> _T: ...
@overload
def parse(self) -> R: ...
def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
NOTE: For the async client: this will become a coroutine in the next major version.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from openai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `int`
- `float`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
if isinstance(parsed, BaseModel):
add_request_id(parsed, self.request_id)
self._parsed_by_type[cache_key] = parsed
return cast(R, parsed)
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def content(self) -> bytes:
"""Return the binary response content.
NOTE: this will be removed in favour of `.read()` in the
next major version.
"""
return self.http_response.content
@property
def text(self) -> str:
"""Return the decoded response content.
NOTE: this will be turned into a method in the next major version.
"""
return self.http_response.text
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def is_closed(self) -> bool:
return self.http_response.is_closed
@property
def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
cast_to = to if to is not None else self._cast_to
# unwrap `TypeAlias('Name', T)` -> `T`
if is_type_alias_type(cast_to):
cast_to = cast_to.__value__ # type: ignore[unreachable]
# unwrap `Annotated[T, ...]` -> `T`
if cast_to and is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)
origin = get_origin(cast_to) or cast_to
if self._stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
return cast(
_T,
to(
cast_to=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
),
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_to=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return cast(
R,
stream_cls(
cast_to=cast_to,
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
if cast_to is NoneType:
return cast(R, None)
response = self.http_response
if cast_to == str:
return cast(R, response.text)
if cast_to == int:
return cast(R, int(response.text))
if cast_to == float:
return cast(R, float(response.text))
if cast_to == bool:
return cast(R, response.text.lower() == "true")
if inspect.isclass(origin) and issubclass(origin, HttpxBinaryResponseContent):
return cast(R, cast_to(response)) # type: ignore
if origin == LegacyAPIResponse:
raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
if inspect.isclass(
origin # pyright: ignore[reportUnknownArgumentType]
) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_to != httpx.Response:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)
if (
inspect.isclass(
origin # pyright: ignore[reportUnknownArgumentType]
)
and not issubclass(origin, BaseModel)
and issubclass(origin, pydantic.BaseModel)
):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
if (
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
if not content_type.endswith("json"):
if is_basemodel(cast_to):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
body=response.text,
)
# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
data = response.json()
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
@override
def __repr__(self) -> str:
return f"<APIResponse [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference",
)
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "true"
kwargs["extra_headers"] = extra_headers
return cast(LegacyAPIResponse[R], func(*args, **kwargs))
return wrapped
def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[LegacyAPIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "true"
kwargs["extra_headers"] = extra_headers
return cast(LegacyAPIResponse[R], await func(*args, **kwargs))
return wrapped
class HttpxBinaryResponseContent:
response: httpx.Response
def __init__(self, response: httpx.Response) -> None:
self.response = response
@property
def content(self) -> bytes:
return self.response.content
@property
def text(self) -> str:
return self.response.text
@property
def encoding(self) -> str | None:
return self.response.encoding
@property
def charset_encoding(self) -> str | None:
return self.response.charset_encoding
def json(self, **kwargs: Any) -> Any:
return self.response.json(**kwargs)
def read(self) -> bytes:
return self.response.read()
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
return self.response.iter_bytes(chunk_size)
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
return self.response.iter_text(chunk_size)
def iter_lines(self) -> Iterator[str]:
return self.response.iter_lines()
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
return self.response.iter_raw(chunk_size)
def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')`
"""
with open(file, mode="wb") as f:
for data in self.response.iter_bytes():
f.write(data)
@deprecated(
"Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead"
)
def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
with open(file, mode="wb") as f:
for data in self.response.iter_bytes(chunk_size):
f.write(data)
def close(self) -> None:
return self.response.close()
async def aread(self) -> bytes:
return await self.response.aread()
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
return self.response.aiter_bytes(chunk_size)
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
return self.response.aiter_text(chunk_size)
async def aiter_lines(self) -> AsyncIterator[str]:
return self.response.aiter_lines()
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
return self.response.aiter_raw(chunk_size)
@deprecated(
"Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead"
)
async def astream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.response.aiter_bytes(chunk_size):
await f.write(data)
async def aclose(self) -> None:
return await self.response.aclose()

View File

@@ -0,0 +1,915 @@
from __future__ import annotations
import os
import inspect
import weakref
from typing import (
IO,
TYPE_CHECKING,
Any,
Type,
Tuple,
Union,
Generic,
TypeVar,
Callable,
Iterable,
Optional,
AsyncIterable,
cast,
)
from datetime import date, datetime
from typing_extensions import (
List,
Unpack,
Literal,
ClassVar,
Protocol,
Required,
Sequence,
ParamSpec,
TypedDict,
TypeGuard,
final,
override,
runtime_checkable,
)
import pydantic
from pydantic.fields import FieldInfo
from ._types import (
Body,
IncEx,
Query,
ModelT,
Headers,
Timeout,
NotGiven,
AnyMapping,
HttpxRequestFiles,
)
from ._utils import (
PropertyInfo,
is_list,
is_given,
json_safe,
lru_cache,
is_mapping,
parse_date,
coerce_boolean,
parse_datetime,
strip_not_given,
extract_type_arg,
is_annotated_type,
is_type_alias_type,
strip_annotated_type,
)
from ._compat import (
PYDANTIC_V1,
ConfigDict,
GenericModel as BaseGenericModel,
get_args,
is_union,
parse_obj,
get_origin,
is_literal_type,
get_model_config,
get_model_fields,
field_get_default,
)
from ._constants import RAW_RESPONSE_HEADER
if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
__all__ = ["BaseModel", "GenericModel"]
_T = TypeVar("_T")
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
P = ParamSpec("P")
ReprArgs = Sequence[Tuple[Optional[str], Any]]
@runtime_checkable
class _ConfigProtocol(Protocol):
allow_population_by_field_name: bool
class BaseModel(pydantic.BaseModel):
if PYDANTIC_V1:
@property
@override
def model_fields_set(self) -> set[str]:
# a forwards-compat shim for pydantic v2
return self.__fields_set__ # type: ignore
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore
@override
def __repr_args__(self) -> ReprArgs:
# we don't want these attributes to be included when something like `rich.print` is used
return [arg for arg in super().__repr_args__() if arg[0] not in {"_request_id", "__exclude_fields__"}]
else:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
)
if TYPE_CHECKING:
_request_id: Optional[str] = None
"""The ID of the request, returned via the X-Request-ID header. Useful for debugging requests and reporting issues to OpenAI.
This will **only** be set for the top-level response object, it will not be defined for nested objects. For example:
```py
completion = await client.chat.completions.create(...)
completion._request_id # req_id_xxx
completion.usage._request_id # raises `AttributeError`
```
Note: unlike other properties that use an `_` prefix, this property
*is* public. Unless documented otherwise, all other `_` prefix properties,
methods and modules are *private*.
"""
def to_dict(
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> dict[str, object]:
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
mode:
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
"""
return self.model_dump(
mode=mode,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
def to_json(
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> str:
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
"""
return self.model_dump_json(
indent=indent,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
@override
def __str__(self) -> str:
# mypy complains about an invalid self arg
return f"{self.__repr_name__()}({self.__repr_str__(', ')})" # type: ignore[misc]
# Override the 'construct' method in a way that supports recursive parsing without validation.
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
@classmethod
@override
def construct( # pyright: ignore[reportIncompatibleMethodOverride]
__cls: Type[ModelT],
_fields_set: set[str] | None = None,
**values: object,
) -> ModelT:
m = __cls.__new__(__cls)
fields_values: dict[str, object] = {}
config = get_model_config(__cls)
populate_by_name = (
config.allow_population_by_field_name
if isinstance(config, _ConfigProtocol)
else config.get("populate_by_name")
)
if _fields_set is None:
_fields_set = set()
model_fields = get_model_fields(__cls)
for name, field in model_fields.items():
key = field.alias
if key is None or (key not in values and populate_by_name):
key = name
if key in values:
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
_fields_set.add(name)
else:
fields_values[name] = field_get_default(field)
extra_field_type = _get_extra_fields_type(__cls)
_extra = {}
for key, value in values.items():
if key not in model_fields:
parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value
if PYDANTIC_V1:
_fields_set.add(key)
fields_values[key] = parsed
else:
_extra[key] = parsed
object.__setattr__(m, "__dict__", fields_values)
if PYDANTIC_V1:
# init_private_attributes() does not exist in v2
m._init_private_attributes() # type: ignore
# copied from Pydantic v1's `construct()` method
object.__setattr__(m, "__fields_set__", _fields_set)
else:
# these properties are copied from Pydantic's `model_construct()` method
object.__setattr__(m, "__pydantic_private__", None)
object.__setattr__(m, "__pydantic_extra__", _extra)
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
return m
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
# because the type signatures are technically different
# although not in practice
model_construct = construct
if PYDANTIC_V1:
# we define aliases for some of the new pydantic v2 methods so
# that we can just document these methods without having to specify
# a specific pydantic version as some users may not know which
# pydantic version they are currently using
@override
def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx | None = None,
exclude: IncEx | None = None,
context: Any | None = None,
by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
Args:
mode: The mode in which `to_python` should run.
If mode is 'json', the output will only contain JSON serializable types.
If mode is 'python', the output may contain non-JSON-serializable Python objects.
include: A set of fields to include in the output.
exclude: A set of fields to exclude from the output.
context: Additional context to pass to the serializer.
by_alias: Whether to use the field's alias in the dictionary key if defined.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value.
exclude_none: Whether to exclude fields that have a value of `None`.
exclude_computed_fields: Whether to exclude computed fields.
While this can be useful for round-tripping, it is usually recommended to use the dedicated
`round_trip` parameter instead.
round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T].
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
fallback: A function to call when an unknown value is encountered. If not provided,
a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
Returns:
A dictionary representation of the model.
"""
if mode not in {"json", "python"}:
raise ValueError("mode must be either 'json' or 'python'")
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
if fallback is not None:
raise ValueError("fallback is only supported in Pydantic v2")
if exclude_computed_fields != False:
raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
dumped = super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias if by_alias is not None else False,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped
@override
def model_dump_json(
self,
*,
indent: int | None = None,
ensure_ascii: bool = False,
include: IncEx | None = None,
exclude: IncEx | None = None,
context: Any | None = None,
by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> str:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
Generates a JSON representation of the model using Pydantic's `to_json` method.
Args:
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
by_alias: Whether to serialize using field aliases.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to use serialization/deserialization between JSON and class instance.
warnings: Whether to show any warnings that occurred during serialization.
Returns:
A JSON string representation of the model.
"""
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
if fallback is not None:
raise ValueError("fallback is only supported in Pydantic v2")
if ensure_ascii != False:
raise ValueError("ensure_ascii is only supported in Pydantic v2")
if exclude_computed_fields != False:
raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
return super().json( # type: ignore[reportDeprecated]
indent=indent,
include=include,
exclude=exclude,
by_alias=by_alias if by_alias is not None else False,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
if PYDANTIC_V1:
type_ = cast(type, field.outer_type_) # type: ignore
else:
type_ = field.annotation # type: ignore
if type_ is None:
raise RuntimeError(f"Unexpected field type is None for {key}")
return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None:
if PYDANTIC_V1:
# TODO
return None
schema = cls.__pydantic_core_schema__
if schema["type"] == "model":
fields = schema["schema"]
if fields["type"] == "model-fields":
extras = fields.get("extras_schema")
if extras and "cls" in extras:
# mypy can't narrow the type
return extras["cls"] # type: ignore[no-any-return]
return None
def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
if is_union(type_):
for variant in get_args(type_):
if is_basemodel(variant):
return True
return False
return is_basemodel_type(type_)
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
if not inspect.isclass(origin):
return False
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
def build(
base_model_cls: Callable[P, _BaseModelT],
*args: P.args,
**kwargs: P.kwargs,
) -> _BaseModelT:
"""Construct a BaseModel class without validation.
This is useful for cases where you need to instantiate a `BaseModel`
from an API response as this provides type-safe params which isn't supported
by helpers like `construct_type()`.
```py
build(MyModel, my_field_a="foo", my_field_b=123)
```
"""
if args:
raise TypeError(
"Received positional arguments which are not supported; Keyword arguments must be used instead",
)
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
"""Loose coercion to the expected type with construction of nested values.
Note: the returned value from this function is not guaranteed to match the
given type.
"""
return cast(_T, construct_type(value=value, type_=type_))
def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object:
"""Loose coercion to the expected type with construction of nested values.
If the given value does not match the expected type then it is returned as-is.
"""
# store a reference to the original type we were given before we extract any inner
# types so that we can properly resolve forward references in `TypeAliasType` annotations
original_type = None
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
if is_type_alias_type(type_):
original_type = type_ # type: ignore[unreachable]
type_ = type_.__value__ # type: ignore[unreachable]
# unwrap `Annotated[T, ...]` -> `T`
if metadata is not None and len(metadata) > 0:
meta: tuple[Any, ...] = tuple(metadata)
elif is_annotated_type(type_):
meta = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = tuple()
# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
origin = get_origin(type_) or type_
args = get_args(type_)
if is_union(origin):
try:
return validate_type(type_=cast("type[object]", original_type or type_), value=value)
except Exception:
pass
# if the type is a discriminated union then we want to construct the right variant
# in the union, even if the data doesn't match exactly, otherwise we'd break code
# that relies on the constructed class types, e.g.
#
# class FooType:
# kind: Literal['foo']
# value: str
#
# class BarType:
# kind: Literal['bar']
# value: int
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)
# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
return construct_type(value=value, type_=variant)
except Exception:
continue
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
if origin == dict:
if not is_mapping(value):
return value
_, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
if (
not is_literal_type(type_)
and inspect.isclass(origin)
and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel))
):
if is_list(value):
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
if is_mapping(value):
if issubclass(type_, BaseModel):
return type_.construct(**value) # type: ignore[arg-type]
return cast(Any, type_).construct(**value)
if origin == list:
if not is_list(value):
return value
inner_type = args[0] # List[inner_type]
return [construct_type(value=entry, type_=inner_type) for entry in value]
if origin == float:
if isinstance(value, int):
coerced = float(value)
if coerced != value:
return value
return coerced
return value
if type_ == datetime:
try:
return parse_datetime(value) # type: ignore
except Exception:
return value
if type_ == date:
try:
return parse_date(value) # type: ignore
except Exception:
return value
return value
@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
```py
class Foo(BaseModel):
type: Literal['foo']
```
Will result in field_name='type'
"""
field_alias_from: str | None
"""The name of the discriminator field in the API response, e.g.
```py
class Foo(BaseModel):
type: Literal['foo'] = Field(alias='type_from_api')
```
Will result in field_alias_from='type_from_api'
"""
mapping: dict[str, type]
"""Mapping of discriminator value to variant type, e.g.
{'foo': FooVariant, 'bar': BarVariant}
"""
def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
self.field_alias_from = discriminator_alias
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
cached = DISCRIMINATOR_CACHE.get(union)
if cached is not None:
return cached
discriminator_field_name: str | None = None
for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
discriminator_field_name = annotation.discriminator
break
if not discriminator_field_name:
return None
mapping: dict[str, type] = {}
discriminator_alias: str | None = None
for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V1:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation):
for entry in get_args(annotation):
if isinstance(entry, str):
mapping[entry] = variant
else:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field.get("serialization_alias")
field_schema = field["schema"]
if field_schema["type"] == "literal":
for entry in cast("LiteralSchema", field_schema)["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
if not mapping:
return None
details = DiscriminatorDetails(
mapping=mapping,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
DISCRIMINATOR_CACHE.setdefault(union, details)
return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] == "definitions":
schema = schema["schema"]
if schema["type"] != "model":
return None
schema = cast("ModelSchema", schema)
fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None
fields_schema = cast("ModelFieldsSchema", fields_schema)
field = fields_schema["fields"].get(field_name)
if not field:
return None
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
return cast(_T, parse_obj(type_, value))
return cast(_T, _validate_non_model_type(type_=type_, value=value))
def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
"""Add a pydantic config for the given type.
Note: this is a no-op on Pydantic v1.
"""
setattr(typ, "__pydantic_config__", config) # noqa: B010
def add_request_id(obj: BaseModel, request_id: str | None) -> None:
obj._request_id = request_id
# in Pydantic v1, using setattr like we do above causes the attribute
# to be included when serializing the model which we don't want in this
# case so we need to explicitly exclude it
if PYDANTIC_V1:
try:
exclude_fields = obj.__exclude_fields__ # type: ignore
except AttributeError:
cast(Any, obj).__exclude_fields__ = {"_request_id", "__exclude_fields__"}
else:
cast(Any, obj).__exclude_fields__ = {*(exclude_fields or {}), "_request_id", "__exclude_fields__"}
# our use of subclassing here causes weirdness for type checkers,
# so we just pretend that we don't subclass
if TYPE_CHECKING:
GenericModel = BaseModel
else:
class GenericModel(BaseGenericModel, BaseModel):
pass
if not PYDANTIC_V1:
from pydantic import TypeAdapter as _TypeAdapter
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
if TYPE_CHECKING:
from pydantic import TypeAdapter
else:
TypeAdapter = _CachedTypeAdapter
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)
elif not TYPE_CHECKING: # TODO: condition is weird
class RootModel(GenericModel, Generic[_T]):
"""Used as a placeholder to easily convert runtime types to a Pydantic format
to provide validation.
For example:
```py
validated = RootModel[int](__root__="5").__root__
# validated: 5
```
"""
__root__: _T
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
model = _create_pydantic_model(type_).validate(value)
return cast(_T, model.__root__)
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
return RootModel[type_] # type: ignore
class FinalRequestOptionsInput(TypedDict, total=False):
method: Required[str]
url: Required[str]
params: Query
headers: Headers
max_retries: int
timeout: float | Timeout | None
files: HttpxRequestFiles | None
idempotency_key: str
content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None]
json_data: Body
extra_json: AnyMapping
follow_redirects: bool
synthesize_event_and_data: bool
@final
class FinalRequestOptions(pydantic.BaseModel):
method: str
url: str
params: Query = {}
headers: Union[Headers, NotGiven] = NotGiven()
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
files: Union[HttpxRequestFiles, None] = None
idempotency_key: Union[str, None] = None
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
follow_redirects: Union[bool, None] = None
synthesize_event_and_data: Optional[bool] = None
content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None] = None
# It should be noted that we cannot use `json` here as that would override
# a BaseModel method in an incompatible fashion.
json_data: Union[Body, None] = None
extra_json: Union[AnyMapping, None] = None
if PYDANTIC_V1:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
arbitrary_types_allowed: bool = True
else:
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
def get_max_retries(self, max_retries: int) -> int:
if isinstance(self.max_retries, NotGiven):
return max_retries
return self.max_retries
def _strip_raw_response_header(self) -> None:
if not is_given(self.headers):
return
if self.headers.get(RAW_RESPONSE_HEADER):
self.headers = {**self.headers}
self.headers.pop(RAW_RESPONSE_HEADER)
# override the `construct` method so that we can run custom transformations.
# this is necessary as we don't want to do any actual runtime type checking
# (which means we can't use validators) but we do want to ensure that `NotGiven`
# values are not present
#
# type ignore required because we're adding explicit types to `**values`
@classmethod
def construct( # type: ignore
cls,
_fields_set: set[str] | None = None,
**values: Unpack[FinalRequestOptionsInput],
) -> FinalRequestOptions:
kwargs: dict[str, Any] = {
# we unconditionally call `strip_not_given` on any value
# as it will just ignore any non-mapping types
key: strip_not_given(value)
for key, value in values.items()
}
if PYDANTIC_V1:
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
return super().model_construct(_fields_set, **kwargs)
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
model_construct = construct

View File

@@ -0,0 +1,181 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import TYPE_CHECKING
from typing_extensions import override
if TYPE_CHECKING:
from .resources.files import Files
from .resources.images import Images
from .resources.models import Models
from .resources.videos import Videos
from .resources.batches import Batches
from .resources.beta.beta import Beta
from .resources.chat.chat import Chat
from .resources.embeddings import Embeddings
from .resources.audio.audio import Audio
from .resources.completions import Completions
from .resources.evals.evals import Evals
from .resources.moderations import Moderations
from .resources.skills.skills import Skills
from .resources.uploads.uploads import Uploads
from .resources.realtime.realtime import Realtime
from .resources.webhooks.webhooks import Webhooks
from .resources.responses.responses import Responses
from .resources.containers.containers import Containers
from .resources.fine_tuning.fine_tuning import FineTuning
from .resources.conversations.conversations import Conversations
from .resources.vector_stores.vector_stores import VectorStores
from . import _load_client
from ._utils import LazyProxy
class ChatProxy(LazyProxy["Chat"]):
@override
def __load__(self) -> Chat:
return _load_client().chat
class BetaProxy(LazyProxy["Beta"]):
@override
def __load__(self) -> Beta:
return _load_client().beta
class FilesProxy(LazyProxy["Files"]):
@override
def __load__(self) -> Files:
return _load_client().files
class AudioProxy(LazyProxy["Audio"]):
@override
def __load__(self) -> Audio:
return _load_client().audio
class EvalsProxy(LazyProxy["Evals"]):
@override
def __load__(self) -> Evals:
return _load_client().evals
class ImagesProxy(LazyProxy["Images"]):
@override
def __load__(self) -> Images:
return _load_client().images
class ModelsProxy(LazyProxy["Models"]):
@override
def __load__(self) -> Models:
return _load_client().models
class SkillsProxy(LazyProxy["Skills"]):
@override
def __load__(self) -> Skills:
return _load_client().skills
class VideosProxy(LazyProxy["Videos"]):
@override
def __load__(self) -> Videos:
return _load_client().videos
class BatchesProxy(LazyProxy["Batches"]):
@override
def __load__(self) -> Batches:
return _load_client().batches
class UploadsProxy(LazyProxy["Uploads"]):
@override
def __load__(self) -> Uploads:
return _load_client().uploads
class WebhooksProxy(LazyProxy["Webhooks"]):
@override
def __load__(self) -> Webhooks:
return _load_client().webhooks
class RealtimeProxy(LazyProxy["Realtime"]):
@override
def __load__(self) -> Realtime:
return _load_client().realtime
class ResponsesProxy(LazyProxy["Responses"]):
@override
def __load__(self) -> Responses:
return _load_client().responses
class EmbeddingsProxy(LazyProxy["Embeddings"]):
@override
def __load__(self) -> Embeddings:
return _load_client().embeddings
class ContainersProxy(LazyProxy["Containers"]):
@override
def __load__(self) -> Containers:
return _load_client().containers
class CompletionsProxy(LazyProxy["Completions"]):
@override
def __load__(self) -> Completions:
return _load_client().completions
class ModerationsProxy(LazyProxy["Moderations"]):
@override
def __load__(self) -> Moderations:
return _load_client().moderations
class FineTuningProxy(LazyProxy["FineTuning"]):
@override
def __load__(self) -> FineTuning:
return _load_client().fine_tuning
class VectorStoresProxy(LazyProxy["VectorStores"]):
@override
def __load__(self) -> VectorStores:
return _load_client().vector_stores
class ConversationsProxy(LazyProxy["Conversations"]):
@override
def __load__(self) -> Conversations:
return _load_client().conversations
chat: Chat = ChatProxy().__as_proxied__()
beta: Beta = BetaProxy().__as_proxied__()
files: Files = FilesProxy().__as_proxied__()
audio: Audio = AudioProxy().__as_proxied__()
evals: Evals = EvalsProxy().__as_proxied__()
images: Images = ImagesProxy().__as_proxied__()
models: Models = ModelsProxy().__as_proxied__()
skills: Skills = SkillsProxy().__as_proxied__()
videos: Videos = VideosProxy().__as_proxied__()
batches: Batches = BatchesProxy().__as_proxied__()
uploads: Uploads = UploadsProxy().__as_proxied__()
webhooks: Webhooks = WebhooksProxy().__as_proxied__()
realtime: Realtime = RealtimeProxy().__as_proxied__()
responses: Responses = ResponsesProxy().__as_proxied__()
embeddings: Embeddings = EmbeddingsProxy().__as_proxied__()
containers: Containers = ContainersProxy().__as_proxied__()
completions: Completions = CompletionsProxy().__as_proxied__()
moderations: Moderations = ModerationsProxy().__as_proxied__()
fine_tuning: FineTuning = FineTuningProxy().__as_proxied__()
vector_stores: VectorStores = VectorStoresProxy().__as_proxied__()
conversations: Conversations = ConversationsProxy().__as_proxied__()

View File

@@ -0,0 +1,150 @@
from __future__ import annotations
from typing import Any, List, Tuple, Union, Mapping, TypeVar
from urllib.parse import parse_qs, urlencode
from typing_extensions import Literal, get_args
from ._types import NotGiven, not_given
from ._utils import flatten
_T = TypeVar("_T")
ArrayFormat = Literal["comma", "repeat", "indices", "brackets"]
NestedFormat = Literal["dots", "brackets"]
PrimitiveData = Union[str, int, float, bool, None]
# this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"]
# https://github.com/microsoft/pyright/issues/3555
Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
Params = Mapping[str, Data]
class Querystring:
array_format: ArrayFormat
nested_format: NestedFormat
def __init__(
self,
*,
array_format: ArrayFormat = "repeat",
nested_format: NestedFormat = "brackets",
) -> None:
self.array_format = array_format
self.nested_format = nested_format
def parse(self, query: str) -> Mapping[str, object]:
# Note: custom format syntax is not supported yet
return parse_qs(query)
def stringify(
self,
params: Params,
*,
array_format: ArrayFormat | NotGiven = not_given,
nested_format: NestedFormat | NotGiven = not_given,
) -> str:
return urlencode(
self.stringify_items(
params,
array_format=array_format,
nested_format=nested_format,
)
)
def stringify_items(
self,
params: Params,
*,
array_format: ArrayFormat | NotGiven = not_given,
nested_format: NestedFormat | NotGiven = not_given,
) -> list[tuple[str, str]]:
opts = Options(
qs=self,
array_format=array_format,
nested_format=nested_format,
)
return flatten([self._stringify_item(key, value, opts) for key, value in params.items()])
def _stringify_item(
self,
key: str,
value: Data,
opts: Options,
) -> list[tuple[str, str]]:
if isinstance(value, Mapping):
items: list[tuple[str, str]] = []
nested_format = opts.nested_format
for subkey, subvalue in value.items():
items.extend(
self._stringify_item(
# TODO: error if unknown format
f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]",
subvalue,
opts,
)
)
return items
if isinstance(value, (list, tuple)):
array_format = opts.array_format
if array_format == "comma":
return [
(
key,
",".join(self._primitive_value_to_str(item) for item in value if item is not None),
),
]
elif array_format == "repeat":
items = []
for item in value:
items.extend(self._stringify_item(key, item, opts))
return items
elif array_format == "indices":
raise NotImplementedError("The array indices format is not supported yet")
elif array_format == "brackets":
items = []
key = key + "[]"
for item in value:
items.extend(self._stringify_item(key, item, opts))
return items
else:
raise NotImplementedError(
f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}"
)
serialised = self._primitive_value_to_str(value)
if not serialised:
return []
return [(key, serialised)]
def _primitive_value_to_str(self, value: PrimitiveData) -> str:
# copied from httpx
if value is True:
return "true"
elif value is False:
return "false"
elif value is None:
return ""
return str(value)
_qs = Querystring()
parse = _qs.parse
stringify = _qs.stringify
stringify_items = _qs.stringify_items
class Options:
array_format: ArrayFormat
nested_format: NestedFormat
def __init__(
self,
qs: Querystring = _qs,
*,
array_format: ArrayFormat | NotGiven = not_given,
nested_format: NestedFormat | NotGiven = not_given,
) -> None:
self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format
self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format

View File

@@ -0,0 +1,43 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import time
from typing import TYPE_CHECKING
import anyio
if TYPE_CHECKING:
from ._client import OpenAI, AsyncOpenAI
class SyncAPIResource:
_client: OpenAI
def __init__(self, client: OpenAI) -> None:
self._client = client
self._get = client.get
self._post = client.post
self._patch = client.patch
self._put = client.put
self._delete = client.delete
self._get_api_list = client.get_api_list
def _sleep(self, seconds: float) -> None:
time.sleep(seconds)
class AsyncAPIResource:
_client: AsyncOpenAI
def __init__(self, client: AsyncOpenAI) -> None:
self._client = client
self._get = client.get
self._post = client.post
self._patch = client.patch
self._put = client.put
self._delete = client.delete
self._get_api_list = client.get_api_list
async def _sleep(self, seconds: float) -> None:
await anyio.sleep(seconds)

View File

@@ -0,0 +1,851 @@
from __future__ import annotations
import os
import inspect
import logging
import datetime
import functools
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Union,
Generic,
TypeVar,
Callable,
Iterator,
AsyncIterator,
cast,
overload,
)
from typing_extensions import Awaitable, ParamSpec, override, get_origin
import anyio
import httpx
import pydantic
from ._types import NoneType
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base
from ._models import BaseModel, is_basemodel, add_request_id
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import OpenAIError, APIResponseValidationError
if TYPE_CHECKING:
from ._models import FinalRequestOptions
from ._base_client import BaseClient
P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")
_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]")
log: logging.Logger = logging.getLogger(__name__)
class BaseAPIResponse(Generic[R]):
_cast_to: type[R]
_client: BaseClient[Any, Any]
_parsed_by_type: dict[type[Any], Any]
_is_sse_stream: bool
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
_options: FinalRequestOptions
http_response: httpx.Response
retries_taken: int
"""The number of retries made. If no retries happened this will be `0`"""
def __init__(
self,
*,
raw: httpx.Response,
cast_to: type[R],
client: BaseClient[Any, Any],
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
options: FinalRequestOptions,
retries_taken: int = 0,
) -> None:
self._cast_to = cast_to
self._client = client
self._parsed_by_type = {}
self._is_sse_stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
self.retries_taken = retries_taken
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
"""Returns the httpx Request instance associated with the current response."""
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
"""Returns the URL for which the request was made."""
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
@property
def is_closed(self) -> bool:
"""Whether or not the response body has been closed.
If this is False then there is response data that has not been read yet.
You must either fully consume the response body or call `.close()`
before discarding the response to prevent resource leaks.
"""
return self.http_response.is_closed
@override
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
)
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
cast_to = to if to is not None else self._cast_to
# unwrap `TypeAlias('Name', T)` -> `T`
if is_type_alias_type(cast_to):
cast_to = cast_to.__value__ # type: ignore[unreachable]
# unwrap `Annotated[T, ...]` -> `T`
if cast_to and is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)
origin = get_origin(cast_to) or cast_to
if self._is_sse_stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
return cast(
_T,
to(
cast_to=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
),
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_to=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return cast(
R,
stream_cls(
cast_to=cast_to,
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
if cast_to is NoneType:
return cast(R, None)
response = self.http_response
if cast_to == str:
return cast(R, response.text)
if cast_to == bytes:
return cast(R, response.content)
if cast_to == int:
return cast(R, int(response.text))
if cast_to == float:
return cast(R, float(response.text))
if cast_to == bool:
return cast(R, response.text.lower() == "true")
# handle the legacy binary response case
if inspect.isclass(cast_to) and cast_to.__name__ == "HttpxBinaryResponseContent":
return cast(R, cast_to(response)) # type: ignore
if origin == APIResponse:
raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_to != httpx.Response:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)
if (
inspect.isclass(
origin # pyright: ignore[reportUnknownArgumentType]
)
and not issubclass(origin, BaseModel)
and issubclass(origin, pydantic.BaseModel)
):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
if (
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
if not content_type.endswith("json"):
if is_basemodel(cast_to):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
body=response.text,
)
# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
data = response.json()
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
class APIResponse(BaseAPIResponse[R]):
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
def parse(self, *, to: type[_T]) -> _T: ...
@overload
def parse(self) -> R: ...
def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from openai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `int`
- `float`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
self.read()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
if isinstance(parsed, BaseModel):
add_request_id(parsed, self.request_id)
self._parsed_by_type[cache_key] = parsed
return cast(R, parsed)
def read(self) -> bytes:
"""Read and return the binary response content."""
try:
return self.http_response.read()
except httpx.StreamConsumed as exc:
# The default error raised by httpx isn't very
# helpful in our case so we re-raise it with
# a different error message.
raise StreamAlreadyConsumed() from exc
def text(self) -> str:
"""Read and decode the response content into a string."""
self.read()
return self.http_response.text
def json(self) -> object:
"""Read and decode the JSON response content."""
self.read()
return self.http_response.json()
def close(self) -> None:
"""Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.http_response.close()
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This automatically handles gzip, deflate and brotli encoded responses.
"""
for chunk in self.http_response.iter_bytes(chunk_size):
yield chunk
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
"""A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
for chunk in self.http_response.iter_text(chunk_size):
yield chunk
def iter_lines(self) -> Iterator[str]:
"""Like `iter_text()` but will only yield chunks for each line"""
for chunk in self.http_response.iter_lines():
yield chunk
class AsyncAPIResponse(BaseAPIResponse[R]):
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
async def parse(self, *, to: type[_T]) -> _T: ...
@overload
async def parse(self) -> R: ...
async def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from openai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
await self.read()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
if isinstance(parsed, BaseModel):
add_request_id(parsed, self.request_id)
self._parsed_by_type[cache_key] = parsed
return cast(R, parsed)
async def read(self) -> bytes:
"""Read and return the binary response content."""
try:
return await self.http_response.aread()
except httpx.StreamConsumed as exc:
# the default error raised by httpx isn't very
# helpful in our case so we re-raise it with
# a different error message
raise StreamAlreadyConsumed() from exc
async def text(self) -> str:
"""Read and decode the response content into a string."""
await self.read()
return self.http_response.text
async def json(self) -> object:
"""Read and decode the JSON response content."""
await self.read()
return self.http_response.json()
async def close(self) -> None:
"""Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self.http_response.aclose()
async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This automatically handles gzip, deflate and brotli encoded responses.
"""
async for chunk in self.http_response.aiter_bytes(chunk_size):
yield chunk
async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
"""A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
async for chunk in self.http_response.aiter_text(chunk_size):
yield chunk
async def iter_lines(self) -> AsyncIterator[str]:
"""Like `iter_text()` but will only yield chunks for each line"""
async for chunk in self.http_response.aiter_lines():
yield chunk
class BinaryAPIResponse(APIResponse[bytes]):
"""Subclass of APIResponse providing helpers for dealing with binary data.
Note: If you want to stream the response data instead of eagerly reading it
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
with open(file, mode="wb") as f:
for data in self.iter_bytes():
f.write(data)
class AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]):
"""Subclass of APIResponse providing helpers for dealing with binary data.
Note: If you want to stream the response data instead of eagerly reading it
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
async def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.iter_bytes():
await f.write(data)
class StreamedBinaryAPIResponse(APIResponse[bytes]):
def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
"""Streams the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
"""
with open(file, mode="wb") as f:
for data in self.iter_bytes(chunk_size):
f.write(data)
class AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]):
async def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
"""Streams the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
"""
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.iter_bytes(chunk_size):
await f.write(data)
class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference",
)
class StreamAlreadyConsumed(OpenAIError):
"""
Attempted to read or stream content, but the content has already
been streamed.
This can happen if you use a method like `.iter_lines()` and then attempt
to read th entire response body afterwards, e.g.
```py
response = await client.post(...)
async for line in response.iter_lines():
... # do something with `line`
content = await response.read()
# ^ error
```
If you want this behaviour you'll need to either manually accumulate the response
content or call `await response.read()` before iterating over the stream.
"""
def __init__(self) -> None:
message = (
"Attempted to read or stream some content, but the content has "
"already been streamed. "
"This could be due to attempting to stream the response "
"content more than once."
"\n\n"
"You can fix this by manually accumulating the response content while streaming "
"or by calling `.read()` before starting to stream."
)
super().__init__(message)
class ResponseContextManager(Generic[_APIResponseT]):
"""Context manager for ensuring that a request is not made
until it is entered and that the response will always be closed
when the context manager exits
"""
def __init__(self, request_func: Callable[[], _APIResponseT]) -> None:
self._request_func = request_func
self.__response: _APIResponseT | None = None
def __enter__(self) -> _APIResponseT:
self.__response = self._request_func()
return self.__response
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__response is not None:
self.__response.close()
class AsyncResponseContextManager(Generic[_AsyncAPIResponseT]):
"""Context manager for ensuring that a request is not made
until it is entered and that the response will always be closed
when the context manager exits
"""
def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None:
self._api_request = api_request
self.__response: _AsyncAPIResponseT | None = None
async def __aenter__(self) -> _AsyncAPIResponseT:
self.__response = await self._api_request
return self.__response
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__response is not None:
await self.__response.close()
def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support streaming and returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
kwargs["extra_headers"] = extra_headers
make_request = functools.partial(func, *args, **kwargs)
return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request))
return wrapped
def async_to_streamed_response_wrapper(
func: Callable[P, Awaitable[R]],
) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support streaming and returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
kwargs["extra_headers"] = extra_headers
make_request = func(*args, **kwargs)
return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request))
return wrapped
def to_custom_streamed_response_wrapper(
func: Callable[P, object],
response_cls: type[_APIResponseT],
) -> Callable[P, ResponseContextManager[_APIResponseT]]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support streaming and returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
make_request = functools.partial(func, *args, **kwargs)
return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request))
return wrapped
def async_to_custom_streamed_response_wrapper(
func: Callable[P, Awaitable[object]],
response_cls: type[_AsyncAPIResponseT],
) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support streaming and returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
make_request = func(*args, **kwargs)
return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request))
return wrapped
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
kwargs["extra_headers"] = extra_headers
return cast(APIResponse[R], func(*args, **kwargs))
return wrapped
def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
kwargs["extra_headers"] = extra_headers
return cast(AsyncAPIResponse[R], await func(*args, **kwargs))
return wrapped
def to_custom_raw_response_wrapper(
func: Callable[P, object],
response_cls: type[_APIResponseT],
) -> Callable[P, _APIResponseT]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
return cast(_APIResponseT, func(*args, **kwargs))
return wrapped
def async_to_custom_raw_response_wrapper(
func: Callable[P, Awaitable[object]],
response_cls: type[_AsyncAPIResponseT],
) -> Callable[P, Awaitable[_AsyncAPIResponseT]]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs))
return wrapped
def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
"""Given a type like `APIResponse[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(APIResponse[bytes]):
...
extract_response_type(MyResponse) -> bytes
```
"""
return extract_type_var_from_base(
typ,
generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse, AsyncAPIResponse)),
index=0,
)

View File

@@ -0,0 +1,427 @@
# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
from __future__ import annotations
import json
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, Optional, AsyncIterator, cast
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
import httpx
from ._utils import is_mapping, extract_type_var_from_base
from ._exceptions import APIError
if TYPE_CHECKING:
from ._client import OpenAI, AsyncOpenAI
from ._models import FinalRequestOptions
_T = TypeVar("_T")
class Stream(Generic[_T]):
"""Provides the core interface to iterate over a synchronous stream response."""
response: httpx.Response
_options: Optional[FinalRequestOptions] = None
_decoder: SSEBytesDecoder
def __init__(
self,
*,
cast_to: type[_T],
response: httpx.Response,
client: OpenAI,
options: Optional[FinalRequestOptions] = None,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._options = options
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
def __next__(self) -> _T:
return self._iterator.__next__()
def __iter__(self) -> Iterator[_T]:
for item in self._iterator:
yield item
def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter_bytes(self.response.iter_bytes())
def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
try:
for sse in iterator:
if sse.data.startswith("[DONE]"):
break
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()
if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
yield process_data(
data={"data": data, "event": sse.event}
if self._options is not None and self._options.synthesize_event_and_data
else data,
cast_to=cast_to,
response=response,
)
finally:
# Ensure the response is closed even if the consumer doesn't read all data
response.close()
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.response.close()
class AsyncStream(Generic[_T]):
"""Provides the core interface to iterate over an asynchronous stream response."""
response: httpx.Response
_options: Optional[FinalRequestOptions] = None
_decoder: SSEDecoder | SSEBytesDecoder
def __init__(
self,
*,
cast_to: type[_T],
response: httpx.Response,
client: AsyncOpenAI,
options: Optional[FinalRequestOptions] = None,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._options = options
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
async def __anext__(self) -> _T:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[_T]:
async for item in self._iterator:
yield item
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
yield sse
async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
try:
async for sse in iterator:
if sse.data.startswith("[DONE]"):
break
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()
if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
yield process_data(
data={"data": data, "event": sse.event}
if self._options is not None and self._options.synthesize_event_and_data
else data,
cast_to=cast_to,
response=response,
)
finally:
# Ensure the response is closed even if the consumer doesn't read all data
await response.aclose()
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self.response.aclose()
class ServerSentEvent:
def __init__(
self,
*,
event: str | None = None,
data: str | None = None,
id: str | None = None,
retry: int | None = None,
) -> None:
if data is None:
data = ""
self._id = id
self._data = data
self._event = event or None
self._retry = retry
@property
def event(self) -> str | None:
return self._event
@property
def id(self) -> str | None:
return self._id
@property
def retry(self) -> int | None:
return self._retry
@property
def data(self) -> str:
return self._data
def json(self) -> Any:
return json.loads(self.data)
@override
def __repr__(self) -> str:
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
class SSEDecoder:
_data: list[str]
_event: str | None
_retry: int | None
_last_event_id: str | None
def __init__(self) -> None:
self._event = None
self._data = []
self._last_event_id = None
self._retry = None
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
for chunk in self._iter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
sse = self.decode(line)
if sse:
yield sse
def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
data = b""
for chunk in iterator:
for line in chunk.splitlines(keepends=True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data
async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
async for chunk in self._aiter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
sse = self.decode(line)
if sse:
yield sse
async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
data = b""
async for chunk in iterator:
for line in chunk.splitlines(keepends=True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data
def decode(self, line: str) -> ServerSentEvent | None:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
if not line:
if not self._event and not self._data and not self._last_event_id and self._retry is None:
return None
sse = ServerSentEvent(
event=self._event,
data="\n".join(self._data),
id=self._last_event_id,
retry=self._retry,
)
# NOTE: as per the SSE spec, do not reset last_event_id.
self._event = None
self._data = []
self._retry = None
return sse
if line.startswith(":"):
return None
fieldname, _, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if fieldname == "event":
self._event = value
elif fieldname == "data":
self._data.append(value)
elif fieldname == "id":
if "\0" in value:
pass
else:
self._last_event_id = value
elif fieldname == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.
return None
@runtime_checkable
class SSEBytesDecoder(Protocol):
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
...
def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
...
def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
origin = get_origin(typ) or typ
return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
def extract_stream_chunk_type(
stream_cls: type,
*,
failure_message: str | None = None,
) -> type:
"""Given a type like `Stream[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyStream(Stream[bytes]):
...
extract_stream_chunk_type(MyStream) -> bytes
```
"""
from ._base_client import Stream, AsyncStream
return extract_type_var_from_base(
stream_cls,
index=0,
generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
failure_message=failure_message,
)

View File

@@ -0,0 +1,275 @@
from __future__ import annotations
from os import PathLike
from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
List,
Type,
Tuple,
Union,
Mapping,
TypeVar,
Callable,
Iterable,
Iterator,
Optional,
Sequence,
AsyncIterable,
)
from typing_extensions import (
Set,
Literal,
Protocol,
TypeAlias,
TypedDict,
SupportsIndex,
overload,
override,
runtime_checkable,
)
import httpx
import pydantic
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
if TYPE_CHECKING:
from ._models import BaseModel
from ._response import APIResponse, AsyncAPIResponse
from ._legacy_response import HttpxBinaryResponseContent
Transport = BaseTransport
AsyncTransport = AsyncBaseTransport
Query = Mapping[str, object]
Body = object
AnyMapping = Mapping[str, object]
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")
# Approximates httpx internal ProxiesTypes and RequestFiles types
# while adding support for `PathLike` instances
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
ProxiesTypes = Union[str, Proxy, ProxiesDict]
if TYPE_CHECKING:
Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
# Used for sending raw binary data / streaming data in request bodies
# e.g. for file uploads without multipart encoding
BinaryTypes = Union[bytes, bytearray, IO[bytes], Iterable[bytes]]
AsyncBinaryTypes = Union[bytes, bytearray, IO[bytes], AsyncIterable[bytes]]
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
# duplicate of the above but without our custom file support
HttpxFileContent = Union[IO[bytes], bytes]
HttpxFileTypes = Union[
# file (or bytes)
HttpxFileContent,
# (filename, file (or bytes))
Tuple[Optional[str], HttpxFileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], HttpxFileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
# where ResponseT includes `None`. In order to support directly
# passing `None`, overloads would have to be defined for every
# method that uses `ResponseT` which would lead to an unacceptable
# amount of code duplication and make it unreadable. See _base_client.py
# for example usage.
#
# This unfortunately means that you will either have
# to import this type and pass it explicitly:
#
# from openai import NoneType
# client.get('/foo', cast_to=NoneType)
#
# or build it yourself:
#
# client.get('/foo', cast_to=type(None))
if TYPE_CHECKING:
NoneType: Type[None]
else:
NoneType = type(None)
class RequestOptions(TypedDict, total=False):
headers: Headers
max_retries: int
timeout: float | Timeout | None
params: Query
extra_json: AnyMapping
idempotency_key: str
follow_redirects: bool
synthesize_event_and_data: bool
# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
For parameters with a meaningful None value, we need to distinguish between
the user explicitly passing None, and the user not passing the parameter at
all.
User code shouldn't need to use not_given directly.
For example:
```py
def create(timeout: Timeout | None | NotGiven = not_given): ...
create(timeout=1) # 1s timeout
create(timeout=None) # No timeout
create() # Default timeout behavior
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
not_given = NotGiven()
# for backwards compatibility:
NOT_GIVEN = NotGiven()
class Omit:
"""
To explicitly omit something from being sent in a request, use `omit`.
```py
# as the default `Content-Type` header is `application/json` that will be sent
client.post("/upload/files", files={"file": b"my raw file content"})
# you can't explicitly override the header as it has to be dynamically generated
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={"Content-Type": "multipart/form-data"})
# instead you can remove the default `application/json` header by passing omit
client.post(..., headers={"Content-Type": omit})
```
"""
def __bool__(self) -> Literal[False]:
return False
omit = Omit()
Omittable = Union[_T, Omit]
@runtime_checkable
class ModelBuilderProtocol(Protocol):
@classmethod
def build(
cls: type[_T],
*,
response: Response,
data: object,
) -> _T: ...
Headers = Mapping[str, Union[str, Omit]]
class HeadersLikeProtocol(Protocol):
def get(self, __key: str) -> str | None: ...
HeadersLike = Union[Headers, HeadersLikeProtocol]
ResponseT = TypeVar(
"ResponseT",
bound=Union[
object,
str,
None,
"BaseModel",
List[Any],
Dict[str, Any],
Response,
ModelBuilderProtocol,
"APIResponse[Any]",
"AsyncAPIResponse[Any]",
"HttpxBinaryResponseContent",
],
)
StrBytesIntFloat = Union[str, bytes, int, float]
# Note: copied from Pydantic
# https://github.com/pydantic/pydantic/blob/6f31f8f68ef011f84357330186f603ff295312fd/pydantic/main.py#L79
IncEx: TypeAlias = Union[Set[int], Set[str], Mapping[int, Union["IncEx", bool]], Mapping[str, Union["IncEx", bool]]]
PostParser = Callable[[Any], Any]
@runtime_checkable
class InheritsGeneric(Protocol):
"""Represents a type that has inherited from `Generic`
The `__orig_bases__` property can be used to determine the resolved
type variable for a given base class.
"""
__orig_bases__: tuple[_GenericAlias]
class _GenericAlias(Protocol):
__origin__: type[object]
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth
follow_redirects: bool
_T_co = TypeVar("_T_co", covariant=True)
if TYPE_CHECKING:
# This works because str.__contains__ does not accept object (either in typeshed or at runtime)
# https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285
#
# Note: index() and count() methods are intentionally omitted to allow pyright to properly
# infer TypedDict types when dict literals are used in lists assigned to SequenceNotStr.
class SequenceNotStr(Protocol[_T_co]):
@overload
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
@overload
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
def __contains__(self, value: object, /) -> bool: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T_co]: ...
def __reversed__(self) -> Iterator[_T_co]: ...
else:
# just point this to a normal `Sequence` at runtime to avoid having to special case
# deserializing our custom sequence type
SequenceNotStr = Sequence

View File

@@ -0,0 +1,67 @@
from ._logs import SensitiveHeadersFilter as SensitiveHeadersFilter
from ._sync import asyncify as asyncify
from ._proxy import LazyProxy as LazyProxy
from ._utils import (
flatten as flatten,
is_dict as is_dict,
is_list as is_list,
is_given as is_given,
is_tuple as is_tuple,
json_safe as json_safe,
lru_cache as lru_cache,
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
is_iterable as is_iterable,
is_sequence as is_sequence,
coerce_float as coerce_float,
is_mapping_t as is_mapping_t,
removeprefix as removeprefix,
removesuffix as removesuffix,
extract_files as extract_files,
is_sequence_t as is_sequence_t,
required_args as required_args,
coerce_boolean as coerce_boolean,
coerce_integer as coerce_integer,
file_from_path as file_from_path,
is_azure_client as is_azure_client,
strip_not_given as strip_not_given,
deepcopy_minimal as deepcopy_minimal,
get_async_library as get_async_library,
maybe_coerce_float as maybe_coerce_float,
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
is_async_azure_client as is_async_azure_client,
)
from ._compat import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
)
from ._typing import (
is_list_type as is_list_type,
is_union_type as is_union_type,
extract_type_arg as extract_type_arg,
is_iterable_type as is_iterable_type,
is_required_type as is_required_type,
is_sequence_type as is_sequence_type,
is_annotated_type as is_annotated_type,
is_type_alias_type as is_type_alias_type,
strip_annotated_type as strip_annotated_type,
extract_type_var_from_base as extract_type_var_from_base,
)
from ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator
from ._transform import (
PropertyInfo as PropertyInfo,
transform as transform,
async_transform as async_transform,
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
)
from ._reflection import (
function_has_argument as function_has_argument,
assert_signatures_in_sync as assert_signatures_in_sync,
)
from ._datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import sys
import typing_extensions
from typing import Any, Type, Union, Literal, Optional
from datetime import date, datetime
from typing_extensions import get_args as _get_args, get_origin as _get_origin
from .._types import StrBytesIntFloat
from ._datetime_parse import parse_date as _parse_date, parse_datetime as _parse_datetime
_LITERAL_TYPES = {Literal, typing_extensions.Literal}
def get_args(tp: type[Any]) -> tuple[Any, ...]:
return _get_args(tp)
def get_origin(tp: type[Any]) -> type[Any] | None:
return _get_origin(tp)
def is_union(tp: Optional[Type[Any]]) -> bool:
if sys.version_info < (3, 10):
return tp is Union # type: ignore[comparison-overlap]
else:
import types
return tp is Union or tp is types.UnionType # type: ignore[comparison-overlap]
def is_typeddict(tp: Type[Any]) -> bool:
return typing_extensions.is_typeddict(tp)
def is_literal_type(tp: Type[Any]) -> bool:
return get_origin(tp) in _LITERAL_TYPES
def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
return _parse_date(value)
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
return _parse_datetime(value)

View File

@@ -0,0 +1,136 @@
"""
This file contains code from https://github.com/pydantic/pydantic/blob/main/pydantic/v1/datetime_parse.py
without the Pydantic v1 specific errors.
"""
from __future__ import annotations
import re
from typing import Dict, Union, Optional
from datetime import date, datetime, timezone, timedelta
from .._types import StrBytesIntFloat
date_expr = r"(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})"
time_expr = (
r"(?P<hour>\d{1,2}):(?P<minute>\d{1,2})"
r"(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?"
r"(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$"
)
date_re = re.compile(f"{date_expr}$")
datetime_re = re.compile(f"{date_expr}[T ]{time_expr}")
EPOCH = datetime(1970, 1, 1)
# if greater than this, the number is in ms, if less than or equal it's in seconds
# (in seconds this is 11th October 2603, in ms it's 20th August 1970)
MS_WATERSHED = int(2e10)
# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9
MAX_NUMBER = int(3e20)
def _get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]:
if isinstance(value, (int, float)):
return value
try:
return float(value)
except ValueError:
return None
except TypeError:
raise TypeError(f"invalid type; expected {native_expected_type}, string, bytes, int or float") from None
def _from_unix_seconds(seconds: Union[int, float]) -> datetime:
if seconds > MAX_NUMBER:
return datetime.max
elif seconds < -MAX_NUMBER:
return datetime.min
while abs(seconds) > MS_WATERSHED:
seconds /= 1000
dt = EPOCH + timedelta(seconds=seconds)
return dt.replace(tzinfo=timezone.utc)
def _parse_timezone(value: Optional[str]) -> Union[None, int, timezone]:
if value == "Z":
return timezone.utc
elif value is not None:
offset_mins = int(value[-2:]) if len(value) > 3 else 0
offset = 60 * int(value[1:3]) + offset_mins
if value[0] == "-":
offset = -offset
return timezone(timedelta(minutes=offset))
else:
return None
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
"""
Parse a datetime/int/float/string and return a datetime.datetime.
This function supports time zone offsets. When the input contains one,
the output uses a timezone with a fixed offset from UTC.
Raise ValueError if the input is well formatted but not a valid datetime.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, datetime):
return value
number = _get_numeric(value, "datetime")
if number is not None:
return _from_unix_seconds(number)
if isinstance(value, bytes):
value = value.decode()
assert not isinstance(value, (float, int))
match = datetime_re.match(value)
if match is None:
raise ValueError("invalid datetime format")
kw = match.groupdict()
if kw["microsecond"]:
kw["microsecond"] = kw["microsecond"].ljust(6, "0")
tzinfo = _parse_timezone(kw.pop("tzinfo"))
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
kw_["tzinfo"] = tzinfo
return datetime(**kw_) # type: ignore
def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
"""
Parse a date/int/float/string and return a datetime.date.
Raise ValueError if the input is well formatted but not a valid date.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, date):
if isinstance(value, datetime):
return value.date()
else:
return value
number = _get_numeric(value, "date")
if number is not None:
return _from_unix_seconds(number).date()
if isinstance(value, bytes):
value = value.decode()
assert not isinstance(value, (float, int))
match = date_re.match(value)
if match is None:
raise ValueError("invalid date format")
kw = {k: int(v) for k, v in match.groupdict().items()}
try:
return date(**kw)
except ValueError:
raise ValueError("invalid date format") from None

View File

@@ -0,0 +1,35 @@
import json
from typing import Any
from datetime import datetime
from typing_extensions import override
import pydantic
from .._compat import model_dump
def openapi_dumps(obj: Any) -> bytes:
"""
Serialize an object to UTF-8 encoded JSON bytes.
Extends the standard json.dumps with support for additional types
commonly used in the SDK, such as `datetime`, `pydantic.BaseModel`, etc.
"""
return json.dumps(
obj,
cls=_CustomEncoder,
# Uses the same defaults as httpx's JSON serialization
ensure_ascii=False,
separators=(",", ":"),
allow_nan=False,
).encode()
class _CustomEncoder(json.JSONEncoder):
@override
def default(self, o: Any) -> Any:
if isinstance(o, datetime):
return o.isoformat()
if isinstance(o, pydantic.BaseModel):
return model_dump(o, exclude_unset=True, mode="json", by_alias=True)
return super().default(o)

View File

@@ -0,0 +1,42 @@
import os
import logging
from typing_extensions import override
from ._utils import is_dict
logger: logging.Logger = logging.getLogger("openai")
httpx_logger: logging.Logger = logging.getLogger("httpx")
SENSITIVE_HEADERS = {"api-key", "authorization"}
def _basic_config() -> None:
# e.g. [2023-10-05 14:12:26 - openai._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar "200 OK"
logging.basicConfig(
format="[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def setup_logging() -> None:
env = os.environ.get("OPENAI_LOG")
if env == "debug":
_basic_config()
logger.setLevel(logging.DEBUG)
httpx_logger.setLevel(logging.DEBUG)
elif env == "info":
_basic_config()
logger.setLevel(logging.INFO)
httpx_logger.setLevel(logging.INFO)
class SensitiveHeadersFilter(logging.Filter):
@override
def filter(self, record: logging.LogRecord) -> bool:
if is_dict(record.args) and "headers" in record.args and is_dict(record.args["headers"]):
headers = record.args["headers"] = {**record.args["headers"]}
for header in headers:
if str(header).lower() in SENSITIVE_HEADERS:
headers[header] = "<redacted>"
return True

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Iterable, cast
from typing_extensions import override
T = TypeVar("T")
class LazyProxy(Generic[T], ABC):
"""Implements data methods to pretend that an instance is another instance.
This includes forwarding attribute access and other methods.
"""
# Note: we have to special case proxies that themselves return proxies
# to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz`
def __getattr__(self, attr: str) -> object:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return proxied # pyright: ignore
return getattr(proxied, attr)
@override
def __repr__(self) -> str:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return proxied.__class__.__name__
return repr(self.__get_proxied__())
@override
def __str__(self) -> str:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return proxied.__class__.__name__
return str(proxied)
@override
def __dir__(self) -> Iterable[str]:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return []
return proxied.__dir__()
@property # type: ignore
@override
def __class__(self) -> type: # pyright: ignore
try:
proxied = self.__get_proxied__()
except Exception:
return type(self)
if issubclass(type(proxied), LazyProxy):
return type(proxied)
return proxied.__class__
def __get_proxied__(self) -> T:
return self.__load__()
def __as_proxied__(self) -> T:
"""Helper method that returns the current proxy, typed as the loaded object"""
return cast(T, self)
@abstractmethod
def __load__(self) -> T: ...

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import inspect
from typing import Any, Callable
def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
"""Returns whether or not the given function has a specific parameter"""
sig = inspect.signature(func)
return arg_name in sig.parameters
def assert_signatures_in_sync(
source_func: Callable[..., Any],
check_func: Callable[..., Any],
*,
exclude_params: set[str] = set(),
description: str = "",
) -> None:
"""Ensure that the signature of the second function matches the first."""
check_sig = inspect.signature(check_func)
source_sig = inspect.signature(source_func)
errors: list[str] = []
for name, source_param in source_sig.parameters.items():
if name in exclude_params:
continue
custom_param = check_sig.parameters.get(name)
if not custom_param:
errors.append(f"the `{name}` param is missing")
continue
if custom_param.annotation != source_param.annotation:
errors.append(
f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(custom_param.annotation)}"
)
continue
if errors:
raise AssertionError(
f"{len(errors)} errors encountered when comparing signatures{description}:\n\n" + "\n\n".join(errors)
)

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from typing import Any
from typing_extensions import override
from ._proxy import LazyProxy
class ResourcesProxy(LazyProxy[Any]):
"""A proxy for the `openai.resources` module.
This is used so that we can lazily import `openai.resources` only when
needed *and* so that users can just import `openai` and reference `openai.resources`
"""
@override
def __load__(self) -> Any:
import importlib
mod = importlib.import_module("openai.resources")
return mod
resources = ResourcesProxy().__as_proxied__()

View File

@@ -0,0 +1,12 @@
from typing import Any
from typing_extensions import Iterator, AsyncIterator
def consume_sync_iterator(iterator: Iterator[Any]) -> None:
for _ in iterator:
...
async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
async for _ in iterator:
...

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
import asyncio
import functools
from typing import TypeVar, Callable, Awaitable
from typing_extensions import ParamSpec
import anyio
import sniffio
import anyio.to_thread
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
async def to_thread(
func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
) -> T_Retval:
if sniffio.current_async_library() == "asyncio":
return await asyncio.to_thread(func, *args, **kwargs)
return await anyio.to_thread.run_sync(
functools.partial(func, *args, **kwargs),
)
# inspired by `asyncer`, https://github.com/tiangolo/asyncer
def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
"""
Take a blocking function and create an async one that receives the same
positional and keyword arguments.
Usage:
```python
def blocking_func(arg1, arg2, kwarg1=None):
# blocking code
return result
result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1)
```
## Arguments
`function`: a blocking regular callable (e.g. a function)
## Return
An async function that takes the same positional and keyword arguments as the
original one, that when called runs the same original function in a thread worker
and returns the result.
"""
async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
return await to_thread(function, *args, **kwargs)
return wrapper

View File

@@ -0,0 +1,457 @@
from __future__ import annotations
import io
import base64
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
import anyio
import pydantic
from ._utils import (
is_list,
is_given,
lru_cache,
is_mapping,
is_iterable,
is_sequence,
)
from .._files import is_base64_file_input
from ._compat import get_origin, is_typeddict
from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
is_iterable_type,
is_required_type,
is_sequence_type,
is_annotated_type,
strip_annotated_type,
)
_T = TypeVar("_T")
# TODO: support for drilling globals() and locals()
# TODO: ensure works correctly with forward references in all cases
PropertyFormat = Literal["iso8601", "base64", "custom"]
class PropertyInfo:
"""Metadata class to be used in Annotated types to provide information about a given type.
For example:
class MyParams(TypedDict):
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
"""
alias: str | None
format: PropertyFormat | None
format_template: str | None
discriminator: str | None
def __init__(
self,
*,
alias: str | None = None,
format: PropertyFormat | None = None,
format_template: str | None = None,
discriminator: str | None = None,
) -> None:
self.alias = alias
self.format = format
self.format_template = format_template
self.discriminator = discriminator
@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
def maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `transform()` that allows `None` to be passed.
See `transform()` for more details.
"""
if data is None:
return None
return transform(data, expected_type)
# Wrapper over _transform_recursive providing fake types
def transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
@lru_cache(maxsize=8096)
def _get_annotated_type(type_: type) -> type | None:
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
"""
if is_required_type(type_):
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
type_ = get_args(type_)[0]
if is_annotated_type(type_):
return type_
return None
def _maybe_transform_key(key: str, type_: type) -> str:
"""Transform the given `data` based on the annotations provided in `type_`.
Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
"""
annotated_type = _get_annotated_type(type_)
if annotated_type is None:
# no `Annotated` definition for this type, no transformation needed
return key
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
return annotation.alias
return key
def _no_transform_needed(annotation: type) -> bool:
return annotation == float or annotation == int
def _transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
from .._compat import model_dump
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return _transform_typeddict(data, stripped_type)
if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
# Sequence[T]
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
if isinstance(data, dict):
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json", exclude=getattr(data, "__api_exclude__", None))
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return _format_data(data, annotation.format, annotation.format_template)
return data
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = data.read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
def _transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
if not is_given(value):
# we don't need to include omitted values here as they'll
# be stripped out before the request is sent anyway
continue
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
return result
async def async_maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `async_transform()` that allows `None` to be passed.
See `async_transform()` for more details.
"""
if data is None:
return None
return await async_transform(data, expected_type)
async def async_transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
async def _async_transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
from .._compat import model_dump
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return await _async_transform_typeddict(data, stripped_type)
if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
# Sequence[T]
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
if isinstance(data, dict):
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json")
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)
return data
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = await anyio.Path(data).read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
async def _async_transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
if not is_given(value):
# we don't need to include omitted values here as they'll
# be stripped out before the request is sent anyway
continue
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result
@lru_cache(maxsize=8096)
def get_type_hints(
obj: Any,
globalns: dict[str, Any] | None = None,
localns: Mapping[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]:
return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
import sys
import typing
import typing_extensions
from typing import Any, TypeVar, Iterable, cast
from collections import abc as _c_abc
from typing_extensions import (
TypeIs,
Required,
Annotated,
get_args,
get_origin,
)
from ._utils import lru_cache
from .._types import InheritsGeneric
from ._compat import is_union as _is_union
def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated
def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list
def is_sequence_type(typ: type) -> bool:
origin = get_origin(typ) or typ
return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence
def is_iterable_type(typ: type) -> bool:
"""If the given type is `typing.Iterable[T]`"""
origin = get_origin(typ) or typ
return origin == Iterable or origin == _c_abc.Iterable
def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))
def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required
def is_typevar(typ: type) -> bool:
# type ignore is required because type checkers
# think this expression will always return False
return type(typ) == TypeVar # type: ignore
_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
if sys.version_info >= (3, 12):
_TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)
def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
"""Return whether the provided argument is an instance of `TypeAliasType`.
```python
type Int = int
is_type_alias_type(Int)
# > True
Str = TypeAliasType("Str", str)
is_type_alias_type(Str)
# > True
```
"""
return isinstance(tp, _TYPE_ALIAS_TYPES)
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
@lru_cache(maxsize=8096)
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))
return typ
def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
def extract_type_var_from_base(
typ: type,
*,
generic_bases: tuple[type, ...],
index: int,
failure_message: str | None = None,
) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(Foo[bytes]):
...
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```
And where a generic subclass is given:
```py
_T = TypeVar('_T')
class MyResponse(Foo[_T]):
...
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains]
# we're given the class directly
return extract_type_arg(typ, index)
# if a subclass is given
# ---
# this is needed as __orig_bases__ is not present in the typeshed stubs
# because it is intended to be for internal use only, however there does
# not seem to be a way to resolve generic TypeVars for inherited subclasses
# without using it.
if isinstance(cls, InheritsGeneric):
target_base_class: Any | None = None
for base in cls.__orig_bases__:
if base.__origin__ in generic_bases:
target_base_class = base
break
if target_base_class is None:
raise RuntimeError(
"Could not find the generic base class;\n"
"This should never happen;\n"
f"Does {cls} inherit from one of {generic_bases} ?"
)
extracted = extract_type_arg(target_base_class, index)
if is_typevar(extracted):
# If the extracted type argument is itself a type variable
# then that means the subclass itself is generic, so we have
# to resolve the type argument from the class itself, not
# the base class.
#
# Note: if there is more than 1 type argument, the subclass could
# change the ordering of the type arguments, this is not currently
# supported.
return extract_type_arg(typ, index)
return extracted
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")

View File

@@ -0,0 +1,437 @@
from __future__ import annotations
import os
import re
import inspect
import functools
from typing import (
TYPE_CHECKING,
Any,
Tuple,
Mapping,
TypeVar,
Callable,
Iterable,
Sequence,
cast,
overload,
)
from pathlib import Path
from datetime import date, datetime
from typing_extensions import TypeGuard
import sniffio
from .._types import Omit, NotGiven, FileTypes, HeadersLike
_T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
if TYPE_CHECKING:
from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
def extract_files(
# TODO: this needs to take Dict but variance issues.....
# create protocol type ?
query: Mapping[str, object],
*,
paths: Sequence[Sequence[str]],
) -> list[tuple[str, FileTypes]]:
"""Recursively extract files from the given dictionary based on specified paths.
A path may look like this ['foo', 'files', '<array>', 'data'].
Note: this mutates the given dictionary.
"""
files: list[tuple[str, FileTypes]] = []
for path in paths:
files.extend(_extract_items(query, path, index=0, flattened_key=None))
return files
def _extract_items(
obj: object,
path: Sequence[str],
*,
index: int,
flattened_key: str | None,
) -> list[tuple[str, FileTypes]]:
try:
key = path[index]
except IndexError:
if not is_given(obj):
# no value was provided - we can safely ignore
return []
# cyclical import
from .._files import assert_is_file_content
# We have exhausted the path, return the entry we found.
assert flattened_key is not None
if is_list(obj):
files: list[tuple[str, FileTypes]] = []
for entry in obj:
assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
files.append((flattened_key + "[]", cast(FileTypes, entry)))
return files
assert_is_file_content(obj, key=flattened_key)
return [(flattened_key, cast(FileTypes, obj))]
index += 1
if is_dict(obj):
try:
# We are at the last entry in the path so we must remove the field
if (len(path)) == index:
item = obj.pop(key)
else:
item = obj[key]
except KeyError:
# Key was not present in the dictionary, this is not indicative of an error
# as the given path may not point to a required field. We also do not want
# to enforce required fields as the API may differ from the spec in some cases.
return []
if flattened_key is None:
flattened_key = key
else:
flattened_key += f"[{key}]"
return _extract_items(
item,
path,
index=index,
flattened_key=flattened_key,
)
elif is_list(obj):
if key != "<array>":
return []
return flatten(
[
_extract_items(
item,
path,
index=index,
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
)
for item in obj
]
)
# Something unexpected was passed, just ignore it.
return []
def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
return not isinstance(obj, NotGiven) and not isinstance(obj, Omit)
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
# care about the contained types we can safely use `object` in its place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
# `is_*_t` is for when you're narrowing a known union type to a specific subset
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
return isinstance(obj, tuple)
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
return isinstance(obj, tuple)
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
return isinstance(obj, Sequence)
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
return isinstance(obj, Sequence)
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
return isinstance(obj, Mapping)
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
return isinstance(obj, Mapping)
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)
def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
return isinstance(obj, Iterable)
def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
- mappings, e.g. `dict`
- list
This is done for performance reasons.
"""
if is_mapping(item):
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
if is_list(item):
return cast(_T, [deepcopy_minimal(entry) for entry in item])
return item
# copied from https://github.com/Rapptz/RoboDanny
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
size = len(seq)
if size == 0:
return ""
if size == 1:
return seq[0]
if size == 2:
return f"{seq[0]} {final} {seq[1]}"
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
def quote(string: str) -> str:
"""Add single quotation marks around the given string. Does *not* do any escaping."""
return f"'{string}'"
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
Useful for enforcing runtime validation of overloaded functions.
Example usage:
```py
@overload
def foo(*, a: str) -> str: ...
@overload
def foo(*, b: bool) -> str: ...
# This enforces the same constraints that a static type checker would
# i.e. that either a or b must be passed to the function
@required_args(["a"], ["b"])
def foo(*, a: str | None = None, b: bool | None = None) -> str: ...
```
"""
def inner(func: CallableT) -> CallableT:
params = inspect.signature(func).parameters
positional = [
name
for name, param in params.items()
if param.kind
in {
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
}
]
@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i, _ in enumerate(args):
try:
given_params.add(positional[i])
except IndexError:
raise TypeError(
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
) from None
for key in kwargs.keys():
given_params.add(key)
for variant in variants:
matches = all((param in given_params for param in variant))
if matches:
break
else: # no break
if len(variants) > 1:
variations = human_join(
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
)
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else:
assert len(variants) > 0
# TODO: this error message is not deterministic
missing = list(set(variants[0]) - given_params)
if len(missing) > 1:
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
else:
msg = f"Missing required argument: {quote(missing[0])}"
raise TypeError(msg)
return func(*args, **kwargs)
return wrapper # type: ignore
return inner
_K = TypeVar("_K")
_V = TypeVar("_V")
@overload
def strip_not_given(obj: None) -> None: ...
@overload
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
@overload
def strip_not_given(obj: object) -> object: ...
def strip_not_given(obj: object | None) -> object:
"""Remove all top-level keys where their values are instances of `NotGiven`"""
if obj is None:
return None
if not is_mapping(obj):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
def coerce_integer(val: str) -> int:
return int(val, base=10)
def coerce_float(val: str) -> float:
return float(val)
def coerce_boolean(val: str) -> bool:
return val == "true" or val == "1" or val == "on"
def maybe_coerce_integer(val: str | None) -> int | None:
if val is None:
return None
return coerce_integer(val)
def maybe_coerce_float(val: str | None) -> float | None:
if val is None:
return None
return coerce_float(val)
def maybe_coerce_boolean(val: str | None) -> bool | None:
if val is None:
return None
return coerce_boolean(val)
def removeprefix(string: str, prefix: str) -> str:
"""Remove a prefix from a string.
Backport of `str.removeprefix` for Python < 3.9
"""
if string.startswith(prefix):
return string[len(prefix) :]
return string
def removesuffix(string: str, suffix: str) -> str:
"""Remove a suffix from a string.
Backport of `str.removesuffix` for Python < 3.9
"""
if string.endswith(suffix):
return string[: -len(suffix)]
return string
def file_from_path(path: str) -> FileTypes:
contents = Path(path).read_bytes()
file_name = os.path.basename(path)
return (file_name, contents)
def get_required_header(headers: HeadersLike, header: str) -> str:
lower_header = header.lower()
if is_mapping_t(headers):
# mypy doesn't understand the type narrowing here
for k, v in headers.items(): # type: ignore
if k.lower() == lower_header and isinstance(v, str):
return v
# to deal with the case where the header looks like Stainless-Event-Id
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
value = headers.get(normalized_header)
if value:
return value
raise ValueError(f"Could not find {header} header")
def get_async_library() -> str:
try:
return sniffio.current_async_library()
except Exception:
return "false"
def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
"""A version of functools.lru_cache that retains the type signature
for the wrapped function arguments.
"""
wrapper = functools.lru_cache( # noqa: TID251
maxsize=maxsize,
)
return cast(Any, wrapper) # type: ignore[no-any-return]
def json_safe(data: object) -> object:
"""Translates a mapping / sequence recursively in the same fashion
as `pydantic` v2's `model_dump(mode="json")`.
"""
if is_mapping(data):
return {json_safe(key): json_safe(value) for key, value in data.items()}
if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
return [json_safe(item) for item in data]
if isinstance(data, (datetime, date)):
return data.isoformat()
return data
def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
from ..lib.azure import AzureOpenAI
return isinstance(client, AzureOpenAI)
def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
from ..lib.azure import AsyncAzureOpenAI
return isinstance(client, AsyncAzureOpenAI)

View File

@@ -0,0 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "openai"
__version__ = "2.26.0" # x-release-please-version

View File

@@ -0,0 +1 @@
from ._cli import main as main

View File

@@ -0,0 +1 @@
from ._main import register_commands as register_commands

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from argparse import ArgumentParser
from . import chat, audio, files, image, models, completions, fine_tuning
def register_commands(parser: ArgumentParser) -> None:
subparsers = parser.add_subparsers(help="All API subcommands")
chat.register(subparsers)
image.register(subparsers)
audio.register(subparsers)
files.register(subparsers)
models.register(subparsers)
completions.register(subparsers)
fine_tuning.register(subparsers)

View File

@@ -0,0 +1,108 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any, Optional, cast
from argparse import ArgumentParser
from .._utils import get_client, print_model
from ..._types import omit
from .._models import BaseModel
from .._progress import BufferReader
from ...types.audio import Transcription
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
# transcriptions
sub = subparser.add_parser("audio.transcriptions.create")
# Required
sub.add_argument("-m", "--model", type=str, default="whisper-1")
sub.add_argument("-f", "--file", type=str, required=True)
# Optional
sub.add_argument("--response-format", type=str)
sub.add_argument("--language", type=str)
sub.add_argument("-t", "--temperature", type=float)
sub.add_argument("--prompt", type=str)
sub.set_defaults(func=CLIAudio.transcribe, args_model=CLITranscribeArgs)
# translations
sub = subparser.add_parser("audio.translations.create")
# Required
sub.add_argument("-f", "--file", type=str, required=True)
# Optional
sub.add_argument("-m", "--model", type=str, default="whisper-1")
sub.add_argument("--response-format", type=str)
# TODO: doesn't seem to be supported by the API
# sub.add_argument("--language", type=str)
sub.add_argument("-t", "--temperature", type=float)
sub.add_argument("--prompt", type=str)
sub.set_defaults(func=CLIAudio.translate, args_model=CLITranslationArgs)
class CLITranscribeArgs(BaseModel):
model: str
file: str
response_format: Optional[str] = None
language: Optional[str] = None
temperature: Optional[float] = None
prompt: Optional[str] = None
class CLITranslationArgs(BaseModel):
model: str
file: str
response_format: Optional[str] = None
language: Optional[str] = None
temperature: Optional[float] = None
prompt: Optional[str] = None
class CLIAudio:
@staticmethod
def transcribe(args: CLITranscribeArgs) -> None:
with open(args.file, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
model = cast(
"Transcription | str",
get_client().audio.transcriptions.create(
file=(args.file, buffer_reader),
model=args.model,
language=args.language or omit,
temperature=args.temperature or omit,
prompt=args.prompt or omit,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
response_format=cast(Any, args.response_format),
),
)
if isinstance(model, str):
sys.stdout.write(model + "\n")
else:
print_model(model)
@staticmethod
def translate(args: CLITranslationArgs) -> None:
with open(args.file, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
model = cast(
"Transcription | str",
get_client().audio.translations.create(
file=(args.file, buffer_reader),
model=args.model,
temperature=args.temperature or omit,
prompt=args.prompt or omit,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
response_format=cast(Any, args.response_format),
),
)
if isinstance(model, str):
sys.stdout.write(model + "\n")
else:
print_model(model)

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from argparse import ArgumentParser
from . import completions
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
completions.register(subparser)

View File

@@ -0,0 +1,160 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, List, Optional, cast
from argparse import ArgumentParser
from typing_extensions import Literal, NamedTuple
from ..._utils import get_client
from ..._models import BaseModel
from ...._streaming import Stream
from ....types.chat import (
ChatCompletionRole,
ChatCompletionChunk,
CompletionCreateParams,
)
from ....types.chat.completion_create_params import (
CompletionCreateParamsStreaming,
CompletionCreateParamsNonStreaming,
)
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("chat.completions.create")
sub._action_groups.pop()
req = sub.add_argument_group("required arguments")
opt = sub.add_argument_group("optional arguments")
req.add_argument(
"-g",
"--message",
action="append",
nargs=2,
metavar=("ROLE", "CONTENT"),
help="A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.",
required=True,
)
req.add_argument(
"-m",
"--model",
help="The model to use.",
required=True,
)
opt.add_argument(
"-n",
"--n",
help="How many completions to generate for the conversation.",
type=int,
)
opt.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate.", type=int)
opt.add_argument(
"-t",
"--temperature",
help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
Mutually exclusive with `top_p`.""",
type=float,
)
opt.add_argument(
"-P",
"--top_p",
help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
Mutually exclusive with `temperature`.""",
type=float,
)
opt.add_argument(
"--stop",
help="A stop sequence at which to stop generating tokens for the message.",
)
opt.add_argument("--stream", help="Stream messages as they're ready.", action="store_true")
sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)
class CLIMessage(NamedTuple):
role: ChatCompletionRole
content: str
class CLIChatCompletionCreateArgs(BaseModel):
message: List[CLIMessage]
model: str
n: Optional[int] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
stop: Optional[str] = None
stream: bool = False
class CLIChatCompletion:
@staticmethod
def create(args: CLIChatCompletionCreateArgs) -> None:
params: CompletionCreateParams = {
"model": args.model,
"messages": [
{"role": cast(Literal["user"], message.role), "content": message.content} for message in args.message
],
# type checkers are not good at inferring union types so we have to set stream afterwards
"stream": False,
}
if args.temperature is not None:
params["temperature"] = args.temperature
if args.stop is not None:
params["stop"] = args.stop
if args.top_p is not None:
params["top_p"] = args.top_p
if args.n is not None:
params["n"] = args.n
if args.stream:
params["stream"] = args.stream # type: ignore
if args.max_tokens is not None:
params["max_tokens"] = args.max_tokens
if args.stream:
return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))
return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))
@staticmethod
def _create(params: CompletionCreateParamsNonStreaming) -> None:
completion = get_client().chat.completions.create(**params)
should_print_header = len(completion.choices) > 1
for choice in completion.choices:
if should_print_header:
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
content = choice.message.content if choice.message.content is not None else "None"
sys.stdout.write(content)
if should_print_header or not content.endswith("\n"):
sys.stdout.write("\n")
sys.stdout.flush()
@staticmethod
def _stream_create(params: CompletionCreateParamsStreaming) -> None:
# cast is required for mypy
stream = cast( # pyright: ignore[reportUnnecessaryCast]
Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)
)
for chunk in stream:
should_print_header = len(chunk.choices) > 1
for choice in chunk.choices:
if should_print_header:
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
content = choice.delta.content or ""
sys.stdout.write(content)
if should_print_header:
sys.stdout.write("\n")
sys.stdout.flush()
sys.stdout.write("\n")

View File

@@ -0,0 +1,173 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Optional, cast
from argparse import ArgumentParser
from functools import partial
from openai.types.completion import Completion
from .._utils import get_client
from ..._types import Omittable, omit
from ..._utils import is_given
from .._errors import CLIError
from .._models import BaseModel
from ..._streaming import Stream
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("completions.create")
# Required
sub.add_argument(
"-m",
"--model",
help="The model to use",
required=True,
)
# Optional
sub.add_argument("-p", "--prompt", help="An optional prompt to complete from")
sub.add_argument("--stream", help="Stream tokens as they're ready.", action="store_true")
sub.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate", type=int)
sub.add_argument(
"-t",
"--temperature",
help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
Mutually exclusive with `top_p`.""",
type=float,
)
sub.add_argument(
"-P",
"--top_p",
help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
Mutually exclusive with `temperature`.""",
type=float,
)
sub.add_argument(
"-n",
"--n",
help="How many sub-completions to generate for each prompt.",
type=int,
)
sub.add_argument(
"--logprobs",
help="Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.",
type=int,
)
sub.add_argument(
"--best_of",
help="Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.",
type=int,
)
sub.add_argument(
"--echo",
help="Echo back the prompt in addition to the completion",
action="store_true",
)
sub.add_argument(
"--frequency_penalty",
help="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
type=float,
)
sub.add_argument(
"--presence_penalty",
help="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
type=float,
)
sub.add_argument("--suffix", help="The suffix that comes after a completion of inserted text.")
sub.add_argument("--stop", help="A stop sequence at which to stop generating tokens.")
sub.add_argument(
"--user",
help="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.",
)
# TODO: add support for logit_bias
sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)
class CLICompletionCreateArgs(BaseModel):
model: str
stream: bool = False
prompt: Optional[str] = None
n: Omittable[int] = omit
stop: Omittable[str] = omit
user: Omittable[str] = omit
echo: Omittable[bool] = omit
suffix: Omittable[str] = omit
best_of: Omittable[int] = omit
top_p: Omittable[float] = omit
logprobs: Omittable[int] = omit
max_tokens: Omittable[int] = omit
temperature: Omittable[float] = omit
presence_penalty: Omittable[float] = omit
frequency_penalty: Omittable[float] = omit
class CLICompletions:
@staticmethod
def create(args: CLICompletionCreateArgs) -> None:
if is_given(args.n) and args.n > 1 and args.stream:
raise CLIError("Can't stream completions with n>1 with the current CLI")
make_request = partial(
get_client().completions.create,
n=args.n,
echo=args.echo,
stop=args.stop,
user=args.user,
model=args.model,
top_p=args.top_p,
prompt=args.prompt,
suffix=args.suffix,
best_of=args.best_of,
logprobs=args.logprobs,
max_tokens=args.max_tokens,
temperature=args.temperature,
presence_penalty=args.presence_penalty,
frequency_penalty=args.frequency_penalty,
)
if args.stream:
return CLICompletions._stream_create(
# mypy doesn't understand the `partial` function but pyright does
cast(Stream[Completion], make_request(stream=True)) # pyright: ignore[reportUnnecessaryCast]
)
return CLICompletions._create(make_request())
@staticmethod
def _create(completion: Completion) -> None:
should_print_header = len(completion.choices) > 1
for choice in completion.choices:
if should_print_header:
sys.stdout.write("===== Completion {} =====\n".format(choice.index))
sys.stdout.write(choice.text)
if should_print_header or not choice.text.endswith("\n"):
sys.stdout.write("\n")
sys.stdout.flush()
@staticmethod
def _stream_create(stream: Stream[Completion]) -> None:
for completion in stream:
should_print_header = len(completion.choices) > 1
for choice in sorted(completion.choices, key=lambda c: c.index):
if should_print_header:
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
sys.stdout.write(choice.text)
if should_print_header:
sys.stdout.write("\n")
sys.stdout.flush()
sys.stdout.write("\n")

View File

@@ -0,0 +1,80 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
from argparse import ArgumentParser
from .._utils import get_client, print_model
from .._models import BaseModel
from .._progress import BufferReader
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("files.create")
sub.add_argument(
"-f",
"--file",
required=True,
help="File to upload",
)
sub.add_argument(
"-p",
"--purpose",
help="Why are you uploading this file? (see https://platform.openai.com/docs/api-reference/ for purposes)",
required=True,
)
sub.set_defaults(func=CLIFile.create, args_model=CLIFileCreateArgs)
sub = subparser.add_parser("files.retrieve")
sub.add_argument("-i", "--id", required=True, help="The files ID")
sub.set_defaults(func=CLIFile.get, args_model=CLIFileCreateArgs)
sub = subparser.add_parser("files.delete")
sub.add_argument("-i", "--id", required=True, help="The files ID")
sub.set_defaults(func=CLIFile.delete, args_model=CLIFileCreateArgs)
sub = subparser.add_parser("files.list")
sub.set_defaults(func=CLIFile.list)
class CLIFileIDArgs(BaseModel):
id: str
class CLIFileCreateArgs(BaseModel):
file: str
purpose: str
class CLIFile:
@staticmethod
def create(args: CLIFileCreateArgs) -> None:
with open(args.file, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
file = get_client().files.create(
file=(args.file, buffer_reader),
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
purpose=cast(Any, args.purpose),
)
print_model(file)
@staticmethod
def get(args: CLIFileIDArgs) -> None:
file = get_client().files.retrieve(file_id=args.id)
print_model(file)
@staticmethod
def delete(args: CLIFileIDArgs) -> None:
file = get_client().files.delete(file_id=args.id)
print_model(file)
@staticmethod
def list() -> None:
files = get_client().files.list()
for file in files:
print_model(file)

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from argparse import ArgumentParser
from . import jobs
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
jobs.register(subparser)

View File

@@ -0,0 +1,170 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING
from argparse import ArgumentParser
from ..._utils import get_client, print_model
from ...._types import Omittable, omit
from ...._utils import is_given
from ..._models import BaseModel
from ....pagination import SyncCursorPage
from ....types.fine_tuning import (
FineTuningJob,
FineTuningJobEvent,
)
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("fine_tuning.jobs.create")
sub.add_argument(
"-m",
"--model",
help="The model to fine-tune.",
required=True,
)
sub.add_argument(
"-F",
"--training-file",
help="The training file to fine-tune the model on.",
required=True,
)
sub.add_argument(
"-H",
"--hyperparameters",
help="JSON string of hyperparameters to use for fine-tuning.",
type=str,
)
sub.add_argument(
"-s",
"--suffix",
help="A suffix to add to the fine-tuned model name.",
)
sub.add_argument(
"-V",
"--validation-file",
help="The validation file to use for fine-tuning.",
)
sub.set_defaults(func=CLIFineTuningJobs.create, args_model=CLIFineTuningJobsCreateArgs)
sub = subparser.add_parser("fine_tuning.jobs.retrieve")
sub.add_argument(
"-i",
"--id",
help="The ID of the fine-tuning job to retrieve.",
required=True,
)
sub.set_defaults(func=CLIFineTuningJobs.retrieve, args_model=CLIFineTuningJobsRetrieveArgs)
sub = subparser.add_parser("fine_tuning.jobs.list")
sub.add_argument(
"-a",
"--after",
help="Identifier for the last job from the previous pagination request. If provided, only jobs created after this job will be returned.",
)
sub.add_argument(
"-l",
"--limit",
help="Number of fine-tuning jobs to retrieve.",
type=int,
)
sub.set_defaults(func=CLIFineTuningJobs.list, args_model=CLIFineTuningJobsListArgs)
sub = subparser.add_parser("fine_tuning.jobs.cancel")
sub.add_argument(
"-i",
"--id",
help="The ID of the fine-tuning job to cancel.",
required=True,
)
sub.set_defaults(func=CLIFineTuningJobs.cancel, args_model=CLIFineTuningJobsCancelArgs)
sub = subparser.add_parser("fine_tuning.jobs.list_events")
sub.add_argument(
"-i",
"--id",
help="The ID of the fine-tuning job to list events for.",
required=True,
)
sub.add_argument(
"-a",
"--after",
help="Identifier for the last event from the previous pagination request. If provided, only events created after this event will be returned.",
)
sub.add_argument(
"-l",
"--limit",
help="Number of fine-tuning job events to retrieve.",
type=int,
)
sub.set_defaults(func=CLIFineTuningJobs.list_events, args_model=CLIFineTuningJobsListEventsArgs)
class CLIFineTuningJobsCreateArgs(BaseModel):
model: str
training_file: str
hyperparameters: Omittable[str] = omit
suffix: Omittable[str] = omit
validation_file: Omittable[str] = omit
class CLIFineTuningJobsRetrieveArgs(BaseModel):
id: str
class CLIFineTuningJobsListArgs(BaseModel):
after: Omittable[str] = omit
limit: Omittable[int] = omit
class CLIFineTuningJobsCancelArgs(BaseModel):
id: str
class CLIFineTuningJobsListEventsArgs(BaseModel):
id: str
after: Omittable[str] = omit
limit: Omittable[int] = omit
class CLIFineTuningJobs:
@staticmethod
def create(args: CLIFineTuningJobsCreateArgs) -> None:
hyperparameters = json.loads(str(args.hyperparameters)) if is_given(args.hyperparameters) else omit
fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.create(
model=args.model,
training_file=args.training_file,
hyperparameters=hyperparameters,
suffix=args.suffix,
validation_file=args.validation_file,
)
print_model(fine_tuning_job)
@staticmethod
def retrieve(args: CLIFineTuningJobsRetrieveArgs) -> None:
fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.retrieve(fine_tuning_job_id=args.id)
print_model(fine_tuning_job)
@staticmethod
def list(args: CLIFineTuningJobsListArgs) -> None:
fine_tuning_jobs: SyncCursorPage[FineTuningJob] = get_client().fine_tuning.jobs.list(
after=args.after or omit, limit=args.limit or omit
)
print_model(fine_tuning_jobs)
@staticmethod
def cancel(args: CLIFineTuningJobsCancelArgs) -> None:
fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.cancel(fine_tuning_job_id=args.id)
print_model(fine_tuning_job)
@staticmethod
def list_events(args: CLIFineTuningJobsListEventsArgs) -> None:
fine_tuning_job_events: SyncCursorPage[FineTuningJobEvent] = get_client().fine_tuning.jobs.list_events(
fine_tuning_job_id=args.id,
after=args.after or omit,
limit=args.limit or omit,
)
print_model(fine_tuning_job_events)

View File

@@ -0,0 +1,139 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
from argparse import ArgumentParser
from .._utils import get_client, print_model
from ..._types import Omit, Omittable, omit
from .._models import BaseModel
from .._progress import BufferReader
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("images.generate")
sub.add_argument("-m", "--model", type=str)
sub.add_argument("-p", "--prompt", type=str, required=True)
sub.add_argument("-n", "--num-images", type=int, default=1)
sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
sub.add_argument("--response-format", type=str, default="url")
sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)
sub = subparser.add_parser("images.edit")
sub.add_argument("-m", "--model", type=str)
sub.add_argument("-p", "--prompt", type=str, required=True)
sub.add_argument("-n", "--num-images", type=int, default=1)
sub.add_argument(
"-I",
"--image",
type=str,
required=True,
help="Image to modify. Should be a local path and a PNG encoded image.",
)
sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
sub.add_argument("--response-format", type=str, default="url")
sub.add_argument(
"-M",
"--mask",
type=str,
required=False,
help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
)
sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)
sub = subparser.add_parser("images.create_variation")
sub.add_argument("-m", "--model", type=str)
sub.add_argument("-n", "--num-images", type=int, default=1)
sub.add_argument(
"-I",
"--image",
type=str,
required=True,
help="Image to modify. Should be a local path and a PNG encoded image.",
)
sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
sub.add_argument("--response-format", type=str, default="url")
sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)
class CLIImageCreateArgs(BaseModel):
prompt: str
num_images: int
size: str
response_format: str
model: Omittable[str] = omit
class CLIImageCreateVariationArgs(BaseModel):
image: str
num_images: int
size: str
response_format: str
model: Omittable[str] = omit
class CLIImageEditArgs(BaseModel):
image: str
num_images: int
size: str
response_format: str
prompt: str
mask: Omittable[str] = omit
model: Omittable[str] = omit
class CLIImage:
@staticmethod
def create(args: CLIImageCreateArgs) -> None:
image = get_client().images.generate(
model=args.model,
prompt=args.prompt,
n=args.num_images,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
size=cast(Any, args.size),
response_format=cast(Any, args.response_format),
)
print_model(image)
@staticmethod
def create_variation(args: CLIImageCreateVariationArgs) -> None:
with open(args.image, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
image = get_client().images.create_variation(
model=args.model,
image=("image", buffer_reader),
n=args.num_images,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
size=cast(Any, args.size),
response_format=cast(Any, args.response_format),
)
print_model(image)
@staticmethod
def edit(args: CLIImageEditArgs) -> None:
with open(args.image, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Image upload progress")
if isinstance(args.mask, Omit):
mask: Omittable[BufferReader] = omit
else:
with open(args.mask, "rb") as file_reader:
mask = BufferReader(file_reader.read(), desc="Mask progress")
image = get_client().images.edit(
model=args.model,
prompt=args.prompt,
image=("image", buffer_reader),
n=args.num_images,
mask=("mask", mask) if not isinstance(mask, Omit) else mask,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
size=cast(Any, args.size),
response_format=cast(Any, args.response_format),
)
print_model(image)

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from argparse import ArgumentParser
from .._utils import get_client, print_model
from .._models import BaseModel
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("models.list")
sub.set_defaults(func=CLIModels.list)
sub = subparser.add_parser("models.retrieve")
sub.add_argument("-i", "--id", required=True, help="The model ID")
sub.set_defaults(func=CLIModels.get, args_model=CLIModelIDArgs)
sub = subparser.add_parser("models.delete")
sub.add_argument("-i", "--id", required=True, help="The model ID")
sub.set_defaults(func=CLIModels.delete, args_model=CLIModelIDArgs)
class CLIModelIDArgs(BaseModel):
id: str
class CLIModels:
@staticmethod
def get(args: CLIModelIDArgs) -> None:
model = get_client().models.retrieve(model=args.id)
print_model(model)
@staticmethod
def delete(args: CLIModelIDArgs) -> None:
model = get_client().models.delete(model=args.id)
print_model(model)
@staticmethod
def list() -> None:
models = get_client().models.list()
for model in models:
print_model(model)

View File

@@ -0,0 +1,233 @@
from __future__ import annotations
import sys
import logging
import argparse
from typing import Any, List, Type, Optional
from typing_extensions import ClassVar
import httpx
import pydantic
import openai
from . import _tools
from .. import _ApiType, __version__
from ._api import register_commands
from ._utils import can_use_http2
from ._errors import CLIError, display_error
from .._compat import PYDANTIC_V1, ConfigDict, model_parse
from .._models import BaseModel
from .._exceptions import APIError
logger = logging.getLogger()
formatter = logging.Formatter("[%(asctime)s] %(message)s")
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(formatter)
logger.addHandler(handler)
class Arguments(BaseModel):
if PYDANTIC_V1:
class Config(pydantic.BaseConfig): # type: ignore
extra: Any = pydantic.Extra.ignore # type: ignore
else:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="ignore",
)
verbosity: int
version: Optional[str] = None
api_key: Optional[str]
api_base: Optional[str]
organization: Optional[str]
proxy: Optional[List[str]]
api_type: Optional[_ApiType] = None
api_version: Optional[str] = None
# azure
azure_endpoint: Optional[str] = None
azure_ad_token: Optional[str] = None
# internal, set by subparsers to parse their specific args
args_model: Optional[Type[BaseModel]] = None
# internal, used so that subparsers can forward unknown arguments
unknown_args: List[str] = []
allow_unknown_args: bool = False
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=None, prog="openai")
parser.add_argument(
"-v",
"--verbose",
action="count",
dest="verbosity",
default=0,
help="Set verbosity.",
)
parser.add_argument("-b", "--api-base", help="What API base url to use.")
parser.add_argument("-k", "--api-key", help="What API key to use.")
parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.")
parser.add_argument(
"-o",
"--organization",
help="Which organization to run as (will use your default organization if not specified)",
)
parser.add_argument(
"-t",
"--api-type",
type=str,
choices=("openai", "azure"),
help="The backend API to call, must be `openai` or `azure`",
)
parser.add_argument(
"--api-version",
help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'",
)
# azure
parser.add_argument(
"--azure-endpoint",
help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'",
)
parser.add_argument(
"--azure-ad-token",
help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id",
)
# prints the package version
parser.add_argument(
"-V",
"--version",
action="version",
version="%(prog)s " + __version__,
)
def help() -> None:
parser.print_help()
parser.set_defaults(func=help)
subparsers = parser.add_subparsers()
sub_api = subparsers.add_parser("api", help="Direct API calls")
register_commands(sub_api)
sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
_tools.register_commands(sub_tools, subparsers)
return parser
def main() -> int:
try:
_main()
except (APIError, CLIError, pydantic.ValidationError) as err:
display_error(err)
return 1
except KeyboardInterrupt:
sys.stderr.write("\n")
return 1
return 0
def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:
# argparse by default will strip out the `--` but we want to keep it for unknown arguments
if "--" in sys.argv:
idx = sys.argv.index("--")
known_args = sys.argv[1:idx]
unknown_args = sys.argv[idx:]
else:
known_args = sys.argv[1:]
unknown_args = []
parsed, remaining_unknown = parser.parse_known_args(known_args)
# append any remaining unknown arguments from the initial parsing
remaining_unknown.extend(unknown_args)
args = model_parse(Arguments, vars(parsed))
if not args.allow_unknown_args:
# we have to parse twice to ensure any unknown arguments
# result in an error if that behaviour is desired
parser.parse_args()
return parsed, args, remaining_unknown
def _main() -> None:
parser = _build_parser()
parsed, args, unknown = _parse_args(parser)
if args.verbosity != 0:
sys.stderr.write("Warning: --verbosity isn't supported yet\n")
proxies: dict[str, httpx.BaseTransport] = {}
if args.proxy is not None:
for proxy in args.proxy:
key = "https://" if proxy.startswith("https") else "http://"
if key in proxies:
raise CLIError(f"Multiple {key} proxies given - only the last one would be used")
proxies[key] = httpx.HTTPTransport(proxy=httpx.Proxy(httpx.URL(proxy)))
http_client = httpx.Client(
mounts=proxies or None,
http2=can_use_http2(),
)
openai.http_client = http_client
if args.organization:
openai.organization = args.organization
if args.api_key:
openai.api_key = args.api_key
if args.api_base:
openai.base_url = args.api_base
# azure
if args.api_type is not None:
openai.api_type = args.api_type
if args.azure_endpoint is not None:
openai.azure_endpoint = args.azure_endpoint
if args.api_version is not None:
openai.api_version = args.api_version
if args.azure_ad_token is not None:
openai.azure_ad_token = args.azure_ad_token
try:
if args.args_model:
parsed.func(
model_parse(
args.args_model,
{
**{
# we omit None values so that they can be defaulted to `NotGiven`
# and we'll strip it from the API request
key: value
for key, value in vars(parsed).items()
if value is not None
},
"unknown_args": unknown,
},
)
)
else:
parsed.func()
finally:
try:
http_client.close()
except Exception:
pass
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
import sys
import pydantic
from ._utils import Colors, organization_info
from .._exceptions import APIError, OpenAIError
class CLIError(OpenAIError): ...
class SilentCLIError(CLIError): ...
def display_error(err: CLIError | APIError | pydantic.ValidationError) -> None:
if isinstance(err, SilentCLIError):
return
sys.stderr.write("{}{}Error:{} {}\n".format(organization_info(), Colors.FAIL, Colors.ENDC, err))

View File

@@ -0,0 +1,17 @@
from typing import Any
from typing_extensions import ClassVar
import pydantic
from .. import _models
from .._compat import PYDANTIC_V1, ConfigDict
class BaseModel(_models.BaseModel):
if PYDANTIC_V1:
class Config(pydantic.BaseConfig): # type: ignore
extra: Any = pydantic.Extra.ignore # type: ignore
arbitrary_types_allowed: bool = True
else:
model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore", arbitrary_types_allowed=True)

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
import io
from typing import Callable
from typing_extensions import override
class CancelledError(Exception):
def __init__(self, msg: str) -> None:
self.msg = msg
super().__init__(msg)
@override
def __str__(self) -> str:
return self.msg
__repr__ = __str__
class BufferReader(io.BytesIO):
def __init__(self, buf: bytes = b"", desc: str | None = None) -> None:
super().__init__(buf)
self._len = len(buf)
self._progress = 0
self._callback = progress(len(buf), desc=desc)
def __len__(self) -> int:
return self._len
@override
def read(self, n: int | None = -1) -> bytes:
chunk = io.BytesIO.read(self, n)
self._progress += len(chunk)
try:
self._callback(self._progress)
except Exception as e: # catches exception from the callback
raise CancelledError("The upload was cancelled: {}".format(e)) from e
return chunk
def progress(total: float, desc: str | None) -> Callable[[float], None]:
import tqdm
meter = tqdm.tqdm(total=total, unit_scale=True, desc=desc)
def incr(progress: float) -> None:
meter.n = progress
if progress == total:
meter.close()
else:
meter.refresh()
return incr
def MB(i: int) -> int:
return int(i // 1024**2)

View File

@@ -0,0 +1 @@
from ._main import register_commands as register_commands

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from argparse import ArgumentParser
from . import migrate, fine_tunes
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register_commands(parser: ArgumentParser, subparser: _SubParsersAction[ArgumentParser]) -> None:
migrate.register(subparser)
namespaced = parser.add_subparsers(title="Tools", help="Convenience client side tools")
fine_tunes.register(namespaced)

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING
from argparse import ArgumentParser
from .._models import BaseModel
from ...lib._validators import (
get_validators,
write_out_file,
read_any_format,
apply_validators,
apply_necessary_remediation,
)
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("fine_tunes.prepare_data")
sub.add_argument(
"-f",
"--file",
required=True,
help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed."
"This should be the local file path.",
)
sub.add_argument(
"-q",
"--quiet",
required=False,
action="store_true",
help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
)
sub.set_defaults(func=prepare_data, args_model=PrepareDataArgs)
class PrepareDataArgs(BaseModel):
file: str
quiet: bool
def prepare_data(args: PrepareDataArgs) -> None:
sys.stdout.write("Analyzing...\n")
fname = args.file
auto_accept = args.quiet
df, remediation = read_any_format(fname)
apply_necessary_remediation(None, remediation)
validators = get_validators()
assert df is not None
apply_validators(
df,
fname,
remediation,
validators,
auto_accept,
write_out_file_func=write_out_file,
)

View File

@@ -0,0 +1,164 @@
from __future__ import annotations
import os
import sys
import shutil
import tarfile
import platform
import subprocess
from typing import TYPE_CHECKING, List
from pathlib import Path
from argparse import ArgumentParser
import httpx
from .._errors import CLIError, SilentCLIError
from .._models import BaseModel
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("migrate")
sub.set_defaults(func=migrate, args_model=MigrateArgs, allow_unknown_args=True)
sub = subparser.add_parser("grit")
sub.set_defaults(func=grit, args_model=GritArgs, allow_unknown_args=True)
class GritArgs(BaseModel):
# internal
unknown_args: List[str] = []
def grit(args: GritArgs) -> None:
grit_path = install()
try:
subprocess.check_call([grit_path, *args.unknown_args])
except subprocess.CalledProcessError:
# stdout and stderr are forwarded by subprocess so an error will already
# have been displayed
raise SilentCLIError() from None
class MigrateArgs(BaseModel):
# internal
unknown_args: List[str] = []
def migrate(args: MigrateArgs) -> None:
grit_path = install()
try:
subprocess.check_call([grit_path, "apply", "openai", *args.unknown_args])
except subprocess.CalledProcessError:
# stdout and stderr are forwarded by subprocess so an error will already
# have been displayed
raise SilentCLIError() from None
# handles downloading the Grit CLI until they provide their own PyPi package
KEYGEN_ACCOUNT = "custodian-dev"
def _cache_dir() -> Path:
xdg = os.environ.get("XDG_CACHE_HOME")
if xdg is not None:
return Path(xdg)
return Path.home() / ".cache"
def _debug(message: str) -> None:
if not os.environ.get("DEBUG"):
return
sys.stdout.write(f"[DEBUG]: {message}\n")
def install() -> Path:
"""Installs the Grit CLI and returns the location of the binary"""
if sys.platform == "win32":
raise CLIError("Windows is not supported yet in the migration CLI")
_debug("Using Grit installer from GitHub")
platform = "apple-darwin" if sys.platform == "darwin" else "unknown-linux-gnu"
dir_name = _cache_dir() / "openai-python"
install_dir = dir_name / ".install"
target_dir = install_dir / "bin"
target_path = target_dir / "grit"
temp_file = target_dir / "grit.tmp"
if target_path.exists():
_debug(f"{target_path} already exists")
sys.stdout.flush()
return target_path
_debug(f"Using Grit CLI path: {target_path}")
target_dir.mkdir(parents=True, exist_ok=True)
if temp_file.exists():
temp_file.unlink()
arch = _get_arch()
_debug(f"Using architecture {arch}")
file_name = f"grit-{arch}-{platform}"
download_url = f"https://github.com/getgrit/gritql/releases/latest/download/{file_name}.tar.gz"
sys.stdout.write(f"Downloading Grit CLI from {download_url}\n")
with httpx.Client() as client:
download_response = client.get(download_url, follow_redirects=True)
if download_response.status_code != 200:
raise CLIError(f"Failed to download Grit CLI from {download_url}")
with open(temp_file, "wb") as file:
for chunk in download_response.iter_bytes():
file.write(chunk)
unpacked_dir = target_dir / "cli-bin"
unpacked_dir.mkdir(parents=True, exist_ok=True)
with tarfile.open(temp_file, "r:gz") as archive:
if sys.version_info >= (3, 12):
archive.extractall(unpacked_dir, filter="data")
else:
archive.extractall(unpacked_dir)
_move_files_recursively(unpacked_dir, target_dir)
shutil.rmtree(unpacked_dir)
os.remove(temp_file)
os.chmod(target_path, 0o755)
sys.stdout.flush()
return target_path
def _move_files_recursively(source_dir: Path, target_dir: Path) -> None:
for item in source_dir.iterdir():
if item.is_file():
item.rename(target_dir / item.name)
elif item.is_dir():
_move_files_recursively(item, target_dir)
def _get_arch() -> str:
architecture = platform.machine().lower()
# Map the architecture names to Grit equivalents
arch_map = {
"x86_64": "x86_64",
"amd64": "x86_64",
"armv7l": "aarch64",
"arm64": "aarch64",
}
return arch_map.get(architecture, architecture)

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import sys
import openai
from .. import OpenAI, _load_client
from .._compat import model_json
from .._models import BaseModel
class Colors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def get_client() -> OpenAI:
return _load_client()
def organization_info() -> str:
organization = openai.organization
if organization is not None:
return "[organization={}] ".format(organization)
return ""
def print_model(model: BaseModel) -> None:
sys.stdout.write(model_json(model, indent=2) + "\n")
def can_use_http2() -> bool:
try:
import h2 # type: ignore # noqa
except ImportError:
return False
return True

View File

@@ -0,0 +1,4 @@
from .microphone import Microphone
from .local_audio_player import LocalAudioPlayer
__all__ = ["Microphone", "LocalAudioPlayer"]

View File

@@ -0,0 +1,165 @@
# mypy: ignore-errors
from __future__ import annotations
import queue
import asyncio
from typing import Any, Union, Callable, AsyncGenerator, cast
from typing_extensions import TYPE_CHECKING
from .. import _legacy_response
from .._extras import numpy as np, sounddevice as sd
from .._response import StreamedBinaryAPIResponse, AsyncStreamedBinaryAPIResponse
if TYPE_CHECKING:
import numpy.typing as npt
SAMPLE_RATE = 24000
class LocalAudioPlayer:
def __init__(
self,
should_stop: Union[Callable[[], bool], None] = None,
):
self.channels = 1
self.dtype = np.float32
self.should_stop = should_stop
async def _tts_response_to_buffer(
self,
response: Union[
_legacy_response.HttpxBinaryResponseContent,
AsyncStreamedBinaryAPIResponse,
StreamedBinaryAPIResponse,
],
) -> npt.NDArray[np.float32]:
chunks: list[bytes] = []
if isinstance(response, _legacy_response.HttpxBinaryResponseContent) or isinstance(
response, StreamedBinaryAPIResponse
):
for chunk in response.iter_bytes(chunk_size=1024):
if chunk:
chunks.append(chunk)
else:
async for chunk in response.iter_bytes(chunk_size=1024):
if chunk:
chunks.append(chunk)
audio_bytes = b"".join(chunks)
audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
audio_np = audio_np.reshape(-1, 1)
return audio_np
async def play(
self,
input: Union[
npt.NDArray[np.int16],
npt.NDArray[np.float32],
_legacy_response.HttpxBinaryResponseContent,
AsyncStreamedBinaryAPIResponse,
StreamedBinaryAPIResponse,
],
) -> None:
audio_content: npt.NDArray[np.float32]
if isinstance(input, np.ndarray):
if input.dtype == np.int16 and self.dtype == np.float32:
audio_content = (input.astype(np.float32) / 32767.0).reshape(-1, self.channels)
elif input.dtype == np.float32:
audio_content = cast("npt.NDArray[np.float32]", input)
else:
raise ValueError(f"Unsupported dtype: {input.dtype}")
else:
audio_content = await self._tts_response_to_buffer(input)
loop = asyncio.get_event_loop()
event = asyncio.Event()
idx = 0
def callback(
outdata: npt.NDArray[np.float32],
frame_count: int,
_time_info: Any,
_status: Any,
):
nonlocal idx
remainder = len(audio_content) - idx
if remainder == 0 or (callable(self.should_stop) and self.should_stop()):
loop.call_soon_threadsafe(event.set)
raise sd.CallbackStop
valid_frames = frame_count if remainder >= frame_count else remainder
outdata[:valid_frames] = audio_content[idx : idx + valid_frames]
outdata[valid_frames:] = 0
idx += valid_frames
stream = sd.OutputStream(
samplerate=SAMPLE_RATE,
callback=callback,
dtype=audio_content.dtype,
channels=audio_content.shape[1],
)
with stream:
await event.wait()
async def play_stream(
self,
buffer_stream: AsyncGenerator[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None], None],
) -> None:
loop = asyncio.get_event_loop()
event = asyncio.Event()
buffer_queue: queue.Queue[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None]] = queue.Queue(maxsize=50)
async def buffer_producer():
async for buffer in buffer_stream:
if buffer is None:
break
await loop.run_in_executor(None, buffer_queue.put, buffer)
await loop.run_in_executor(None, buffer_queue.put, None) # Signal completion
def callback(
outdata: npt.NDArray[np.float32],
frame_count: int,
_time_info: Any,
_status: Any,
):
nonlocal current_buffer, buffer_pos
frames_written = 0
while frames_written < frame_count:
if current_buffer is None or buffer_pos >= len(current_buffer):
try:
current_buffer = buffer_queue.get(timeout=0.1)
if current_buffer is None:
loop.call_soon_threadsafe(event.set)
raise sd.CallbackStop
buffer_pos = 0
if current_buffer.dtype == np.int16 and self.dtype == np.float32:
current_buffer = (current_buffer.astype(np.float32) / 32767.0).reshape(-1, self.channels)
except queue.Empty:
outdata[frames_written:] = 0
return
remaining_frames = len(current_buffer) - buffer_pos
frames_to_write = min(frame_count - frames_written, remaining_frames)
outdata[frames_written : frames_written + frames_to_write] = current_buffer[
buffer_pos : buffer_pos + frames_to_write
]
buffer_pos += frames_to_write
frames_written += frames_to_write
current_buffer = None
buffer_pos = 0
producer_task = asyncio.create_task(buffer_producer())
with sd.OutputStream(
samplerate=SAMPLE_RATE,
channels=self.channels,
dtype=self.dtype,
callback=callback,
):
await event.wait()
await producer_task

View File

@@ -0,0 +1,100 @@
# mypy: ignore-errors
from __future__ import annotations
import io
import time
import wave
import asyncio
from typing import Any, Type, Union, Generic, TypeVar, Callable, overload
from typing_extensions import TYPE_CHECKING, Literal
from .._types import FileTypes, FileContent
from .._extras import numpy as np, sounddevice as sd
if TYPE_CHECKING:
import numpy.typing as npt
SAMPLE_RATE = 24000
DType = TypeVar("DType", bound=np.generic)
class Microphone(Generic[DType]):
def __init__(
self,
channels: int = 1,
dtype: Type[DType] = np.int16,
should_record: Union[Callable[[], bool], None] = None,
timeout: Union[float, None] = None,
):
self.channels = channels
self.dtype = dtype
self.should_record = should_record
self.buffer_chunks = []
self.timeout = timeout
self.has_record_function = callable(should_record)
def _ndarray_to_wav(self, audio_data: npt.NDArray[DType]) -> FileTypes:
buffer: FileContent = io.BytesIO()
with wave.open(buffer, "w") as wav_file:
wav_file.setnchannels(self.channels)
wav_file.setsampwidth(np.dtype(self.dtype).itemsize)
wav_file.setframerate(SAMPLE_RATE)
wav_file.writeframes(audio_data.tobytes())
buffer.seek(0)
return ("audio.wav", buffer, "audio/wav")
@overload
async def record(self, return_ndarray: Literal[True]) -> npt.NDArray[DType]: ...
@overload
async def record(self, return_ndarray: Literal[False]) -> FileTypes: ...
@overload
async def record(self, return_ndarray: None = ...) -> FileTypes: ...
async def record(self, return_ndarray: Union[bool, None] = False) -> Union[npt.NDArray[DType], FileTypes]:
loop = asyncio.get_event_loop()
event = asyncio.Event()
self.buffer_chunks: list[npt.NDArray[DType]] = []
start_time = time.perf_counter()
def callback(
indata: npt.NDArray[DType],
_frame_count: int,
_time_info: Any,
_status: Any,
):
execution_time = time.perf_counter() - start_time
reached_recording_timeout = execution_time > self.timeout if self.timeout is not None else False
if reached_recording_timeout:
loop.call_soon_threadsafe(event.set)
raise sd.CallbackStop
should_be_recording = self.should_record() if callable(self.should_record) else True
if not should_be_recording:
loop.call_soon_threadsafe(event.set)
raise sd.CallbackStop
self.buffer_chunks.append(indata.copy())
stream = sd.InputStream(
callback=callback,
dtype=self.dtype,
samplerate=SAMPLE_RATE,
channels=self.channels,
)
with stream:
await event.wait()
# Concatenate all chunks into a single buffer, handle empty case
concatenated_chunks: npt.NDArray[DType] = (
np.concatenate(self.buffer_chunks, axis=0)
if len(self.buffer_chunks) > 0
else np.array([], dtype=self.dtype)
)
if return_ndarray:
return concatenated_chunks
else:
return self._ndarray_to_wav(concatenated_chunks)

View File

@@ -0,0 +1,4 @@
File generated from our OpenAPI spec by Stainless.
This directory can be used to store custom files to expand the SDK.
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.

View File

@@ -0,0 +1,2 @@
from ._tools import pydantic_function_tool as pydantic_function_tool
from ._parsing import ResponseFormatT as ResponseFormatT

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing_extensions import override
from .._utils import LazyProxy
from .._exceptions import OpenAIError
INSTRUCTIONS = """
You tried to access openai.{symbol}, but this is no longer supported in openai>=1.0.0 - see the README at https://github.com/openai/openai-python for the API.
You can run `openai migrate` to automatically upgrade your codebase to use the 1.0.0 interface.
Alternatively, you can pin your installation to the old version, e.g. `pip install openai==0.28`
A detailed migration guide is available here: https://github.com/openai/openai-python/discussions/742
"""
class APIRemovedInV1(OpenAIError):
def __init__(self, *, symbol: str) -> None:
super().__init__(INSTRUCTIONS.format(symbol=symbol))
class APIRemovedInV1Proxy(LazyProxy[Any]):
def __init__(self, *, symbol: str) -> None:
super().__init__()
self._symbol = symbol
@override
def __load__(self) -> Any:
# return the proxy until it is eventually called so that
# we don't break people that are just checking the attributes
# of a module
return self
def __call__(self, *_args: Any, **_kwargs: Any) -> Any:
raise APIRemovedInV1(symbol=self._symbol)
SYMBOLS = [
"Edit",
"File",
"Audio",
"Image",
"Model",
"Engine",
"Customer",
"FineTune",
"Embedding",
"Completion",
"Deployment",
"Moderation",
"ErrorObject",
"FineTuningJob",
"ChatCompletion",
]
# we explicitly tell type checkers that nothing is exported
# from this file so that when we re-export the old symbols
# in `openai/__init__.py` they aren't added to the auto-complete
# suggestions given by editors
if TYPE_CHECKING:
__all__: list[str] = []
else:
__all__ = SYMBOLS
__locals = locals()
for symbol in SYMBOLS:
__locals[symbol] = APIRemovedInV1Proxy(symbol=symbol)

View File

@@ -0,0 +1,11 @@
from ._completions import (
ResponseFormatT as ResponseFormatT,
has_parseable_input,
has_parseable_input as has_parseable_input,
maybe_parse_content as maybe_parse_content,
validate_input_tools as validate_input_tools,
parse_chat_completion as parse_chat_completion,
get_input_tool_by_name as get_input_tool_by_name,
parse_function_tool_arguments as parse_function_tool_arguments,
type_to_response_format_param as type_to_response_format_param,
)

View File

@@ -0,0 +1,288 @@
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING, Any, Iterable, cast
from typing_extensions import TypeVar, TypeGuard, assert_never
import pydantic
from .._tools import PydanticFunctionTool
from ..._types import Omit, omit
from ..._utils import is_dict, is_given
from ..._compat import PYDANTIC_V1, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
from ...types.chat import (
ParsedChoice,
ChatCompletion,
ParsedFunction,
ParsedChatCompletion,
ChatCompletionMessage,
ParsedFunctionToolCall,
ParsedChatCompletionMessage,
ChatCompletionToolUnionParam,
ChatCompletionFunctionToolParam,
completion_create_params,
)
from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
from ...types.shared_params import FunctionDefinition
from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
from ...types.chat.chat_completion_message_function_tool_call import Function
ResponseFormatT = TypeVar(
"ResponseFormatT",
# if it isn't given then we don't do any parsing
default=None,
)
_default_response_format: None = None
log: logging.Logger = logging.getLogger("openai.lib.parsing")
def is_strict_chat_completion_tool_param(
tool: ChatCompletionToolUnionParam,
) -> TypeGuard[ChatCompletionFunctionToolParam]:
"""Check if the given tool is a strict ChatCompletionFunctionToolParam."""
if not tool["type"] == "function":
return False
if tool["function"].get("strict") is not True:
return False
return True
def select_strict_chat_completion_tools(
tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
"""Select only the strict ChatCompletionFunctionToolParams from the given tools."""
if not is_given(tools):
return omit
return [t for t in tools if is_strict_chat_completion_tool_param(t)]
def validate_input_tools(
tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
if not is_given(tools):
return omit
for tool in tools:
if tool["type"] != "function":
raise ValueError(
f"Currently only `function` tool types support auto-parsing; Received `{tool['type']}`",
)
strict = tool["function"].get("strict")
if strict is not True:
raise ValueError(
f"`{tool['function']['name']}` is not strict. Only `strict` function tools can be auto-parsed"
)
return cast(Iterable[ChatCompletionFunctionToolParam], tools)
def parse_chat_completion(
*,
response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | Omit,
input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
chat_completion: ChatCompletion | ParsedChatCompletion[object],
) -> ParsedChatCompletion[ResponseFormatT]:
if is_given(input_tools):
input_tools = [t for t in input_tools]
else:
input_tools = []
choices: list[ParsedChoice[ResponseFormatT]] = []
for choice in chat_completion.choices:
if choice.finish_reason == "length":
raise LengthFinishReasonError(completion=chat_completion)
if choice.finish_reason == "content_filter":
raise ContentFilterFinishReasonError()
message = choice.message
tool_calls: list[ParsedFunctionToolCall] = []
if message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.type == "function":
tool_call_dict = tool_call.to_dict()
tool_calls.append(
construct_type_unchecked(
value={
**tool_call_dict,
"function": {
**cast(Any, tool_call_dict["function"]),
"parsed_arguments": parse_function_tool_arguments(
input_tools=input_tools, function=tool_call.function
),
},
},
type_=ParsedFunctionToolCall,
)
)
elif tool_call.type == "custom":
# warn user that custom tool calls are not callable here
log.warning(
"Custom tool calls are not callable. Ignoring tool call: %s - %s",
tool_call.id,
tool_call.custom.name,
stacklevel=2,
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call)
else:
tool_calls.append(tool_call)
choices.append(
construct_type_unchecked(
type_=ParsedChoice[ResponseFormatT],
value={
**choice.to_dict(),
"message": {
**message.to_dict(),
"parsed": maybe_parse_content(
response_format=response_format,
message=message,
),
"tool_calls": tool_calls if tool_calls else None,
},
},
)
)
return construct_type_unchecked(
type_=ParsedChatCompletion[ResponseFormatT],
value={
**chat_completion.to_dict(),
"choices": choices,
},
)
def get_input_tool_by_name(
*, input_tools: list[ChatCompletionToolUnionParam], name: str
) -> ChatCompletionFunctionToolParam | None:
return next((t for t in input_tools if t["type"] == "function" and t.get("function", {}).get("name") == name), None)
def parse_function_tool_arguments(
*, input_tools: list[ChatCompletionToolUnionParam], function: Function | ParsedFunction
) -> object | None:
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
if not input_tool:
return None
input_fn = cast(object, input_tool.get("function"))
if isinstance(input_fn, PydanticFunctionTool):
return model_parse_json(input_fn.model, function.arguments)
input_fn = cast(FunctionDefinition, input_fn)
if not input_fn.get("strict"):
return None
return json.loads(function.arguments) # type: ignore[no-any-return]
def maybe_parse_content(
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
) -> ResponseFormatT | None:
if has_rich_response_format(response_format) and message.content and not message.refusal:
return _parse_content(response_format, message.content)
return None
def has_parseable_input(
*,
response_format: type | ResponseFormatParam | Omit,
input_tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
) -> bool:
if has_rich_response_format(response_format):
return True
for input_tool in input_tools or []:
if is_parseable_tool(input_tool):
return True
return False
def has_rich_response_format(
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
) -> TypeGuard[type[ResponseFormatT]]:
if not is_given(response_format):
return False
if is_response_format_param(response_format):
return False
return True
def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
return is_dict(response_format)
def is_parseable_tool(input_tool: ChatCompletionToolUnionParam) -> bool:
if input_tool["type"] != "function":
return False
input_fn = cast(object, input_tool.get("function"))
if isinstance(input_fn, PydanticFunctionTool):
return True
return cast(FunctionDefinition, input_fn).get("strict") or False
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
if is_basemodel_type(response_format):
return cast(ResponseFormatT, model_parse_json(response_format, content))
if is_dataclass_like_type(response_format):
if PYDANTIC_V1:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
return pydantic.TypeAdapter(response_format).validate_json(content)
raise TypeError(f"Unable to automatically parse response format type {response_format}")
def type_to_response_format_param(
response_format: type | completion_create_params.ResponseFormat | Omit,
) -> ResponseFormatParam | Omit:
if not is_given(response_format):
return omit
if is_response_format_param(response_format):
return response_format
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
# a safe default behaviour but we know that at this point the `response_format`
# can only be a `type`
response_format = cast(type, response_format)
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
if is_basemodel_type(response_format):
name = response_format.__name__
json_schema_type = response_format
elif is_dataclass_like_type(response_format):
name = response_format.__name__
json_schema_type = pydantic.TypeAdapter(response_format)
else:
raise TypeError(f"Unsupported response_format type - {response_format}")
return {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(json_schema_type),
"name": name,
"strict": True,
},
}

View File

@@ -0,0 +1,179 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, List, Iterable, cast
from typing_extensions import TypeVar, assert_never
import pydantic
from .._tools import ResponsesPydanticFunctionTool
from ..._types import Omit
from ..._utils import is_given
from ..._compat import PYDANTIC_V1, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, is_dataclass_like_type
from ._completions import type_to_response_format_param
from ...types.responses import (
Response,
ToolParam,
ParsedContent,
ParsedResponse,
FunctionToolParam,
ParsedResponseOutputItem,
ParsedResponseOutputText,
ResponseFunctionToolCall,
ParsedResponseOutputMessage,
ResponseFormatTextConfigParam,
ParsedResponseFunctionToolCall,
)
from ...types.chat.completion_create_params import ResponseFormat
TextFormatT = TypeVar(
"TextFormatT",
# if it isn't given then we don't do any parsing
default=None,
)
def type_to_text_format_param(type_: type) -> ResponseFormatTextConfigParam:
response_format_dict = type_to_response_format_param(type_)
assert is_given(response_format_dict)
response_format_dict = cast(ResponseFormat, response_format_dict) # pyright: ignore[reportUnnecessaryCast]
assert response_format_dict["type"] == "json_schema"
assert "schema" in response_format_dict["json_schema"]
return {
"type": "json_schema",
"strict": True,
"name": response_format_dict["json_schema"]["name"],
"schema": response_format_dict["json_schema"]["schema"],
}
def parse_response(
*,
text_format: type[TextFormatT] | Omit,
input_tools: Iterable[ToolParam] | Omit | None,
response: Response | ParsedResponse[object],
) -> ParsedResponse[TextFormatT]:
output_list: List[ParsedResponseOutputItem[TextFormatT]] = []
for output in response.output:
if output.type == "message":
content_list: List[ParsedContent[TextFormatT]] = []
for item in output.content:
if item.type != "output_text":
content_list.append(item)
continue
content_list.append(
construct_type_unchecked(
type_=ParsedResponseOutputText[TextFormatT],
value={
**item.to_dict(),
"parsed": parse_text(item.text, text_format=text_format),
},
)
)
output_list.append(
construct_type_unchecked(
type_=ParsedResponseOutputMessage[TextFormatT],
value={
**output.to_dict(),
"content": content_list,
},
)
)
elif output.type == "function_call":
output_list.append(
construct_type_unchecked(
type_=ParsedResponseFunctionToolCall,
value={
**output.to_dict(),
"parsed_arguments": parse_function_tool_arguments(
input_tools=input_tools, function_call=output
),
},
)
)
elif (
output.type == "computer_call"
or output.type == "file_search_call"
or output.type == "web_search_call"
or output.type == "tool_search_call"
or output.type == "tool_search_output"
or output.type == "reasoning"
or output.type == "compaction"
or output.type == "mcp_call"
or output.type == "mcp_approval_request"
or output.type == "image_generation_call"
or output.type == "code_interpreter_call"
or output.type == "local_shell_call"
or output.type == "shell_call"
or output.type == "shell_call_output"
or output.type == "apply_patch_call"
or output.type == "apply_patch_call_output"
or output.type == "mcp_list_tools"
or output.type == "exec"
or output.type == "custom_tool_call"
):
output_list.append(output)
elif TYPE_CHECKING: # type: ignore
assert_never(output)
else:
output_list.append(output)
return construct_type_unchecked(
type_=ParsedResponse[TextFormatT],
value={
**response.to_dict(),
"output": output_list,
},
)
def parse_text(text: str, text_format: type[TextFormatT] | Omit) -> TextFormatT | None:
if not is_given(text_format):
return None
if is_basemodel_type(text_format):
return cast(TextFormatT, model_parse_json(text_format, text))
if is_dataclass_like_type(text_format):
if PYDANTIC_V1:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {text_format}")
return pydantic.TypeAdapter(text_format).validate_json(text)
raise TypeError(f"Unable to automatically parse response format type {text_format}")
def get_input_tool_by_name(*, input_tools: Iterable[ToolParam], name: str) -> FunctionToolParam | None:
for tool in input_tools:
if tool["type"] == "function" and tool.get("name") == name:
return tool
return None
def parse_function_tool_arguments(
*,
input_tools: Iterable[ToolParam] | Omit | None,
function_call: ParsedResponseFunctionToolCall | ResponseFunctionToolCall,
) -> object:
if input_tools is None or not is_given(input_tools):
return None
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function_call.name)
if not input_tool:
return None
tool = cast(object, input_tool)
if isinstance(tool, ResponsesPydanticFunctionTool):
return model_parse_json(tool.model, function_call.arguments)
if not input_tool.get("strict"):
return None
return json.loads(function_call.arguments)

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import inspect
from typing import Any, TypeVar
from typing_extensions import TypeGuard
import pydantic
from .._types import NOT_GIVEN
from .._utils import is_dict as _is_dict, is_list
from .._compat import PYDANTIC_V1, model_json_schema
_T = TypeVar("_T")
def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
if inspect.isclass(model) and is_basemodel_type(model):
schema = model_json_schema(model)
elif (not PYDANTIC_V1) and isinstance(model, pydantic.TypeAdapter):
schema = model.json_schema()
else:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
return _ensure_strict_json_schema(schema, path=(), root=schema)
def _ensure_strict_json_schema(
json_schema: object,
*,
path: tuple[str, ...],
root: dict[str, object],
) -> dict[str, Any]:
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
that the API expects.
"""
if not is_dict(json_schema):
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
defs = json_schema.get("$defs")
if is_dict(defs):
for def_name, def_schema in defs.items():
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
definitions = json_schema.get("definitions")
if is_dict(definitions):
for definition_name, definition_schema in definitions.items():
_ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root)
typ = json_schema.get("type")
if typ == "object" and "additionalProperties" not in json_schema:
json_schema["additionalProperties"] = False
# object types
# { 'type': 'object', 'properties': { 'a': {...} } }
properties = json_schema.get("properties")
if is_dict(properties):
json_schema["required"] = [prop for prop in properties.keys()]
json_schema["properties"] = {
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
for key, prop_schema in properties.items()
}
# arrays
# { 'type': 'array', 'items': {...} }
items = json_schema.get("items")
if is_dict(items):
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
# unions
any_of = json_schema.get("anyOf")
if is_list(any_of):
json_schema["anyOf"] = [
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
for i, variant in enumerate(any_of)
]
# intersections
all_of = json_schema.get("allOf")
if is_list(all_of):
if len(all_of) == 1:
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root))
json_schema.pop("allOf")
else:
json_schema["allOf"] = [
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
for i, entry in enumerate(all_of)
]
# strip `None` defaults as there's no meaningful distinction here
# the schema will still be `nullable` and the model will default
# to using `None` anyway
if json_schema.get("default", NOT_GIVEN) is None:
json_schema.pop("default")
# we can't use `$ref`s if there are also other properties defined, e.g.
# `{"$ref": "...", "description": "my description"}`
#
# so we unravel the ref
# `{"type": "string", "description": "my description"}`
ref = json_schema.get("$ref")
if ref and has_more_than_n_keys(json_schema, 1):
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
resolved = resolve_ref(root=root, ref=ref)
if not is_dict(resolved):
raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}")
# properties from the json schema take priority over the ones on the `$ref`
json_schema.update({**resolved, **json_schema})
json_schema.pop("$ref")
# Since the schema expanded from `$ref` might not have `additionalProperties: false` applied,
# we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid.
return _ensure_strict_json_schema(json_schema, path=path, root=root)
return json_schema
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
if not ref.startswith("#/"):
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
path = ref[2:].split("/")
resolved = root
for key in path:
value = resolved[key]
assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
resolved = value
return resolved
def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
if not inspect.isclass(typ):
return False
return issubclass(typ, pydantic.BaseModel)
def is_dataclass_like_type(typ: type) -> bool:
"""Returns True if the given type likely used `@pydantic.dataclass`"""
return hasattr(typ, "__pydantic_config__")
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
# just pretend that we know there are only `str` keys
# as that check is not worth the performance cost
return _is_dict(obj)
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
i = 0
for _ in obj.keys():
i += 1
if i > n:
return True
return False

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
import json
from typing_extensions import override
import httpx
from openai import _legacy_response
from openai._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from openai._utils import maybe_transform, async_maybe_transform
from openai._base_client import make_request_options
from openai.resources.realtime.calls import Calls, AsyncCalls
from openai.types.realtime.realtime_session_create_request_param import RealtimeSessionCreateRequestParam
__all__ = ["_Calls", "_AsyncCalls"]
# Custom code to override the `create` method to have correct behavior with
# application/sdp and multipart/form-data.
# Ideally we can cutover to the generated code this overrides eventually and remove this.
class _Calls(Calls):
@override
def create(
self,
*,
sdp: str,
session: RealtimeSessionCreateRequestParam | Omit = omit,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> _legacy_response.HttpxBinaryResponseContent:
if session is omit:
extra_headers = {"Accept": "application/sdp", "Content-Type": "application/sdp", **(extra_headers or {})}
return self._post(
"/realtime/calls",
content=sdp.encode("utf-8"),
options=make_request_options(extra_headers=extra_headers, extra_query=extra_query, timeout=timeout),
cast_to=_legacy_response.HttpxBinaryResponseContent,
)
extra_headers = {"Accept": "application/sdp", "Content-Type": "multipart/form-data", **(extra_headers or {})}
session_payload = maybe_transform(session, RealtimeSessionCreateRequestParam)
files = [
("sdp", (None, sdp.encode("utf-8"), "application/sdp")),
("session", (None, json.dumps(session_payload).encode("utf-8"), "application/json")),
]
return self._post(
"/realtime/calls",
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=_legacy_response.HttpxBinaryResponseContent,
)
class _AsyncCalls(AsyncCalls):
@override
async def create(
self,
*,
sdp: str,
session: RealtimeSessionCreateRequestParam | Omit = omit,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> _legacy_response.HttpxBinaryResponseContent:
if session is omit:
extra_headers = {"Accept": "application/sdp", "Content-Type": "application/sdp", **(extra_headers or {})}
return await self._post(
"/realtime/calls",
content=sdp.encode("utf-8"),
options=make_request_options(extra_headers=extra_headers, extra_query=extra_query, timeout=timeout),
cast_to=_legacy_response.HttpxBinaryResponseContent,
)
extra_headers = {"Accept": "application/sdp", "Content-Type": "multipart/form-data", **(extra_headers or {})}
session_payload = await async_maybe_transform(session, RealtimeSessionCreateRequestParam)
files = [
("sdp", (None, sdp.encode("utf-8"), "application/sdp")),
("session", (None, json.dumps(session_payload).encode("utf-8"), "application/json")),
]
return await self._post(
"/realtime/calls",
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=_legacy_response.HttpxBinaryResponseContent,
)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from typing import Any, Dict, cast
import pydantic
from ._pydantic import to_strict_json_schema
from ..types.chat import ChatCompletionFunctionToolParam
from ..types.shared_params import FunctionDefinition
from ..types.responses.function_tool_param import FunctionToolParam as ResponsesFunctionToolParam
class PydanticFunctionTool(Dict[str, Any]):
"""Dictionary wrapper so we can pass the given base model
throughout the entire request stack without having to special
case it.
"""
model: type[pydantic.BaseModel]
def __init__(self, defn: FunctionDefinition, model: type[pydantic.BaseModel]) -> None:
super().__init__(defn)
self.model = model
def cast(self) -> FunctionDefinition:
return cast(FunctionDefinition, self)
class ResponsesPydanticFunctionTool(Dict[str, Any]):
model: type[pydantic.BaseModel]
def __init__(self, tool: ResponsesFunctionToolParam, model: type[pydantic.BaseModel]) -> None:
super().__init__(tool)
self.model = model
def cast(self) -> ResponsesFunctionToolParam:
return cast(ResponsesFunctionToolParam, self)
def pydantic_function_tool(
model: type[pydantic.BaseModel],
*,
name: str | None = None, # inferred from class name by default
description: str | None = None, # inferred from class docstring by default
) -> ChatCompletionFunctionToolParam:
if description is None:
# note: we intentionally don't use `.getdoc()` to avoid
# including pydantic's docstrings
description = model.__doc__
function = PydanticFunctionTool(
{
"name": name or model.__name__,
"strict": True,
"parameters": to_strict_json_schema(model),
},
model,
).cast()
if description is not None:
function["description"] = description
return {
"type": "function",
"function": function,
}

View File

@@ -0,0 +1,809 @@
# pyright: basic
from __future__ import annotations
import os
import sys
from typing import Any, TypeVar, Callable, Optional, NamedTuple
from typing_extensions import TypeAlias
from .._extras import pandas as pd
class Remediation(NamedTuple):
name: str
immediate_msg: Optional[str] = None
necessary_msg: Optional[str] = None
necessary_fn: Optional[Callable[[Any], Any]] = None
optional_msg: Optional[str] = None
optional_fn: Optional[Callable[[Any], Any]] = None
error_msg: Optional[str] = None
OptionalDataFrameT = TypeVar("OptionalDataFrameT", bound="Optional[pd.DataFrame]")
def num_examples_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100.
"""
MIN_EXAMPLES = 100
optional_suggestion = (
""
if len(df) >= MIN_EXAMPLES
else ". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples"
)
immediate_msg = f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}"
return Remediation(name="num_examples", immediate_msg=immediate_msg)
def necessary_column_validator(df: pd.DataFrame, necessary_column: str) -> Remediation:
"""
This validator will ensure that the necessary column is present in the dataframe.
"""
def lower_case_column(df: pd.DataFrame, column: Any) -> pd.DataFrame:
cols = [c for c in df.columns if str(c).lower() == column]
df.rename(columns={cols[0]: column.lower()}, inplace=True)
return df
immediate_msg = None
necessary_fn = None
necessary_msg = None
error_msg = None
if necessary_column not in df.columns:
if necessary_column in [str(c).lower() for c in df.columns]:
def lower_case_column_creator(df: pd.DataFrame) -> pd.DataFrame:
return lower_case_column(df, necessary_column)
necessary_fn = lower_case_column_creator
immediate_msg = f"\n- The `{necessary_column}` column/key should be lowercase"
necessary_msg = f"Lower case column name to `{necessary_column}`"
else:
error_msg = f"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry"
return Remediation(
name="necessary_column",
immediate_msg=immediate_msg,
necessary_msg=necessary_msg,
necessary_fn=necessary_fn,
error_msg=error_msg,
)
def additional_column_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
"""
This validator will remove additional columns from the dataframe.
"""
additional_columns = []
necessary_msg = None
immediate_msg = None
necessary_fn = None # type: ignore
if len(df.columns) > 2:
additional_columns = [c for c in df.columns if c not in fields]
warn_message = ""
for ac in additional_columns:
dups = [c for c in additional_columns if ac in c]
if len(dups) > 0:
warn_message += f"\n WARNING: Some of the additional columns/keys contain `{ac}` in their name. These will be ignored, and the column/key `{ac}` will be used instead. This could also result from a duplicate column/key in the provided file."
immediate_msg = f"\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}"
necessary_msg = f"Remove additional columns/keys: {additional_columns}"
def necessary_fn(x: Any) -> Any:
return x[fields]
return Remediation(
name="additional_column",
immediate_msg=immediate_msg,
necessary_msg=necessary_msg,
necessary_fn=necessary_fn,
)
def non_empty_field_validator(df: pd.DataFrame, field: str = "completion") -> Remediation:
"""
This validator will ensure that no completion is empty.
"""
necessary_msg = None
necessary_fn = None # type: ignore
immediate_msg = None
if df[field].apply(lambda x: x == "").any() or df[field].isnull().any():
empty_rows = (df[field] == "") | (df[field].isnull())
empty_indexes = df.reset_index().index[empty_rows].tolist()
immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}"
def necessary_fn(x: Any) -> Any:
return x[x[field] != ""].dropna(subset=[field])
necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s"
return Remediation(
name=f"empty_{field}",
immediate_msg=immediate_msg,
necessary_msg=necessary_msg,
necessary_fn=necessary_fn,
)
def duplicated_rows_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
"""
This validator will suggest to the user to remove duplicate rows if they exist.
"""
duplicated_rows = df.duplicated(subset=fields)
duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
if len(duplicated_indexes) > 0:
immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}"
optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows"
def optional_fn(x: Any) -> Any:
return x.drop_duplicates(subset=fields)
return Remediation(
name="duplicated_rows",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def long_examples_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to the user to remove examples that are too long.
"""
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
ft_type = infer_task_type(df)
if ft_type != "open-ended generation":
def get_long_indexes(d: pd.DataFrame) -> Any:
long_examples = d.apply(lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1)
return d.reset_index().index[long_examples].tolist()
long_indexes = get_long_indexes(df)
if len(long_indexes) > 0:
immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens."
optional_msg = f"Remove {len(long_indexes)} long examples"
def optional_fn(x: Any) -> Any:
long_indexes_to_drop = get_long_indexes(x)
if long_indexes != long_indexes_to_drop:
sys.stdout.write(
f"The indices of the long examples has changed as a result of a previously applied recommendation.\nThe {len(long_indexes_to_drop)} long examples to be dropped are now at the following indices: {long_indexes_to_drop}\n"
)
return x.drop(long_indexes_to_drop)
return Remediation(
name="long_examples",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def common_prompt_suffix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation.
"""
error_msg = None
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
# Find a suffix which is not contained within the prompt otherwise
suggested_suffix = "\n\n### =>\n\n"
suffix_options = [
" ->",
"\n\n###\n\n",
"\n\n===\n\n",
"\n\n---\n\n",
"\n\n===>\n\n",
"\n\n--->\n\n",
]
for suffix_option in suffix_options:
if suffix_option == " ->":
if df.prompt.str.contains("\n").any():
continue
if df.prompt.str.contains(suffix_option, regex=False).any():
continue
suggested_suffix = suffix_option
break
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
ft_type = infer_task_type(df)
if ft_type == "open-ended generation":
return Remediation(name="common_suffix")
def add_suffix(x: Any, suffix: Any) -> Any:
x["prompt"] += suffix
return x
common_suffix = get_common_xfix(df.prompt, xfix="suffix")
if (df.prompt == common_suffix).all():
error_msg = f"All prompts are identical: `{common_suffix}`\nConsider leaving the prompts blank if you want to do open-ended generation, otherwise ensure prompts are different"
return Remediation(name="common_suffix", error_msg=error_msg)
if common_suffix != "":
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
immediate_msg = f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`"
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"
else:
immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty"
if common_suffix == "":
optional_msg = f"Add a suffix separator `{display_suggested_suffix}` to all prompts"
def optional_fn(x: Any) -> Any:
return add_suffix(x, suggested_suffix)
return Remediation(
name="common_completion_suffix",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
error_msg=error_msg,
)
def common_prompt_prefix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to remove a common prefix from the prompt if a long one exist.
"""
MAX_PREFIX_LEN = 12
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
common_prefix = get_common_xfix(df.prompt, xfix="prefix")
if common_prefix == "":
return Remediation(name="common_prefix")
def remove_common_prefix(x: Any, prefix: Any) -> Any:
x["prompt"] = x["prompt"].str[len(prefix) :]
return x
if (df.prompt == common_prefix).all():
# already handled by common_suffix_validator
return Remediation(name="common_prefix")
if common_prefix != "":
immediate_msg = f"\n- All prompts start with prefix `{common_prefix}`"
if MAX_PREFIX_LEN < len(common_prefix):
immediate_msg += ". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion"
optional_msg = f"Remove prefix `{common_prefix}` from all prompts"
def optional_fn(x: Any) -> Any:
return remove_common_prefix(x, common_prefix)
return Remediation(
name="common_prompt_prefix",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def common_completion_prefix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to remove a common prefix from the completion if a long one exist.
"""
MAX_PREFIX_LEN = 5
common_prefix = get_common_xfix(df.completion, xfix="prefix")
ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " "
if len(common_prefix) < MAX_PREFIX_LEN:
return Remediation(name="common_prefix")
def remove_common_prefix(x: Any, prefix: Any, ws_prefix: Any) -> Any:
x["completion"] = x["completion"].str[len(prefix) :]
if ws_prefix:
# keep the single whitespace as prefix
x["completion"] = f" {x['completion']}"
return x
if (df.completion == common_prefix).all():
# already handled by common_suffix_validator
return Remediation(name="common_prefix")
immediate_msg = f"\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix"
optional_msg = f"Remove prefix `{common_prefix}` from all completions"
def optional_fn(x: Any) -> Any:
return remove_common_prefix(x, common_prefix, ws_prefix)
return Remediation(
name="common_completion_prefix",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def common_completion_suffix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.
"""
error_msg = None
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
ft_type = infer_task_type(df)
if ft_type == "open-ended generation" or ft_type == "classification":
return Remediation(name="common_suffix")
common_suffix = get_common_xfix(df.completion, xfix="suffix")
if (df.completion == common_suffix).all():
error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
return Remediation(name="common_suffix", error_msg=error_msg)
# Find a suffix which is not contained within the completion otherwise
suggested_suffix = " [END]"
suffix_options = [
"\n",
".",
" END",
"***",
"+++",
"&&&",
"$$$",
"@@@",
"%%%",
]
for suffix_option in suffix_options:
if df.completion.str.contains(suffix_option, regex=False).any():
continue
suggested_suffix = suffix_option
break
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
def add_suffix(x: Any, suffix: Any) -> Any:
x["completion"] += suffix
return x
if common_suffix != "":
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
immediate_msg = f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
if df.completion.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"
else:
immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples."
if common_suffix == "":
optional_msg = f"Add a suffix ending `{display_suggested_suffix}` to all completions"
def optional_fn(x: Any) -> Any:
return add_suffix(x, suggested_suffix)
return Remediation(
name="common_completion_suffix",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
error_msg=error_msg,
)
def completions_space_start_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization.
"""
def add_space_start(x: Any) -> Any:
x["completion"] = x["completion"].apply(lambda s: ("" if s.startswith(" ") else " ") + s)
return x
optional_msg = None
optional_fn = None
immediate_msg = None
if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ":
immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details"
optional_msg = "Add a whitespace character to the beginning of the completion"
optional_fn = add_space_start
return Remediation(
name="completion_space_start",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def lower_case_validator(df: pd.DataFrame, column: Any) -> Remediation | None:
"""
This validator will suggest to lowercase the column values, if more than a third of letters are uppercase.
"""
def lower_case(x: Any) -> Any:
x[column] = x[column].str.lower()
return x
count_upper = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())).sum()
count_lower = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())).sum()
if count_upper * 2 > count_lower:
return Remediation(
name="lower_case",
immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details",
optional_msg=f"Lowercase all your data in column/key `{column}`",
optional_fn=lower_case,
)
return None
def read_any_format(
fname: str, fields: list[str] = ["prompt", "completion"]
) -> tuple[pd.DataFrame | None, Remediation]:
"""
This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
- for .xlsx it will read the first sheet
- for .txt it will assume completions and split on newline
"""
remediation = None
necessary_msg = None
immediate_msg = None
error_msg = None
df = None
if os.path.isfile(fname):
try:
if fname.lower().endswith(".csv") or fname.lower().endswith(".tsv"):
file_extension_str, separator = ("CSV", ",") if fname.lower().endswith(".csv") else ("TSV", "\t")
immediate_msg = (
f"\n- Based on your file extension, your file is formatted as a {file_extension_str} file"
)
necessary_msg = f"Your format `{file_extension_str}` will be converted to `JSONL`"
df = pd.read_csv(fname, sep=separator, dtype=str).fillna("")
elif fname.lower().endswith(".xlsx"):
immediate_msg = "\n- Based on your file extension, your file is formatted as an Excel file"
necessary_msg = "Your format `XLSX` will be converted to `JSONL`"
xls = pd.ExcelFile(fname)
sheets = xls.sheet_names
if len(sheets) > 1:
immediate_msg += "\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet..."
df = pd.read_excel(fname, dtype=str).fillna("")
elif fname.lower().endswith(".txt"):
immediate_msg = "\n- Based on your file extension, you provided a text file"
necessary_msg = "Your format `TXT` will be converted to `JSONL`"
with open(fname, "r") as f:
content = f.read()
df = pd.DataFrame(
[["", line] for line in content.split("\n")],
columns=fields,
dtype=str,
).fillna("")
elif fname.lower().endswith(".jsonl"):
df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
if len(df) == 1: # type: ignore
# this is NOT what we expect for a .jsonl file
immediate_msg = "\n- Your JSONL file appears to be in a JSON format. Your file will be converted to JSONL format"
necessary_msg = "Your format `JSON` will be converted to `JSONL`"
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
else:
pass # this is what we expect for a .jsonl file
elif fname.lower().endswith(".json"):
try:
# to handle case where .json file is actually a .jsonl file
df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
if len(df) == 1: # type: ignore
# this code path corresponds to a .json file that has one line
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
else:
# this is NOT what we expect for a .json file
immediate_msg = "\n- Your JSON file appears to be in a JSONL format. Your file will be converted to JSONL format"
necessary_msg = "Your format `JSON` will be converted to `JSONL`"
except ValueError:
# this code path corresponds to a .json file that has multiple lines (i.e. it is indented)
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
else:
error_msg = (
"Your file must have one of the following extensions: .CSV, .TSV, .XLSX, .TXT, .JSON or .JSONL"
)
if "." in fname:
error_msg += f" Your file `{fname}` ends with the extension `.{fname.split('.')[-1]}` which is not supported."
else:
error_msg += f" Your file `{fname}` is missing a file extension."
except (ValueError, TypeError):
file_extension_str = fname.split(".")[-1].upper()
error_msg = f"Your file `{fname}` does not appear to be in valid {file_extension_str} format. Please ensure your file is formatted as a valid {file_extension_str} file."
else:
error_msg = f"File {fname} does not exist."
remediation = Remediation(
name="read_any_format",
necessary_msg=necessary_msg,
immediate_msg=immediate_msg,
error_msg=error_msg,
)
return df, remediation
def format_inferrer_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.
It will also suggest to use ada and explain train/validation split benefits.
"""
ft_type = infer_task_type(df)
immediate_msg = None
if ft_type == "classification":
immediate_msg = f"\n- Based on your data it seems like you're trying to fine-tune a model for {ft_type}\n- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training"
return Remediation(name="num_examples", immediate_msg=immediate_msg)
def apply_necessary_remediation(df: OptionalDataFrameT, remediation: Remediation) -> OptionalDataFrameT:
"""
This function will apply a necessary remediation to a dataframe, or print an error message if one exists.
"""
if remediation.error_msg is not None:
sys.stderr.write(f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting...")
sys.exit(1)
if remediation.immediate_msg is not None:
sys.stdout.write(remediation.immediate_msg)
if remediation.necessary_fn is not None:
df = remediation.necessary_fn(df)
return df
def accept_suggestion(input_text: str, auto_accept: bool) -> bool:
sys.stdout.write(input_text)
if auto_accept:
sys.stdout.write("Y\n")
return True
return input().lower() != "n"
def apply_optional_remediation(
df: pd.DataFrame, remediation: Remediation, auto_accept: bool
) -> tuple[pd.DataFrame, bool]:
"""
This function will apply an optional remediation to a dataframe, based on the user input.
"""
optional_applied = False
input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: "
if remediation.optional_msg is not None:
if accept_suggestion(input_text, auto_accept):
assert remediation.optional_fn is not None
df = remediation.optional_fn(df)
optional_applied = True
if remediation.necessary_msg is not None:
sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n")
return df, optional_applied
def estimate_fine_tuning_time(df: pd.DataFrame) -> None:
"""
Estimate the time it'll take to fine-tune the dataset
"""
ft_format = infer_task_type(df)
expected_time = 1.0
if ft_format == "classification":
num_examples = len(df)
expected_time = num_examples * 1.44
else:
size = df.memory_usage(index=True).sum()
expected_time = size * 0.0515
def format_time(time: float) -> str:
if time < 60:
return f"{round(time, 2)} seconds"
elif time < 3600:
return f"{round(time / 60, 2)} minutes"
elif time < 86400:
return f"{round(time / 3600, 2)} hours"
else:
return f"{round(time / 86400, 2)} days"
time_string = format_time(expected_time + 140)
sys.stdout.write(
f"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
)
def get_outfnames(fname: str, split: bool) -> list[str]:
suffixes = ["_train", "_valid"] if split else [""]
i = 0
while True:
index_suffix = f" ({i})" if i > 0 else ""
candidate_fnames = [f"{os.path.splitext(fname)[0]}_prepared{suffix}{index_suffix}.jsonl" for suffix in suffixes]
if not any(os.path.isfile(f) for f in candidate_fnames):
return candidate_fnames
i += 1
def get_classification_hyperparams(df: pd.DataFrame) -> tuple[int, object]:
n_classes = df.completion.nunique()
pos_class = None
if n_classes == 2:
pos_class = df.completion.value_counts().index[0]
return n_classes, pos_class
def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_accept: bool) -> None:
"""
This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file.
For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set.
"""
ft_format = infer_task_type(df)
common_prompt_suffix = get_common_xfix(df.prompt, xfix="suffix")
common_completion_suffix = get_common_xfix(df.completion, xfix="suffix")
split = False
input_text = "- [Recommended] Would you like to split into training and validation set? [Y/n]: "
if ft_format == "classification":
if accept_suggestion(input_text, auto_accept):
split = True
additional_params = ""
common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n")
common_completion_suffix_new_line_handled = common_completion_suffix.replace("\n", "\\n")
optional_ending_string = (
f' Make sure to include `stop=["{common_completion_suffix_new_line_handled}"]` so that the generated texts ends at the expected place.'
if len(common_completion_suffix_new_line_handled) > 0
else ""
)
input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "
if not any_remediations and not split:
sys.stdout.write(
f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{additional_params}\n\nAfter youve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
)
estimate_fine_tuning_time(df)
elif accept_suggestion(input_text, auto_accept):
fnames = get_outfnames(fname, split)
if split:
assert len(fnames) == 2 and "train" in fnames[0] and "valid" in fnames[1]
MAX_VALID_EXAMPLES = 1000
n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))
df_train = df.sample(n=n_train, random_state=42)
df_valid = df.drop(df_train.index)
df_train[["prompt", "completion"]].to_json( # type: ignore
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
)
df_valid[["prompt", "completion"]].to_json(
fnames[1], lines=True, orient="records", force_ascii=False, indent=None
)
n_classes, pos_class = get_classification_hyperparams(df)
additional_params += " --compute_classification_metrics"
if n_classes == 2:
additional_params += f' --classification_positive_class "{pos_class}"'
else:
additional_params += f" --classification_n_classes {n_classes}"
else:
assert len(fnames) == 1
df[["prompt", "completion"]].to_json(
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
)
# Add -v VALID_FILE if we split the file into train / valid
files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))
valid_string = f' -v "{fnames[1]}"' if split else ""
separator_reminder = (
""
if len(common_prompt_suffix_new_line_handled) == 0
else f"After youve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt."
)
sys.stdout.write(
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{fnames[0]}"{valid_string}{additional_params}\n\n{separator_reminder}{optional_ending_string}\n'
)
estimate_fine_tuning_time(df)
else:
sys.stdout.write("Aborting... did not write the file\n")
def infer_task_type(df: pd.DataFrame) -> str:
"""
Infer the likely fine-tuning task type from the data
"""
CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class
if sum(df.prompt.str.len()) == 0:
return "open-ended generation"
if len(df.completion.unique()) < len(df) / CLASSIFICATION_THRESHOLD:
return "classification"
return "conditional generation"
def get_common_xfix(series: Any, xfix: str = "suffix") -> str:
"""
Finds the longest common suffix or prefix of all the values in a series
"""
common_xfix = ""
while True:
common_xfixes = (
series.str[-(len(common_xfix) + 1) :] if xfix == "suffix" else series.str[: len(common_xfix) + 1]
) # first few or last few characters
if common_xfixes.nunique() != 1: # we found the character at which we don't have a unique xfix anymore
break
elif common_xfix == common_xfixes.values[0]: # the entire first row is a prefix of every other row
break
else: # the first or last few characters are still common across all rows - let's try to add one more
common_xfix = common_xfixes.values[0]
return common_xfix
Validator: TypeAlias = "Callable[[pd.DataFrame], Remediation | None]"
def get_validators() -> list[Validator]:
return [
num_examples_validator,
lambda x: necessary_column_validator(x, "prompt"),
lambda x: necessary_column_validator(x, "completion"),
additional_column_validator,
non_empty_field_validator,
format_inferrer_validator,
duplicated_rows_validator,
long_examples_validator,
lambda x: lower_case_validator(x, "prompt"),
lambda x: lower_case_validator(x, "completion"),
common_prompt_suffix_validator,
common_prompt_prefix_validator,
common_completion_prefix_validator,
common_completion_suffix_validator,
completions_space_start_validator,
]
def apply_validators(
df: pd.DataFrame,
fname: str,
remediation: Remediation | None,
validators: list[Validator],
auto_accept: bool,
write_out_file_func: Callable[..., Any],
) -> None:
optional_remediations: list[Remediation] = []
if remediation is not None:
optional_remediations.append(remediation)
for validator in validators:
remediation = validator(df)
if remediation is not None:
optional_remediations.append(remediation)
df = apply_necessary_remediation(df, remediation)
any_optional_or_necessary_remediations = any(
[
remediation
for remediation in optional_remediations
if remediation.optional_msg is not None or remediation.necessary_msg is not None
]
)
any_necessary_applied = any(
[remediation for remediation in optional_remediations if remediation.necessary_msg is not None]
)
any_optional_applied = False
if any_optional_or_necessary_remediations:
sys.stdout.write("\n\nBased on the analysis we will perform the following actions:\n")
for remediation in optional_remediations:
df, optional_applied = apply_optional_remediation(df, remediation, auto_accept)
any_optional_applied = any_optional_applied or optional_applied
else:
sys.stdout.write("\n\nNo remediations found.\n")
any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied
write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)

View File

@@ -0,0 +1,647 @@
from __future__ import annotations
import os
import inspect
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
from typing_extensions import Self, override
import httpx
from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven
from .._utils import is_given, is_mapping
from .._client import OpenAI, AsyncOpenAI
from .._compat import model_copy
from .._models import FinalRequestOptions
from .._streaming import Stream, AsyncStream
from .._exceptions import OpenAIError
from .._base_client import DEFAULT_MAX_RETRIES, BaseClient
_deployments_endpoints = set(
[
"/completions",
"/chat/completions",
"/embeddings",
"/audio/transcriptions",
"/audio/translations",
"/audio/speech",
"/images/generations",
"/images/edits",
]
)
AzureADTokenProvider = Callable[[], str]
AsyncAzureADTokenProvider = Callable[[], "str | Awaitable[str]"]
_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
# we need to use a sentinel API key value for Azure AD
# as we don't want to make the `api_key` in the main client Optional
# and Azure AD tokens may be retrieved on a per-request basis
API_KEY_SENTINEL = "".join(["<", "missing API key", ">"])
class MutuallyExclusiveAuthError(OpenAIError):
def __init__(self) -> None:
super().__init__(
"The `api_key`, `azure_ad_token` and `azure_ad_token_provider` arguments are mutually exclusive; Only one can be passed at a time"
)
class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
_azure_endpoint: httpx.URL | None
_azure_deployment: str | None
@override
def _build_request(
self,
options: FinalRequestOptions,
*,
retries_taken: int = 0,
) -> httpx.Request:
if options.url in _deployments_endpoints and is_mapping(options.json_data):
model = options.json_data.get("model")
if model is not None and "/deployments" not in str(self.base_url.path):
options.url = f"/deployments/{model}{options.url}"
return super()._build_request(options, retries_taken=retries_taken)
@override
def _prepare_url(self, url: str) -> httpx.URL:
"""Adjust the URL if the client was configured with an Azure endpoint + deployment
and the API feature being called is **not** a deployments-based endpoint
(i.e. requires /deployments/deployment-name in the URL path).
"""
if self._azure_deployment and self._azure_endpoint and url not in _deployments_endpoints:
merge_url = httpx.URL(url)
if merge_url.is_relative_url:
merge_raw_path = (
self._azure_endpoint.raw_path.rstrip(b"/") + b"/openai/" + merge_url.raw_path.lstrip(b"/")
)
return self._azure_endpoint.copy_with(raw_path=merge_raw_path)
return merge_url
return super()._prepare_url(url)
class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
@overload
def __init__(
self,
*,
azure_endpoint: str,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | Callable[[], str] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.Client | None = None,
_strict_response_validation: bool = False,
) -> None: ...
@overload
def __init__(
self,
*,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | Callable[[], str] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.Client | None = None,
_strict_response_validation: bool = False,
) -> None: ...
@overload
def __init__(
self,
*,
base_url: str,
api_version: str | None = None,
api_key: str | Callable[[], str] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.Client | None = None,
_strict_response_validation: bool = False,
) -> None: ...
def __init__(
self,
*,
api_version: str | None = None,
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_key: str | Callable[[], str] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
base_url: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.Client | None = None,
_strict_response_validation: bool = False,
) -> None:
"""Construct a new synchronous azure openai client instance.
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `AZURE_OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
- `api_version` from `OPENAI_API_VERSION`
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
Args:
azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
Not supported with Assistants APIs.
"""
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
if azure_ad_token is None:
azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
raise OpenAIError(
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
)
if api_version is None:
api_version = os.environ.get("OPENAI_API_VERSION")
if api_version is None:
raise ValueError(
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
)
if default_query is None:
default_query = {"api-version": api_version}
else:
default_query = {**default_query, "api-version": api_version}
if base_url is None:
if azure_endpoint is None:
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
if azure_endpoint is None:
raise ValueError(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)
if azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
else:
if azure_endpoint is not None:
raise ValueError("base_url and azure_endpoint are mutually exclusive")
if api_key is None:
# define a sentinel value to avoid any typing issues
api_key = API_KEY_SENTINEL
super().__init__(
api_key=api_key,
organization=organization,
project=project,
webhook_secret=webhook_secret,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=http_client,
websocket_base_url=websocket_base_url,
_strict_response_validation=_strict_response_validation,
)
self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
@override
def copy(
self,
*,
api_key: str | Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
api_version: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
return super().copy(
api_key=api_key,
organization=organization,
project=project,
webhook_secret=webhook_secret,
websocket_base_url=websocket_base_url,
base_url=base_url,
timeout=timeout,
http_client=http_client,
max_retries=max_retries,
default_headers=default_headers,
set_default_headers=set_default_headers,
default_query=default_query,
set_default_query=set_default_query,
_extra_kwargs={
"api_version": api_version or self._api_version,
"azure_ad_token": azure_ad_token or self._azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
**_extra_kwargs,
},
)
with_options = copy
def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
return self._azure_ad_token
provider = self._azure_ad_token_provider
if provider is not None:
token = provider()
if not token or not isinstance(token, str): # pyright: ignore[reportUnnecessaryIsInstance]
raise ValueError(
f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
)
return token
return None
@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
options = model_copy(options)
options.headers = headers
azure_ad_token = self._get_azure_ad_token()
if azure_ad_token is not None:
if headers.get("Authorization") is None:
headers["Authorization"] = f"Bearer {azure_ad_token}"
elif self.api_key is not API_KEY_SENTINEL:
if headers.get("api-key") is None:
headers["api-key"] = self.api_key
else:
# should never be hit
raise ValueError("Unable to handle auth")
return options
def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self.api_key and self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
if self.websocket_base_url is not None:
base_url = httpx.URL(self.websocket_base_url)
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
else:
base_url = self._prepare_url("/realtime")
realtime_url = base_url.copy_with(scheme="wss")
url = realtime_url.copy_with(params={**query})
return url, auth_headers
class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
@overload
def __init__(
self,
*,
azure_endpoint: str,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.AsyncClient | None = None,
_strict_response_validation: bool = False,
) -> None: ...
@overload
def __init__(
self,
*,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.AsyncClient | None = None,
_strict_response_validation: bool = False,
) -> None: ...
@overload
def __init__(
self,
*,
base_url: str,
api_version: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.AsyncClient | None = None,
_strict_response_validation: bool = False,
) -> None: ...
def __init__(
self,
*,
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
base_url: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.AsyncClient | None = None,
_strict_response_validation: bool = False,
) -> None:
"""Construct a new asynchronous azure openai client instance.
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `AZURE_OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
- `api_version` from `OPENAI_API_VERSION`
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
Args:
azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
Not supported with Assistants APIs.
"""
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
if azure_ad_token is None:
azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
raise OpenAIError(
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
)
if api_version is None:
api_version = os.environ.get("OPENAI_API_VERSION")
if api_version is None:
raise ValueError(
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
)
if default_query is None:
default_query = {"api-version": api_version}
else:
default_query = {**default_query, "api-version": api_version}
if base_url is None:
if azure_endpoint is None:
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
if azure_endpoint is None:
raise ValueError(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)
if azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
else:
if azure_endpoint is not None:
raise ValueError("base_url and azure_endpoint are mutually exclusive")
if api_key is None:
# define a sentinel value to avoid any typing issues
api_key = API_KEY_SENTINEL
super().__init__(
api_key=api_key,
organization=organization,
project=project,
webhook_secret=webhook_secret,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=http_client,
websocket_base_url=websocket_base_url,
_strict_response_validation=_strict_response_validation,
)
self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
@override
def copy(
self,
*,
api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
api_version: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
return super().copy(
api_key=api_key,
organization=organization,
project=project,
webhook_secret=webhook_secret,
websocket_base_url=websocket_base_url,
base_url=base_url,
timeout=timeout,
http_client=http_client,
max_retries=max_retries,
default_headers=default_headers,
set_default_headers=set_default_headers,
default_query=default_query,
set_default_query=set_default_query,
_extra_kwargs={
"api_version": api_version or self._api_version,
"azure_ad_token": azure_ad_token or self._azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
**_extra_kwargs,
},
)
with_options = copy
async def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
return self._azure_ad_token
provider = self._azure_ad_token_provider
if provider is not None:
token = provider()
if inspect.isawaitable(token):
token = await token
if not token or not isinstance(cast(Any, token), str):
raise ValueError(
f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
)
return str(token)
return None
@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
options = model_copy(options)
options.headers = headers
azure_ad_token = await self._get_azure_ad_token()
if azure_ad_token is not None:
if headers.get("Authorization") is None:
headers["Authorization"] = f"Bearer {azure_ad_token}"
elif self.api_key is not API_KEY_SENTINEL:
if headers.get("api-key") is None:
headers["api-key"] = self.api_key
else:
# should never be hit
raise ValueError("Unable to handle auth")
return options
async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self.api_key and self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
if self.websocket_base_url is not None:
base_url = httpx.URL(self.websocket_base_url)
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
else:
base_url = self._prepare_url("/realtime")
realtime_url = base_url.copy_with(scheme="wss")
url = realtime_url.copy_with(params={**query})
return url, auth_headers

View File

@@ -0,0 +1,8 @@
from ._assistants import (
AssistantEventHandler as AssistantEventHandler,
AssistantEventHandlerT as AssistantEventHandlerT,
AssistantStreamManager as AssistantStreamManager,
AsyncAssistantEventHandler as AsyncAssistantEventHandler,
AsyncAssistantEventHandlerT as AsyncAssistantEventHandlerT,
AsyncAssistantStreamManager as AsyncAssistantStreamManager,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
from ..._utils import is_dict, is_list
def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
for key, delta_value in delta.items():
if key not in acc:
acc[key] = delta_value
continue
acc_value = acc[key]
if acc_value is None:
acc[key] = delta_value
continue
# the `index` property is used in arrays of objects so it should
# not be accumulated like other values e.g.
# [{'foo': 'bar', 'index': 0}]
#
# the same applies to `type` properties as they're used for
# discriminated unions
if key == "index" or key == "type":
acc[key] = delta_value
continue
if isinstance(acc_value, str) and isinstance(delta_value, str):
acc_value += delta_value
elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
acc_value += delta_value
elif is_dict(acc_value) and is_dict(delta_value):
acc_value = accumulate_delta(acc_value, delta_value)
elif is_list(acc_value) and is_list(delta_value):
# for lists of non-dictionary items we'll only ever get new entries
# in the array, existing entries will never be changed
if all(isinstance(x, (str, int, float)) for x in acc_value):
acc_value.extend(delta_value)
continue
for delta_entry in delta_value:
if not is_dict(delta_entry):
raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
try:
index = delta_entry["index"]
except KeyError as exc:
raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
if not isinstance(index, int):
raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
try:
acc_entry = acc_value[index]
except IndexError:
acc_value.insert(index, delta_entry)
else:
if not is_dict(acc_entry):
raise TypeError("not handled yet")
acc_value[index] = accumulate_delta(acc_entry, delta_entry)
acc[key] = acc_value
return acc

View File

@@ -0,0 +1,27 @@
from ._types import (
ParsedChoiceSnapshot as ParsedChoiceSnapshot,
ParsedChatCompletionSnapshot as ParsedChatCompletionSnapshot,
ParsedChatCompletionMessageSnapshot as ParsedChatCompletionMessageSnapshot,
)
from ._events import (
ChunkEvent as ChunkEvent,
ContentDoneEvent as ContentDoneEvent,
RefusalDoneEvent as RefusalDoneEvent,
ContentDeltaEvent as ContentDeltaEvent,
RefusalDeltaEvent as RefusalDeltaEvent,
LogprobsContentDoneEvent as LogprobsContentDoneEvent,
LogprobsRefusalDoneEvent as LogprobsRefusalDoneEvent,
ChatCompletionStreamEvent as ChatCompletionStreamEvent,
LogprobsContentDeltaEvent as LogprobsContentDeltaEvent,
LogprobsRefusalDeltaEvent as LogprobsRefusalDeltaEvent,
ParsedChatCompletionSnapshot as ParsedChatCompletionSnapshot,
FunctionToolCallArgumentsDoneEvent as FunctionToolCallArgumentsDoneEvent,
FunctionToolCallArgumentsDeltaEvent as FunctionToolCallArgumentsDeltaEvent,
)
from ._completions import (
ChatCompletionStream as ChatCompletionStream,
AsyncChatCompletionStream as AsyncChatCompletionStream,
ChatCompletionStreamState as ChatCompletionStreamState,
ChatCompletionStreamManager as ChatCompletionStreamManager,
AsyncChatCompletionStreamManager as AsyncChatCompletionStreamManager,
)

View File

@@ -0,0 +1,769 @@
from __future__ import annotations
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
from typing_extensions import Self, Iterator, assert_never
from jiter import from_json
from ._types import ParsedChoiceSnapshot, ParsedChatCompletionSnapshot, ParsedChatCompletionMessageSnapshot
from ._events import (
ChunkEvent,
ContentDoneEvent,
RefusalDoneEvent,
ContentDeltaEvent,
RefusalDeltaEvent,
LogprobsContentDoneEvent,
LogprobsRefusalDoneEvent,
ChatCompletionStreamEvent,
LogprobsContentDeltaEvent,
LogprobsRefusalDeltaEvent,
FunctionToolCallArgumentsDoneEvent,
FunctionToolCallArgumentsDeltaEvent,
)
from .._deltas import accumulate_delta
from ...._types import Omit, IncEx, omit
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
from ...._compat import model_dump
from ...._models import build, construct_type
from ..._parsing import (
ResponseFormatT,
has_parseable_input,
maybe_parse_content,
parse_chat_completion,
get_input_tool_by_name,
parse_function_tool_arguments,
)
from ...._streaming import Stream, AsyncStream
from ....types.chat import ChatCompletionChunk, ParsedChatCompletion, ChatCompletionToolUnionParam
from ...._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
from ....types.chat.chat_completion import ChoiceLogprobs
from ....types.chat.chat_completion_chunk import Choice as ChoiceChunk
from ....types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
class ChatCompletionStream(Generic[ResponseFormatT]):
"""Wrapper over the Chat Completions streaming API that adds helpful
events such as `content.done`, supports automatically parsing
responses & tool calls and accumulates a `ChatCompletion` object
from each individual chunk.
https://platform.openai.com/docs/api-reference/streaming
"""
def __init__(
self,
*,
raw_stream: Stream[ChatCompletionChunk],
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
def __next__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
return self._iterator.__next__()
def __iter__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
for item in self._iterator:
yield item
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self._response.close()
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedChatCompletion` object.
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
property will be the content deserialised into that class, if there was any content returned
by the API.
"""
self.until_done()
return self._state.get_final_completion()
def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
consume_sync_iterator(self)
return self
@property
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
return self._state.current_completion_snapshot
def __stream__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
for sse_event in self._raw_stream:
if not _is_valid_chat_completion_chunk_weak(sse_event):
continue
events_to_fire = self._state.handle_chunk(sse_event)
for event in events_to_fire:
yield event
class ChatCompletionStreamManager(Generic[ResponseFormatT]):
"""Context manager over a `ChatCompletionStream` that is returned by `.stream()`.
This context manager ensures the response cannot be leaked if you don't read
the stream to completion.
Usage:
```py
with client.chat.completions.stream(...) as stream:
for event in stream:
...
```
"""
def __init__(
self,
api_request: Callable[[], Stream[ChatCompletionChunk]],
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
) -> None:
self.__stream: ChatCompletionStream[ResponseFormatT] | None = None
self.__api_request = api_request
self.__response_format = response_format
self.__input_tools = input_tools
def __enter__(self) -> ChatCompletionStream[ResponseFormatT]:
raw_stream = self.__api_request()
self.__stream = ChatCompletionStream(
raw_stream=raw_stream,
response_format=self.__response_format,
input_tools=self.__input_tools,
)
return self.__stream
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
self.__stream.close()
class AsyncChatCompletionStream(Generic[ResponseFormatT]):
"""Wrapper over the Chat Completions streaming API that adds helpful
events such as `content.done`, supports automatically parsing
responses & tool calls and accumulates a `ChatCompletion` object
from each individual chunk.
https://platform.openai.com/docs/api-reference/streaming
"""
def __init__(
self,
*,
raw_stream: AsyncStream[ChatCompletionChunk],
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
async def __anext__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
async for item in self._iterator:
yield item
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self._response.aclose()
async def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedChatCompletion` object.
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
property will be the content deserialised into that class, if there was any content returned
by the API.
"""
await self.until_done()
return self._state.get_final_completion()
async def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
await consume_async_iterator(self)
return self
@property
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
return self._state.current_completion_snapshot
async def __stream__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
async for sse_event in self._raw_stream:
if not _is_valid_chat_completion_chunk_weak(sse_event):
continue
events_to_fire = self._state.handle_chunk(sse_event)
for event in events_to_fire:
yield event
class AsyncChatCompletionStreamManager(Generic[ResponseFormatT]):
"""Context manager over a `AsyncChatCompletionStream` that is returned by `.stream()`.
This context manager ensures the response cannot be leaked if you don't read
the stream to completion.
Usage:
```py
async with client.chat.completions.stream(...) as stream:
for event in stream:
...
```
"""
def __init__(
self,
api_request: Awaitable[AsyncStream[ChatCompletionChunk]],
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
) -> None:
self.__stream: AsyncChatCompletionStream[ResponseFormatT] | None = None
self.__api_request = api_request
self.__response_format = response_format
self.__input_tools = input_tools
async def __aenter__(self) -> AsyncChatCompletionStream[ResponseFormatT]:
raw_stream = await self.__api_request
self.__stream = AsyncChatCompletionStream(
raw_stream=raw_stream,
response_format=self.__response_format,
input_tools=self.__input_tools,
)
return self.__stream
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
await self.__stream.close()
class ChatCompletionStreamState(Generic[ResponseFormatT]):
"""Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.
This is useful in cases where you can't always use the `.stream()` method, e.g.
```py
from openai.lib.streaming.chat import ChatCompletionStreamState
state = ChatCompletionStreamState()
stream = client.chat.completions.create(..., stream=True)
for chunk in response:
state.handle_chunk(chunk)
# can also access the accumulated `ChatCompletion` mid-stream
state.current_completion_snapshot
print(state.get_final_completion())
```
"""
def __init__(
self,
*,
input_tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit = omit,
) -> None:
self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
self.__choice_event_states: list[ChoiceEventState] = []
self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
self._response_format = response_format
self._rich_response_format: type | Omit = response_format if inspect.isclass(response_format) else omit
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Parse the final completion object.
Note this does not provide any guarantees that the stream has actually finished, you must
only call this method when the stream is finished.
"""
return parse_chat_completion(
chat_completion=self.current_completion_snapshot,
response_format=self._rich_response_format,
input_tools=self._input_tools,
)
@property
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
assert self.__current_completion_snapshot is not None
return self.__current_completion_snapshot
def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
"""Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
self.__current_completion_snapshot = self._accumulate_chunk(chunk)
return self._build_events(
chunk=chunk,
completion_snapshot=self.__current_completion_snapshot,
)
def _get_choice_state(self, choice: ChoiceChunk) -> ChoiceEventState:
try:
return self.__choice_event_states[choice.index]
except IndexError:
choice_state = ChoiceEventState(input_tools=self._input_tools)
self.__choice_event_states.append(choice_state)
return choice_state
def _accumulate_chunk(self, chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
completion_snapshot = self.__current_completion_snapshot
if completion_snapshot is None:
return _convert_initial_chunk_into_snapshot(chunk)
for choice in chunk.choices:
try:
choice_snapshot = completion_snapshot.choices[choice.index]
previous_tool_calls = choice_snapshot.message.tool_calls or []
choice_snapshot.message = cast(
ParsedChatCompletionMessageSnapshot,
construct_type(
type_=ParsedChatCompletionMessageSnapshot,
value=accumulate_delta(
cast(
"dict[object, object]",
model_dump(
choice_snapshot.message,
# we don't want to serialise / deserialise our custom properties
# as they won't appear in the delta and we don't want to have to
# continuosly reparse the content
exclude=cast(
# cast required as mypy isn't smart enough to infer `True` here to `Literal[True]`
IncEx,
{
"parsed": True,
"tool_calls": {
idx: {"function": {"parsed_arguments": True}}
for idx, _ in enumerate(choice_snapshot.message.tool_calls or [])
},
},
),
),
),
cast("dict[object, object]", choice.delta.to_dict()),
),
),
)
# ensure tools that have already been parsed are added back into the newly
# constructed message snapshot
for tool_index, prev_tool in enumerate(previous_tool_calls):
new_tool = (choice_snapshot.message.tool_calls or [])[tool_index]
if prev_tool.type == "function":
assert new_tool.type == "function"
new_tool.function.parsed_arguments = prev_tool.function.parsed_arguments
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(prev_tool)
except IndexError:
choice_snapshot = cast(
ParsedChoiceSnapshot,
construct_type(
type_=ParsedChoiceSnapshot,
value={
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
"message": choice.delta.to_dict(),
},
),
)
completion_snapshot.choices.append(choice_snapshot)
if choice.finish_reason:
choice_snapshot.finish_reason = choice.finish_reason
if has_parseable_input(response_format=self._response_format, input_tools=self._input_tools):
if choice.finish_reason == "length":
# at the time of writing, `.usage` will always be `None` but
# we include it here in case that is changed in the future
raise LengthFinishReasonError(completion=completion_snapshot)
if choice.finish_reason == "content_filter":
raise ContentFilterFinishReasonError()
if (
choice_snapshot.message.content
and not choice_snapshot.message.refusal
and is_given(self._rich_response_format)
# partial parsing fails on white-space
and choice_snapshot.message.content.lstrip()
):
choice_snapshot.message.parsed = from_json(
bytes(choice_snapshot.message.content, "utf-8"),
partial_mode=True,
)
for tool_call_chunk in choice.delta.tool_calls or []:
tool_call_snapshot = (choice_snapshot.message.tool_calls or [])[tool_call_chunk.index]
if tool_call_snapshot.type == "function":
input_tool = get_input_tool_by_name(
input_tools=self._input_tools, name=tool_call_snapshot.function.name
)
if (
input_tool
and input_tool.get("function", {}).get("strict")
and tool_call_snapshot.function.arguments
):
tool_call_snapshot.function.parsed_arguments = from_json(
bytes(tool_call_snapshot.function.arguments, "utf-8"),
partial_mode=True,
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call_snapshot)
if choice.logprobs is not None:
if choice_snapshot.logprobs is None:
choice_snapshot.logprobs = build(
ChoiceLogprobs,
content=choice.logprobs.content,
refusal=choice.logprobs.refusal,
)
else:
if choice.logprobs.content:
if choice_snapshot.logprobs.content is None:
choice_snapshot.logprobs.content = []
choice_snapshot.logprobs.content.extend(choice.logprobs.content)
if choice.logprobs.refusal:
if choice_snapshot.logprobs.refusal is None:
choice_snapshot.logprobs.refusal = []
choice_snapshot.logprobs.refusal.extend(choice.logprobs.refusal)
completion_snapshot.usage = chunk.usage
completion_snapshot.system_fingerprint = chunk.system_fingerprint
return completion_snapshot
def _build_events(
self,
*,
chunk: ChatCompletionChunk,
completion_snapshot: ParsedChatCompletionSnapshot,
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
events_to_fire.append(
build(ChunkEvent, type="chunk", chunk=chunk, snapshot=completion_snapshot),
)
for choice in chunk.choices:
choice_state = self._get_choice_state(choice)
choice_snapshot = completion_snapshot.choices[choice.index]
if choice.delta.content is not None and choice_snapshot.message.content is not None:
events_to_fire.append(
build(
ContentDeltaEvent,
type="content.delta",
delta=choice.delta.content,
snapshot=choice_snapshot.message.content,
parsed=choice_snapshot.message.parsed,
)
)
if choice.delta.refusal is not None and choice_snapshot.message.refusal is not None:
events_to_fire.append(
build(
RefusalDeltaEvent,
type="refusal.delta",
delta=choice.delta.refusal,
snapshot=choice_snapshot.message.refusal,
)
)
if choice.delta.tool_calls:
tool_calls = choice_snapshot.message.tool_calls
assert tool_calls is not None
for tool_call_delta in choice.delta.tool_calls:
tool_call = tool_calls[tool_call_delta.index]
if tool_call.type == "function":
assert tool_call_delta.function is not None
events_to_fire.append(
build(
FunctionToolCallArgumentsDeltaEvent,
type="tool_calls.function.arguments.delta",
name=tool_call.function.name,
index=tool_call_delta.index,
arguments=tool_call.function.arguments,
parsed_arguments=tool_call.function.parsed_arguments,
arguments_delta=tool_call_delta.function.arguments or "",
)
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call)
if choice.logprobs is not None and choice_snapshot.logprobs is not None:
if choice.logprobs.content and choice_snapshot.logprobs.content:
events_to_fire.append(
build(
LogprobsContentDeltaEvent,
type="logprobs.content.delta",
content=choice.logprobs.content,
snapshot=choice_snapshot.logprobs.content,
),
)
if choice.logprobs.refusal and choice_snapshot.logprobs.refusal:
events_to_fire.append(
build(
LogprobsRefusalDeltaEvent,
type="logprobs.refusal.delta",
refusal=choice.logprobs.refusal,
snapshot=choice_snapshot.logprobs.refusal,
),
)
events_to_fire.extend(
choice_state.get_done_events(
choice_chunk=choice,
choice_snapshot=choice_snapshot,
response_format=self._response_format,
)
)
return events_to_fire
class ChoiceEventState:
def __init__(self, *, input_tools: list[ChatCompletionToolUnionParam]) -> None:
self._input_tools = input_tools
self._content_done = False
self._refusal_done = False
self._logprobs_content_done = False
self._logprobs_refusal_done = False
self._done_tool_calls: set[int] = set()
self.__current_tool_call_index: int | None = None
def get_done_events(
self,
*,
choice_chunk: ChoiceChunk,
choice_snapshot: ParsedChoiceSnapshot,
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
if choice_snapshot.finish_reason:
events_to_fire.extend(
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
)
if (
self.__current_tool_call_index is not None
and self.__current_tool_call_index not in self._done_tool_calls
):
self._add_tool_done_event(
events_to_fire=events_to_fire,
choice_snapshot=choice_snapshot,
tool_index=self.__current_tool_call_index,
)
for tool_call in choice_chunk.delta.tool_calls or []:
if self.__current_tool_call_index != tool_call.index:
events_to_fire.extend(
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
)
if self.__current_tool_call_index is not None:
self._add_tool_done_event(
events_to_fire=events_to_fire,
choice_snapshot=choice_snapshot,
tool_index=self.__current_tool_call_index,
)
self.__current_tool_call_index = tool_call.index
return events_to_fire
def _content_done_events(
self,
*,
choice_snapshot: ParsedChoiceSnapshot,
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
if choice_snapshot.message.content and not self._content_done:
self._content_done = True
parsed = maybe_parse_content(
response_format=response_format,
message=choice_snapshot.message,
)
# update the parsed content to now use the richer `response_format`
# as opposed to the raw JSON-parsed object as the content is now
# complete and can be fully validated.
choice_snapshot.message.parsed = parsed
events_to_fire.append(
build(
# we do this dance so that when the `ContentDoneEvent` instance
# is printed at runtime the class name will include the solved
# type variable, e.g. `ContentDoneEvent[MyModelType]`
cast( # pyright: ignore[reportUnnecessaryCast]
"type[ContentDoneEvent[ResponseFormatT]]",
cast(Any, ContentDoneEvent),
),
type="content.done",
content=choice_snapshot.message.content,
parsed=parsed,
),
)
if choice_snapshot.message.refusal is not None and not self._refusal_done:
self._refusal_done = True
events_to_fire.append(
build(RefusalDoneEvent, type="refusal.done", refusal=choice_snapshot.message.refusal),
)
if (
choice_snapshot.logprobs is not None
and choice_snapshot.logprobs.content is not None
and not self._logprobs_content_done
):
self._logprobs_content_done = True
events_to_fire.append(
build(LogprobsContentDoneEvent, type="logprobs.content.done", content=choice_snapshot.logprobs.content),
)
if (
choice_snapshot.logprobs is not None
and choice_snapshot.logprobs.refusal is not None
and not self._logprobs_refusal_done
):
self._logprobs_refusal_done = True
events_to_fire.append(
build(LogprobsRefusalDoneEvent, type="logprobs.refusal.done", refusal=choice_snapshot.logprobs.refusal),
)
return events_to_fire
def _add_tool_done_event(
self,
*,
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]],
choice_snapshot: ParsedChoiceSnapshot,
tool_index: int,
) -> None:
if tool_index in self._done_tool_calls:
return
self._done_tool_calls.add(tool_index)
assert choice_snapshot.message.tool_calls is not None
tool_call_snapshot = choice_snapshot.message.tool_calls[tool_index]
if tool_call_snapshot.type == "function":
parsed_arguments = parse_function_tool_arguments(
input_tools=self._input_tools, function=tool_call_snapshot.function
)
# update the parsed content to potentially use a richer type
# as opposed to the raw JSON-parsed object as the content is now
# complete and can be fully validated.
tool_call_snapshot.function.parsed_arguments = parsed_arguments
events_to_fire.append(
build(
FunctionToolCallArgumentsDoneEvent,
type="tool_calls.function.arguments.done",
index=tool_index,
name=tool_call_snapshot.function.name,
arguments=tool_call_snapshot.function.arguments,
parsed_arguments=parsed_arguments,
)
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call_snapshot)
def _convert_initial_chunk_into_snapshot(chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
data = chunk.to_dict()
choices = cast("list[object]", data["choices"])
for choice in chunk.choices:
choices[choice.index] = {
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
"message": choice.delta.to_dict(),
}
return cast(
ParsedChatCompletionSnapshot,
construct_type(
type_=ParsedChatCompletionSnapshot,
value={
"system_fingerprint": None,
**data,
"object": "chat.completion",
},
),
)
def _is_valid_chat_completion_chunk_weak(sse_event: ChatCompletionChunk) -> bool:
# Although the _raw_stream is always supposed to contain only objects adhering to ChatCompletionChunk schema,
# this is broken by the Azure OpenAI in case of Asynchronous Filter enabled.
# An easy filter is to check for the "object" property:
# - should be "chat.completion.chunk" for a ChatCompletionChunk;
# - is an empty string for Asynchronous Filter events.
return sse_event.object == "chat.completion.chunk" # type: ignore # pylance reports this as a useless check

View File

@@ -0,0 +1,123 @@
from typing import List, Union, Generic, Optional
from typing_extensions import Literal
from ._types import ParsedChatCompletionSnapshot
from ...._models import BaseModel, GenericModel
from ..._parsing import ResponseFormatT
from ....types.chat import ChatCompletionChunk, ChatCompletionTokenLogprob
class ChunkEvent(BaseModel):
type: Literal["chunk"]
chunk: ChatCompletionChunk
snapshot: ParsedChatCompletionSnapshot
class ContentDeltaEvent(BaseModel):
"""This event is yielded for every chunk with `choice.delta.content` data."""
type: Literal["content.delta"]
delta: str
snapshot: str
parsed: Optional[object] = None
class ContentDoneEvent(GenericModel, Generic[ResponseFormatT]):
type: Literal["content.done"]
content: str
parsed: Optional[ResponseFormatT] = None
class RefusalDeltaEvent(BaseModel):
type: Literal["refusal.delta"]
delta: str
snapshot: str
class RefusalDoneEvent(BaseModel):
type: Literal["refusal.done"]
refusal: str
class FunctionToolCallArgumentsDeltaEvent(BaseModel):
type: Literal["tool_calls.function.arguments.delta"]
name: str
index: int
arguments: str
"""Accumulated raw JSON string"""
parsed_arguments: object
"""The parsed arguments so far"""
arguments_delta: str
"""The JSON string delta"""
class FunctionToolCallArgumentsDoneEvent(BaseModel):
type: Literal["tool_calls.function.arguments.done"]
name: str
index: int
arguments: str
"""Accumulated raw JSON string"""
parsed_arguments: object
"""The parsed arguments"""
class LogprobsContentDeltaEvent(BaseModel):
type: Literal["logprobs.content.delta"]
content: List[ChatCompletionTokenLogprob]
snapshot: List[ChatCompletionTokenLogprob]
class LogprobsContentDoneEvent(BaseModel):
type: Literal["logprobs.content.done"]
content: List[ChatCompletionTokenLogprob]
class LogprobsRefusalDeltaEvent(BaseModel):
type: Literal["logprobs.refusal.delta"]
refusal: List[ChatCompletionTokenLogprob]
snapshot: List[ChatCompletionTokenLogprob]
class LogprobsRefusalDoneEvent(BaseModel):
type: Literal["logprobs.refusal.done"]
refusal: List[ChatCompletionTokenLogprob]
ChatCompletionStreamEvent = Union[
ChunkEvent,
ContentDeltaEvent,
ContentDoneEvent[ResponseFormatT],
RefusalDeltaEvent,
RefusalDoneEvent,
FunctionToolCallArgumentsDeltaEvent,
FunctionToolCallArgumentsDoneEvent,
LogprobsContentDeltaEvent,
LogprobsContentDoneEvent,
LogprobsRefusalDeltaEvent,
LogprobsRefusalDoneEvent,
]

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from typing_extensions import TypeAlias
from ....types.chat import ParsedChoice, ParsedChatCompletion, ParsedChatCompletionMessage
ParsedChatCompletionSnapshot: TypeAlias = ParsedChatCompletion[object]
"""Snapshot type representing an in-progress accumulation of
a `ParsedChatCompletion` object.
"""
ParsedChatCompletionMessageSnapshot: TypeAlias = ParsedChatCompletionMessage[object]
"""Snapshot type representing an in-progress accumulation of
a `ParsedChatCompletionMessage` object.
If the content has been fully accumulated, the `.parsed` content will be
the `response_format` instance, otherwise it'll be the raw JSON parsed version.
"""
ParsedChoiceSnapshot: TypeAlias = ParsedChoice[object]

View File

@@ -0,0 +1,13 @@
from ._events import (
ResponseTextDoneEvent as ResponseTextDoneEvent,
ResponseTextDeltaEvent as ResponseTextDeltaEvent,
ResponseFunctionCallArgumentsDeltaEvent as ResponseFunctionCallArgumentsDeltaEvent,
)
from ._responses import (
ResponseStream as ResponseStream,
AsyncResponseStream as AsyncResponseStream,
ResponseStreamEvent as ResponseStreamEvent,
ResponseStreamState as ResponseStreamState,
ResponseStreamManager as ResponseStreamManager,
AsyncResponseStreamManager as AsyncResponseStreamManager,
)

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
from typing import Optional
from typing_extensions import Union, Generic, TypeVar, Annotated, TypeAlias
from ...._utils import PropertyInfo
from ...._compat import GenericModel
from ....types.responses import (
ParsedResponse,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseQueuedEvent,
ResponseCreatedEvent,
ResponseTextDoneEvent as RawResponseTextDoneEvent,
ResponseAudioDoneEvent,
ResponseCompletedEvent as RawResponseCompletedEvent,
ResponseTextDeltaEvent as RawResponseTextDeltaEvent,
ResponseAudioDeltaEvent,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseRefusalDoneEvent,
ResponseRefusalDeltaEvent,
ResponseMcpCallFailedEvent,
ResponseOutputItemDoneEvent,
ResponseContentPartDoneEvent,
ResponseOutputItemAddedEvent,
ResponseContentPartAddedEvent,
ResponseMcpCallCompletedEvent,
ResponseMcpCallInProgressEvent,
ResponseMcpListToolsFailedEvent,
ResponseAudioTranscriptDoneEvent,
ResponseAudioTranscriptDeltaEvent,
ResponseMcpCallArgumentsDoneEvent,
ResponseImageGenCallCompletedEvent,
ResponseMcpCallArgumentsDeltaEvent,
ResponseMcpListToolsCompletedEvent,
ResponseImageGenCallGeneratingEvent,
ResponseImageGenCallInProgressEvent,
ResponseMcpListToolsInProgressEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallSearchingEvent,
ResponseCustomToolCallInputDoneEvent,
ResponseFileSearchCallCompletedEvent,
ResponseFileSearchCallSearchingEvent,
ResponseWebSearchCallInProgressEvent,
ResponseCustomToolCallInputDeltaEvent,
ResponseFileSearchCallInProgressEvent,
ResponseImageGenCallPartialImageEvent,
ResponseReasoningSummaryPartDoneEvent,
ResponseReasoningSummaryTextDoneEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseOutputTextAnnotationAddedEvent,
ResponseReasoningSummaryPartAddedEvent,
ResponseReasoningSummaryTextDeltaEvent,
ResponseFunctionCallArgumentsDeltaEvent as RawResponseFunctionCallArgumentsDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent,
)
from ....types.responses.response_reasoning_text_done_event import ResponseReasoningTextDoneEvent
from ....types.responses.response_reasoning_text_delta_event import ResponseReasoningTextDeltaEvent
TextFormatT = TypeVar(
"TextFormatT",
# if it isn't given then we don't do any parsing
default=None,
)
class ResponseTextDeltaEvent(RawResponseTextDeltaEvent):
snapshot: str
class ResponseTextDoneEvent(RawResponseTextDoneEvent, GenericModel, Generic[TextFormatT]):
parsed: Optional[TextFormatT] = None
class ResponseFunctionCallArgumentsDeltaEvent(RawResponseFunctionCallArgumentsDeltaEvent):
snapshot: str
class ResponseCompletedEvent(RawResponseCompletedEvent, GenericModel, Generic[TextFormatT]):
response: ParsedResponse[TextFormatT] # type: ignore[assignment]
ResponseStreamEvent: TypeAlias = Annotated[
Union[
# wrappers with snapshots added on
ResponseTextDeltaEvent,
ResponseTextDoneEvent[TextFormatT],
ResponseFunctionCallArgumentsDeltaEvent,
ResponseCompletedEvent[TextFormatT],
# the same as the non-accumulated API
ResponseAudioDeltaEvent,
ResponseAudioDoneEvent,
ResponseAudioTranscriptDeltaEvent,
ResponseAudioTranscriptDoneEvent,
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseErrorEvent,
ResponseFileSearchCallCompletedEvent,
ResponseFileSearchCallInProgressEvent,
ResponseFileSearchCallSearchingEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseInProgressEvent,
ResponseFailedEvent,
ResponseIncompleteEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseRefusalDeltaEvent,
ResponseRefusalDoneEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
ResponseReasoningSummaryPartAddedEvent,
ResponseReasoningSummaryPartDoneEvent,
ResponseReasoningSummaryTextDeltaEvent,
ResponseReasoningSummaryTextDoneEvent,
ResponseImageGenCallCompletedEvent,
ResponseImageGenCallInProgressEvent,
ResponseImageGenCallGeneratingEvent,
ResponseImageGenCallPartialImageEvent,
ResponseMcpCallCompletedEvent,
ResponseMcpCallArgumentsDeltaEvent,
ResponseMcpCallArgumentsDoneEvent,
ResponseMcpCallFailedEvent,
ResponseMcpCallInProgressEvent,
ResponseMcpListToolsCompletedEvent,
ResponseMcpListToolsFailedEvent,
ResponseMcpListToolsInProgressEvent,
ResponseOutputTextAnnotationAddedEvent,
ResponseQueuedEvent,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseCustomToolCallInputDeltaEvent,
ResponseCustomToolCallInputDoneEvent,
],
PropertyInfo(discriminator="type"),
]

View File

@@ -0,0 +1,372 @@
from __future__ import annotations
import inspect
from types import TracebackType
from typing import Any, List, Generic, Iterable, Awaitable, cast
from typing_extensions import Self, Callable, Iterator, AsyncIterator
from ._types import ParsedResponseSnapshot
from ._events import (
ResponseStreamEvent,
ResponseTextDoneEvent,
ResponseCompletedEvent,
ResponseTextDeltaEvent,
ResponseFunctionCallArgumentsDeltaEvent,
)
from ...._types import Omit, omit
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
from ...._models import build, construct_type_unchecked
from ...._streaming import Stream, AsyncStream
from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
from ..._parsing._responses import TextFormatT, parse_text, parse_response
from ....types.responses.tool_param import ToolParam
from ....types.responses.parsed_response import (
ParsedContent,
ParsedResponseOutputMessage,
ParsedResponseFunctionToolCall,
)
class ResponseStream(Generic[TextFormatT]):
def __init__(
self,
*,
raw_stream: Stream[RawResponseStreamEvent],
text_format: type[TextFormatT] | Omit,
input_tools: Iterable[ToolParam] | Omit,
starting_after: int | None,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
self._starting_after = starting_after
def __next__(self) -> ResponseStreamEvent[TextFormatT]:
return self._iterator.__next__()
def __iter__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
for item in self._iterator:
yield item
def __enter__(self) -> Self:
return self
def __stream__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
for sse_event in self._raw_stream:
events_to_fire = self._state.handle_event(sse_event)
for event in events_to_fire:
if self._starting_after is None or event.sequence_number > self._starting_after:
yield event
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self._response.close()
def get_final_response(self) -> ParsedResponse[TextFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedResponse` object.
"""
self.until_done()
response = self._state._completed_response
if not response:
raise RuntimeError("Didn't receive a `response.completed` event.")
return response
def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
consume_sync_iterator(self)
return self
class ResponseStreamManager(Generic[TextFormatT]):
def __init__(
self,
api_request: Callable[[], Stream[RawResponseStreamEvent]],
*,
text_format: type[TextFormatT] | Omit,
input_tools: Iterable[ToolParam] | Omit,
starting_after: int | None,
) -> None:
self.__stream: ResponseStream[TextFormatT] | None = None
self.__api_request = api_request
self.__text_format = text_format
self.__input_tools = input_tools
self.__starting_after = starting_after
def __enter__(self) -> ResponseStream[TextFormatT]:
raw_stream = self.__api_request()
self.__stream = ResponseStream(
raw_stream=raw_stream,
text_format=self.__text_format,
input_tools=self.__input_tools,
starting_after=self.__starting_after,
)
return self.__stream
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
self.__stream.close()
class AsyncResponseStream(Generic[TextFormatT]):
def __init__(
self,
*,
raw_stream: AsyncStream[RawResponseStreamEvent],
text_format: type[TextFormatT] | Omit,
input_tools: Iterable[ToolParam] | Omit,
starting_after: int | None,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
self._starting_after = starting_after
async def __anext__(self) -> ResponseStreamEvent[TextFormatT]:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
async for item in self._iterator:
yield item
async def __stream__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
async for sse_event in self._raw_stream:
events_to_fire = self._state.handle_event(sse_event)
for event in events_to_fire:
if self._starting_after is None or event.sequence_number > self._starting_after:
yield event
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self._response.aclose()
async def get_final_response(self) -> ParsedResponse[TextFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedResponse` object.
"""
await self.until_done()
response = self._state._completed_response
if not response:
raise RuntimeError("Didn't receive a `response.completed` event.")
return response
async def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
await consume_async_iterator(self)
return self
class AsyncResponseStreamManager(Generic[TextFormatT]):
def __init__(
self,
api_request: Awaitable[AsyncStream[RawResponseStreamEvent]],
*,
text_format: type[TextFormatT] | Omit,
input_tools: Iterable[ToolParam] | Omit,
starting_after: int | None,
) -> None:
self.__stream: AsyncResponseStream[TextFormatT] | None = None
self.__api_request = api_request
self.__text_format = text_format
self.__input_tools = input_tools
self.__starting_after = starting_after
async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
raw_stream = await self.__api_request
self.__stream = AsyncResponseStream(
raw_stream=raw_stream,
text_format=self.__text_format,
input_tools=self.__input_tools,
starting_after=self.__starting_after,
)
return self.__stream
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
await self.__stream.close()
class ResponseStreamState(Generic[TextFormatT]):
def __init__(
self,
*,
input_tools: Iterable[ToolParam] | Omit,
text_format: type[TextFormatT] | Omit,
) -> None:
self.__current_snapshot: ParsedResponseSnapshot | None = None
self._completed_response: ParsedResponse[TextFormatT] | None = None
self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
self._text_format = text_format
self._rich_text_format: type | Omit = text_format if inspect.isclass(text_format) else omit
def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]:
self.__current_snapshot = snapshot = self.accumulate_event(event)
events: List[ResponseStreamEvent[TextFormatT]] = []
if event.type == "response.output_text.delta":
output = snapshot.output[event.output_index]
assert output.type == "message"
content = output.content[event.content_index]
assert content.type == "output_text"
events.append(
build(
ResponseTextDeltaEvent,
content_index=event.content_index,
delta=event.delta,
item_id=event.item_id,
output_index=event.output_index,
sequence_number=event.sequence_number,
logprobs=event.logprobs,
type="response.output_text.delta",
snapshot=content.text,
)
)
elif event.type == "response.output_text.done":
output = snapshot.output[event.output_index]
assert output.type == "message"
content = output.content[event.content_index]
assert content.type == "output_text"
events.append(
build(
ResponseTextDoneEvent[TextFormatT],
content_index=event.content_index,
item_id=event.item_id,
output_index=event.output_index,
sequence_number=event.sequence_number,
logprobs=event.logprobs,
type="response.output_text.done",
text=event.text,
parsed=parse_text(event.text, text_format=self._text_format),
)
)
elif event.type == "response.function_call_arguments.delta":
output = snapshot.output[event.output_index]
assert output.type == "function_call"
events.append(
build(
ResponseFunctionCallArgumentsDeltaEvent,
delta=event.delta,
item_id=event.item_id,
output_index=event.output_index,
sequence_number=event.sequence_number,
type="response.function_call_arguments.delta",
snapshot=output.arguments,
)
)
elif event.type == "response.completed":
response = self._completed_response
assert response is not None
events.append(
build(
ResponseCompletedEvent,
sequence_number=event.sequence_number,
type="response.completed",
response=response,
)
)
else:
events.append(event)
return events
def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
snapshot = self.__current_snapshot
if snapshot is None:
return self._create_initial_response(event)
if event.type == "response.output_item.added":
if event.item.type == "function_call":
snapshot.output.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseFunctionToolCall), value=event.item.to_dict()
)
)
elif event.item.type == "message":
snapshot.output.append(
construct_type_unchecked(type_=cast(Any, ParsedResponseOutputMessage), value=event.item.to_dict())
)
else:
snapshot.output.append(event.item)
elif event.type == "response.content_part.added":
output = snapshot.output[event.output_index]
if output.type == "message":
output.content.append(
construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict())
)
elif event.type == "response.output_text.delta":
output = snapshot.output[event.output_index]
if output.type == "message":
content = output.content[event.content_index]
assert content.type == "output_text"
content.text += event.delta
elif event.type == "response.function_call_arguments.delta":
output = snapshot.output[event.output_index]
if output.type == "function_call":
output.arguments += event.delta
elif event.type == "response.completed":
self._completed_response = parse_response(
text_format=self._text_format,
response=event.response,
input_tools=self._input_tools,
)
return snapshot
def _create_initial_response(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
if event.type != "response.created":
raise RuntimeError(f"Expected to have received `response.created` before `{event.type}`")
return construct_type_unchecked(type_=ParsedResponseSnapshot, value=event.response.to_dict())

View File

@@ -0,0 +1,10 @@
from __future__ import annotations
from typing_extensions import TypeAlias
from ....types.responses import ParsedResponse
ParsedResponseSnapshot: TypeAlias = ParsedResponse[object]
"""Snapshot type representing an in-progress accumulation of
a `ParsedResponse` object.
"""

View File

@@ -0,0 +1,190 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Any, List, Generic, TypeVar, Optional, cast
from typing_extensions import Protocol, override, runtime_checkable
from ._base_client import BasePage, PageInfo, BaseSyncPage, BaseAsyncPage
__all__ = [
"SyncPage",
"AsyncPage",
"SyncCursorPage",
"AsyncCursorPage",
"SyncConversationCursorPage",
"AsyncConversationCursorPage",
]
_T = TypeVar("_T")
@runtime_checkable
class CursorPageItem(Protocol):
id: Optional[str]
class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""
data: List[_T]
object: str
@override
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
return data
@override
def next_page_info(self) -> None:
"""
This page represents a response that isn't actually paginated at the API level
so there will never be a next page.
"""
return None
class AsyncPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""
data: List[_T]
object: str
@override
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
return data
@override
def next_page_info(self) -> None:
"""
This page represents a response that isn't actually paginated at the API level
so there will never be a next page.
"""
return None
class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
data: List[_T]
has_more: Optional[bool] = None
@override
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
return data
@override
def has_next_page(self) -> bool:
has_more = self.has_more
if has_more is not None and has_more is False:
return False
return super().has_next_page()
@override
def next_page_info(self) -> Optional[PageInfo]:
data = self.data
if not data:
return None
item = cast(Any, data[-1])
if not isinstance(item, CursorPageItem) or item.id is None:
# TODO emit warning log
return None
return PageInfo(params={"after": item.id})
class AsyncCursorPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
data: List[_T]
has_more: Optional[bool] = None
@override
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
return data
@override
def has_next_page(self) -> bool:
has_more = self.has_more
if has_more is not None and has_more is False:
return False
return super().has_next_page()
@override
def next_page_info(self) -> Optional[PageInfo]:
data = self.data
if not data:
return None
item = cast(Any, data[-1])
if not isinstance(item, CursorPageItem) or item.id is None:
# TODO emit warning log
return None
return PageInfo(params={"after": item.id})
class SyncConversationCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
data: List[_T]
has_more: Optional[bool] = None
last_id: Optional[str] = None
@override
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
return data
@override
def has_next_page(self) -> bool:
has_more = self.has_more
if has_more is not None and has_more is False:
return False
return super().has_next_page()
@override
def next_page_info(self) -> Optional[PageInfo]:
last_id = self.last_id
if not last_id:
return None
return PageInfo(params={"after": last_id})
class AsyncConversationCursorPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
data: List[_T]
has_more: Optional[bool] = None
last_id: Optional[str] = None
@override
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
return data
@override
def has_next_page(self) -> bool:
has_more = self.has_more
if has_more is not None and has_more is False:
return False
return super().has_next_page()
@override
def next_page_info(self) -> Optional[PageInfo]:
last_id = self.last_id
if not last_id:
return None
return PageInfo(params={"after": last_id})

View File

@@ -0,0 +1,243 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from .beta import (
Beta,
AsyncBeta,
BetaWithRawResponse,
AsyncBetaWithRawResponse,
BetaWithStreamingResponse,
AsyncBetaWithStreamingResponse,
)
from .chat import (
Chat,
AsyncChat,
ChatWithRawResponse,
AsyncChatWithRawResponse,
ChatWithStreamingResponse,
AsyncChatWithStreamingResponse,
)
from .audio import (
Audio,
AsyncAudio,
AudioWithRawResponse,
AsyncAudioWithRawResponse,
AudioWithStreamingResponse,
AsyncAudioWithStreamingResponse,
)
from .evals import (
Evals,
AsyncEvals,
EvalsWithRawResponse,
AsyncEvalsWithRawResponse,
EvalsWithStreamingResponse,
AsyncEvalsWithStreamingResponse,
)
from .files import (
Files,
AsyncFiles,
FilesWithRawResponse,
AsyncFilesWithRawResponse,
FilesWithStreamingResponse,
AsyncFilesWithStreamingResponse,
)
from .images import (
Images,
AsyncImages,
ImagesWithRawResponse,
AsyncImagesWithRawResponse,
ImagesWithStreamingResponse,
AsyncImagesWithStreamingResponse,
)
from .models import (
Models,
AsyncModels,
ModelsWithRawResponse,
AsyncModelsWithRawResponse,
ModelsWithStreamingResponse,
AsyncModelsWithStreamingResponse,
)
from .skills import (
Skills,
AsyncSkills,
SkillsWithRawResponse,
AsyncSkillsWithRawResponse,
SkillsWithStreamingResponse,
AsyncSkillsWithStreamingResponse,
)
from .videos import (
Videos,
AsyncVideos,
VideosWithRawResponse,
AsyncVideosWithRawResponse,
VideosWithStreamingResponse,
AsyncVideosWithStreamingResponse,
)
from .batches import (
Batches,
AsyncBatches,
BatchesWithRawResponse,
AsyncBatchesWithRawResponse,
BatchesWithStreamingResponse,
AsyncBatchesWithStreamingResponse,
)
from .uploads import (
Uploads,
AsyncUploads,
UploadsWithRawResponse,
AsyncUploadsWithRawResponse,
UploadsWithStreamingResponse,
AsyncUploadsWithStreamingResponse,
)
from .containers import (
Containers,
AsyncContainers,
ContainersWithRawResponse,
AsyncContainersWithRawResponse,
ContainersWithStreamingResponse,
AsyncContainersWithStreamingResponse,
)
from .embeddings import (
Embeddings,
AsyncEmbeddings,
EmbeddingsWithRawResponse,
AsyncEmbeddingsWithRawResponse,
EmbeddingsWithStreamingResponse,
AsyncEmbeddingsWithStreamingResponse,
)
from .completions import (
Completions,
AsyncCompletions,
CompletionsWithRawResponse,
AsyncCompletionsWithRawResponse,
CompletionsWithStreamingResponse,
AsyncCompletionsWithStreamingResponse,
)
from .fine_tuning import (
FineTuning,
AsyncFineTuning,
FineTuningWithRawResponse,
AsyncFineTuningWithRawResponse,
FineTuningWithStreamingResponse,
AsyncFineTuningWithStreamingResponse,
)
from .moderations import (
Moderations,
AsyncModerations,
ModerationsWithRawResponse,
AsyncModerationsWithRawResponse,
ModerationsWithStreamingResponse,
AsyncModerationsWithStreamingResponse,
)
from .vector_stores import (
VectorStores,
AsyncVectorStores,
VectorStoresWithRawResponse,
AsyncVectorStoresWithRawResponse,
VectorStoresWithStreamingResponse,
AsyncVectorStoresWithStreamingResponse,
)
__all__ = [
"Completions",
"AsyncCompletions",
"CompletionsWithRawResponse",
"AsyncCompletionsWithRawResponse",
"CompletionsWithStreamingResponse",
"AsyncCompletionsWithStreamingResponse",
"Chat",
"AsyncChat",
"ChatWithRawResponse",
"AsyncChatWithRawResponse",
"ChatWithStreamingResponse",
"AsyncChatWithStreamingResponse",
"Embeddings",
"AsyncEmbeddings",
"EmbeddingsWithRawResponse",
"AsyncEmbeddingsWithRawResponse",
"EmbeddingsWithStreamingResponse",
"AsyncEmbeddingsWithStreamingResponse",
"Files",
"AsyncFiles",
"FilesWithRawResponse",
"AsyncFilesWithRawResponse",
"FilesWithStreamingResponse",
"AsyncFilesWithStreamingResponse",
"Images",
"AsyncImages",
"ImagesWithRawResponse",
"AsyncImagesWithRawResponse",
"ImagesWithStreamingResponse",
"AsyncImagesWithStreamingResponse",
"Audio",
"AsyncAudio",
"AudioWithRawResponse",
"AsyncAudioWithRawResponse",
"AudioWithStreamingResponse",
"AsyncAudioWithStreamingResponse",
"Moderations",
"AsyncModerations",
"ModerationsWithRawResponse",
"AsyncModerationsWithRawResponse",
"ModerationsWithStreamingResponse",
"AsyncModerationsWithStreamingResponse",
"Models",
"AsyncModels",
"ModelsWithRawResponse",
"AsyncModelsWithRawResponse",
"ModelsWithStreamingResponse",
"AsyncModelsWithStreamingResponse",
"FineTuning",
"AsyncFineTuning",
"FineTuningWithRawResponse",
"AsyncFineTuningWithRawResponse",
"FineTuningWithStreamingResponse",
"AsyncFineTuningWithStreamingResponse",
"VectorStores",
"AsyncVectorStores",
"VectorStoresWithRawResponse",
"AsyncVectorStoresWithRawResponse",
"VectorStoresWithStreamingResponse",
"AsyncVectorStoresWithStreamingResponse",
"Beta",
"AsyncBeta",
"BetaWithRawResponse",
"AsyncBetaWithRawResponse",
"BetaWithStreamingResponse",
"AsyncBetaWithStreamingResponse",
"Batches",
"AsyncBatches",
"BatchesWithRawResponse",
"AsyncBatchesWithRawResponse",
"BatchesWithStreamingResponse",
"AsyncBatchesWithStreamingResponse",
"Uploads",
"AsyncUploads",
"UploadsWithRawResponse",
"AsyncUploadsWithRawResponse",
"UploadsWithStreamingResponse",
"AsyncUploadsWithStreamingResponse",
"Evals",
"AsyncEvals",
"EvalsWithRawResponse",
"AsyncEvalsWithRawResponse",
"EvalsWithStreamingResponse",
"AsyncEvalsWithStreamingResponse",
"Containers",
"AsyncContainers",
"ContainersWithRawResponse",
"AsyncContainersWithRawResponse",
"ContainersWithStreamingResponse",
"AsyncContainersWithStreamingResponse",
"Skills",
"AsyncSkills",
"SkillsWithRawResponse",
"AsyncSkillsWithRawResponse",
"SkillsWithStreamingResponse",
"AsyncSkillsWithStreamingResponse",
"Videos",
"AsyncVideos",
"VideosWithRawResponse",
"AsyncVideosWithRawResponse",
"VideosWithStreamingResponse",
"AsyncVideosWithStreamingResponse",
]

View File

@@ -0,0 +1,61 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from .audio import (
Audio,
AsyncAudio,
AudioWithRawResponse,
AsyncAudioWithRawResponse,
AudioWithStreamingResponse,
AsyncAudioWithStreamingResponse,
)
from .speech import (
Speech,
AsyncSpeech,
SpeechWithRawResponse,
AsyncSpeechWithRawResponse,
SpeechWithStreamingResponse,
AsyncSpeechWithStreamingResponse,
)
from .translations import (
Translations,
AsyncTranslations,
TranslationsWithRawResponse,
AsyncTranslationsWithRawResponse,
TranslationsWithStreamingResponse,
AsyncTranslationsWithStreamingResponse,
)
from .transcriptions import (
Transcriptions,
AsyncTranscriptions,
TranscriptionsWithRawResponse,
AsyncTranscriptionsWithRawResponse,
TranscriptionsWithStreamingResponse,
AsyncTranscriptionsWithStreamingResponse,
)
__all__ = [
"Transcriptions",
"AsyncTranscriptions",
"TranscriptionsWithRawResponse",
"AsyncTranscriptionsWithRawResponse",
"TranscriptionsWithStreamingResponse",
"AsyncTranscriptionsWithStreamingResponse",
"Translations",
"AsyncTranslations",
"TranslationsWithRawResponse",
"AsyncTranslationsWithRawResponse",
"TranslationsWithStreamingResponse",
"AsyncTranslationsWithStreamingResponse",
"Speech",
"AsyncSpeech",
"SpeechWithRawResponse",
"AsyncSpeechWithRawResponse",
"SpeechWithStreamingResponse",
"AsyncSpeechWithStreamingResponse",
"Audio",
"AsyncAudio",
"AudioWithRawResponse",
"AsyncAudioWithRawResponse",
"AudioWithStreamingResponse",
"AsyncAudioWithStreamingResponse",
]

View File

@@ -0,0 +1,184 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from .speech import (
Speech,
AsyncSpeech,
SpeechWithRawResponse,
AsyncSpeechWithRawResponse,
SpeechWithStreamingResponse,
AsyncSpeechWithStreamingResponse,
)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from .translations import (
Translations,
AsyncTranslations,
TranslationsWithRawResponse,
AsyncTranslationsWithRawResponse,
TranslationsWithStreamingResponse,
AsyncTranslationsWithStreamingResponse,
)
from .transcriptions import (
Transcriptions,
AsyncTranscriptions,
TranscriptionsWithRawResponse,
AsyncTranscriptionsWithRawResponse,
TranscriptionsWithStreamingResponse,
AsyncTranscriptionsWithStreamingResponse,
)
__all__ = ["Audio", "AsyncAudio"]
class Audio(SyncAPIResource):
@cached_property
def transcriptions(self) -> Transcriptions:
"""Turn audio into text or text into audio."""
return Transcriptions(self._client)
@cached_property
def translations(self) -> Translations:
"""Turn audio into text or text into audio."""
return Translations(self._client)
@cached_property
def speech(self) -> Speech:
"""Turn audio into text or text into audio."""
return Speech(self._client)
@cached_property
def with_raw_response(self) -> AudioWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AudioWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AudioWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AudioWithStreamingResponse(self)
class AsyncAudio(AsyncAPIResource):
@cached_property
def transcriptions(self) -> AsyncTranscriptions:
"""Turn audio into text or text into audio."""
return AsyncTranscriptions(self._client)
@cached_property
def translations(self) -> AsyncTranslations:
"""Turn audio into text or text into audio."""
return AsyncTranslations(self._client)
@cached_property
def speech(self) -> AsyncSpeech:
"""Turn audio into text or text into audio."""
return AsyncSpeech(self._client)
@cached_property
def with_raw_response(self) -> AsyncAudioWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AsyncAudioWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncAudioWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AsyncAudioWithStreamingResponse(self)
class AudioWithRawResponse:
def __init__(self, audio: Audio) -> None:
self._audio = audio
@cached_property
def transcriptions(self) -> TranscriptionsWithRawResponse:
"""Turn audio into text or text into audio."""
return TranscriptionsWithRawResponse(self._audio.transcriptions)
@cached_property
def translations(self) -> TranslationsWithRawResponse:
"""Turn audio into text or text into audio."""
return TranslationsWithRawResponse(self._audio.translations)
@cached_property
def speech(self) -> SpeechWithRawResponse:
"""Turn audio into text or text into audio."""
return SpeechWithRawResponse(self._audio.speech)
class AsyncAudioWithRawResponse:
def __init__(self, audio: AsyncAudio) -> None:
self._audio = audio
@cached_property
def transcriptions(self) -> AsyncTranscriptionsWithRawResponse:
"""Turn audio into text or text into audio."""
return AsyncTranscriptionsWithRawResponse(self._audio.transcriptions)
@cached_property
def translations(self) -> AsyncTranslationsWithRawResponse:
"""Turn audio into text or text into audio."""
return AsyncTranslationsWithRawResponse(self._audio.translations)
@cached_property
def speech(self) -> AsyncSpeechWithRawResponse:
"""Turn audio into text or text into audio."""
return AsyncSpeechWithRawResponse(self._audio.speech)
class AudioWithStreamingResponse:
def __init__(self, audio: Audio) -> None:
self._audio = audio
@cached_property
def transcriptions(self) -> TranscriptionsWithStreamingResponse:
"""Turn audio into text or text into audio."""
return TranscriptionsWithStreamingResponse(self._audio.transcriptions)
@cached_property
def translations(self) -> TranslationsWithStreamingResponse:
"""Turn audio into text or text into audio."""
return TranslationsWithStreamingResponse(self._audio.translations)
@cached_property
def speech(self) -> SpeechWithStreamingResponse:
"""Turn audio into text or text into audio."""
return SpeechWithStreamingResponse(self._audio.speech)
class AsyncAudioWithStreamingResponse:
def __init__(self, audio: AsyncAudio) -> None:
self._audio = audio
@cached_property
def transcriptions(self) -> AsyncTranscriptionsWithStreamingResponse:
"""Turn audio into text or text into audio."""
return AsyncTranscriptionsWithStreamingResponse(self._audio.transcriptions)
@cached_property
def translations(self) -> AsyncTranslationsWithStreamingResponse:
"""Turn audio into text or text into audio."""
return AsyncTranslationsWithStreamingResponse(self._audio.translations)
@cached_property
def speech(self) -> AsyncSpeechWithStreamingResponse:
"""Turn audio into text or text into audio."""
return AsyncSpeechWithStreamingResponse(self._audio.speech)

View File

@@ -0,0 +1,265 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Union
from typing_extensions import Literal
import httpx
from ... import _legacy_response
from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from ..._utils import maybe_transform, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
StreamedBinaryAPIResponse,
AsyncStreamedBinaryAPIResponse,
to_custom_streamed_response_wrapper,
async_to_custom_streamed_response_wrapper,
)
from ...types.audio import speech_create_params
from ..._base_client import make_request_options
from ...types.audio.speech_model import SpeechModel
__all__ = ["Speech", "AsyncSpeech"]
class Speech(SyncAPIResource):
"""Turn audio into text or text into audio."""
@cached_property
def with_raw_response(self) -> SpeechWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return SpeechWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> SpeechWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return SpeechWithStreamingResponse(self)
def create(
self,
*,
input: str,
model: Union[str, SpeechModel],
voice: Union[
str, Literal["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse", "marin", "cedar"]
],
instructions: str | Omit = omit,
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] | Omit = omit,
speed: float | Omit = omit,
stream_format: Literal["sse", "audio"] | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> _legacy_response.HttpxBinaryResponseContent:
"""
Generates audio from the input text.
Returns the audio file content, or a stream of audio events.
Args:
input: The text to generate audio for. The maximum length is 4096 characters.
model:
One of the available [TTS models](https://platform.openai.com/docs/models#tts):
`tts-1`, `tts-1-hd`, `gpt-4o-mini-tts`, or `gpt-4o-mini-tts-2025-12-15`.
voice: The voice to use when generating the audio. Supported built-in voices are
`alloy`, `ash`, `ballad`, `coral`, `echo`, `fable`, `onyx`, `nova`, `sage`,
`shimmer`, `verse`, `marin`, and `cedar`. Previews of the voices are available
in the
[Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech#voice-options).
instructions: Control the voice of your generated audio with additional instructions. Does not
work with `tts-1` or `tts-1-hd`.
response_format: The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`,
`wav`, and `pcm`.
speed: The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is
the default.
stream_format: The format to stream the audio in. Supported formats are `sse` and `audio`.
`sse` is not supported for `tts-1` or `tts-1-hd`.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "application/octet-stream", **(extra_headers or {})}
return self._post(
"/audio/speech",
body=maybe_transform(
{
"input": input,
"model": model,
"voice": voice,
"instructions": instructions,
"response_format": response_format,
"speed": speed,
"stream_format": stream_format,
},
speech_create_params.SpeechCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=_legacy_response.HttpxBinaryResponseContent,
)
class AsyncSpeech(AsyncAPIResource):
"""Turn audio into text or text into audio."""
@cached_property
def with_raw_response(self) -> AsyncSpeechWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AsyncSpeechWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncSpeechWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AsyncSpeechWithStreamingResponse(self)
async def create(
self,
*,
input: str,
model: Union[str, SpeechModel],
voice: Union[
str, Literal["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse", "marin", "cedar"]
],
instructions: str | Omit = omit,
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] | Omit = omit,
speed: float | Omit = omit,
stream_format: Literal["sse", "audio"] | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> _legacy_response.HttpxBinaryResponseContent:
"""
Generates audio from the input text.
Returns the audio file content, or a stream of audio events.
Args:
input: The text to generate audio for. The maximum length is 4096 characters.
model:
One of the available [TTS models](https://platform.openai.com/docs/models#tts):
`tts-1`, `tts-1-hd`, `gpt-4o-mini-tts`, or `gpt-4o-mini-tts-2025-12-15`.
voice: The voice to use when generating the audio. Supported built-in voices are
`alloy`, `ash`, `ballad`, `coral`, `echo`, `fable`, `onyx`, `nova`, `sage`,
`shimmer`, `verse`, `marin`, and `cedar`. Previews of the voices are available
in the
[Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech#voice-options).
instructions: Control the voice of your generated audio with additional instructions. Does not
work with `tts-1` or `tts-1-hd`.
response_format: The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`,
`wav`, and `pcm`.
speed: The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is
the default.
stream_format: The format to stream the audio in. Supported formats are `sse` and `audio`.
`sse` is not supported for `tts-1` or `tts-1-hd`.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "application/octet-stream", **(extra_headers or {})}
return await self._post(
"/audio/speech",
body=await async_maybe_transform(
{
"input": input,
"model": model,
"voice": voice,
"instructions": instructions,
"response_format": response_format,
"speed": speed,
"stream_format": stream_format,
},
speech_create_params.SpeechCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=_legacy_response.HttpxBinaryResponseContent,
)
class SpeechWithRawResponse:
def __init__(self, speech: Speech) -> None:
self._speech = speech
self.create = _legacy_response.to_raw_response_wrapper(
speech.create,
)
class AsyncSpeechWithRawResponse:
def __init__(self, speech: AsyncSpeech) -> None:
self._speech = speech
self.create = _legacy_response.async_to_raw_response_wrapper(
speech.create,
)
class SpeechWithStreamingResponse:
def __init__(self, speech: Speech) -> None:
self._speech = speech
self.create = to_custom_streamed_response_wrapper(
speech.create,
StreamedBinaryAPIResponse,
)
class AsyncSpeechWithStreamingResponse:
def __init__(self, speech: AsyncSpeech) -> None:
self._speech = speech
self.create = async_to_custom_streamed_response_wrapper(
speech.create,
AsyncStreamedBinaryAPIResponse,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,371 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Union, Mapping, cast
from typing_extensions import Literal, overload, assert_never
import httpx
from ... import _legacy_response
from ..._types import Body, Omit, Query, Headers, NotGiven, FileTypes, omit, not_given
from ..._utils import extract_files, maybe_transform, deepcopy_minimal, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
from ...types.audio import translation_create_params
from ..._base_client import make_request_options
from ...types.audio_model import AudioModel
from ...types.audio.translation import Translation
from ...types.audio_response_format import AudioResponseFormat
from ...types.audio.translation_verbose import TranslationVerbose
__all__ = ["Translations", "AsyncTranslations"]
log: logging.Logger = logging.getLogger("openai.audio.transcriptions")
class Translations(SyncAPIResource):
"""Turn audio into text or text into audio."""
@cached_property
def with_raw_response(self) -> TranslationsWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return TranslationsWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> TranslationsWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return TranslationsWithStreamingResponse(self)
@overload
def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Union[Literal["json"], Omit] = omit,
prompt: str | Omit = omit,
temperature: float | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Translation: ...
@overload
def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["verbose_json"],
prompt: str | Omit = omit,
temperature: float | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> TranslationVerbose: ...
@overload
def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["text", "srt", "vtt"],
prompt: str | Omit = omit,
temperature: float | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> str: ...
def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
prompt: str | Omit = omit,
response_format: Union[Literal["json", "text", "srt", "verbose_json", "vtt"], Omit] = omit,
temperature: float | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Translation | TranslationVerbose | str:
"""
Translates audio into English.
Args:
file: The audio file object (not file name) translate, in one of these formats: flac,
mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
model: ID of the model to use. Only `whisper-1` (which is powered by our open source
Whisper V2 model) is currently available.
prompt: An optional text to guide the model's style or continue a previous audio
segment. The
[prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should be in English.
response_format: The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
output more random, while lower values like 0.2 will make it more focused and
deterministic. If set to 0, the model will use
[log probability](https://en.wikipedia.org/wiki/Log_probability) to
automatically increase the temperature until certain thresholds are hit.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
body = deepcopy_minimal(
{
"file": file,
"model": model,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature,
}
)
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
# It should be noted that the actual Content-Type header that will be
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post( # type: ignore[return-value]
"/audio/translations",
body=maybe_transform(body, translation_create_params.TranslationCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=_get_response_format_type(response_format),
)
class AsyncTranslations(AsyncAPIResource):
"""Turn audio into text or text into audio."""
@cached_property
def with_raw_response(self) -> AsyncTranslationsWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AsyncTranslationsWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncTranslationsWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AsyncTranslationsWithStreamingResponse(self)
@overload
async def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Union[Literal["json"], Omit] = omit,
prompt: str | Omit = omit,
temperature: float | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Translation: ...
@overload
async def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["verbose_json"],
prompt: str | Omit = omit,
temperature: float | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> TranslationVerbose: ...
@overload
async def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["text", "srt", "vtt"],
prompt: str | Omit = omit,
temperature: float | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> str: ...
async def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
prompt: str | Omit = omit,
response_format: Union[AudioResponseFormat, Omit] = omit,
temperature: float | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Translation | TranslationVerbose | str:
"""
Translates audio into English.
Args:
file: The audio file object (not file name) translate, in one of these formats: flac,
mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
model: ID of the model to use. Only `whisper-1` (which is powered by our open source
Whisper V2 model) is currently available.
prompt: An optional text to guide the model's style or continue a previous audio
segment. The
[prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should be in English.
response_format: The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
output more random, while lower values like 0.2 will make it more focused and
deterministic. If set to 0, the model will use
[log probability](https://en.wikipedia.org/wiki/Log_probability) to
automatically increase the temperature until certain thresholds are hit.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
body = deepcopy_minimal(
{
"file": file,
"model": model,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature,
}
)
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
# It should be noted that the actual Content-Type header that will be
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/audio/translations",
body=await async_maybe_transform(body, translation_create_params.TranslationCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=_get_response_format_type(response_format),
)
class TranslationsWithRawResponse:
def __init__(self, translations: Translations) -> None:
self._translations = translations
self.create = _legacy_response.to_raw_response_wrapper(
translations.create,
)
class AsyncTranslationsWithRawResponse:
def __init__(self, translations: AsyncTranslations) -> None:
self._translations = translations
self.create = _legacy_response.async_to_raw_response_wrapper(
translations.create,
)
class TranslationsWithStreamingResponse:
def __init__(self, translations: Translations) -> None:
self._translations = translations
self.create = to_streamed_response_wrapper(
translations.create,
)
class AsyncTranslationsWithStreamingResponse:
def __init__(self, translations: AsyncTranslations) -> None:
self._translations = translations
self.create = async_to_streamed_response_wrapper(
translations.create,
)
def _get_response_format_type(
response_format: AudioResponseFormat | Omit,
) -> type[Translation | TranslationVerbose | str]:
if isinstance(response_format, Omit) or response_format is None: # pyright: ignore[reportUnnecessaryComparison]
return Translation
if response_format == "json":
return Translation
elif response_format == "verbose_json":
return TranslationVerbose
elif response_format == "srt" or response_format == "text" or response_format == "vtt":
return str
elif TYPE_CHECKING and response_format != "diarized_json": # type: ignore[unreachable]
assert_never(response_format)
else:
log.warning("Unexpected audio response format: %s", response_format)
return Translation

View File

@@ -0,0 +1,546 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Optional
from typing_extensions import Literal
import httpx
from .. import _legacy_response
from ..types import batch_list_params, batch_create_params
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from .._utils import maybe_transform, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
from ..pagination import SyncCursorPage, AsyncCursorPage
from ..types.batch import Batch
from .._base_client import AsyncPaginator, make_request_options
from ..types.shared_params.metadata import Metadata
__all__ = ["Batches", "AsyncBatches"]
class Batches(SyncAPIResource):
"""Create large batches of API requests to run asynchronously."""
@cached_property
def with_raw_response(self) -> BatchesWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return BatchesWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> BatchesWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return BatchesWithStreamingResponse(self)
def create(
self,
*,
completion_window: Literal["24h"],
endpoint: Literal[
"/v1/responses",
"/v1/chat/completions",
"/v1/embeddings",
"/v1/completions",
"/v1/moderations",
"/v1/images/generations",
"/v1/images/edits",
],
input_file_id: str,
metadata: Optional[Metadata] | Omit = omit,
output_expires_after: batch_create_params.OutputExpiresAfter | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Batch:
"""
Creates and executes a batch from an uploaded file of requests
Args:
completion_window: The time frame within which the batch should be processed. Currently only `24h`
is supported.
endpoint: The endpoint to be used for all requests in the batch. Currently
`/v1/responses`, `/v1/chat/completions`, `/v1/embeddings`, `/v1/completions`,
`/v1/moderations`, `/v1/images/generations`, and `/v1/images/edits` are
supported. Note that `/v1/embeddings` batches are also restricted to a maximum
of 50,000 embedding inputs across all requests in the batch.
input_file_id: The ID of an uploaded file that contains requests for the new batch.
See [upload file](https://platform.openai.com/docs/api-reference/files/create)
for how to upload a file.
Your input file must be formatted as a
[JSONL file](https://platform.openai.com/docs/api-reference/batch/request-input),
and must be uploaded with the purpose `batch`. The file can contain up to 50,000
requests, and can be up to 200 MB in size.
metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful
for storing additional information about the object in a structured format, and
querying for objects via API or the dashboard.
Keys are strings with a maximum length of 64 characters. Values are strings with
a maximum length of 512 characters.
output_expires_after: The expiration policy for the output and/or error file that are generated for a
batch.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return self._post(
"/batches",
body=maybe_transform(
{
"completion_window": completion_window,
"endpoint": endpoint,
"input_file_id": input_file_id,
"metadata": metadata,
"output_expires_after": output_expires_after,
},
batch_create_params.BatchCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Batch,
)
def retrieve(
self,
batch_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Batch:
"""
Retrieves a batch.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._get(
f"/batches/{batch_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Batch,
)
def list(
self,
*,
after: str | Omit = omit,
limit: int | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> SyncCursorPage[Batch]:
"""List your organization's batches.
Args:
after: A cursor for use in pagination.
`after` is an object ID that defines your place
in the list. For instance, if you make a list request and receive 100 objects,
ending with obj_foo, your subsequent call can include after=obj_foo in order to
fetch the next page of the list.
limit: A limit on the number of objects to be returned. Limit can range between 1 and
100, and the default is 20.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return self._get_api_list(
"/batches",
page=SyncCursorPage[Batch],
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"limit": limit,
},
batch_list_params.BatchListParams,
),
),
model=Batch,
)
def cancel(
self,
batch_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Batch:
"""Cancels an in-progress batch.
The batch will be in status `cancelling` for up to
10 minutes, before changing to `cancelled`, where it will have partial results
(if any) available in the output file.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._post(
f"/batches/{batch_id}/cancel",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Batch,
)
class AsyncBatches(AsyncAPIResource):
"""Create large batches of API requests to run asynchronously."""
@cached_property
def with_raw_response(self) -> AsyncBatchesWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AsyncBatchesWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncBatchesWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AsyncBatchesWithStreamingResponse(self)
async def create(
self,
*,
completion_window: Literal["24h"],
endpoint: Literal[
"/v1/responses",
"/v1/chat/completions",
"/v1/embeddings",
"/v1/completions",
"/v1/moderations",
"/v1/images/generations",
"/v1/images/edits",
],
input_file_id: str,
metadata: Optional[Metadata] | Omit = omit,
output_expires_after: batch_create_params.OutputExpiresAfter | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Batch:
"""
Creates and executes a batch from an uploaded file of requests
Args:
completion_window: The time frame within which the batch should be processed. Currently only `24h`
is supported.
endpoint: The endpoint to be used for all requests in the batch. Currently
`/v1/responses`, `/v1/chat/completions`, `/v1/embeddings`, `/v1/completions`,
`/v1/moderations`, `/v1/images/generations`, and `/v1/images/edits` are
supported. Note that `/v1/embeddings` batches are also restricted to a maximum
of 50,000 embedding inputs across all requests in the batch.
input_file_id: The ID of an uploaded file that contains requests for the new batch.
See [upload file](https://platform.openai.com/docs/api-reference/files/create)
for how to upload a file.
Your input file must be formatted as a
[JSONL file](https://platform.openai.com/docs/api-reference/batch/request-input),
and must be uploaded with the purpose `batch`. The file can contain up to 50,000
requests, and can be up to 200 MB in size.
metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful
for storing additional information about the object in a structured format, and
querying for objects via API or the dashboard.
Keys are strings with a maximum length of 64 characters. Values are strings with
a maximum length of 512 characters.
output_expires_after: The expiration policy for the output and/or error file that are generated for a
batch.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return await self._post(
"/batches",
body=await async_maybe_transform(
{
"completion_window": completion_window,
"endpoint": endpoint,
"input_file_id": input_file_id,
"metadata": metadata,
"output_expires_after": output_expires_after,
},
batch_create_params.BatchCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Batch,
)
async def retrieve(
self,
batch_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Batch:
"""
Retrieves a batch.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return await self._get(
f"/batches/{batch_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Batch,
)
def list(
self,
*,
after: str | Omit = omit,
limit: int | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> AsyncPaginator[Batch, AsyncCursorPage[Batch]]:
"""List your organization's batches.
Args:
after: A cursor for use in pagination.
`after` is an object ID that defines your place
in the list. For instance, if you make a list request and receive 100 objects,
ending with obj_foo, your subsequent call can include after=obj_foo in order to
fetch the next page of the list.
limit: A limit on the number of objects to be returned. Limit can range between 1 and
100, and the default is 20.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
return self._get_api_list(
"/batches",
page=AsyncCursorPage[Batch],
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"limit": limit,
},
batch_list_params.BatchListParams,
),
),
model=Batch,
)
async def cancel(
self,
batch_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Batch:
"""Cancels an in-progress batch.
The batch will be in status `cancelling` for up to
10 minutes, before changing to `cancelled`, where it will have partial results
(if any) available in the output file.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return await self._post(
f"/batches/{batch_id}/cancel",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Batch,
)
class BatchesWithRawResponse:
def __init__(self, batches: Batches) -> None:
self._batches = batches
self.create = _legacy_response.to_raw_response_wrapper(
batches.create,
)
self.retrieve = _legacy_response.to_raw_response_wrapper(
batches.retrieve,
)
self.list = _legacy_response.to_raw_response_wrapper(
batches.list,
)
self.cancel = _legacy_response.to_raw_response_wrapper(
batches.cancel,
)
class AsyncBatchesWithRawResponse:
def __init__(self, batches: AsyncBatches) -> None:
self._batches = batches
self.create = _legacy_response.async_to_raw_response_wrapper(
batches.create,
)
self.retrieve = _legacy_response.async_to_raw_response_wrapper(
batches.retrieve,
)
self.list = _legacy_response.async_to_raw_response_wrapper(
batches.list,
)
self.cancel = _legacy_response.async_to_raw_response_wrapper(
batches.cancel,
)
class BatchesWithStreamingResponse:
def __init__(self, batches: Batches) -> None:
self._batches = batches
self.create = to_streamed_response_wrapper(
batches.create,
)
self.retrieve = to_streamed_response_wrapper(
batches.retrieve,
)
self.list = to_streamed_response_wrapper(
batches.list,
)
self.cancel = to_streamed_response_wrapper(
batches.cancel,
)
class AsyncBatchesWithStreamingResponse:
def __init__(self, batches: AsyncBatches) -> None:
self._batches = batches
self.create = async_to_streamed_response_wrapper(
batches.create,
)
self.retrieve = async_to_streamed_response_wrapper(
batches.retrieve,
)
self.list = async_to_streamed_response_wrapper(
batches.list,
)
self.cancel = async_to_streamed_response_wrapper(
batches.cancel,
)

View File

@@ -0,0 +1,61 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from .beta import (
Beta,
AsyncBeta,
BetaWithRawResponse,
AsyncBetaWithRawResponse,
BetaWithStreamingResponse,
AsyncBetaWithStreamingResponse,
)
from .chatkit import (
ChatKit,
AsyncChatKit,
ChatKitWithRawResponse,
AsyncChatKitWithRawResponse,
ChatKitWithStreamingResponse,
AsyncChatKitWithStreamingResponse,
)
from .threads import (
Threads,
AsyncThreads,
ThreadsWithRawResponse,
AsyncThreadsWithRawResponse,
ThreadsWithStreamingResponse,
AsyncThreadsWithStreamingResponse,
)
from .assistants import (
Assistants,
AsyncAssistants,
AssistantsWithRawResponse,
AsyncAssistantsWithRawResponse,
AssistantsWithStreamingResponse,
AsyncAssistantsWithStreamingResponse,
)
__all__ = [
"ChatKit",
"AsyncChatKit",
"ChatKitWithRawResponse",
"AsyncChatKitWithRawResponse",
"ChatKitWithStreamingResponse",
"AsyncChatKitWithStreamingResponse",
"Assistants",
"AsyncAssistants",
"AssistantsWithRawResponse",
"AsyncAssistantsWithRawResponse",
"AssistantsWithStreamingResponse",
"AsyncAssistantsWithStreamingResponse",
"Threads",
"AsyncThreads",
"ThreadsWithRawResponse",
"AsyncThreadsWithRawResponse",
"ThreadsWithStreamingResponse",
"AsyncThreadsWithStreamingResponse",
"Beta",
"AsyncBeta",
"BetaWithRawResponse",
"AsyncBetaWithRawResponse",
"BetaWithStreamingResponse",
"AsyncBetaWithStreamingResponse",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,199 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from ..._compat import cached_property
from .assistants import (
Assistants,
AsyncAssistants,
AssistantsWithRawResponse,
AsyncAssistantsWithRawResponse,
AssistantsWithStreamingResponse,
AsyncAssistantsWithStreamingResponse,
)
from ..._resource import SyncAPIResource, AsyncAPIResource
from .chatkit.chatkit import (
ChatKit,
AsyncChatKit,
ChatKitWithRawResponse,
AsyncChatKitWithRawResponse,
ChatKitWithStreamingResponse,
AsyncChatKitWithStreamingResponse,
)
from .threads.threads import (
Threads,
AsyncThreads,
ThreadsWithRawResponse,
AsyncThreadsWithRawResponse,
ThreadsWithStreamingResponse,
AsyncThreadsWithStreamingResponse,
)
from ...resources.chat import Chat, AsyncChat
from .realtime.realtime import (
Realtime,
AsyncRealtime,
)
__all__ = ["Beta", "AsyncBeta"]
class Beta(SyncAPIResource):
@cached_property
def chat(self) -> Chat:
return Chat(self._client)
@cached_property
def realtime(self) -> Realtime:
return Realtime(self._client)
@cached_property
def chatkit(self) -> ChatKit:
return ChatKit(self._client)
@cached_property
def assistants(self) -> Assistants:
"""Build Assistants that can call models and use tools."""
return Assistants(self._client)
@cached_property
def threads(self) -> Threads:
"""Build Assistants that can call models and use tools."""
return Threads(self._client)
@cached_property
def with_raw_response(self) -> BetaWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return BetaWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> BetaWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return BetaWithStreamingResponse(self)
class AsyncBeta(AsyncAPIResource):
@cached_property
def chat(self) -> AsyncChat:
return AsyncChat(self._client)
@cached_property
def realtime(self) -> AsyncRealtime:
return AsyncRealtime(self._client)
@cached_property
def chatkit(self) -> AsyncChatKit:
return AsyncChatKit(self._client)
@cached_property
def assistants(self) -> AsyncAssistants:
"""Build Assistants that can call models and use tools."""
return AsyncAssistants(self._client)
@cached_property
def threads(self) -> AsyncThreads:
"""Build Assistants that can call models and use tools."""
return AsyncThreads(self._client)
@cached_property
def with_raw_response(self) -> AsyncBetaWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AsyncBetaWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncBetaWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AsyncBetaWithStreamingResponse(self)
class BetaWithRawResponse:
def __init__(self, beta: Beta) -> None:
self._beta = beta
@cached_property
def chatkit(self) -> ChatKitWithRawResponse:
return ChatKitWithRawResponse(self._beta.chatkit)
@cached_property
def assistants(self) -> AssistantsWithRawResponse:
"""Build Assistants that can call models and use tools."""
return AssistantsWithRawResponse(self._beta.assistants)
@cached_property
def threads(self) -> ThreadsWithRawResponse:
"""Build Assistants that can call models and use tools."""
return ThreadsWithRawResponse(self._beta.threads)
class AsyncBetaWithRawResponse:
def __init__(self, beta: AsyncBeta) -> None:
self._beta = beta
@cached_property
def chatkit(self) -> AsyncChatKitWithRawResponse:
return AsyncChatKitWithRawResponse(self._beta.chatkit)
@cached_property
def assistants(self) -> AsyncAssistantsWithRawResponse:
"""Build Assistants that can call models and use tools."""
return AsyncAssistantsWithRawResponse(self._beta.assistants)
@cached_property
def threads(self) -> AsyncThreadsWithRawResponse:
"""Build Assistants that can call models and use tools."""
return AsyncThreadsWithRawResponse(self._beta.threads)
class BetaWithStreamingResponse:
def __init__(self, beta: Beta) -> None:
self._beta = beta
@cached_property
def chatkit(self) -> ChatKitWithStreamingResponse:
return ChatKitWithStreamingResponse(self._beta.chatkit)
@cached_property
def assistants(self) -> AssistantsWithStreamingResponse:
"""Build Assistants that can call models and use tools."""
return AssistantsWithStreamingResponse(self._beta.assistants)
@cached_property
def threads(self) -> ThreadsWithStreamingResponse:
"""Build Assistants that can call models and use tools."""
return ThreadsWithStreamingResponse(self._beta.threads)
class AsyncBetaWithStreamingResponse:
def __init__(self, beta: AsyncBeta) -> None:
self._beta = beta
@cached_property
def chatkit(self) -> AsyncChatKitWithStreamingResponse:
return AsyncChatKitWithStreamingResponse(self._beta.chatkit)
@cached_property
def assistants(self) -> AsyncAssistantsWithStreamingResponse:
"""Build Assistants that can call models and use tools."""
return AsyncAssistantsWithStreamingResponse(self._beta.assistants)
@cached_property
def threads(self) -> AsyncThreadsWithStreamingResponse:
"""Build Assistants that can call models and use tools."""
return AsyncThreadsWithStreamingResponse(self._beta.threads)

View File

@@ -0,0 +1,47 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from .chatkit import (
ChatKit,
AsyncChatKit,
ChatKitWithRawResponse,
AsyncChatKitWithRawResponse,
ChatKitWithStreamingResponse,
AsyncChatKitWithStreamingResponse,
)
from .threads import (
Threads,
AsyncThreads,
ThreadsWithRawResponse,
AsyncThreadsWithRawResponse,
ThreadsWithStreamingResponse,
AsyncThreadsWithStreamingResponse,
)
from .sessions import (
Sessions,
AsyncSessions,
SessionsWithRawResponse,
AsyncSessionsWithRawResponse,
SessionsWithStreamingResponse,
AsyncSessionsWithStreamingResponse,
)
__all__ = [
"Sessions",
"AsyncSessions",
"SessionsWithRawResponse",
"AsyncSessionsWithRawResponse",
"SessionsWithStreamingResponse",
"AsyncSessionsWithStreamingResponse",
"Threads",
"AsyncThreads",
"ThreadsWithRawResponse",
"AsyncThreadsWithRawResponse",
"ThreadsWithStreamingResponse",
"AsyncThreadsWithStreamingResponse",
"ChatKit",
"AsyncChatKit",
"ChatKitWithRawResponse",
"AsyncChatKitWithRawResponse",
"ChatKitWithStreamingResponse",
"AsyncChatKitWithStreamingResponse",
]

View File

@@ -0,0 +1,134 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from .threads import (
Threads,
AsyncThreads,
ThreadsWithRawResponse,
AsyncThreadsWithRawResponse,
ThreadsWithStreamingResponse,
AsyncThreadsWithStreamingResponse,
)
from .sessions import (
Sessions,
AsyncSessions,
SessionsWithRawResponse,
AsyncSessionsWithRawResponse,
SessionsWithStreamingResponse,
AsyncSessionsWithStreamingResponse,
)
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
__all__ = ["ChatKit", "AsyncChatKit"]
class ChatKit(SyncAPIResource):
@cached_property
def sessions(self) -> Sessions:
return Sessions(self._client)
@cached_property
def threads(self) -> Threads:
return Threads(self._client)
@cached_property
def with_raw_response(self) -> ChatKitWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return ChatKitWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> ChatKitWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return ChatKitWithStreamingResponse(self)
class AsyncChatKit(AsyncAPIResource):
@cached_property
def sessions(self) -> AsyncSessions:
return AsyncSessions(self._client)
@cached_property
def threads(self) -> AsyncThreads:
return AsyncThreads(self._client)
@cached_property
def with_raw_response(self) -> AsyncChatKitWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AsyncChatKitWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncChatKitWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AsyncChatKitWithStreamingResponse(self)
class ChatKitWithRawResponse:
def __init__(self, chatkit: ChatKit) -> None:
self._chatkit = chatkit
@cached_property
def sessions(self) -> SessionsWithRawResponse:
return SessionsWithRawResponse(self._chatkit.sessions)
@cached_property
def threads(self) -> ThreadsWithRawResponse:
return ThreadsWithRawResponse(self._chatkit.threads)
class AsyncChatKitWithRawResponse:
def __init__(self, chatkit: AsyncChatKit) -> None:
self._chatkit = chatkit
@cached_property
def sessions(self) -> AsyncSessionsWithRawResponse:
return AsyncSessionsWithRawResponse(self._chatkit.sessions)
@cached_property
def threads(self) -> AsyncThreadsWithRawResponse:
return AsyncThreadsWithRawResponse(self._chatkit.threads)
class ChatKitWithStreamingResponse:
def __init__(self, chatkit: ChatKit) -> None:
self._chatkit = chatkit
@cached_property
def sessions(self) -> SessionsWithStreamingResponse:
return SessionsWithStreamingResponse(self._chatkit.sessions)
@cached_property
def threads(self) -> ThreadsWithStreamingResponse:
return ThreadsWithStreamingResponse(self._chatkit.threads)
class AsyncChatKitWithStreamingResponse:
def __init__(self, chatkit: AsyncChatKit) -> None:
self._chatkit = chatkit
@cached_property
def sessions(self) -> AsyncSessionsWithStreamingResponse:
return AsyncSessionsWithStreamingResponse(self._chatkit.sessions)
@cached_property
def threads(self) -> AsyncThreadsWithStreamingResponse:
return AsyncThreadsWithStreamingResponse(self._chatkit.threads)

View File

@@ -0,0 +1,305 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import httpx
from .... import _legacy_response
from ...._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from ...._utils import maybe_transform, async_maybe_transform
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
from ...._base_client import make_request_options
from ....types.beta.chatkit import (
ChatSessionWorkflowParam,
ChatSessionRateLimitsParam,
ChatSessionExpiresAfterParam,
ChatSessionChatKitConfigurationParam,
session_create_params,
)
from ....types.beta.chatkit.chat_session import ChatSession
from ....types.beta.chatkit.chat_session_workflow_param import ChatSessionWorkflowParam
from ....types.beta.chatkit.chat_session_rate_limits_param import ChatSessionRateLimitsParam
from ....types.beta.chatkit.chat_session_expires_after_param import ChatSessionExpiresAfterParam
from ....types.beta.chatkit.chat_session_chatkit_configuration_param import ChatSessionChatKitConfigurationParam
__all__ = ["Sessions", "AsyncSessions"]
class Sessions(SyncAPIResource):
@cached_property
def with_raw_response(self) -> SessionsWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return SessionsWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> SessionsWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return SessionsWithStreamingResponse(self)
def create(
self,
*,
user: str,
workflow: ChatSessionWorkflowParam,
chatkit_configuration: ChatSessionChatKitConfigurationParam | Omit = omit,
expires_after: ChatSessionExpiresAfterParam | Omit = omit,
rate_limits: ChatSessionRateLimitsParam | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ChatSession:
"""
Create a ChatKit session.
Args:
user: A free-form string that identifies your end user; ensures this Session can
access other objects that have the same `user` scope.
workflow: Workflow that powers the session.
chatkit_configuration: Optional overrides for ChatKit runtime configuration features
expires_after: Optional override for session expiration timing in seconds from creation.
Defaults to 10 minutes.
rate_limits: Optional override for per-minute request limits. When omitted, defaults to 10.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return self._post(
"/chatkit/sessions",
body=maybe_transform(
{
"user": user,
"workflow": workflow,
"chatkit_configuration": chatkit_configuration,
"expires_after": expires_after,
"rate_limits": rate_limits,
},
session_create_params.SessionCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatSession,
)
def cancel(
self,
session_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ChatSession:
"""
Cancel an active ChatKit session and return its most recent metadata.
Cancelling prevents new requests from using the issued client secret.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not session_id:
raise ValueError(f"Expected a non-empty value for `session_id` but received {session_id!r}")
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return self._post(
f"/chatkit/sessions/{session_id}/cancel",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatSession,
)
class AsyncSessions(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncSessionsWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AsyncSessionsWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncSessionsWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AsyncSessionsWithStreamingResponse(self)
async def create(
self,
*,
user: str,
workflow: ChatSessionWorkflowParam,
chatkit_configuration: ChatSessionChatKitConfigurationParam | Omit = omit,
expires_after: ChatSessionExpiresAfterParam | Omit = omit,
rate_limits: ChatSessionRateLimitsParam | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ChatSession:
"""
Create a ChatKit session.
Args:
user: A free-form string that identifies your end user; ensures this Session can
access other objects that have the same `user` scope.
workflow: Workflow that powers the session.
chatkit_configuration: Optional overrides for ChatKit runtime configuration features
expires_after: Optional override for session expiration timing in seconds from creation.
Defaults to 10 minutes.
rate_limits: Optional override for per-minute request limits. When omitted, defaults to 10.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return await self._post(
"/chatkit/sessions",
body=await async_maybe_transform(
{
"user": user,
"workflow": workflow,
"chatkit_configuration": chatkit_configuration,
"expires_after": expires_after,
"rate_limits": rate_limits,
},
session_create_params.SessionCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatSession,
)
async def cancel(
self,
session_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ChatSession:
"""
Cancel an active ChatKit session and return its most recent metadata.
Cancelling prevents new requests from using the issued client secret.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not session_id:
raise ValueError(f"Expected a non-empty value for `session_id` but received {session_id!r}")
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return await self._post(
f"/chatkit/sessions/{session_id}/cancel",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatSession,
)
class SessionsWithRawResponse:
def __init__(self, sessions: Sessions) -> None:
self._sessions = sessions
self.create = _legacy_response.to_raw_response_wrapper(
sessions.create,
)
self.cancel = _legacy_response.to_raw_response_wrapper(
sessions.cancel,
)
class AsyncSessionsWithRawResponse:
def __init__(self, sessions: AsyncSessions) -> None:
self._sessions = sessions
self.create = _legacy_response.async_to_raw_response_wrapper(
sessions.create,
)
self.cancel = _legacy_response.async_to_raw_response_wrapper(
sessions.cancel,
)
class SessionsWithStreamingResponse:
def __init__(self, sessions: Sessions) -> None:
self._sessions = sessions
self.create = to_streamed_response_wrapper(
sessions.create,
)
self.cancel = to_streamed_response_wrapper(
sessions.cancel,
)
class AsyncSessionsWithStreamingResponse:
def __init__(self, sessions: AsyncSessions) -> None:
self._sessions = sessions
self.create = async_to_streamed_response_wrapper(
sessions.create,
)
self.cancel = async_to_streamed_response_wrapper(
sessions.cancel,
)

View File

@@ -0,0 +1,521 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Any, cast
from typing_extensions import Literal
import httpx
from .... import _legacy_response
from ...._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
from ...._utils import maybe_transform
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
from ....pagination import SyncConversationCursorPage, AsyncConversationCursorPage
from ...._base_client import AsyncPaginator, make_request_options
from ....types.beta.chatkit import thread_list_params, thread_list_items_params
from ....types.beta.chatkit.chatkit_thread import ChatKitThread
from ....types.beta.chatkit.thread_delete_response import ThreadDeleteResponse
from ....types.beta.chatkit.chatkit_thread_item_list import Data
__all__ = ["Threads", "AsyncThreads"]
class Threads(SyncAPIResource):
@cached_property
def with_raw_response(self) -> ThreadsWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return ThreadsWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> ThreadsWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return ThreadsWithStreamingResponse(self)
def retrieve(
self,
thread_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ChatKitThread:
"""
Retrieve a ChatKit thread by its identifier.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not thread_id:
raise ValueError(f"Expected a non-empty value for `thread_id` but received {thread_id!r}")
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return self._get(
f"/chatkit/threads/{thread_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatKitThread,
)
def list(
self,
*,
after: str | Omit = omit,
before: str | Omit = omit,
limit: int | Omit = omit,
order: Literal["asc", "desc"] | Omit = omit,
user: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> SyncConversationCursorPage[ChatKitThread]:
"""
List ChatKit threads with optional pagination and user filters.
Args:
after: List items created after this thread item ID. Defaults to null for the first
page.
before: List items created before this thread item ID. Defaults to null for the newest
results.
limit: Maximum number of thread items to return. Defaults to 20.
order: Sort order for results by creation time. Defaults to `desc`.
user: Filter threads that belong to this user identifier. Defaults to null to return
all users.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return self._get_api_list(
"/chatkit/threads",
page=SyncConversationCursorPage[ChatKitThread],
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"before": before,
"limit": limit,
"order": order,
"user": user,
},
thread_list_params.ThreadListParams,
),
),
model=ChatKitThread,
)
def delete(
self,
thread_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ThreadDeleteResponse:
"""
Delete a ChatKit thread along with its items and stored attachments.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not thread_id:
raise ValueError(f"Expected a non-empty value for `thread_id` but received {thread_id!r}")
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return self._delete(
f"/chatkit/threads/{thread_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ThreadDeleteResponse,
)
def list_items(
self,
thread_id: str,
*,
after: str | Omit = omit,
before: str | Omit = omit,
limit: int | Omit = omit,
order: Literal["asc", "desc"] | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> SyncConversationCursorPage[Data]:
"""
List items that belong to a ChatKit thread.
Args:
after: List items created after this thread item ID. Defaults to null for the first
page.
before: List items created before this thread item ID. Defaults to null for the newest
results.
limit: Maximum number of thread items to return. Defaults to 20.
order: Sort order for results by creation time. Defaults to `desc`.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not thread_id:
raise ValueError(f"Expected a non-empty value for `thread_id` but received {thread_id!r}")
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return self._get_api_list(
f"/chatkit/threads/{thread_id}/items",
page=SyncConversationCursorPage[Data],
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"before": before,
"limit": limit,
"order": order,
},
thread_list_items_params.ThreadListItemsParams,
),
),
model=cast(Any, Data), # Union types cannot be passed in as arguments in the type system
)
class AsyncThreads(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncThreadsWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AsyncThreadsWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncThreadsWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AsyncThreadsWithStreamingResponse(self)
async def retrieve(
self,
thread_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ChatKitThread:
"""
Retrieve a ChatKit thread by its identifier.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not thread_id:
raise ValueError(f"Expected a non-empty value for `thread_id` but received {thread_id!r}")
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return await self._get(
f"/chatkit/threads/{thread_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatKitThread,
)
def list(
self,
*,
after: str | Omit = omit,
before: str | Omit = omit,
limit: int | Omit = omit,
order: Literal["asc", "desc"] | Omit = omit,
user: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> AsyncPaginator[ChatKitThread, AsyncConversationCursorPage[ChatKitThread]]:
"""
List ChatKit threads with optional pagination and user filters.
Args:
after: List items created after this thread item ID. Defaults to null for the first
page.
before: List items created before this thread item ID. Defaults to null for the newest
results.
limit: Maximum number of thread items to return. Defaults to 20.
order: Sort order for results by creation time. Defaults to `desc`.
user: Filter threads that belong to this user identifier. Defaults to null to return
all users.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return self._get_api_list(
"/chatkit/threads",
page=AsyncConversationCursorPage[ChatKitThread],
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"before": before,
"limit": limit,
"order": order,
"user": user,
},
thread_list_params.ThreadListParams,
),
),
model=ChatKitThread,
)
async def delete(
self,
thread_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ThreadDeleteResponse:
"""
Delete a ChatKit thread along with its items and stored attachments.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not thread_id:
raise ValueError(f"Expected a non-empty value for `thread_id` but received {thread_id!r}")
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return await self._delete(
f"/chatkit/threads/{thread_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ThreadDeleteResponse,
)
def list_items(
self,
thread_id: str,
*,
after: str | Omit = omit,
before: str | Omit = omit,
limit: int | Omit = omit,
order: Literal["asc", "desc"] | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> AsyncPaginator[Data, AsyncConversationCursorPage[Data]]:
"""
List items that belong to a ChatKit thread.
Args:
after: List items created after this thread item ID. Defaults to null for the first
page.
before: List items created before this thread item ID. Defaults to null for the newest
results.
limit: Maximum number of thread items to return. Defaults to 20.
order: Sort order for results by creation time. Defaults to `desc`.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not thread_id:
raise ValueError(f"Expected a non-empty value for `thread_id` but received {thread_id!r}")
extra_headers = {"OpenAI-Beta": "chatkit_beta=v1", **(extra_headers or {})}
return self._get_api_list(
f"/chatkit/threads/{thread_id}/items",
page=AsyncConversationCursorPage[Data],
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"before": before,
"limit": limit,
"order": order,
},
thread_list_items_params.ThreadListItemsParams,
),
),
model=cast(Any, Data), # Union types cannot be passed in as arguments in the type system
)
class ThreadsWithRawResponse:
def __init__(self, threads: Threads) -> None:
self._threads = threads
self.retrieve = _legacy_response.to_raw_response_wrapper(
threads.retrieve,
)
self.list = _legacy_response.to_raw_response_wrapper(
threads.list,
)
self.delete = _legacy_response.to_raw_response_wrapper(
threads.delete,
)
self.list_items = _legacy_response.to_raw_response_wrapper(
threads.list_items,
)
class AsyncThreadsWithRawResponse:
def __init__(self, threads: AsyncThreads) -> None:
self._threads = threads
self.retrieve = _legacy_response.async_to_raw_response_wrapper(
threads.retrieve,
)
self.list = _legacy_response.async_to_raw_response_wrapper(
threads.list,
)
self.delete = _legacy_response.async_to_raw_response_wrapper(
threads.delete,
)
self.list_items = _legacy_response.async_to_raw_response_wrapper(
threads.list_items,
)
class ThreadsWithStreamingResponse:
def __init__(self, threads: Threads) -> None:
self._threads = threads
self.retrieve = to_streamed_response_wrapper(
threads.retrieve,
)
self.list = to_streamed_response_wrapper(
threads.list,
)
self.delete = to_streamed_response_wrapper(
threads.delete,
)
self.list_items = to_streamed_response_wrapper(
threads.list_items,
)
class AsyncThreadsWithStreamingResponse:
def __init__(self, threads: AsyncThreads) -> None:
self._threads = threads
self.retrieve = async_to_streamed_response_wrapper(
threads.retrieve,
)
self.list = async_to_streamed_response_wrapper(
threads.list,
)
self.delete = async_to_streamed_response_wrapper(
threads.delete,
)
self.list_items = async_to_streamed_response_wrapper(
threads.list_items,
)

View File

@@ -0,0 +1,47 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from .realtime import (
Realtime,
AsyncRealtime,
RealtimeWithRawResponse,
AsyncRealtimeWithRawResponse,
RealtimeWithStreamingResponse,
AsyncRealtimeWithStreamingResponse,
)
from .sessions import (
Sessions,
AsyncSessions,
SessionsWithRawResponse,
AsyncSessionsWithRawResponse,
SessionsWithStreamingResponse,
AsyncSessionsWithStreamingResponse,
)
from .transcription_sessions import (
TranscriptionSessions,
AsyncTranscriptionSessions,
TranscriptionSessionsWithRawResponse,
AsyncTranscriptionSessionsWithRawResponse,
TranscriptionSessionsWithStreamingResponse,
AsyncTranscriptionSessionsWithStreamingResponse,
)
__all__ = [
"Sessions",
"AsyncSessions",
"SessionsWithRawResponse",
"AsyncSessionsWithRawResponse",
"SessionsWithStreamingResponse",
"AsyncSessionsWithStreamingResponse",
"TranscriptionSessions",
"AsyncTranscriptionSessions",
"TranscriptionSessionsWithRawResponse",
"AsyncTranscriptionSessionsWithRawResponse",
"TranscriptionSessionsWithStreamingResponse",
"AsyncTranscriptionSessionsWithStreamingResponse",
"Realtime",
"AsyncRealtime",
"RealtimeWithRawResponse",
"AsyncRealtimeWithRawResponse",
"RealtimeWithStreamingResponse",
"AsyncRealtimeWithStreamingResponse",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,424 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import List, Union, Iterable
from typing_extensions import Literal
import httpx
from .... import _legacy_response
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import maybe_transform, async_maybe_transform
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
from ...._base_client import make_request_options
from ....types.beta.realtime import session_create_params
from ....types.beta.realtime.session_create_response import SessionCreateResponse
__all__ = ["Sessions", "AsyncSessions"]
class Sessions(SyncAPIResource):
@cached_property
def with_raw_response(self) -> SessionsWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return SessionsWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> SessionsWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return SessionsWithStreamingResponse(self)
def create(
self,
*,
client_secret: session_create_params.ClientSecret | NotGiven = NOT_GIVEN,
input_audio_format: Literal["pcm16", "g711_ulaw", "g711_alaw"] | NotGiven = NOT_GIVEN,
input_audio_noise_reduction: session_create_params.InputAudioNoiseReduction | NotGiven = NOT_GIVEN,
input_audio_transcription: session_create_params.InputAudioTranscription | NotGiven = NOT_GIVEN,
instructions: str | NotGiven = NOT_GIVEN,
max_response_output_tokens: Union[int, Literal["inf"]] | NotGiven = NOT_GIVEN,
modalities: List[Literal["text", "audio"]] | NotGiven = NOT_GIVEN,
model: Literal[
"gpt-realtime",
"gpt-realtime-2025-08-28",
"gpt-4o-realtime-preview",
"gpt-4o-realtime-preview-2024-10-01",
"gpt-4o-realtime-preview-2024-12-17",
"gpt-4o-realtime-preview-2025-06-03",
"gpt-4o-mini-realtime-preview",
"gpt-4o-mini-realtime-preview-2024-12-17",
]
| NotGiven = NOT_GIVEN,
output_audio_format: Literal["pcm16", "g711_ulaw", "g711_alaw"] | NotGiven = NOT_GIVEN,
speed: float | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
tools: Iterable[session_create_params.Tool] | NotGiven = NOT_GIVEN,
tracing: session_create_params.Tracing | NotGiven = NOT_GIVEN,
turn_detection: session_create_params.TurnDetection | NotGiven = NOT_GIVEN,
voice: Union[str, Literal["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse"]]
| NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> SessionCreateResponse:
"""
Create an ephemeral API token for use in client-side applications with the
Realtime API. Can be configured with the same session parameters as the
`session.update` client event.
It responds with a session object, plus a `client_secret` key which contains a
usable ephemeral API token that can be used to authenticate browser clients for
the Realtime API.
Args:
client_secret: Configuration options for the generated client secret.
input_audio_format: The format of input audio. Options are `pcm16`, `g711_ulaw`, or `g711_alaw`. For
`pcm16`, input audio must be 16-bit PCM at a 24kHz sample rate, single channel
(mono), and little-endian byte order.
input_audio_noise_reduction: Configuration for input audio noise reduction. This can be set to `null` to turn
off. Noise reduction filters audio added to the input audio buffer before it is
sent to VAD and the model. Filtering the audio can improve VAD and turn
detection accuracy (reducing false positives) and model performance by improving
perception of the input audio.
input_audio_transcription: Configuration for input audio transcription, defaults to off and can be set to
`null` to turn off once on. Input audio transcription is not native to the
model, since the model consumes audio directly. Transcription runs
asynchronously through
[the /audio/transcriptions endpoint](https://platform.openai.com/docs/api-reference/audio/createTranscription)
and should be treated as guidance of input audio content rather than precisely
what the model heard. The client can optionally set the language and prompt for
transcription, these offer additional guidance to the transcription service.
instructions: The default system instructions (i.e. system message) prepended to model calls.
This field allows the client to guide the model on desired responses. The model
can be instructed on response content and format, (e.g. "be extremely succinct",
"act friendly", "here are examples of good responses") and on audio behavior
(e.g. "talk quickly", "inject emotion into your voice", "laugh frequently"). The
instructions are not guaranteed to be followed by the model, but they provide
guidance to the model on the desired behavior.
Note that the server sets default instructions which will be used if this field
is not set and are visible in the `session.created` event at the start of the
session.
max_response_output_tokens: Maximum number of output tokens for a single assistant response, inclusive of
tool calls. Provide an integer between 1 and 4096 to limit output tokens, or
`inf` for the maximum available tokens for a given model. Defaults to `inf`.
modalities: The set of modalities the model can respond with. To disable audio, set this to
["text"].
model: The Realtime model used for this session.
output_audio_format: The format of output audio. Options are `pcm16`, `g711_ulaw`, or `g711_alaw`.
For `pcm16`, output audio is sampled at a rate of 24kHz.
speed: The speed of the model's spoken response. 1.0 is the default speed. 0.25 is the
minimum speed. 1.5 is the maximum speed. This value can only be changed in
between model turns, not while a response is in progress.
temperature: Sampling temperature for the model, limited to [0.6, 1.2]. For audio models a
temperature of 0.8 is highly recommended for best performance.
tool_choice: How the model chooses tools. Options are `auto`, `none`, `required`, or specify
a function.
tools: Tools (functions) available to the model.
tracing: Configuration options for tracing. Set to null to disable tracing. Once tracing
is enabled for a session, the configuration cannot be modified.
`auto` will create a trace for the session with default values for the workflow
name, group id, and metadata.
turn_detection: Configuration for turn detection, ether Server VAD or Semantic VAD. This can be
set to `null` to turn off, in which case the client must manually trigger model
response. Server VAD means that the model will detect the start and end of
speech based on audio volume and respond at the end of user speech. Semantic VAD
is more advanced and uses a turn detection model (in conjunction with VAD) to
semantically estimate whether the user has finished speaking, then dynamically
sets a timeout based on this probability. For example, if user audio trails off
with "uhhm", the model will score a low probability of turn end and wait longer
for the user to continue speaking. This can be useful for more natural
conversations, but may have a higher latency.
voice: The voice the model uses to respond. Voice cannot be changed during the session
once the model has responded with audio at least once. Current voice options are
`alloy`, `ash`, `ballad`, `coral`, `echo`, `sage`, `shimmer`, and `verse`.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"OpenAI-Beta": "assistants=v2", **(extra_headers or {})}
return self._post(
"/realtime/sessions",
body=maybe_transform(
{
"client_secret": client_secret,
"input_audio_format": input_audio_format,
"input_audio_noise_reduction": input_audio_noise_reduction,
"input_audio_transcription": input_audio_transcription,
"instructions": instructions,
"max_response_output_tokens": max_response_output_tokens,
"modalities": modalities,
"model": model,
"output_audio_format": output_audio_format,
"speed": speed,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"tracing": tracing,
"turn_detection": turn_detection,
"voice": voice,
},
session_create_params.SessionCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=SessionCreateResponse,
)
class AsyncSessions(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncSessionsWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
"""
return AsyncSessionsWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncSessionsWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
"""
return AsyncSessionsWithStreamingResponse(self)
async def create(
self,
*,
client_secret: session_create_params.ClientSecret | NotGiven = NOT_GIVEN,
input_audio_format: Literal["pcm16", "g711_ulaw", "g711_alaw"] | NotGiven = NOT_GIVEN,
input_audio_noise_reduction: session_create_params.InputAudioNoiseReduction | NotGiven = NOT_GIVEN,
input_audio_transcription: session_create_params.InputAudioTranscription | NotGiven = NOT_GIVEN,
instructions: str | NotGiven = NOT_GIVEN,
max_response_output_tokens: Union[int, Literal["inf"]] | NotGiven = NOT_GIVEN,
modalities: List[Literal["text", "audio"]] | NotGiven = NOT_GIVEN,
model: Literal[
"gpt-realtime",
"gpt-realtime-2025-08-28",
"gpt-4o-realtime-preview",
"gpt-4o-realtime-preview-2024-10-01",
"gpt-4o-realtime-preview-2024-12-17",
"gpt-4o-realtime-preview-2025-06-03",
"gpt-4o-mini-realtime-preview",
"gpt-4o-mini-realtime-preview-2024-12-17",
]
| NotGiven = NOT_GIVEN,
output_audio_format: Literal["pcm16", "g711_ulaw", "g711_alaw"] | NotGiven = NOT_GIVEN,
speed: float | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
tools: Iterable[session_create_params.Tool] | NotGiven = NOT_GIVEN,
tracing: session_create_params.Tracing | NotGiven = NOT_GIVEN,
turn_detection: session_create_params.TurnDetection | NotGiven = NOT_GIVEN,
voice: Union[str, Literal["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse"]]
| NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> SessionCreateResponse:
"""
Create an ephemeral API token for use in client-side applications with the
Realtime API. Can be configured with the same session parameters as the
`session.update` client event.
It responds with a session object, plus a `client_secret` key which contains a
usable ephemeral API token that can be used to authenticate browser clients for
the Realtime API.
Args:
client_secret: Configuration options for the generated client secret.
input_audio_format: The format of input audio. Options are `pcm16`, `g711_ulaw`, or `g711_alaw`. For
`pcm16`, input audio must be 16-bit PCM at a 24kHz sample rate, single channel
(mono), and little-endian byte order.
input_audio_noise_reduction: Configuration for input audio noise reduction. This can be set to `null` to turn
off. Noise reduction filters audio added to the input audio buffer before it is
sent to VAD and the model. Filtering the audio can improve VAD and turn
detection accuracy (reducing false positives) and model performance by improving
perception of the input audio.
input_audio_transcription: Configuration for input audio transcription, defaults to off and can be set to
`null` to turn off once on. Input audio transcription is not native to the
model, since the model consumes audio directly. Transcription runs
asynchronously through
[the /audio/transcriptions endpoint](https://platform.openai.com/docs/api-reference/audio/createTranscription)
and should be treated as guidance of input audio content rather than precisely
what the model heard. The client can optionally set the language and prompt for
transcription, these offer additional guidance to the transcription service.
instructions: The default system instructions (i.e. system message) prepended to model calls.
This field allows the client to guide the model on desired responses. The model
can be instructed on response content and format, (e.g. "be extremely succinct",
"act friendly", "here are examples of good responses") and on audio behavior
(e.g. "talk quickly", "inject emotion into your voice", "laugh frequently"). The
instructions are not guaranteed to be followed by the model, but they provide
guidance to the model on the desired behavior.
Note that the server sets default instructions which will be used if this field
is not set and are visible in the `session.created` event at the start of the
session.
max_response_output_tokens: Maximum number of output tokens for a single assistant response, inclusive of
tool calls. Provide an integer between 1 and 4096 to limit output tokens, or
`inf` for the maximum available tokens for a given model. Defaults to `inf`.
modalities: The set of modalities the model can respond with. To disable audio, set this to
["text"].
model: The Realtime model used for this session.
output_audio_format: The format of output audio. Options are `pcm16`, `g711_ulaw`, or `g711_alaw`.
For `pcm16`, output audio is sampled at a rate of 24kHz.
speed: The speed of the model's spoken response. 1.0 is the default speed. 0.25 is the
minimum speed. 1.5 is the maximum speed. This value can only be changed in
between model turns, not while a response is in progress.
temperature: Sampling temperature for the model, limited to [0.6, 1.2]. For audio models a
temperature of 0.8 is highly recommended for best performance.
tool_choice: How the model chooses tools. Options are `auto`, `none`, `required`, or specify
a function.
tools: Tools (functions) available to the model.
tracing: Configuration options for tracing. Set to null to disable tracing. Once tracing
is enabled for a session, the configuration cannot be modified.
`auto` will create a trace for the session with default values for the workflow
name, group id, and metadata.
turn_detection: Configuration for turn detection, ether Server VAD or Semantic VAD. This can be
set to `null` to turn off, in which case the client must manually trigger model
response. Server VAD means that the model will detect the start and end of
speech based on audio volume and respond at the end of user speech. Semantic VAD
is more advanced and uses a turn detection model (in conjunction with VAD) to
semantically estimate whether the user has finished speaking, then dynamically
sets a timeout based on this probability. For example, if user audio trails off
with "uhhm", the model will score a low probability of turn end and wait longer
for the user to continue speaking. This can be useful for more natural
conversations, but may have a higher latency.
voice: The voice the model uses to respond. Voice cannot be changed during the session
once the model has responded with audio at least once. Current voice options are
`alloy`, `ash`, `ballad`, `coral`, `echo`, `sage`, `shimmer`, and `verse`.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"OpenAI-Beta": "assistants=v2", **(extra_headers or {})}
return await self._post(
"/realtime/sessions",
body=await async_maybe_transform(
{
"client_secret": client_secret,
"input_audio_format": input_audio_format,
"input_audio_noise_reduction": input_audio_noise_reduction,
"input_audio_transcription": input_audio_transcription,
"instructions": instructions,
"max_response_output_tokens": max_response_output_tokens,
"modalities": modalities,
"model": model,
"output_audio_format": output_audio_format,
"speed": speed,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"tracing": tracing,
"turn_detection": turn_detection,
"voice": voice,
},
session_create_params.SessionCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=SessionCreateResponse,
)
class SessionsWithRawResponse:
def __init__(self, sessions: Sessions) -> None:
self._sessions = sessions
self.create = _legacy_response.to_raw_response_wrapper(
sessions.create,
)
class AsyncSessionsWithRawResponse:
def __init__(self, sessions: AsyncSessions) -> None:
self._sessions = sessions
self.create = _legacy_response.async_to_raw_response_wrapper(
sessions.create,
)
class SessionsWithStreamingResponse:
def __init__(self, sessions: Sessions) -> None:
self._sessions = sessions
self.create = to_streamed_response_wrapper(
sessions.create,
)
class AsyncSessionsWithStreamingResponse:
def __init__(self, sessions: AsyncSessions) -> None:
self._sessions = sessions
self.create = async_to_streamed_response_wrapper(
sessions.create,
)

Some files were not shown because too many files have changed in this diff Show More