fix: 修复代理问题

This commit is contained in:
丹尼尔
2026-03-15 17:16:05 +08:00
parent 8b62c445fc
commit 15c9e1772a
100 changed files with 6157 additions and 69 deletions

View File

@@ -0,0 +1,20 @@
from ._version import __version__, __title__
from ._types import ProxyType
from ._helpers import parse_proxy_url
from ._errors import (
ProxyError,
ProxyTimeoutError,
ProxyConnectionError,
)
__all__ = (
'__title__',
'__version__',
'ProxyError',
'ProxyTimeoutError',
'ProxyConnectionError',
'ProxyType',
'parse_proxy_url',
)

View File

@@ -0,0 +1,40 @@
from typing import Optional
class SyncResolver:
def resolve(self, host, port=0, family=0):
raise NotImplementedError()
class AsyncResolver:
async def resolve(self, host, port=0, family=0):
raise NotImplementedError()
class SyncSocketStream:
def write_all(self, data: bytes):
raise NotImplementedError()
def read(self, max_bytes: Optional[int] = None):
raise NotImplementedError()
def read_exact(self, n: int):
raise NotImplementedError()
def close(self):
raise NotImplementedError()
class AsyncSocketStream:
async def write_all(self, data: bytes):
raise NotImplementedError()
async def read(self, max_bytes: Optional[int] = None):
raise NotImplementedError()
async def read_exact(self, n: int):
raise NotImplementedError()
async def close(self):
raise NotImplementedError()

View File

@@ -0,0 +1,21 @@
from .._abc import SyncSocketStream, AsyncSocketStream
class SyncConnector:
def connect(
self,
stream: SyncSocketStream,
host: str,
port: int,
):
raise NotImplementedError
class AsyncConnector:
async def connect(
self,
stream: AsyncSocketStream,
host: str,
port: int,
):
raise NotImplementedError

View File

@@ -0,0 +1,40 @@
from typing import Optional
from .._abc import AsyncResolver
from .._types import ProxyType
from .abc import AsyncConnector
from .socks5_async import Socks5AsyncConnector
from .socks4_async import Socks4AsyncConnector
from .http_async import HttpAsyncConnector
def create_connector(
proxy_type: ProxyType,
username: Optional[str],
password: Optional[str],
rdns: Optional[bool],
resolver: AsyncResolver,
) -> AsyncConnector:
if proxy_type == ProxyType.SOCKS4:
return Socks4AsyncConnector(
user_id=username,
rdns=rdns,
resolver=resolver,
)
if proxy_type == ProxyType.SOCKS5:
return Socks5AsyncConnector(
username=username,
password=password,
rdns=rdns,
resolver=resolver,
)
if proxy_type == ProxyType.HTTP:
return HttpAsyncConnector(
username=username,
password=password,
resolver=resolver,
)
raise ValueError(f'Invalid proxy type: {proxy_type}')

View File

@@ -0,0 +1,40 @@
from typing import Optional
from .._abc import SyncResolver
from .._types import ProxyType
from .abc import SyncConnector
from .socks5_sync import Socks5SyncConnector
from .socks4_sync import Socks4SyncConnector
from .http_sync import HttpSyncConnector
def create_connector(
proxy_type: ProxyType,
username: Optional[str],
password: Optional[str],
rdns: Optional[bool],
resolver: SyncResolver,
) -> SyncConnector:
if proxy_type == ProxyType.SOCKS4:
return Socks4SyncConnector(
user_id=username,
rdns=rdns,
resolver=resolver,
)
if proxy_type == ProxyType.SOCKS5:
return Socks5SyncConnector(
username=username,
password=password,
rdns=rdns,
resolver=resolver,
)
if proxy_type == ProxyType.HTTP:
return HttpSyncConnector(
username=username,
password=password,
resolver=resolver,
)
raise ValueError(f'Invalid proxy type: {proxy_type}')

View File

@@ -0,0 +1,38 @@
from typing import Optional
from .._abc import AsyncSocketStream, AsyncResolver
from .abc import AsyncConnector
from .._protocols import http
class HttpAsyncConnector(AsyncConnector):
def __init__(
self,
username: Optional[str],
password: Optional[str],
resolver: AsyncResolver,
):
self._username = username
self._password = password
self._resolver = resolver
async def connect(
self,
stream: AsyncSocketStream,
host: str,
port: int,
) -> http.ConnectReply:
conn = http.Connection()
request = http.ConnectRequest(
host=host,
port=port,
username=self._username,
password=self._password,
)
data = conn.send(request)
await stream.write_all(data)
data = await stream.read()
reply: http.ConnectReply = conn.receive(data)
return reply

View File

@@ -0,0 +1,38 @@
from typing import Optional
from .._abc import SyncSocketStream, SyncResolver
from .abc import SyncConnector
from .._protocols import http
class HttpSyncConnector(SyncConnector):
def __init__(
self,
username: Optional[str],
password: Optional[str],
resolver: SyncResolver,
):
self._username = username
self._password = password
self._resolver = resolver
def connect(
self,
stream: SyncSocketStream,
host: str,
port: int,
) -> http.ConnectReply:
conn = http.Connection()
request = http.ConnectRequest(
host=host,
port=port,
username=self._username,
password=self._password,
)
data = conn.send(request)
stream.write_all(data)
data = stream.read()
reply: http.ConnectReply = conn.receive(data)
return reply

View File

@@ -0,0 +1,45 @@
import socket
from typing import Optional
from .._abc import AsyncSocketStream, AsyncResolver
from .abc import AsyncConnector
from .._protocols import socks4
from .._helpers import is_ip_address
class Socks4AsyncConnector(AsyncConnector):
def __init__(
self,
user_id: Optional[str],
rdns: Optional[bool],
resolver: AsyncResolver,
):
if rdns is None:
rdns = False
self._user_id = user_id
self._rdns = rdns
self._resolver = resolver
async def connect(
self,
stream: AsyncSocketStream,
host: str,
port: int,
) -> socks4.ConnectReply:
conn = socks4.Connection()
if not is_ip_address(host) and not self._rdns:
_, host = await self._resolver.resolve(
host,
family=socket.AF_INET,
)
request = socks4.ConnectRequest(host=host, port=port, user_id=self._user_id)
data = conn.send(request)
await stream.write_all(data)
data = await stream.read_exact(socks4.ConnectReply.SIZE)
reply: socks4.ConnectReply = conn.receive(data)
return reply

View File

@@ -0,0 +1,45 @@
import socket
from typing import Optional
from .._abc import SyncSocketStream, SyncResolver
from .abc import SyncConnector
from .._protocols import socks4
from .._helpers import is_ip_address
class Socks4SyncConnector(SyncConnector):
def __init__(
self,
user_id: Optional[str],
rdns: Optional[bool],
resolver: SyncResolver,
):
if rdns is None:
rdns = False
self._user_id = user_id
self._rdns = rdns
self._resolver = resolver
def connect(
self,
stream: SyncSocketStream,
host: str,
port: int,
) -> socks4.ConnectReply:
conn = socks4.Connection()
if not is_ip_address(host) and not self._rdns:
_, host = self._resolver.resolve(
host,
family=socket.AF_INET,
)
request = socks4.ConnectRequest(host=host, port=port, user_id=self._user_id)
data = conn.send(request)
stream.write_all(data)
data = stream.read_exact(socks4.ConnectReply.SIZE)
reply: socks4.ConnectReply = conn.receive(data)
return reply

View File

@@ -0,0 +1,95 @@
import socket
from typing import Optional
from .._abc import AsyncSocketStream, AsyncResolver
from .abc import AsyncConnector
from .._protocols import socks5
from .._helpers import is_ip_address
class Socks5AsyncConnector(AsyncConnector):
def __init__(
self,
username: Optional[str],
password: Optional[str],
rdns: Optional[bool],
resolver: AsyncResolver,
):
if rdns is None:
rdns = True
self._username = username
self._password = password
self._rdns = rdns
self._resolver = resolver
async def connect(
self,
stream: AsyncSocketStream,
host: str,
port: int,
) -> socks5.ConnectReply:
conn = socks5.Connection()
# Auth methods
request = socks5.AuthMethodsRequest(
username=self._username,
password=self._password,
)
data = conn.send(request)
await stream.write_all(data)
data = await stream.read_exact(socks5.AuthMethodReply.SIZE)
reply: socks5.AuthMethodReply = conn.receive(data)
# Authenticate
if reply.method == socks5.AuthMethod.USERNAME_PASSWORD:
request = socks5.AuthRequest(
username=self._username,
password=self._password,
)
data = conn.send(request)
await stream.write_all(data)
data = await stream.read_exact(socks5.AuthReply.SIZE)
_: socks5.AuthReply = conn.receive(data)
# Connect
if not is_ip_address(host) and not self._rdns:
_, host = await self._resolver.resolve(
host,
family=socket.AF_UNSPEC,
)
request = socks5.ConnectRequest(host=host, port=port)
data = conn.send(request)
await stream.write_all(data)
data = await self._read_reply(stream)
reply: socks5.ConnectReply = conn.receive(data)
return reply
# noinspection PyMethodMayBeStatic
async def _read_reply(self, stream: AsyncSocketStream) -> bytes:
data = await stream.read_exact(3)
if data[0] != socks5.SOCKS_VER:
return data
if data[1] != socks5.ReplyCode.SUCCEEDED:
return data
if data[2] != socks5.RSV:
return data
data += await stream.read_exact(1)
addr_type = data[3]
if addr_type == socks5.AddressType.IPV4:
data += await stream.read_exact(6)
elif addr_type == socks5.AddressType.IPV6:
data += await stream.read_exact(18)
elif addr_type == socks5.AddressType.DOMAIN:
data += await stream.read_exact(1)
host_len = data[-1]
data += await stream.read_exact(host_len + 2)
return data

View File

@@ -0,0 +1,86 @@
import socket
from typing import Optional
from .._abc import SyncSocketStream, SyncResolver
from .abc import SyncConnector
from .._protocols import socks5
from .._helpers import is_ip_address
class Socks5SyncConnector(SyncConnector):
def __init__(
self,
username: Optional[str],
password: Optional[str],
rdns: Optional[bool],
resolver: SyncResolver,
):
if rdns is None:
rdns = True
self._username = username
self._password = password
self._rdns = rdns
self._resolver = resolver
def connect(
self,
stream: SyncSocketStream,
host: str,
port: int,
) -> socks5.ConnectReply:
conn = socks5.Connection()
# Auth methods
request = socks5.AuthMethodsRequest(username=self._username, password=self._password)
data = conn.send(request)
stream.write_all(data)
data = stream.read_exact(socks5.AuthMethodReply.SIZE)
reply: socks5.AuthMethodReply = conn.receive(data)
# Authenticate
if reply.method == socks5.AuthMethod.USERNAME_PASSWORD:
request = socks5.AuthRequest(username=self._username, password=self._password)
data = conn.send(request)
stream.write_all(data)
data = stream.read_exact(socks5.AuthReply.SIZE)
_: socks5.AuthReply = conn.receive(data)
# Connect
if not is_ip_address(host) and not self._rdns:
_, host = self._resolver.resolve(host, family=socket.AF_UNSPEC)
request = socks5.ConnectRequest(host=host, port=port)
data = conn.send(request)
stream.write_all(data)
data = self._read_reply(stream)
reply: socks5.ConnectReply = conn.receive(data)
return reply
# noinspection PyMethodMayBeStatic
def _read_reply(self, stream: SyncSocketStream) -> bytes:
data = stream.read_exact(3)
if data[0] != socks5.SOCKS_VER:
return data
if data[1] != socks5.ReplyCode.SUCCEEDED:
return data
if data[2] != socks5.RSV:
return data
data += stream.read_exact(1)
addr_type = data[3]
if addr_type == socks5.AddressType.IPV4:
data += stream.read_exact(6)
elif addr_type == socks5.AddressType.IPV6:
data += stream.read_exact(18)
elif addr_type == socks5.AddressType.DOMAIN:
data += stream.read_exact(1)
host_len = data[-1]
data += stream.read_exact(host_len + 2)
return data

View File

@@ -0,0 +1,16 @@
class ProxyException(Exception):
pass
class ProxyTimeoutError(ProxyException, TimeoutError):
pass
class ProxyConnectionError(ProxyException, OSError):
pass
class ProxyError(ProxyException):
def __init__(self, message, error_code=None):
super().__init__(message)
self.error_code = error_code

View File

@@ -0,0 +1,81 @@
import functools
import re
from typing import Optional, Tuple
from urllib.parse import urlparse, unquote
from ._types import ProxyType
# pylint:disable-next=invalid-name
_ipv4_pattern = (
r'^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}'
r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$'
)
# pylint:disable-next=invalid-name
_ipv6_pattern = (
r'^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}'
r'(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)'
r'((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})'
r'(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}'
r'(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}'
r'[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)'
r'(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}'
r':|:(:[A-F0-9]{1,4}){7})$'
)
_ipv4_regex = re.compile(_ipv4_pattern)
_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
_ipv4_regexb = re.compile(_ipv4_pattern.encode('ascii'))
_ipv6_regexb = re.compile(_ipv6_pattern.encode('ascii'), flags=re.IGNORECASE)
def _is_ip_address(regex, regexb, host):
# if host is None:
# return False
if isinstance(host, str):
return bool(regex.match(host))
elif isinstance(host, (bytes, bytearray, memoryview)):
return bool(regexb.match(host))
else:
raise TypeError(
'{} [{}] is not a str or bytes'.format(host, type(host)) # pragma: no cover
)
is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)
def is_ip_address(host):
return is_ipv4_address(host) or is_ipv6_address(host)
def parse_proxy_url(url: str) -> Tuple[ProxyType, str, int, Optional[str], Optional[str]]:
parsed = urlparse(url)
scheme = parsed.scheme
if scheme == 'socks5':
proxy_type = ProxyType.SOCKS5
elif scheme == 'socks4':
proxy_type = ProxyType.SOCKS4
elif scheme == 'http':
proxy_type = ProxyType.HTTP
else:
raise ValueError(f'Invalid scheme component: {scheme}') # pragma: no cover
host = parsed.hostname
if not host:
raise ValueError('Empty host component') # pragma: no cover
try:
port = parsed.port
assert port is not None
except (ValueError, TypeError, AssertionError) as e: # pragma: no cover
raise ValueError('Invalid port component') from e
try:
username, password = (unquote(parsed.username), unquote(parsed.password))
except (AttributeError, TypeError):
username, password = '', ''
return proxy_type, host, port, username, password

View File

@@ -0,0 +1,4 @@
class ReplyError(Exception):
def __init__(self, message, error_code=None):
super().__init__(message)
self.error_code = error_code

View File

@@ -0,0 +1,148 @@
import sys
from dataclasses import dataclass
import base64
import binascii
from collections import namedtuple
from typing import Optional
from .._version import __title__, __version__
from .errors import ReplyError
DEFAULT_USER_AGENT = 'Python/{0[0]}.{0[1]} {1}/{2}'.format(
sys.version_info,
__title__,
__version__,
)
CRLF = '\r\n'
class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])):
"""Http basic authentication helper."""
def __new__(cls, login: str, password: str = '', encoding: str = 'latin1') -> 'BasicAuth':
if login is None:
raise ValueError('None is not allowed as login value')
if password is None:
raise ValueError('None is not allowed as password value')
if ':' in login:
raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
# noinspection PyTypeChecker,PyArgumentList
return super().__new__(cls, login, password, encoding)
@classmethod
def decode(cls, auth_header: str, encoding: str = 'latin1') -> 'BasicAuth':
"""Create a BasicAuth object from an Authorization HTTP header."""
try:
auth_type, encoded_credentials = auth_header.split(' ', 1)
except ValueError:
raise ValueError('Could not parse authorization header.')
if auth_type.lower() != 'basic':
raise ValueError('Unknown authorization method %s' % auth_type)
try:
decoded = base64.b64decode(encoded_credentials.encode('ascii'), validate=True).decode(
encoding
)
except binascii.Error:
raise ValueError('Invalid base64 encoding.')
try:
# RFC 2617 HTTP Authentication
# https://www.ietf.org/rfc/rfc2617.txt
# the colon must be present, but the username and password may be
# otherwise blank.
username, password = decoded.split(':', 1)
except ValueError:
raise ValueError('Invalid credentials.')
# noinspection PyTypeChecker
return cls(username, password, encoding=encoding)
def encode(self) -> str:
"""Encode credentials."""
creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding)
return 'Basic %s' % base64.b64encode(creds).decode(self.encoding)
class _Buffer:
def __init__(self, encoding: str = 'utf-8'):
self._encoding = encoding
self._buffer = bytearray()
def append_line(self, line: str = ""):
if line:
self._buffer.extend(line.encode(self._encoding))
self._buffer.extend(CRLF.encode('ascii'))
def dumps(self) -> bytes:
return bytes(self._buffer)
@dataclass
class ConnectRequest:
host: str
port: int
username: Optional[str]
password: Optional[str]
def dumps(self) -> bytes:
buff = _Buffer()
buff.append_line(f'CONNECT {self.host}:{self.port} HTTP/1.1')
buff.append_line(f'Host: {self.host}:{self.port}')
buff.append_line(f'User-Agent: {DEFAULT_USER_AGENT}')
if self.username and self.password:
auth = BasicAuth(self.username, self.password)
buff.append_line(f'Proxy-Authorization: {auth.encode()}')
buff.append_line()
return buff.dumps()
@dataclass
class ConnectReply:
status_code: int
message: str
@classmethod
def loads(cls, data: bytes) -> 'ConnectReply':
if not data:
raise ReplyError('Invalid proxy response') # pragma: no cover
line = data.split(CRLF.encode('ascii'), 1)[0]
line = line.decode('utf-8', 'surrogateescape')
try:
version, code, *reason = line.split()
except ValueError: # pragma: no cover
raise ReplyError(f'Invalid status line: {line}')
try:
status_code = int(code)
except ValueError: # pragma: no cover
raise ReplyError(f'Invalid status code: {code}')
status_message = " ".join(reason)
if status_code != 200:
msg = f'{status_code} {status_message}'
raise ReplyError(msg, error_code=status_code)
return cls(status_code=status_code, message=status_message)
# noinspection PyMethodMayBeStatic
class Connection:
def send(self, request: ConnectRequest) -> bytes:
return request.dumps()
def receive(self, data: bytes) -> ConnectReply:
return ConnectReply.loads(data)

View File

@@ -0,0 +1,116 @@
import enum
import ipaddress
import socket
from dataclasses import dataclass
from typing import Optional
from .errors import ReplyError
from .._helpers import is_ipv4_address
RSV = NULL = 0x00
SOCKS_VER = 0x04
class Command(enum.IntEnum):
CONNECT = 0x01
BIND = 0x02
class ReplyCode(enum.IntEnum):
REQUEST_GRANTED = 0x5A
REQUEST_REJECTED_OR_FAILED = 0x5B
CONNECTION_FAILED = 0x5C
AUTHENTICATION_FAILED = 0x5D
ReplyMessages = {
ReplyCode.REQUEST_GRANTED: 'Request granted',
ReplyCode.REQUEST_REJECTED_OR_FAILED: 'Request rejected or failed',
ReplyCode.CONNECTION_FAILED: (
'Request rejected because SOCKS server cannot connect to identd on the client'
),
ReplyCode.AUTHENTICATION_FAILED: (
'Request rejected because the client program and identd report different user-ids'
),
}
@dataclass
class ConnectRequest:
host: str # hostname or IPv4 address
port: int
user_id: Optional[str]
def dumps(self):
port_bytes = self.port.to_bytes(2, 'big')
include_hostname = False
if is_ipv4_address(self.host):
host_bytes = ipaddress.IPv4Address(self.host).packed
else:
include_hostname = True
host_bytes = bytes([NULL, NULL, NULL, 1])
data = bytearray([SOCKS_VER, Command.CONNECT])
data += port_bytes
data += host_bytes
if self.user_id:
data += self.user_id.encode('ascii')
data.append(NULL)
if include_hostname:
data += self.host.encode('idna')
data.append(NULL)
return bytes(data)
@dataclass
class ConnectReply:
SIZE = 8
rsv: int
reply: ReplyCode
host: str # should be ignored when using Command.CONNECT
port: int # should be ignored when using Command.CONNECT
@classmethod
def loads(cls, data: bytes) -> 'ConnectReply':
if len(data) != cls.SIZE:
raise ReplyError('Malformed connect reply')
rsv = data[0]
if rsv != RSV: # pragma: no cover
raise ReplyError(f'Unexpected reply version: {data[0]:#02X}')
try:
reply = ReplyCode(data[1])
except ValueError:
raise ReplyError(f'Invalid reply code: {data[1]:#02X}')
if reply != ReplyCode.REQUEST_GRANTED: # pragma: no cover
msg = ReplyMessages.get(reply, 'Unknown error')
raise ReplyError(msg, error_code=reply)
try:
port = int.from_bytes(data[2:4], byteorder="big")
except ValueError:
raise ReplyError('Invalid port data')
try:
host = socket.inet_ntop(socket.AF_INET, data[4:8])
except ValueError:
raise ReplyError('Invalid port data')
return cls(rsv=rsv, reply=reply, host=host, port=port)
# noinspection PyMethodMayBeStatic
class Connection:
def send(self, request: ConnectRequest) -> bytes:
return request.dumps()
def receive(self, data: bytes) -> ConnectReply:
return ConnectReply.loads(data)

View File

@@ -0,0 +1,355 @@
import enum
import ipaddress
import socket
from typing import Optional, Union
from dataclasses import dataclass, field
from .errors import ReplyError
from .._helpers import is_ip_address
RSV = NULL = AUTH_GRANTED = 0x00
SOCKS_VER = 0x05
class AuthMethod(enum.IntEnum):
ANONYMOUS = 0x00
GSSAPI = 0x01
USERNAME_PASSWORD = 0x02
NO_ACCEPTABLE = 0xFF
class AddressType(enum.IntEnum):
IPV4 = 0x01
DOMAIN = 0x03
IPV6 = 0x04
@classmethod
def from_ip_ver(cls, ver: int):
if ver == 4:
return cls.IPV4
if ver == 6:
return cls.IPV6
raise ValueError('Invalid IP version')
class Command(enum.IntEnum):
CONNECT = 0x01
BIND = 0x02
UDP_ASSOCIATE = 0x03
class ReplyCode(enum.IntEnum):
SUCCEEDED = 0x00
GENERAL_FAILURE = 0x01
CONNECTION_NOT_ALLOWED = 0x02
NETWORK_UNREACHABLE = 0x03
HOST_UNREACHABLE = 0x04
CONNECTION_REFUSED = 0x05
TTL_EXPIRED = 0x06
COMMAND_NOT_SUPPORTED = 0x07
ADDRESS_TYPE_NOT_SUPPORTED = 0x08
ReplyMessages = {
ReplyCode.SUCCEEDED: 'Request granted',
ReplyCode.GENERAL_FAILURE: 'General SOCKS server failure',
ReplyCode.CONNECTION_NOT_ALLOWED: 'Connection not allowed by ruleset',
ReplyCode.NETWORK_UNREACHABLE: 'Network unreachable',
ReplyCode.HOST_UNREACHABLE: 'Host unreachable',
ReplyCode.CONNECTION_REFUSED: 'Connection refused by destination host',
ReplyCode.TTL_EXPIRED: 'TTL expired',
ReplyCode.COMMAND_NOT_SUPPORTED: 'Command not supported or protocol error',
ReplyCode.ADDRESS_TYPE_NOT_SUPPORTED: 'Address type not supported',
}
@dataclass
class AuthMethodsRequest:
username: Optional[str]
password: Optional[str]
methods: bytearray = field(init=False)
def __post_init__(self):
methods = bytearray([AuthMethod.ANONYMOUS])
if self.username and self.password:
methods.append(AuthMethod.USERNAME_PASSWORD)
self.methods = methods
def dumps(self) -> bytes:
return bytes([SOCKS_VER, len(self.methods)]) + self.methods
@dataclass
class AuthMethodReply:
SIZE = 2
ver: int
method: AuthMethod
def validate(self, request: AuthMethodsRequest):
if self.method not in request.methods: # pragma: no cover
raise ReplyError(f'Unexpected SOCKS authentication method: {self.method}')
@classmethod
def loads(cls, data: bytes) -> 'AuthMethodReply':
if len(data) != cls.SIZE:
raise ReplyError('Malformed authentication method reply')
ver = data[0]
if ver != SOCKS_VER: # pragma: no cover
raise ReplyError(f'Unexpected SOCKS version number: {ver}')
try:
method = AuthMethod(data[1])
except ValueError:
raise ReplyError(f'Invalid authentication method: {data[1]:#02X}')
if method == AuthMethod.NO_ACCEPTABLE: # pragma: no cover
raise ReplyError('No acceptable authentication methods were offered')
return cls(ver=ver, method=method)
@dataclass
class AuthRequest:
VER = 0x01
username: str
password: str
def dumps(self) -> bytes:
data = bytearray()
data.append(self.VER)
data.append(len(self.username))
data += self.username.encode('ascii')
data.append(len(self.password))
data += self.password.encode('ascii')
return bytes(data)
@dataclass
class AuthReply:
SIZE = 2
ver: int
status: int
@classmethod
def loads(cls, data: bytes) -> 'AuthReply':
if len(data) != cls.SIZE:
raise ReplyError('Malformed auth reply')
ver = data[0]
if ver != AuthRequest.VER: # pragma: no cover
raise ReplyError('Invalid authentication response')
status = data[1]
if status != AUTH_GRANTED: # pragma: no cover
raise ReplyError('Username and password authentication failure')
return cls(ver=ver, status=status)
@dataclass
class ConnectRequest:
host: str # hostname or IPv4 or IPv6 address
port: int
def dumps(self) -> bytes:
data = bytearray([SOCKS_VER, Command.CONNECT, RSV])
data += self._build_addr_request()
return bytes(data)
def _build_addr_request(self) -> bytes:
port = self.port.to_bytes(2, 'big')
if is_ip_address(self.host):
ip = ipaddress.ip_address(self.host)
address_type = AddressType.from_ip_ver(ip.version)
return bytes([address_type]) + ip.packed + port
else:
address_type = AddressType.DOMAIN
host = self.host.encode('idna')
return bytes([address_type, len(host)]) + host + port
@dataclass
class ConnectReply:
ver: int
reply: ReplyCode
rsv: int
bound_host: str
bound_port: int
def validate(self):
pass
@classmethod
def loads(cls, data: bytes) -> 'ConnectReply':
if not data:
raise ReplyError('Empty connect reply')
ver = data[0]
if ver != SOCKS_VER: # pragma: no cover
raise ReplyError(f'Unexpected SOCKS version number: {ver:#02X}')
try:
reply = ReplyCode(data[1])
except IndexError:
raise ReplyError('Malformed connect reply')
except ValueError:
raise ReplyError(f'Invalid reply code: {data[1]:#02X}')
if reply != ReplyCode.SUCCEEDED: # pragma: no cover
msg = ReplyMessages.get(reply, 'Unknown error') # type: ignore
raise ReplyError(msg, error_code=reply)
try:
rsv = data[2]
except IndexError:
raise ReplyError('Malformed connect reply')
if rsv != RSV: # pragma: no cover
raise ReplyError(f'The reserved byte must be {RSV:#02X}')
try:
addr_type = data[3]
bnd_host_data = data[4:-2]
bnd_port_data = data[-2:]
except IndexError:
raise ReplyError('Malformed connect reply')
if addr_type == AddressType.IPV4:
bnd_host = socket.inet_ntop(socket.AF_INET, bnd_host_data)
elif addr_type == AddressType.IPV6:
bnd_host = socket.inet_ntop(socket.AF_INET6, bnd_host_data)
elif addr_type == AddressType.DOMAIN: # pragma: no cover
# host_len = bnd_host_data[0]
bnd_host = bnd_host_data[1:].decode()
else: # pragma: no cover
raise ReplyError(f'Invalid address type: {addr_type:#02X}')
bnd_port = int.from_bytes(bnd_port_data, 'big')
return cls(
ver=ver,
reply=reply,
rsv=rsv,
bound_host=bnd_host,
bound_port=bnd_port,
)
class StateServerWaitingForAuthMethods:
pass
@dataclass
class StateClientSentAuthMethods:
data: AuthMethodsRequest
@dataclass
class StateServerWaitingForAuth:
data: AuthMethodReply
@dataclass
class StateClientAuthenticated:
data: Optional[AuthReply] = None
@dataclass
class StateClientSentAuthRequest:
data: AuthRequest
@dataclass
class StateClientSentConnectRequest:
data: ConnectRequest
@dataclass
class StateServerConnected:
data: ConnectReply
Request = Union[
AuthMethodsRequest,
AuthRequest,
ConnectRequest,
]
Reply = Union[
AuthMethodReply,
AuthReply,
ConnectReply,
]
ConnectionState = Union[
StateServerWaitingForAuthMethods,
StateClientSentAuthMethods,
StateServerWaitingForAuth,
StateClientSentAuthRequest,
StateClientAuthenticated,
StateClientSentConnectRequest,
StateServerConnected,
]
class Connection:
_state: ConnectionState
def __init__(self):
self._state = StateServerWaitingForAuthMethods()
def send(self, request: Request) -> bytes:
if type(request) is AuthMethodsRequest:
if type(self._state) is not StateServerWaitingForAuthMethods:
raise RuntimeError('Server is not currently waiting for auth methods')
self._state = StateClientSentAuthMethods(request)
return request.dumps()
if type(request) is AuthRequest:
if type(self._state) is not StateServerWaitingForAuth:
raise RuntimeError('Server is not currently waiting for authentication')
self._state = StateClientSentAuthRequest(request)
return request.dumps()
if type(request) is ConnectRequest:
if type(self._state) is not StateClientAuthenticated:
raise RuntimeError('Client is not authenticated')
self._state = StateClientSentConnectRequest(request)
return request.dumps()
raise RuntimeError(f'Invalid request type: {type(request)}')
def receive(self, data: bytes) -> Reply:
if type(self._state) is StateClientSentAuthMethods:
reply = AuthMethodReply.loads(data)
reply.validate(self._state.data)
if reply.method == AuthMethod.USERNAME_PASSWORD:
self._state = StateServerWaitingForAuth(data=reply)
else:
self._state = StateClientAuthenticated()
return reply
if type(self._state) is StateClientSentAuthRequest:
reply = AuthReply.loads(data)
self._state = StateClientAuthenticated(data=reply)
return reply
if type(self._state) is StateClientSentConnectRequest:
reply = ConnectReply.loads(data)
self._state = StateServerConnected(data=reply)
return reply
raise RuntimeError(f'Invalid connection state: {self._state}')
@property
def state(self):
return self._state

View File

@@ -0,0 +1,7 @@
from enum import Enum
class ProxyType(Enum):
SOCKS4 = 1
SOCKS5 = 2
HTTP = 3

View File

@@ -0,0 +1,2 @@
__title__ = 'python-socks'
__version__ = '2.8.1'

View File

@@ -0,0 +1,3 @@
from ._proxy_chain import ProxyChain
__all__ = ('ProxyChain',)

View File

@@ -0,0 +1,34 @@
from typing import Iterable
import warnings
class ProxyChain:
def __init__(self, proxies: Iterable):
warnings.warn(
'This implementation of ProxyChain is deprecated and will be removed in the future',
DeprecationWarning,
stacklevel=2,
)
self._proxies = proxies
async def connect(self, dest_host, dest_port, timeout=None):
curr_socket = None
proxies = list(self._proxies)
length = len(proxies) - 1
for i in range(length):
curr_socket = await proxies[i].connect(
dest_host=proxies[i + 1].proxy_host,
dest_port=proxies[i + 1].proxy_port,
timeout=timeout,
_socket=curr_socket,
)
curr_socket = await proxies[length].connect(
dest_host=dest_host,
dest_port=dest_port,
timeout=timeout,
_socket=curr_socket,
)
return curr_socket

View File

@@ -0,0 +1,4 @@
from ._proxy import AnyioProxy as Proxy
from ._chain import ProxyChain
__all__ = ('Proxy', 'ProxyChain')

View File

@@ -0,0 +1,42 @@
from typing import Iterable
import warnings
from ._proxy import AnyioProxy
class ProxyChain:
def __init__(self, proxies: Iterable[AnyioProxy]):
warnings.warn(
'This implementation of ProxyChain is deprecated and will be removed in the future',
DeprecationWarning,
stacklevel=2,
)
self._proxies = proxies
async def connect(
self,
dest_host,
dest_port,
dest_ssl=None,
timeout=None,
):
_stream = None
proxies = list(self._proxies)
length = len(proxies) - 1
for i in range(length):
_stream = await proxies[i].connect(
dest_host=proxies[i + 1].proxy_host,
dest_port=proxies[i + 1].proxy_port,
timeout=timeout,
_stream=_stream,
)
_stream = await proxies[length].connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
timeout=timeout,
_stream=_stream,
)
return _stream

View File

@@ -0,0 +1,16 @@
from typing import Optional
import anyio
import anyio.abc
async def connect_tcp(
host: str,
port: int,
local_host: Optional[str] = None,
) -> anyio.abc.SocketStream:
return await anyio.connect_tcp(
remote_host=host,
remote_port=port,
local_host=local_host,
)

View File

@@ -0,0 +1,137 @@
import ssl
from typing import Any, Optional
import warnings
import anyio
from ..._types import ProxyType
from ..._helpers import parse_proxy_url
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ._resolver import Resolver
from ._stream import AnyioSocketStream
from ._connect import connect_tcp
from ..._protocols.errors import ReplyError
from ..._connectors.factory_async import create_connector
DEFAULT_TIMEOUT = 60
class AnyioProxy:
_stream: Optional[AnyioSocketStream]
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
proxy_ssl: Optional[ssl.SSLContext] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._password = password
self._username = username
self._rdns = rdns
self._proxy_ssl = proxy_ssl
self._resolver = Resolver()
async def connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> AnyioSocketStream:
if timeout is None:
timeout = DEFAULT_TIMEOUT
_stream = kwargs.get('_stream')
if _stream is not None:
warnings.warn(
"The '_stream' argument is deprecated and will be removed in the future",
DeprecationWarning,
stacklevel=2,
)
local_host = kwargs.get('local_host')
try:
with anyio.fail_after(timeout):
if _stream is None:
try:
_stream = AnyioSocketStream(
await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
local_host=local_host,
)
)
except OSError as e:
msg = 'Could not connect to proxy {}:{} [{}]'.format(
self._proxy_host,
self._proxy_port,
e.strerror,
)
raise ProxyConnectionError(e.errno, msg) from e
stream = _stream
try:
if self._proxy_ssl is not None:
stream = await stream.start_tls(
hostname=self._proxy_host,
ssl_context=self._proxy_ssl,
)
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
if dest_ssl is not None:
stream = await stream.start_tls(
hostname=dest_host,
ssl_context=dest_ssl,
)
return stream
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except BaseException:
await stream.close()
raise
except TimeoutError as e:
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
@property
def proxy_host(self):
return self._proxy_host
@property
def proxy_port(self):
return self._proxy_port
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'AnyioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,22 @@
import anyio
import socket
from ... import _abc as abc
class Resolver(abc.AsyncResolver):
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
infos = await anyio.getaddrinfo(
host=host,
port=port,
family=family,
type=socket.SOCK_STREAM,
)
if not infos: # pragma: no cover
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
infos = sorted(infos, key=lambda info: info[0])
family, _, _, _, address = infos[0]
return family, address[0]

View File

@@ -0,0 +1,59 @@
import ssl
from typing import Union
import anyio
import anyio.abc
from anyio.streams.tls import TLSStream
from ..._errors import ProxyError
from ... import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
AnyioStreamType = Union[anyio.abc.SocketStream, TLSStream]
class AnyioSocketStream(abc.AsyncSocketStream):
_stream: AnyioStreamType
def __init__(self, stream: AnyioStreamType) -> None:
self._stream = stream
async def write_all(self, data: bytes):
await self._stream.send(item=data)
async def read(self, max_bytes: int = DEFAULT_RECEIVE_SIZE):
try:
return await self._stream.receive(max_bytes=max_bytes)
except anyio.EndOfStream: # pragma: no cover
return b""
async def read_exact(self, n: int):
data = bytearray()
while len(data) < n:
packet = await self.read(n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
async def start_tls(
self,
hostname: str,
ssl_context: ssl.SSLContext,
) -> 'AnyioSocketStream':
ssl_stream = await TLSStream.wrap(
self._stream,
ssl_context=ssl_context,
hostname=hostname,
standard_compatible=False,
server_side=False,
)
return AnyioSocketStream(ssl_stream)
async def close(self):
await self._stream.aclose()
@property
def anyio_stream(self) -> AnyioStreamType: # pragma: no cover
return self._stream

View File

@@ -0,0 +1,7 @@
from ._proxy import AnyioProxy as Proxy
from ._chain import ProxyChain
__all__ = (
'Proxy',
'ProxyChain',
)

View File

@@ -0,0 +1,32 @@
from typing import Sequence
import warnings
from ._proxy import AnyioProxy
class ProxyChain:
def __init__(self, proxies: Sequence[AnyioProxy]):
warnings.warn(
'This implementation of ProxyChain is deprecated and will be removed in the future',
DeprecationWarning,
stacklevel=2,
)
self._proxies = proxies
async def connect(
self,
dest_host,
dest_port,
dest_ssl=None,
timeout=None,
):
forward = None
for proxy in self._proxies:
proxy._forward = forward
forward = proxy
return await forward.connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
timeout=timeout,
)

View File

@@ -0,0 +1,17 @@
from typing import Optional
import anyio
import anyio.abc
from ._stream import AnyioSocketStream
async def connect_tcp(
host: str,
port: int,
local_host: Optional[str] = None,
) -> AnyioSocketStream:
s = await anyio.connect_tcp(
remote_host=host,
remote_port=port,
local_host=local_host,
)
return AnyioSocketStream(s)

View File

@@ -0,0 +1,135 @@
import ssl
from typing import Any, Optional
import anyio
from ._connect import connect_tcp
from ._stream import AnyioSocketStream
from .._resolver import Resolver
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ...._types import ProxyType
from ...._helpers import parse_proxy_url
from ...._protocols.errors import ReplyError
from ...._connectors.factory_async import create_connector
DEFAULT_TIMEOUT = 60
class AnyioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
proxy_ssl: Optional[ssl.SSLContext] = None,
forward: Optional['AnyioProxy'] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._username = username
self._password = password
self._rdns = rdns
self._proxy_ssl = proxy_ssl
self._forward = forward
self._resolver = Resolver()
async def connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> AnyioSocketStream:
if timeout is None:
timeout = DEFAULT_TIMEOUT
local_host = kwargs.get('local_host')
try:
with anyio.fail_after(timeout):
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
local_host=local_host,
)
except TimeoutError as e:
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
async def _connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
local_host: Optional[str] = None,
) -> AnyioSocketStream:
if self._forward is None:
try:
stream = await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
local_host=local_host,
)
except OSError as e:
raise ProxyConnectionError(
e.errno,
"Couldn't connect to proxy"
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
) from e
else:
stream = await self._forward.connect(
dest_host=self._proxy_host,
dest_port=self._proxy_port,
)
try:
if self._proxy_ssl is not None:
stream = await stream.start_tls(
hostname=self._proxy_host,
ssl_context=self._proxy_ssl,
)
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
if dest_ssl is not None:
stream = await stream.start_tls(
hostname=dest_host,
ssl_context=dest_ssl,
)
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except BaseException:
with anyio.CancelScope(shield=True):
await stream.close()
raise
return stream
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'AnyioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,59 @@
import ssl
from typing import Union
import anyio
import anyio.abc
from anyio.streams.tls import TLSStream
from ...._errors import ProxyError
from .... import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
AnyioStreamType = Union[anyio.abc.SocketStream, TLSStream]
class AnyioSocketStream(abc.AsyncSocketStream):
_stream: AnyioStreamType
def __init__(self, stream: AnyioStreamType) -> None:
self._stream = stream
async def write_all(self, data: bytes):
await self._stream.send(item=data)
async def read(self, max_bytes: int = DEFAULT_RECEIVE_SIZE):
try:
return await self._stream.receive(max_bytes=max_bytes)
except anyio.EndOfStream: # pragma: no cover
return b""
async def read_exact(self, n: int):
data = bytearray()
while len(data) < n:
packet = await self.read(n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
async def start_tls(
self,
hostname: str,
ssl_context: ssl.SSLContext,
) -> 'AnyioSocketStream':
ssl_stream = await TLSStream.wrap(
self._stream,
ssl_context=ssl_context,
hostname=hostname,
standard_compatible=False,
server_side=False,
)
return AnyioSocketStream(ssl_stream)
async def close(self):
await self._stream.aclose()
@property
def anyio_stream(self) -> AnyioStreamType: # pragma: no cover
return self._stream

View File

@@ -0,0 +1,4 @@
from ._proxy import AsyncioProxy as Proxy
__all__ = ('Proxy',)

View File

@@ -0,0 +1,43 @@
import socket
import asyncio
from typing import Optional, Tuple
from ._resolver import Resolver
from ..._helpers import is_ipv4_address, is_ipv6_address
async def connect_tcp(
host: str,
port: int,
loop: asyncio.AbstractEventLoop,
local_addr: Optional[Tuple[str, int]] = None,
) -> socket.socket:
family, host = await _resolve_host(host, loop)
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setblocking(False)
if local_addr is not None: # pragma: no cover
sock.bind(local_addr)
if is_ipv6_address(host):
address = (host, port, 0, 0) # to fix OSError: [WinError 10022]
else:
address = (host, port) # type: ignore[assignment]
try:
await loop.sock_connect(sock=sock, address=address)
except OSError:
sock.close()
raise
return sock
async def _resolve_host(host, loop):
if is_ipv4_address(host):
return socket.AF_INET, host
if is_ipv6_address(host):
return socket.AF_INET6, host
resolver = Resolver(loop=loop)
return await resolver.resolve(host=host)

View File

@@ -0,0 +1,143 @@
import asyncio
import socket
import sys
from typing import Any, Optional
import warnings
from ..._types import ProxyType
from ..._helpers import parse_proxy_url
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ._stream import AsyncioSocketStream
from ._resolver import Resolver
from ..._protocols.errors import ReplyError
from ..._connectors.factory_async import create_connector
from ._connect import connect_tcp
if sys.version_info >= (3, 11):
import asyncio as async_timeout # pylint:disable=reimported
else:
import async_timeout
DEFAULT_TIMEOUT = 60
class AsyncioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._password = password
self._username = username
self._rdns = rdns
self._resolver = Resolver(loop=loop)
async def connect(
self,
dest_host: str,
dest_port: int,
timeout: Optional[float] = None,
**kwargs: Any,
) -> socket.socket:
if timeout is None:
timeout = DEFAULT_TIMEOUT
_socket = kwargs.get('_socket')
if _socket is not None:
warnings.warn(
"The '_socket' argument is deprecated and will be removed in the future",
DeprecationWarning,
stacklevel=2,
)
local_addr = kwargs.get('local_addr')
try:
async with async_timeout.timeout(timeout):
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
_socket=_socket,
local_addr=local_addr,
)
except asyncio.TimeoutError as e:
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
async def _connect(
self,
dest_host,
dest_port,
_socket=None,
local_addr=None,
) -> socket.socket:
if _socket is None:
try:
_socket = await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
loop=self._loop,
local_addr=local_addr,
)
except OSError as e:
msg = 'Could not connect to proxy {}:{} [{}]'.format(
self._proxy_host,
self._proxy_port,
e.strerror,
)
raise ProxyConnectionError(e.errno, msg) from e
stream = AsyncioSocketStream(sock=_socket, loop=self._loop)
try:
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
return _socket
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except (asyncio.CancelledError, Exception): # pragma: no cover
await stream.close()
raise
@property
def proxy_host(self):
return self._proxy_host
@property
def proxy_port(self):
return self._proxy_port
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'AsyncioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,25 @@
import asyncio
import socket
from ... import _abc as abc
class Resolver(abc.AsyncResolver):
def __init__(self, loop: asyncio.AbstractEventLoop):
self._loop = loop
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
infos = await self._loop.getaddrinfo(
host=host,
port=port,
family=family,
type=socket.SOCK_STREAM,
)
if not infos: # pragma: no cover
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
infos = sorted(infos, key=lambda info: info[0])
family, _, _, _, address = infos[0]
return family, address[0]

View File

@@ -0,0 +1,36 @@
import asyncio
import socket
from ..._errors import ProxyError
from ... import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
class AsyncioSocketStream(abc.AsyncSocketStream):
_loop: asyncio.AbstractEventLoop = None
_socket = None
def __init__(self, sock: socket.socket, loop: asyncio.AbstractEventLoop):
self._loop = loop
self._socket = sock
async def write_all(self, data):
await self._loop.sock_sendall(self._socket, data)
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
return await self._loop.sock_recv(self._socket, max_bytes)
async def read_exact(self, n):
data = bytearray()
while len(data) < n:
packet = await self._loop.sock_recv(self._socket, n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
async def close(self):
if self._socket is not None:
self._socket.close()

View File

@@ -0,0 +1,4 @@
from ._proxy import AsyncioProxy as Proxy
from ._chain import ProxyChain
__all__ = ('Proxy', 'ProxyChain')

View File

@@ -0,0 +1,32 @@
from typing import Sequence
import warnings
from ._proxy import AsyncioProxy
class ProxyChain:
def __init__(self, proxies: Sequence[AsyncioProxy]):
warnings.warn(
'This implementation of ProxyChain is deprecated and will be removed in the future',
DeprecationWarning,
stacklevel=2,
)
self._proxies = proxies
async def connect(
self,
dest_host: str,
dest_port: int,
dest_ssl=None,
timeout=None,
):
forward = None
for proxy in self._proxies:
proxy._forward = forward
forward = proxy
return await forward.connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
timeout=timeout,
)

View File

@@ -0,0 +1,26 @@
import asyncio
from typing import Optional, Tuple
from ._stream import AsyncioSocketStream
async def connect_tcp(
host: str,
port: int,
loop: asyncio.AbstractEventLoop,
local_addr: Optional[Tuple[str, int]] = None,
) -> AsyncioSocketStream:
kwargs = {}
if local_addr is not None:
kwargs['local_addr'] = local_addr # pragma: no cover
reader, writer = await asyncio.open_connection(
host=host,
port=port,
**kwargs, # type: ignore
)
return AsyncioSocketStream(
loop=loop,
reader=reader,
writer=writer,
)

View File

@@ -0,0 +1,157 @@
import asyncio
import ssl
from typing import Any, Optional, Tuple
import warnings
import sys
from ...._types import ProxyType
from ...._helpers import parse_proxy_url
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ...._protocols.errors import ReplyError
from ...._connectors.factory_async import create_connector
from .._resolver import Resolver
from ._stream import AsyncioSocketStream
from ._connect import connect_tcp
if sys.version_info >= (3, 11):
import asyncio as async_timeout # pylint:disable=reimported
else:
import async_timeout
DEFAULT_TIMEOUT = 60
class AsyncioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
proxy_ssl: Optional[ssl.SSLContext] = None,
forward: Optional['AsyncioProxy'] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
if loop is not None: # pragma: no cover
warnings.warn(
'The loop argument is deprecated and scheduled for removal in the future.',
DeprecationWarning,
stacklevel=2,
)
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._username = username
self._password = password
self._rdns = rdns
self._proxy_ssl = proxy_ssl
self._forward = forward
self._resolver = Resolver(loop=loop)
async def connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> AsyncioSocketStream:
if timeout is None:
timeout = DEFAULT_TIMEOUT
local_addr = kwargs.get('local_addr')
try:
async with async_timeout.timeout(timeout):
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
local_addr=local_addr,
)
except asyncio.TimeoutError as e:
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
async def _connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
local_addr: Optional[Tuple[str, int]] = None,
) -> AsyncioSocketStream:
if self._forward is None:
try:
stream = await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
loop=self._loop,
local_addr=local_addr,
)
except OSError as e:
raise ProxyConnectionError(
e.errno,
"Couldn't connect to proxy"
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
) from e
else:
stream = await self._forward.connect(
dest_host=self._proxy_host,
dest_port=self._proxy_port,
)
try:
if self._proxy_ssl is not None:
stream = await stream.start_tls(
hostname=self._proxy_host,
ssl_context=self._proxy_ssl,
)
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
if dest_ssl is not None:
stream = await stream.start_tls(
hostname=dest_host,
ssl_context=dest_ssl,
)
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except (asyncio.CancelledError, Exception):
await stream.close()
raise
return stream
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'AsyncioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,91 @@
import asyncio
import ssl
from .... import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
class AsyncioSocketStream(abc.AsyncSocketStream):
_loop: asyncio.AbstractEventLoop
_reader: asyncio.StreamReader
_writer: asyncio.StreamWriter
def __init__(
self,
loop: asyncio.AbstractEventLoop,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
):
self._loop = loop
self._reader = reader
self._writer = writer
async def write_all(self, data):
self._writer.write(data)
await self._writer.drain()
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
return await self._reader.read(max_bytes)
async def read_exact(self, n):
return await self._reader.readexactly(n)
async def start_tls(
self,
hostname: str,
ssl_context: ssl.SSLContext,
ssl_handshake_timeout=None,
) -> 'AsyncioSocketStream':
if hasattr(self._writer, 'start_tls'): # Python>=3.11
await self._writer.start_tls(
ssl_context,
server_hostname=hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
)
return self
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
transport: asyncio.Transport = await self._loop.start_tls(
self._writer.transport, # type: ignore
protocol,
ssl_context,
server_side=False,
server_hostname=hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
)
# reader.set_transport(transport)
# Initialize the protocol, so it is made aware of being tied to
# a TLS connection.
# See: https://github.com/encode/httpx/issues/859
protocol.connection_made(transport)
writer = asyncio.StreamWriter(
transport=transport,
protocol=protocol,
reader=reader,
loop=self._loop,
)
stream = AsyncioSocketStream(loop=self._loop, reader=reader, writer=writer)
# When we return a new SocketStream with new StreamReader/StreamWriter instances
# we need to keep references to the old StreamReader/StreamWriter so that they
# are not garbage collected and closed while we're still using them.
stream._inner = self # type: ignore # pylint:disable=W0212,W0201
return stream
async def close(self):
self._writer.close()
self._writer.transport.abort() # noqa
@property
def reader(self):
return self._reader # pragma: no cover
@property
def writer(self):
return self._writer # pragma: no cover

View File

@@ -0,0 +1,4 @@
from ._proxy import CurioProxy as Proxy
__all__ = ('Proxy',)

View File

@@ -0,0 +1,17 @@
from typing import Optional, Tuple
import curio
import curio.io
import curio.socket
async def connect_tcp(
host: str,
port: int,
local_addr: Optional[Tuple[str, int]] = None,
) -> curio.io.Socket:
return await curio.open_connection(
host=host,
port=port,
source_addr=local_addr,
)

View File

@@ -0,0 +1,132 @@
from typing import Any, Optional
import warnings
import curio
import curio.io
from ..._types import ProxyType
from ..._helpers import parse_proxy_url
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ._stream import CurioSocketStream
from ._resolver import Resolver
from ._connect import connect_tcp
from ..._protocols.errors import ReplyError
from ..._connectors.factory_async import create_connector
DEFAULT_TIMEOUT = 60
class CurioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._password = password
self._username = username
self._rdns = rdns
self._resolver = Resolver()
async def connect(
self,
dest_host: str,
dest_port: int,
timeout: Optional[float] = None,
**kwargs: Any,
) -> curio.io.Socket:
if timeout is None:
timeout = DEFAULT_TIMEOUT
_socket = kwargs.get('_socket')
if _socket is not None:
warnings.warn(
"The '_socket' argument is deprecated and will be removed in the future",
DeprecationWarning,
stacklevel=2,
)
local_addr = kwargs.get('local_addr')
try:
return await curio.timeout_after(
timeout,
self._connect,
dest_host,
dest_port,
_socket,
local_addr,
)
except curio.TaskTimeout as e:
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
async def _connect(
self,
dest_host: str,
dest_port: int,
_socket=None,
local_addr=None,
):
if _socket is None:
try:
_socket = await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
local_addr=local_addr,
)
except OSError as e:
msg = 'Could not connect to proxy {}:{} [{}]'.format(
self._proxy_host,
self._proxy_port,
e.strerror,
)
raise ProxyConnectionError(e.errno, msg) from e
stream = CurioSocketStream(_socket)
try:
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
return _socket
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except BaseException:
await stream.close()
raise
@property
def proxy_host(self):
return self._proxy_host
@property
def proxy_port(self):
return self._proxy_port
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'CurioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,25 @@
import socket
from curio.socket import getaddrinfo
from ... import _abc as abc
class Resolver(abc.AsyncResolver):
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
try:
infos = await getaddrinfo(
host=host,
port=port,
family=family,
type=socket.SOCK_STREAM,
)
except socket.gaierror: # pragma: no cover
infos = None
if not infos: # pragma: no cover
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
infos = sorted(infos, key=lambda info: info[0])
family, _, _, _, address = infos[0]
return family, address[0]

View File

@@ -0,0 +1,32 @@
import curio.io
import curio.socket
from ... import _abc as abc
from ..._errors import ProxyError
DEFAULT_RECEIVE_SIZE = 65536
class CurioSocketStream(abc.AsyncSocketStream):
_socket: curio.io.Socket = None
def __init__(self, sock: curio.io.Socket):
self._socket = sock
async def write_all(self, data):
await self._socket.sendall(data)
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
return await self._socket.recv(max_bytes)
async def read_exact(self, n):
data = bytearray()
while len(data) < n:
packet = await self._socket.recv(n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
async def close(self):
await self._socket.close()

View File

@@ -0,0 +1,3 @@
from ._proxy import TrioProxy as Proxy
__all__ = ('Proxy',)

View File

@@ -0,0 +1,36 @@
from typing import Optional, Tuple
import trio
from ._resolver import Resolver
from ..._helpers import is_ipv4_address, is_ipv6_address
async def connect_tcp(
host: str,
port: int,
local_addr: Optional[Tuple[str, int]] = None,
) -> trio.socket.SocketType:
family, host = await _resolve_host(host)
sock = trio.socket.socket(family=family, type=trio.socket.SOCK_STREAM)
if local_addr is not None: # pragma: no cover
await sock.bind(local_addr)
try:
await sock.connect((host, port))
except OSError:
sock.close()
raise
return sock
async def _resolve_host(host):
if is_ipv4_address(host):
return trio.socket.AF_INET, host
if is_ipv6_address(host):
return trio.socket.AF_INET6, host
resolver = Resolver()
return await resolver.resolve(host=host)

View File

@@ -0,0 +1,131 @@
from typing import Any, Optional
import warnings
import trio
from ..._types import ProxyType
from ..._helpers import parse_proxy_url
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ._stream import TrioSocketStream
from ._resolver import Resolver
from ._connect import connect_tcp
from ..._protocols.errors import ReplyError
from ..._connectors.factory_async import create_connector
DEFAULT_TIMEOUT = 60
class TrioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._password = password
self._username = username
self._rdns = rdns
self._resolver = Resolver()
async def connect(
self,
dest_host: str,
dest_port: int,
timeout: Optional[float] = None,
**kwargs: Any,
) -> trio.socket.SocketType:
if timeout is None:
timeout = DEFAULT_TIMEOUT
_socket = kwargs.get('_socket')
if _socket is not None:
warnings.warn(
"The '_socket' argument is deprecated and will be removed in the future",
DeprecationWarning,
stacklevel=2,
)
local_addr = kwargs.get('local_addr')
try:
with trio.fail_after(timeout):
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
_socket=_socket,
local_addr=local_addr,
)
except trio.TooSlowError as e:
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
async def _connect(
self,
dest_host: str,
dest_port: int,
_socket=None,
local_addr=None,
) -> trio.socket.SocketType:
if _socket is None:
try:
_socket = await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
local_addr=local_addr,
)
except OSError as e:
msg = 'Could not connect to proxy {}:{} [{}]'.format(
self._proxy_host,
self._proxy_port,
e.strerror,
)
raise ProxyConnectionError(e.errno, msg) from e
stream = TrioSocketStream(sock=_socket)
try:
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
return _socket
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except BaseException: # trio.Cancelled...
with trio.CancelScope(shield=True):
await stream.close()
raise
@property
def proxy_host(self):
return self._proxy_host
@property
def proxy_port(self):
return self._proxy_port
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'TrioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,21 @@
import trio
from ... import _abc as abc
class Resolver(abc.AsyncResolver):
async def resolve(self, host, port=0, family=trio.socket.AF_UNSPEC):
infos = await trio.socket.getaddrinfo(
host=host,
port=port,
family=family,
type=trio.socket.SOCK_STREAM,
)
if not infos: # pragma: no cover
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
infos = sorted(infos, key=lambda info: info[0])
family, _, _, _, address = infos[0]
return family, address[0]

View File

@@ -0,0 +1,35 @@
import trio
from ..._errors import ProxyError
from ... import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
class TrioSocketStream(abc.AsyncSocketStream):
def __init__(self, sock):
self._socket = sock
async def write_all(self, data):
total_sent = 0
while total_sent < len(data):
remaining = data[total_sent:]
sent = await self._socket.send(remaining)
total_sent += sent
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
return await self._socket.recv(max_bytes)
async def read_exact(self, n):
data = bytearray()
while len(data) < n:
packet = await self._socket.recv(n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
async def close(self):
if self._socket is not None:
self._socket.close()
await trio.lowlevel.checkpoint()

View File

@@ -0,0 +1,7 @@
from ._proxy import TrioProxy as Proxy
from ._chain import ProxyChain
__all__ = (
'Proxy',
'ProxyChain',
)

View File

@@ -0,0 +1,32 @@
from typing import Sequence
import warnings
from ._proxy import TrioProxy
class ProxyChain:
def __init__(self, proxies: Sequence[TrioProxy]):
warnings.warn(
'This implementation of ProxyChain is deprecated and will be removed in the future',
DeprecationWarning,
stacklevel=2,
)
self._proxies = proxies
async def connect(
self,
dest_host,
dest_port,
dest_ssl=None,
timeout=None,
):
forward = None
for proxy in self._proxies:
proxy._forward = forward
forward = proxy
return await forward.connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
timeout=timeout,
)

View File

@@ -0,0 +1,17 @@
from typing import Optional
import trio
from ._stream import TrioSocketStream
async def connect_tcp(
host: str,
port: int,
local_addr: Optional[str] = None,
) -> TrioSocketStream:
trio_stream = await trio.open_tcp_stream(
host=host,
port=port,
local_address=local_addr,
)
return TrioSocketStream(trio_stream)

View File

@@ -0,0 +1,135 @@
import ssl
from typing import Any, Optional
import trio
from ._connect import connect_tcp
from ._stream import TrioSocketStream
from .._resolver import Resolver
from ...._types import ProxyType
from ...._helpers import parse_proxy_url
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ...._protocols.errors import ReplyError
from ...._connectors.factory_async import create_connector
DEFAULT_TIMEOUT = 60
class TrioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
proxy_ssl: Optional[ssl.SSLContext] = None,
forward: Optional['TrioProxy'] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._username = username
self._password = password
self._rdns = rdns
self._proxy_ssl = proxy_ssl
self._forward = forward
self._resolver = Resolver()
async def connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> TrioSocketStream:
if timeout is None:
timeout = DEFAULT_TIMEOUT
local_addr = kwargs.get('local_addr')
try:
with trio.fail_after(timeout):
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
local_addr=local_addr,
)
except trio.TooSlowError as e:
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
async def _connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
local_addr: Optional[str] = None,
) -> TrioSocketStream:
if self._forward is None:
try:
stream = await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
local_addr=local_addr,
)
except OSError as e:
raise ProxyConnectionError(
e.errno,
"Couldn't connect to proxy"
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
) from e
else:
stream = await self._forward.connect(
dest_host=self._proxy_host,
dest_port=self._proxy_port,
)
try:
if self._proxy_ssl is not None:
stream = await stream.start_tls(
hostname=self._proxy_host,
ssl_context=self._proxy_ssl,
)
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
if dest_ssl is not None:
stream = await stream.start_tls(
hostname=dest_host,
ssl_context=dest_ssl,
)
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except BaseException: # trio.Cancelled...
with trio.CancelScope(shield=True):
await stream.close()
raise
return stream
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'TrioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,55 @@
import ssl
from typing import Union
import trio
from ...._errors import ProxyError
from .... import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
TrioStreamType = Union[trio.SocketStream, trio.SSLStream]
class TrioSocketStream(abc.AsyncSocketStream):
_stream: TrioStreamType
def __init__(self, stream: TrioStreamType):
self._stream = stream
async def write_all(self, data):
await self._stream.send_all(data)
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
return await self._stream.receive_some(max_bytes)
async def read_exact(self, n):
data = bytearray()
while len(data) < n:
packet = await self._stream.receive_some(n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
async def start_tls(
self,
hostname: str,
ssl_context: ssl.SSLContext,
) -> 'TrioSocketStream':
ssl_stream = trio.SSLStream(
self._stream,
ssl_context=ssl_context,
server_hostname=hostname,
https_compatible=True,
server_side=False,
)
await ssl_stream.do_handshake()
return TrioSocketStream(ssl_stream)
async def close(self):
await self._stream.aclose()
@property
def trio_stream(self) -> TrioStreamType: # pragma: nocover
return self._stream

View File

@@ -0,0 +1,5 @@
from ._proxy import SyncProxy as Proxy
from ._chain import ProxyChain
__all__ = ('Proxy', 'ProxyChain')

View File

@@ -0,0 +1,32 @@
from typing import Iterable
import warnings
from ._proxy import SyncProxy
class ProxyChain:
def __init__(self, proxies: Iterable[SyncProxy]):
warnings.warn(
'This implementation of ProxyChain is deprecated and will be removed in the future',
DeprecationWarning,
stacklevel=2,
)
self._proxies = proxies
def connect(self, dest_host, dest_port, timeout=None):
curr_socket = None
proxies = list(self._proxies)
length = len(proxies) - 1
for i in range(length):
curr_socket = proxies[i].connect(
dest_host=proxies[i + 1].proxy_host,
dest_port=proxies[i + 1].proxy_port,
timeout=timeout,
_socket=curr_socket,
)
curr_socket = proxies[length].connect(
dest_host=dest_host, dest_port=dest_port, timeout=timeout, _socket=curr_socket
)
return curr_socket

View File

@@ -0,0 +1,16 @@
import socket
from typing import Optional, Tuple
def connect_tcp(
host: str,
port: int,
timeout: Optional[float] = None,
local_addr: Optional[Tuple[str, int]] = None,
) -> socket.socket:
address = (host, port)
return socket.create_connection(
address,
timeout,
source_address=local_addr,
)

View File

@@ -0,0 +1,116 @@
import socket
from typing import Optional, Any
import warnings
from .._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from .._types import ProxyType
from .._helpers import parse_proxy_url
from .._protocols.errors import ReplyError
from .._connectors.factory_sync import create_connector
from ._stream import SyncSocketStream
from ._resolver import SyncResolver
from ._connect import connect_tcp
DEFAULT_TIMEOUT = 60
class SyncProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._password = password
self._username = username
self._rdns = rdns
self._resolver = SyncResolver()
def connect(
self,
dest_host: str,
dest_port: int,
timeout: Optional[float] = None,
**kwargs: Any,
) -> socket.socket:
if timeout is None:
timeout = DEFAULT_TIMEOUT
_socket = kwargs.get('_socket')
if _socket is not None:
warnings.warn(
"The '_socket' argument is deprecated and will be removed in the future",
DeprecationWarning,
stacklevel=2,
)
if _socket is None:
local_addr = kwargs.get('local_addr')
try:
_socket = connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
timeout=timeout,
local_addr=local_addr,
)
except OSError as e:
msg = 'Could not connect to proxy {}:{} [{}]'.format(
self._proxy_host,
self._proxy_port,
e.strerror,
)
raise ProxyConnectionError(e.errno, msg) from e
stream = SyncSocketStream(_socket)
try:
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
return _socket
except socket.timeout as e:
stream.close()
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
except ReplyError as e:
stream.close()
raise ProxyError(e, error_code=e.error_code)
except Exception:
stream.close()
raise
@property
def proxy_host(self):
return self._proxy_host
@property
def proxy_port(self):
return self._proxy_port
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'SyncProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,16 @@
import socket
from .. import _abc as abc
class SyncResolver(abc.SyncResolver):
# noinspection PyMethodMayBeStatic
def resolve(self, host, port=0, family=socket.AF_UNSPEC):
infos = socket.getaddrinfo(host=host, port=port, family=family, type=socket.SOCK_STREAM)
if not infos: # pragma: no cover
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
infos = sorted(infos, key=lambda info: info[0])
family, _, _, _, address = infos[0]
return family, address[0]

View File

@@ -0,0 +1,32 @@
import socket
from .._errors import ProxyError
from .. import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
class SyncSocketStream(abc.SyncSocketStream):
_socket: socket.socket
def __init__(self, sock: socket.socket):
self._socket = sock
def write_all(self, data):
self._socket.sendall(data)
def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
return self._socket.recv(max_bytes)
def read_exact(self, n):
data = bytearray()
while len(data) < n:
packet = self._socket.recv(n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
def close(self):
if self._socket is not None:
self._socket.close()

View File

@@ -0,0 +1,7 @@
from ._proxy import SyncProxy as Proxy
from ._chain import ProxyChain
__all__ = (
'Proxy',
'ProxyChain',
)

View File

@@ -0,0 +1,26 @@
from typing import Iterable
from ._proxy import SyncProxy
class ProxyChain:
def __init__(self, proxies: Iterable[SyncProxy]):
self._proxies = proxies
def connect(
self,
dest_host,
dest_port,
dest_ssl=None,
timeout=None,
):
forward = None
for proxy in self._proxies:
proxy._forward = forward
forward = proxy
return forward.connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
timeout=timeout,
)

View File

@@ -0,0 +1,19 @@
import socket
from typing import Optional, Tuple
from ._stream import SyncSocketStream
def connect_tcp(
host: str,
port: int,
timeout: Optional[float] = None,
local_addr: Optional[Tuple[str, int]] = None,
) -> SyncSocketStream:
address = (host, port)
sock = socket.create_connection(
address,
timeout,
source_address=local_addr,
)
return SyncSocketStream(sock)

View File

@@ -0,0 +1,121 @@
import socket
import ssl
from typing import Any, Optional
from ._connect import connect_tcp
from ._stream import SyncSocketStream
from .._resolver import SyncResolver
from ..._types import ProxyType
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ..._helpers import parse_proxy_url
from ..._protocols.errors import ReplyError
from ..._connectors.factory_sync import create_connector
DEFAULT_TIMEOUT = 60
class SyncProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
proxy_ssl: Optional[ssl.SSLContext] = None,
forward: Optional['SyncProxy'] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._username = username
self._password = password
self._rdns = rdns
self._proxy_ssl = proxy_ssl
self._forward = forward
self._resolver = SyncResolver()
def connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> SyncSocketStream:
if timeout is None:
timeout = DEFAULT_TIMEOUT
if self._forward is None:
local_addr = kwargs.get('local_addr')
try:
stream = connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
timeout=timeout,
local_addr=local_addr,
)
except OSError as e:
msg = 'Could not connect to proxy {}:{} [{}]'.format(
self._proxy_host,
self._proxy_port,
e.strerror,
)
raise ProxyConnectionError(e.errno, msg) from e
else:
stream = self._forward.connect(
dest_host=self._proxy_host,
dest_port=self._proxy_port,
timeout=timeout,
)
try:
if self._proxy_ssl is not None:
stream = stream.start_tls(
hostname=self._proxy_host,
ssl_context=self._proxy_ssl,
)
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
if dest_ssl is not None:
stream = stream.start_tls(
hostname=dest_host,
ssl_context=dest_ssl,
)
return stream
except socket.timeout as e:
stream.close()
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
except ReplyError as e:
stream.close()
raise ProxyError(e, error_code=e.error_code)
except Exception:
stream.close()
raise
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'SyncProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,200 @@
"""
Copied from urllib3.util.ssltransport
"""
import io
import socket
import ssl
SSL_BLOCKSIZE = 16384
class SSLTransport:
"""
The SSLTransport wraps an existing socket and establishes an SSL connection.
Contrary to Python's implementation of SSLSocket, it allows you to chain
multiple TLS connections together. It's particularly useful if you need to
implement TLS within TLS.
The class supports most of the socket API operations.
"""
def __init__(
self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
):
"""
Create an SSLTransport around socket using the provided ssl_context.
"""
self.incoming = ssl.MemoryBIO()
self.outgoing = ssl.MemoryBIO()
self.suppress_ragged_eofs = suppress_ragged_eofs
self.socket = socket
self.sslobj = ssl_context.wrap_bio(
self.incoming, self.outgoing, server_hostname=server_hostname
)
# Perform initial handshake.
self._ssl_io_loop(self.sslobj.do_handshake)
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
def fileno(self):
return self.socket.fileno()
def read(self, len=1024, buffer=None):
return self._wrap_ssl_read(len, buffer)
def recv(self, len=1024, flags=0):
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to recv")
return self._wrap_ssl_read(len)
def recv_into(self, buffer, nbytes=None, flags=0):
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to recv_into")
if buffer and (nbytes is None):
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
return self.read(nbytes, buffer)
def sendall(self, data, flags=0):
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to sendall")
count = 0
with memoryview(data) as view, view.cast("B") as byte_view:
amount = len(byte_view)
while count < amount:
v = self.send(byte_view[count:])
count += v
def send(self, data, flags=0):
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to send")
response = self._ssl_io_loop(self.sslobj.write, data)
return response
def makefile(
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
):
"""
Python's httpclient uses makefile and buffered io when reading HTTP
messages and we need to support it.
This is unfortunately a copy and paste of socket.py makefile with small
changes to point to the socket directly.
"""
if not set(mode) <= {"r", "w", "b"}:
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
writing = "w" in mode
reading = "r" in mode or not writing
assert reading or writing
binary = "b" in mode
rawmode = ""
if reading:
rawmode += "r"
if writing:
rawmode += "w"
raw = socket.SocketIO(self, rawmode)
self.socket._io_refs += 1
if buffering is None:
buffering = -1
if buffering < 0:
buffering = io.DEFAULT_BUFFER_SIZE
if buffering == 0:
if not binary:
raise ValueError("unbuffered streams must be binary")
return raw
if reading and writing:
buffer = io.BufferedRWPair(raw, raw, buffering)
elif reading:
buffer = io.BufferedReader(raw, buffering)
else:
assert writing
buffer = io.BufferedWriter(raw, buffering)
if binary:
return buffer
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
return text
def unwrap(self):
self._ssl_io_loop(self.sslobj.unwrap)
def close(self):
self.socket.close()
def getpeercert(self, binary_form=False):
return self.sslobj.getpeercert(binary_form)
def version(self):
return self.sslobj.version()
def cipher(self):
return self.sslobj.cipher()
def selected_alpn_protocol(self):
return self.sslobj.selected_alpn_protocol()
def selected_npn_protocol(self):
return self.sslobj.selected_npn_protocol()
def shared_ciphers(self):
return self.sslobj.shared_ciphers()
def compression(self):
return self.sslobj.compression()
def settimeout(self, value):
self.socket.settimeout(value)
def gettimeout(self):
return self.socket.gettimeout()
def _decref_socketios(self):
self.socket._decref_socketios()
def _wrap_ssl_read(self, len, buffer=None):
try:
return self._ssl_io_loop(self.sslobj.read, len, buffer)
except ssl.SSLError as e:
if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
return 0 # eof, return 0.
else:
raise
def _ssl_io_loop(self, func, *args):
"""Performs an I/O loop between incoming/outgoing and the socket."""
should_loop = True
ret = None
while should_loop:
errno = None
try:
ret = func(*args)
except ssl.SSLError as e:
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
# WANT_READ, and WANT_WRITE are expected, others are not.
raise e
errno = e.errno
buf = self.outgoing.read()
self.socket.sendall(buf)
if errno is None:
should_loop = False
elif errno == ssl.SSL_ERROR_WANT_READ:
buf = self.socket.recv(SSL_BLOCKSIZE)
if buf:
self.incoming.write(buf)
else:
self.incoming.write_eof()
return ret

View File

@@ -0,0 +1,56 @@
import socket
import ssl
from typing import Union
from ._ssl_transport import SSLTransport
from ..._errors import ProxyError
from ... import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
SocketType = Union[socket.socket, ssl.SSLSocket, SSLTransport]
class SyncSocketStream(abc.SyncSocketStream):
_socket: SocketType
def __init__(self, sock: SocketType):
self._socket = sock
def write_all(self, data):
self._socket.sendall(data)
def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
return self._socket.recv(max_bytes)
def read_exact(self, n):
data = bytearray()
while len(data) < n:
packet = self._socket.recv(n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
def start_tls(self, hostname: str, ssl_context: ssl.SSLContext) -> 'SyncSocketStream':
if isinstance(self._socket, (ssl.SSLSocket, SSLTransport)):
ssl_socket = SSLTransport(
self._socket,
ssl_context=ssl_context,
server_hostname=hostname,
)
else: # plain socket?
ssl_socket = ssl_context.wrap_socket(
self._socket,
server_hostname=hostname,
)
return SyncSocketStream(ssl_socket)
def close(self):
self._socket.close()
@property
def socket(self) -> SocketType: # pragma: nocover
return self._socket