fix: 修复代理问题
This commit is contained in:
20
.venv/lib/python3.9/site-packages/python_socks/__init__.py
Normal file
20
.venv/lib/python3.9/site-packages/python_socks/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from ._version import __version__, __title__
|
||||
|
||||
from ._types import ProxyType
|
||||
from ._helpers import parse_proxy_url
|
||||
|
||||
from ._errors import (
|
||||
ProxyError,
|
||||
ProxyTimeoutError,
|
||||
ProxyConnectionError,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'__title__',
|
||||
'__version__',
|
||||
'ProxyError',
|
||||
'ProxyTimeoutError',
|
||||
'ProxyConnectionError',
|
||||
'ProxyType',
|
||||
'parse_proxy_url',
|
||||
)
|
||||
40
.venv/lib/python3.9/site-packages/python_socks/_abc.py
Normal file
40
.venv/lib/python3.9/site-packages/python_socks/_abc.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class SyncResolver:
|
||||
def resolve(self, host, port=0, family=0):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AsyncResolver:
|
||||
async def resolve(self, host, port=0, family=0):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SyncSocketStream:
|
||||
|
||||
def write_all(self, data: bytes):
|
||||
raise NotImplementedError()
|
||||
|
||||
def read(self, max_bytes: Optional[int] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def read_exact(self, n: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AsyncSocketStream:
|
||||
async def write_all(self, data: bytes):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def read(self, max_bytes: Optional[int] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def read_exact(self, n: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def close(self):
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,21 @@
|
||||
from .._abc import SyncSocketStream, AsyncSocketStream
|
||||
|
||||
|
||||
class SyncConnector:
|
||||
def connect(
|
||||
self,
|
||||
stream: SyncSocketStream,
|
||||
host: str,
|
||||
port: int,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncConnector:
|
||||
async def connect(
|
||||
self,
|
||||
stream: AsyncSocketStream,
|
||||
host: str,
|
||||
port: int,
|
||||
):
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
from .._abc import AsyncResolver
|
||||
from .._types import ProxyType
|
||||
|
||||
from .abc import AsyncConnector
|
||||
from .socks5_async import Socks5AsyncConnector
|
||||
from .socks4_async import Socks4AsyncConnector
|
||||
from .http_async import HttpAsyncConnector
|
||||
|
||||
|
||||
def create_connector(
|
||||
proxy_type: ProxyType,
|
||||
username: Optional[str],
|
||||
password: Optional[str],
|
||||
rdns: Optional[bool],
|
||||
resolver: AsyncResolver,
|
||||
) -> AsyncConnector:
|
||||
if proxy_type == ProxyType.SOCKS4:
|
||||
return Socks4AsyncConnector(
|
||||
user_id=username,
|
||||
rdns=rdns,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
if proxy_type == ProxyType.SOCKS5:
|
||||
return Socks5AsyncConnector(
|
||||
username=username,
|
||||
password=password,
|
||||
rdns=rdns,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
if proxy_type == ProxyType.HTTP:
|
||||
return HttpAsyncConnector(
|
||||
username=username,
|
||||
password=password,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
raise ValueError(f'Invalid proxy type: {proxy_type}')
|
||||
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
from .._abc import SyncResolver
|
||||
from .._types import ProxyType
|
||||
|
||||
from .abc import SyncConnector
|
||||
from .socks5_sync import Socks5SyncConnector
|
||||
from .socks4_sync import Socks4SyncConnector
|
||||
from .http_sync import HttpSyncConnector
|
||||
|
||||
|
||||
def create_connector(
|
||||
proxy_type: ProxyType,
|
||||
username: Optional[str],
|
||||
password: Optional[str],
|
||||
rdns: Optional[bool],
|
||||
resolver: SyncResolver,
|
||||
) -> SyncConnector:
|
||||
if proxy_type == ProxyType.SOCKS4:
|
||||
return Socks4SyncConnector(
|
||||
user_id=username,
|
||||
rdns=rdns,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
if proxy_type == ProxyType.SOCKS5:
|
||||
return Socks5SyncConnector(
|
||||
username=username,
|
||||
password=password,
|
||||
rdns=rdns,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
if proxy_type == ProxyType.HTTP:
|
||||
return HttpSyncConnector(
|
||||
username=username,
|
||||
password=password,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
raise ValueError(f'Invalid proxy type: {proxy_type}')
|
||||
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
from .._abc import AsyncSocketStream, AsyncResolver
|
||||
from .abc import AsyncConnector
|
||||
|
||||
from .._protocols import http
|
||||
|
||||
|
||||
class HttpAsyncConnector(AsyncConnector):
|
||||
def __init__(
|
||||
self,
|
||||
username: Optional[str],
|
||||
password: Optional[str],
|
||||
resolver: AsyncResolver,
|
||||
):
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._resolver = resolver
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
stream: AsyncSocketStream,
|
||||
host: str,
|
||||
port: int,
|
||||
) -> http.ConnectReply:
|
||||
conn = http.Connection()
|
||||
|
||||
request = http.ConnectRequest(
|
||||
host=host,
|
||||
port=port,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
)
|
||||
data = conn.send(request)
|
||||
await stream.write_all(data)
|
||||
|
||||
data = await stream.read()
|
||||
reply: http.ConnectReply = conn.receive(data)
|
||||
return reply
|
||||
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
from .._abc import SyncSocketStream, SyncResolver
|
||||
from .abc import SyncConnector
|
||||
|
||||
from .._protocols import http
|
||||
|
||||
|
||||
class HttpSyncConnector(SyncConnector):
|
||||
def __init__(
|
||||
self,
|
||||
username: Optional[str],
|
||||
password: Optional[str],
|
||||
resolver: SyncResolver,
|
||||
):
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._resolver = resolver
|
||||
|
||||
def connect(
|
||||
self,
|
||||
stream: SyncSocketStream,
|
||||
host: str,
|
||||
port: int,
|
||||
) -> http.ConnectReply:
|
||||
conn = http.Connection()
|
||||
|
||||
request = http.ConnectRequest(
|
||||
host=host,
|
||||
port=port,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
)
|
||||
data = conn.send(request)
|
||||
stream.write_all(data)
|
||||
|
||||
data = stream.read()
|
||||
reply: http.ConnectReply = conn.receive(data)
|
||||
return reply
|
||||
@@ -0,0 +1,45 @@
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
from .._abc import AsyncSocketStream, AsyncResolver
|
||||
from .abc import AsyncConnector
|
||||
|
||||
from .._protocols import socks4
|
||||
from .._helpers import is_ip_address
|
||||
|
||||
|
||||
class Socks4AsyncConnector(AsyncConnector):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
rdns: Optional[bool],
|
||||
resolver: AsyncResolver,
|
||||
):
|
||||
if rdns is None:
|
||||
rdns = False
|
||||
|
||||
self._user_id = user_id
|
||||
self._rdns = rdns
|
||||
self._resolver = resolver
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
stream: AsyncSocketStream,
|
||||
host: str,
|
||||
port: int,
|
||||
) -> socks4.ConnectReply:
|
||||
conn = socks4.Connection()
|
||||
|
||||
if not is_ip_address(host) and not self._rdns:
|
||||
_, host = await self._resolver.resolve(
|
||||
host,
|
||||
family=socket.AF_INET,
|
||||
)
|
||||
|
||||
request = socks4.ConnectRequest(host=host, port=port, user_id=self._user_id)
|
||||
data = conn.send(request)
|
||||
await stream.write_all(data)
|
||||
|
||||
data = await stream.read_exact(socks4.ConnectReply.SIZE)
|
||||
reply: socks4.ConnectReply = conn.receive(data)
|
||||
return reply
|
||||
@@ -0,0 +1,45 @@
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
from .._abc import SyncSocketStream, SyncResolver
|
||||
from .abc import SyncConnector
|
||||
|
||||
from .._protocols import socks4
|
||||
from .._helpers import is_ip_address
|
||||
|
||||
|
||||
class Socks4SyncConnector(SyncConnector):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
rdns: Optional[bool],
|
||||
resolver: SyncResolver,
|
||||
):
|
||||
if rdns is None:
|
||||
rdns = False
|
||||
|
||||
self._user_id = user_id
|
||||
self._rdns = rdns
|
||||
self._resolver = resolver
|
||||
|
||||
def connect(
|
||||
self,
|
||||
stream: SyncSocketStream,
|
||||
host: str,
|
||||
port: int,
|
||||
) -> socks4.ConnectReply:
|
||||
conn = socks4.Connection()
|
||||
|
||||
if not is_ip_address(host) and not self._rdns:
|
||||
_, host = self._resolver.resolve(
|
||||
host,
|
||||
family=socket.AF_INET,
|
||||
)
|
||||
|
||||
request = socks4.ConnectRequest(host=host, port=port, user_id=self._user_id)
|
||||
data = conn.send(request)
|
||||
stream.write_all(data)
|
||||
|
||||
data = stream.read_exact(socks4.ConnectReply.SIZE)
|
||||
reply: socks4.ConnectReply = conn.receive(data)
|
||||
return reply
|
||||
@@ -0,0 +1,95 @@
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
from .._abc import AsyncSocketStream, AsyncResolver
|
||||
from .abc import AsyncConnector
|
||||
|
||||
from .._protocols import socks5
|
||||
from .._helpers import is_ip_address
|
||||
|
||||
|
||||
class Socks5AsyncConnector(AsyncConnector):
|
||||
def __init__(
|
||||
self,
|
||||
username: Optional[str],
|
||||
password: Optional[str],
|
||||
rdns: Optional[bool],
|
||||
resolver: AsyncResolver,
|
||||
):
|
||||
if rdns is None:
|
||||
rdns = True
|
||||
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._rdns = rdns
|
||||
self._resolver = resolver
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
stream: AsyncSocketStream,
|
||||
host: str,
|
||||
port: int,
|
||||
) -> socks5.ConnectReply:
|
||||
conn = socks5.Connection()
|
||||
|
||||
# Auth methods
|
||||
request = socks5.AuthMethodsRequest(
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
)
|
||||
data = conn.send(request)
|
||||
await stream.write_all(data)
|
||||
|
||||
data = await stream.read_exact(socks5.AuthMethodReply.SIZE)
|
||||
reply: socks5.AuthMethodReply = conn.receive(data)
|
||||
|
||||
# Authenticate
|
||||
if reply.method == socks5.AuthMethod.USERNAME_PASSWORD:
|
||||
request = socks5.AuthRequest(
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
)
|
||||
data = conn.send(request)
|
||||
await stream.write_all(data)
|
||||
|
||||
data = await stream.read_exact(socks5.AuthReply.SIZE)
|
||||
_: socks5.AuthReply = conn.receive(data)
|
||||
|
||||
# Connect
|
||||
if not is_ip_address(host) and not self._rdns:
|
||||
_, host = await self._resolver.resolve(
|
||||
host,
|
||||
family=socket.AF_UNSPEC,
|
||||
)
|
||||
|
||||
request = socks5.ConnectRequest(host=host, port=port)
|
||||
data = conn.send(request)
|
||||
await stream.write_all(data)
|
||||
|
||||
data = await self._read_reply(stream)
|
||||
reply: socks5.ConnectReply = conn.receive(data)
|
||||
return reply
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
async def _read_reply(self, stream: AsyncSocketStream) -> bytes:
|
||||
data = await stream.read_exact(3)
|
||||
if data[0] != socks5.SOCKS_VER:
|
||||
return data
|
||||
if data[1] != socks5.ReplyCode.SUCCEEDED:
|
||||
return data
|
||||
if data[2] != socks5.RSV:
|
||||
return data
|
||||
|
||||
data += await stream.read_exact(1)
|
||||
addr_type = data[3]
|
||||
|
||||
if addr_type == socks5.AddressType.IPV4:
|
||||
data += await stream.read_exact(6)
|
||||
elif addr_type == socks5.AddressType.IPV6:
|
||||
data += await stream.read_exact(18)
|
||||
elif addr_type == socks5.AddressType.DOMAIN:
|
||||
data += await stream.read_exact(1)
|
||||
host_len = data[-1]
|
||||
data += await stream.read_exact(host_len + 2)
|
||||
|
||||
return data
|
||||
@@ -0,0 +1,86 @@
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
from .._abc import SyncSocketStream, SyncResolver
|
||||
from .abc import SyncConnector
|
||||
|
||||
from .._protocols import socks5
|
||||
from .._helpers import is_ip_address
|
||||
|
||||
|
||||
class Socks5SyncConnector(SyncConnector):
|
||||
def __init__(
|
||||
self,
|
||||
username: Optional[str],
|
||||
password: Optional[str],
|
||||
rdns: Optional[bool],
|
||||
resolver: SyncResolver,
|
||||
):
|
||||
if rdns is None:
|
||||
rdns = True
|
||||
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._rdns = rdns
|
||||
self._resolver = resolver
|
||||
|
||||
def connect(
|
||||
self,
|
||||
stream: SyncSocketStream,
|
||||
host: str,
|
||||
port: int,
|
||||
) -> socks5.ConnectReply:
|
||||
conn = socks5.Connection()
|
||||
|
||||
# Auth methods
|
||||
request = socks5.AuthMethodsRequest(username=self._username, password=self._password)
|
||||
data = conn.send(request)
|
||||
stream.write_all(data)
|
||||
|
||||
data = stream.read_exact(socks5.AuthMethodReply.SIZE)
|
||||
reply: socks5.AuthMethodReply = conn.receive(data)
|
||||
|
||||
# Authenticate
|
||||
if reply.method == socks5.AuthMethod.USERNAME_PASSWORD:
|
||||
request = socks5.AuthRequest(username=self._username, password=self._password)
|
||||
data = conn.send(request)
|
||||
stream.write_all(data)
|
||||
|
||||
data = stream.read_exact(socks5.AuthReply.SIZE)
|
||||
_: socks5.AuthReply = conn.receive(data)
|
||||
|
||||
# Connect
|
||||
if not is_ip_address(host) and not self._rdns:
|
||||
_, host = self._resolver.resolve(host, family=socket.AF_UNSPEC)
|
||||
|
||||
request = socks5.ConnectRequest(host=host, port=port)
|
||||
data = conn.send(request)
|
||||
stream.write_all(data)
|
||||
|
||||
data = self._read_reply(stream)
|
||||
reply: socks5.ConnectReply = conn.receive(data)
|
||||
return reply
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
def _read_reply(self, stream: SyncSocketStream) -> bytes:
|
||||
data = stream.read_exact(3)
|
||||
if data[0] != socks5.SOCKS_VER:
|
||||
return data
|
||||
if data[1] != socks5.ReplyCode.SUCCEEDED:
|
||||
return data
|
||||
if data[2] != socks5.RSV:
|
||||
return data
|
||||
|
||||
data += stream.read_exact(1)
|
||||
addr_type = data[3]
|
||||
|
||||
if addr_type == socks5.AddressType.IPV4:
|
||||
data += stream.read_exact(6)
|
||||
elif addr_type == socks5.AddressType.IPV6:
|
||||
data += stream.read_exact(18)
|
||||
elif addr_type == socks5.AddressType.DOMAIN:
|
||||
data += stream.read_exact(1)
|
||||
host_len = data[-1]
|
||||
data += stream.read_exact(host_len + 2)
|
||||
|
||||
return data
|
||||
16
.venv/lib/python3.9/site-packages/python_socks/_errors.py
Normal file
16
.venv/lib/python3.9/site-packages/python_socks/_errors.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class ProxyException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProxyTimeoutError(ProxyException, TimeoutError):
|
||||
pass
|
||||
|
||||
|
||||
class ProxyConnectionError(ProxyException, OSError):
|
||||
pass
|
||||
|
||||
|
||||
class ProxyError(ProxyException):
|
||||
def __init__(self, message, error_code=None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
81
.venv/lib/python3.9/site-packages/python_socks/_helpers.py
Normal file
81
.venv/lib/python3.9/site-packages/python_socks/_helpers.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import functools
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
from ._types import ProxyType
|
||||
|
||||
# pylint:disable-next=invalid-name
|
||||
_ipv4_pattern = (
|
||||
r'^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}'
|
||||
r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$'
|
||||
)
|
||||
|
||||
# pylint:disable-next=invalid-name
|
||||
_ipv6_pattern = (
|
||||
r'^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}'
|
||||
r'(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)'
|
||||
r'((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})'
|
||||
r'(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}'
|
||||
r'(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}'
|
||||
r'[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)'
|
||||
r'(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}'
|
||||
r':|:(:[A-F0-9]{1,4}){7})$'
|
||||
)
|
||||
|
||||
_ipv4_regex = re.compile(_ipv4_pattern)
|
||||
_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
|
||||
_ipv4_regexb = re.compile(_ipv4_pattern.encode('ascii'))
|
||||
_ipv6_regexb = re.compile(_ipv6_pattern.encode('ascii'), flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def _is_ip_address(regex, regexb, host):
|
||||
# if host is None:
|
||||
# return False
|
||||
if isinstance(host, str):
|
||||
return bool(regex.match(host))
|
||||
elif isinstance(host, (bytes, bytearray, memoryview)):
|
||||
return bool(regexb.match(host))
|
||||
else:
|
||||
raise TypeError(
|
||||
'{} [{}] is not a str or bytes'.format(host, type(host)) # pragma: no cover
|
||||
)
|
||||
|
||||
|
||||
is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
|
||||
is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)
|
||||
|
||||
|
||||
def is_ip_address(host):
|
||||
return is_ipv4_address(host) or is_ipv6_address(host)
|
||||
|
||||
|
||||
def parse_proxy_url(url: str) -> Tuple[ProxyType, str, int, Optional[str], Optional[str]]:
|
||||
parsed = urlparse(url)
|
||||
|
||||
scheme = parsed.scheme
|
||||
if scheme == 'socks5':
|
||||
proxy_type = ProxyType.SOCKS5
|
||||
elif scheme == 'socks4':
|
||||
proxy_type = ProxyType.SOCKS4
|
||||
elif scheme == 'http':
|
||||
proxy_type = ProxyType.HTTP
|
||||
else:
|
||||
raise ValueError(f'Invalid scheme component: {scheme}') # pragma: no cover
|
||||
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise ValueError('Empty host component') # pragma: no cover
|
||||
|
||||
try:
|
||||
port = parsed.port
|
||||
assert port is not None
|
||||
except (ValueError, TypeError, AssertionError) as e: # pragma: no cover
|
||||
raise ValueError('Invalid port component') from e
|
||||
|
||||
try:
|
||||
username, password = (unquote(parsed.username), unquote(parsed.password))
|
||||
except (AttributeError, TypeError):
|
||||
username, password = '', ''
|
||||
|
||||
return proxy_type, host, port, username, password
|
||||
@@ -0,0 +1,4 @@
|
||||
class ReplyError(Exception):
|
||||
def __init__(self, message, error_code=None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
@@ -0,0 +1,148 @@
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import base64
|
||||
import binascii
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
|
||||
from .._version import __title__, __version__
|
||||
|
||||
from .errors import ReplyError
|
||||
|
||||
DEFAULT_USER_AGENT = 'Python/{0[0]}.{0[1]} {1}/{2}'.format(
|
||||
sys.version_info,
|
||||
__title__,
|
||||
__version__,
|
||||
)
|
||||
|
||||
CRLF = '\r\n'
|
||||
|
||||
|
||||
class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])):
|
||||
"""Http basic authentication helper."""
|
||||
|
||||
def __new__(cls, login: str, password: str = '', encoding: str = 'latin1') -> 'BasicAuth':
|
||||
if login is None:
|
||||
raise ValueError('None is not allowed as login value')
|
||||
|
||||
if password is None:
|
||||
raise ValueError('None is not allowed as password value')
|
||||
|
||||
if ':' in login:
|
||||
raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
|
||||
|
||||
# noinspection PyTypeChecker,PyArgumentList
|
||||
return super().__new__(cls, login, password, encoding)
|
||||
|
||||
@classmethod
|
||||
def decode(cls, auth_header: str, encoding: str = 'latin1') -> 'BasicAuth':
|
||||
"""Create a BasicAuth object from an Authorization HTTP header."""
|
||||
try:
|
||||
auth_type, encoded_credentials = auth_header.split(' ', 1)
|
||||
except ValueError:
|
||||
raise ValueError('Could not parse authorization header.')
|
||||
|
||||
if auth_type.lower() != 'basic':
|
||||
raise ValueError('Unknown authorization method %s' % auth_type)
|
||||
|
||||
try:
|
||||
decoded = base64.b64decode(encoded_credentials.encode('ascii'), validate=True).decode(
|
||||
encoding
|
||||
)
|
||||
except binascii.Error:
|
||||
raise ValueError('Invalid base64 encoding.')
|
||||
|
||||
try:
|
||||
# RFC 2617 HTTP Authentication
|
||||
# https://www.ietf.org/rfc/rfc2617.txt
|
||||
# the colon must be present, but the username and password may be
|
||||
# otherwise blank.
|
||||
username, password = decoded.split(':', 1)
|
||||
except ValueError:
|
||||
raise ValueError('Invalid credentials.')
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
return cls(username, password, encoding=encoding)
|
||||
|
||||
def encode(self) -> str:
|
||||
"""Encode credentials."""
|
||||
creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding)
|
||||
return 'Basic %s' % base64.b64encode(creds).decode(self.encoding)
|
||||
|
||||
|
||||
class _Buffer:
|
||||
def __init__(self, encoding: str = 'utf-8'):
|
||||
self._encoding = encoding
|
||||
self._buffer = bytearray()
|
||||
|
||||
def append_line(self, line: str = ""):
|
||||
if line:
|
||||
self._buffer.extend(line.encode(self._encoding))
|
||||
|
||||
self._buffer.extend(CRLF.encode('ascii'))
|
||||
|
||||
def dumps(self) -> bytes:
|
||||
return bytes(self._buffer)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectRequest:
|
||||
host: str
|
||||
port: int
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
|
||||
def dumps(self) -> bytes:
|
||||
buff = _Buffer()
|
||||
buff.append_line(f'CONNECT {self.host}:{self.port} HTTP/1.1')
|
||||
buff.append_line(f'Host: {self.host}:{self.port}')
|
||||
buff.append_line(f'User-Agent: {DEFAULT_USER_AGENT}')
|
||||
|
||||
if self.username and self.password:
|
||||
auth = BasicAuth(self.username, self.password)
|
||||
buff.append_line(f'Proxy-Authorization: {auth.encode()}')
|
||||
|
||||
buff.append_line()
|
||||
|
||||
return buff.dumps()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectReply:
|
||||
status_code: int
|
||||
message: str
|
||||
|
||||
@classmethod
|
||||
def loads(cls, data: bytes) -> 'ConnectReply':
|
||||
if not data:
|
||||
raise ReplyError('Invalid proxy response') # pragma: no cover
|
||||
|
||||
line = data.split(CRLF.encode('ascii'), 1)[0]
|
||||
line = line.decode('utf-8', 'surrogateescape')
|
||||
|
||||
try:
|
||||
version, code, *reason = line.split()
|
||||
except ValueError: # pragma: no cover
|
||||
raise ReplyError(f'Invalid status line: {line}')
|
||||
|
||||
try:
|
||||
status_code = int(code)
|
||||
except ValueError: # pragma: no cover
|
||||
raise ReplyError(f'Invalid status code: {code}')
|
||||
|
||||
status_message = " ".join(reason)
|
||||
|
||||
if status_code != 200:
|
||||
msg = f'{status_code} {status_message}'
|
||||
raise ReplyError(msg, error_code=status_code)
|
||||
|
||||
return cls(status_code=status_code, message=status_message)
|
||||
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
class Connection:
|
||||
def send(self, request: ConnectRequest) -> bytes:
|
||||
return request.dumps()
|
||||
|
||||
def receive(self, data: bytes) -> ConnectReply:
|
||||
return ConnectReply.loads(data)
|
||||
@@ -0,0 +1,116 @@
|
||||
import enum
|
||||
import ipaddress
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from .errors import ReplyError
|
||||
from .._helpers import is_ipv4_address
|
||||
|
||||
RSV = NULL = 0x00
|
||||
SOCKS_VER = 0x04
|
||||
|
||||
|
||||
class Command(enum.IntEnum):
|
||||
CONNECT = 0x01
|
||||
BIND = 0x02
|
||||
|
||||
|
||||
class ReplyCode(enum.IntEnum):
|
||||
REQUEST_GRANTED = 0x5A
|
||||
REQUEST_REJECTED_OR_FAILED = 0x5B
|
||||
CONNECTION_FAILED = 0x5C
|
||||
AUTHENTICATION_FAILED = 0x5D
|
||||
|
||||
|
||||
ReplyMessages = {
|
||||
ReplyCode.REQUEST_GRANTED: 'Request granted',
|
||||
ReplyCode.REQUEST_REJECTED_OR_FAILED: 'Request rejected or failed',
|
||||
ReplyCode.CONNECTION_FAILED: (
|
||||
'Request rejected because SOCKS server cannot connect to identd on the client'
|
||||
),
|
||||
ReplyCode.AUTHENTICATION_FAILED: (
|
||||
'Request rejected because the client program and identd report different user-ids'
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectRequest:
|
||||
host: str # hostname or IPv4 address
|
||||
port: int
|
||||
user_id: Optional[str]
|
||||
|
||||
def dumps(self):
|
||||
port_bytes = self.port.to_bytes(2, 'big')
|
||||
include_hostname = False
|
||||
|
||||
if is_ipv4_address(self.host):
|
||||
host_bytes = ipaddress.IPv4Address(self.host).packed
|
||||
else:
|
||||
include_hostname = True
|
||||
host_bytes = bytes([NULL, NULL, NULL, 1])
|
||||
|
||||
data = bytearray([SOCKS_VER, Command.CONNECT])
|
||||
data += port_bytes
|
||||
data += host_bytes
|
||||
|
||||
if self.user_id:
|
||||
data += self.user_id.encode('ascii')
|
||||
|
||||
data.append(NULL)
|
||||
|
||||
if include_hostname:
|
||||
data += self.host.encode('idna')
|
||||
data.append(NULL)
|
||||
|
||||
return bytes(data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectReply:
|
||||
SIZE = 8
|
||||
|
||||
rsv: int
|
||||
reply: ReplyCode
|
||||
host: str # should be ignored when using Command.CONNECT
|
||||
port: int # should be ignored when using Command.CONNECT
|
||||
|
||||
@classmethod
|
||||
def loads(cls, data: bytes) -> 'ConnectReply':
|
||||
if len(data) != cls.SIZE:
|
||||
raise ReplyError('Malformed connect reply')
|
||||
|
||||
rsv = data[0]
|
||||
if rsv != RSV: # pragma: no cover
|
||||
raise ReplyError(f'Unexpected reply version: {data[0]:#02X}')
|
||||
|
||||
try:
|
||||
reply = ReplyCode(data[1])
|
||||
except ValueError:
|
||||
raise ReplyError(f'Invalid reply code: {data[1]:#02X}')
|
||||
|
||||
if reply != ReplyCode.REQUEST_GRANTED: # pragma: no cover
|
||||
msg = ReplyMessages.get(reply, 'Unknown error')
|
||||
raise ReplyError(msg, error_code=reply)
|
||||
|
||||
try:
|
||||
port = int.from_bytes(data[2:4], byteorder="big")
|
||||
except ValueError:
|
||||
raise ReplyError('Invalid port data')
|
||||
|
||||
try:
|
||||
host = socket.inet_ntop(socket.AF_INET, data[4:8])
|
||||
except ValueError:
|
||||
raise ReplyError('Invalid port data')
|
||||
|
||||
return cls(rsv=rsv, reply=reply, host=host, port=port)
|
||||
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
class Connection:
|
||||
def send(self, request: ConnectRequest) -> bytes:
|
||||
return request.dumps()
|
||||
|
||||
def receive(self, data: bytes) -> ConnectReply:
|
||||
return ConnectReply.loads(data)
|
||||
@@ -0,0 +1,355 @@
|
||||
import enum
|
||||
import ipaddress
|
||||
import socket
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .errors import ReplyError
|
||||
from .._helpers import is_ip_address
|
||||
|
||||
|
||||
RSV = NULL = AUTH_GRANTED = 0x00
|
||||
SOCKS_VER = 0x05
|
||||
|
||||
|
||||
class AuthMethod(enum.IntEnum):
|
||||
ANONYMOUS = 0x00
|
||||
GSSAPI = 0x01
|
||||
USERNAME_PASSWORD = 0x02
|
||||
NO_ACCEPTABLE = 0xFF
|
||||
|
||||
|
||||
class AddressType(enum.IntEnum):
|
||||
IPV4 = 0x01
|
||||
DOMAIN = 0x03
|
||||
IPV6 = 0x04
|
||||
|
||||
@classmethod
|
||||
def from_ip_ver(cls, ver: int):
|
||||
if ver == 4:
|
||||
return cls.IPV4
|
||||
if ver == 6:
|
||||
return cls.IPV6
|
||||
|
||||
raise ValueError('Invalid IP version')
|
||||
|
||||
|
||||
class Command(enum.IntEnum):
|
||||
CONNECT = 0x01
|
||||
BIND = 0x02
|
||||
UDP_ASSOCIATE = 0x03
|
||||
|
||||
|
||||
class ReplyCode(enum.IntEnum):
|
||||
SUCCEEDED = 0x00
|
||||
GENERAL_FAILURE = 0x01
|
||||
CONNECTION_NOT_ALLOWED = 0x02
|
||||
NETWORK_UNREACHABLE = 0x03
|
||||
HOST_UNREACHABLE = 0x04
|
||||
CONNECTION_REFUSED = 0x05
|
||||
TTL_EXPIRED = 0x06
|
||||
COMMAND_NOT_SUPPORTED = 0x07
|
||||
ADDRESS_TYPE_NOT_SUPPORTED = 0x08
|
||||
|
||||
|
||||
ReplyMessages = {
|
||||
ReplyCode.SUCCEEDED: 'Request granted',
|
||||
ReplyCode.GENERAL_FAILURE: 'General SOCKS server failure',
|
||||
ReplyCode.CONNECTION_NOT_ALLOWED: 'Connection not allowed by ruleset',
|
||||
ReplyCode.NETWORK_UNREACHABLE: 'Network unreachable',
|
||||
ReplyCode.HOST_UNREACHABLE: 'Host unreachable',
|
||||
ReplyCode.CONNECTION_REFUSED: 'Connection refused by destination host',
|
||||
ReplyCode.TTL_EXPIRED: 'TTL expired',
|
||||
ReplyCode.COMMAND_NOT_SUPPORTED: 'Command not supported or protocol error',
|
||||
ReplyCode.ADDRESS_TYPE_NOT_SUPPORTED: 'Address type not supported',
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthMethodsRequest:
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
methods: bytearray = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
methods = bytearray([AuthMethod.ANONYMOUS])
|
||||
|
||||
if self.username and self.password:
|
||||
methods.append(AuthMethod.USERNAME_PASSWORD)
|
||||
|
||||
self.methods = methods
|
||||
|
||||
def dumps(self) -> bytes:
|
||||
return bytes([SOCKS_VER, len(self.methods)]) + self.methods
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthMethodReply:
|
||||
SIZE = 2
|
||||
|
||||
ver: int
|
||||
method: AuthMethod
|
||||
|
||||
def validate(self, request: AuthMethodsRequest):
|
||||
if self.method not in request.methods: # pragma: no cover
|
||||
raise ReplyError(f'Unexpected SOCKS authentication method: {self.method}')
|
||||
|
||||
@classmethod
|
||||
def loads(cls, data: bytes) -> 'AuthMethodReply':
|
||||
if len(data) != cls.SIZE:
|
||||
raise ReplyError('Malformed authentication method reply')
|
||||
|
||||
ver = data[0]
|
||||
if ver != SOCKS_VER: # pragma: no cover
|
||||
raise ReplyError(f'Unexpected SOCKS version number: {ver}')
|
||||
|
||||
try:
|
||||
method = AuthMethod(data[1])
|
||||
except ValueError:
|
||||
raise ReplyError(f'Invalid authentication method: {data[1]:#02X}')
|
||||
|
||||
if method == AuthMethod.NO_ACCEPTABLE: # pragma: no cover
|
||||
raise ReplyError('No acceptable authentication methods were offered')
|
||||
|
||||
return cls(ver=ver, method=method)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthRequest:
|
||||
VER = 0x01
|
||||
|
||||
username: str
|
||||
password: str
|
||||
|
||||
def dumps(self) -> bytes:
|
||||
data = bytearray()
|
||||
data.append(self.VER)
|
||||
data.append(len(self.username))
|
||||
data += self.username.encode('ascii')
|
||||
data.append(len(self.password))
|
||||
data += self.password.encode('ascii')
|
||||
return bytes(data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthReply:
|
||||
SIZE = 2
|
||||
|
||||
ver: int
|
||||
status: int
|
||||
|
||||
@classmethod
|
||||
def loads(cls, data: bytes) -> 'AuthReply':
|
||||
if len(data) != cls.SIZE:
|
||||
raise ReplyError('Malformed auth reply')
|
||||
|
||||
ver = data[0]
|
||||
if ver != AuthRequest.VER: # pragma: no cover
|
||||
raise ReplyError('Invalid authentication response')
|
||||
|
||||
status = data[1]
|
||||
if status != AUTH_GRANTED: # pragma: no cover
|
||||
raise ReplyError('Username and password authentication failure')
|
||||
|
||||
return cls(ver=ver, status=status)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectRequest:
|
||||
host: str # hostname or IPv4 or IPv6 address
|
||||
port: int
|
||||
|
||||
def dumps(self) -> bytes:
|
||||
data = bytearray([SOCKS_VER, Command.CONNECT, RSV])
|
||||
data += self._build_addr_request()
|
||||
return bytes(data)
|
||||
|
||||
def _build_addr_request(self) -> bytes:
|
||||
port = self.port.to_bytes(2, 'big')
|
||||
|
||||
if is_ip_address(self.host):
|
||||
ip = ipaddress.ip_address(self.host)
|
||||
address_type = AddressType.from_ip_ver(ip.version)
|
||||
return bytes([address_type]) + ip.packed + port
|
||||
else:
|
||||
address_type = AddressType.DOMAIN
|
||||
host = self.host.encode('idna')
|
||||
return bytes([address_type, len(host)]) + host + port
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectReply:
|
||||
ver: int
|
||||
reply: ReplyCode
|
||||
rsv: int
|
||||
bound_host: str
|
||||
bound_port: int
|
||||
|
||||
def validate(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def loads(cls, data: bytes) -> 'ConnectReply':
|
||||
if not data:
|
||||
raise ReplyError('Empty connect reply')
|
||||
|
||||
ver = data[0]
|
||||
if ver != SOCKS_VER: # pragma: no cover
|
||||
raise ReplyError(f'Unexpected SOCKS version number: {ver:#02X}')
|
||||
|
||||
try:
|
||||
reply = ReplyCode(data[1])
|
||||
except IndexError:
|
||||
raise ReplyError('Malformed connect reply')
|
||||
except ValueError:
|
||||
raise ReplyError(f'Invalid reply code: {data[1]:#02X}')
|
||||
|
||||
if reply != ReplyCode.SUCCEEDED: # pragma: no cover
|
||||
msg = ReplyMessages.get(reply, 'Unknown error') # type: ignore
|
||||
raise ReplyError(msg, error_code=reply)
|
||||
|
||||
try:
|
||||
rsv = data[2]
|
||||
except IndexError:
|
||||
raise ReplyError('Malformed connect reply')
|
||||
|
||||
if rsv != RSV: # pragma: no cover
|
||||
raise ReplyError(f'The reserved byte must be {RSV:#02X}')
|
||||
|
||||
try:
|
||||
addr_type = data[3]
|
||||
bnd_host_data = data[4:-2]
|
||||
bnd_port_data = data[-2:]
|
||||
except IndexError:
|
||||
raise ReplyError('Malformed connect reply')
|
||||
|
||||
if addr_type == AddressType.IPV4:
|
||||
bnd_host = socket.inet_ntop(socket.AF_INET, bnd_host_data)
|
||||
elif addr_type == AddressType.IPV6:
|
||||
bnd_host = socket.inet_ntop(socket.AF_INET6, bnd_host_data)
|
||||
elif addr_type == AddressType.DOMAIN: # pragma: no cover
|
||||
# host_len = bnd_host_data[0]
|
||||
bnd_host = bnd_host_data[1:].decode()
|
||||
else: # pragma: no cover
|
||||
raise ReplyError(f'Invalid address type: {addr_type:#02X}')
|
||||
|
||||
bnd_port = int.from_bytes(bnd_port_data, 'big')
|
||||
|
||||
return cls(
|
||||
ver=ver,
|
||||
reply=reply,
|
||||
rsv=rsv,
|
||||
bound_host=bnd_host,
|
||||
bound_port=bnd_port,
|
||||
)
|
||||
|
||||
|
||||
class StateServerWaitingForAuthMethods:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateClientSentAuthMethods:
|
||||
data: AuthMethodsRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateServerWaitingForAuth:
|
||||
data: AuthMethodReply
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateClientAuthenticated:
|
||||
data: Optional[AuthReply] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateClientSentAuthRequest:
|
||||
data: AuthRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateClientSentConnectRequest:
|
||||
data: ConnectRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateServerConnected:
|
||||
data: ConnectReply
|
||||
|
||||
|
||||
Request = Union[
|
||||
AuthMethodsRequest,
|
||||
AuthRequest,
|
||||
ConnectRequest,
|
||||
]
|
||||
|
||||
Reply = Union[
|
||||
AuthMethodReply,
|
||||
AuthReply,
|
||||
ConnectReply,
|
||||
]
|
||||
|
||||
ConnectionState = Union[
|
||||
StateServerWaitingForAuthMethods,
|
||||
StateClientSentAuthMethods,
|
||||
StateServerWaitingForAuth,
|
||||
StateClientSentAuthRequest,
|
||||
StateClientAuthenticated,
|
||||
StateClientSentConnectRequest,
|
||||
StateServerConnected,
|
||||
]
|
||||
|
||||
|
||||
class Connection:
|
||||
_state: ConnectionState
|
||||
|
||||
def __init__(self):
|
||||
self._state = StateServerWaitingForAuthMethods()
|
||||
|
||||
def send(self, request: Request) -> bytes:
|
||||
if type(request) is AuthMethodsRequest:
|
||||
if type(self._state) is not StateServerWaitingForAuthMethods:
|
||||
raise RuntimeError('Server is not currently waiting for auth methods')
|
||||
self._state = StateClientSentAuthMethods(request)
|
||||
return request.dumps()
|
||||
|
||||
if type(request) is AuthRequest:
|
||||
if type(self._state) is not StateServerWaitingForAuth:
|
||||
raise RuntimeError('Server is not currently waiting for authentication')
|
||||
self._state = StateClientSentAuthRequest(request)
|
||||
return request.dumps()
|
||||
|
||||
if type(request) is ConnectRequest:
|
||||
if type(self._state) is not StateClientAuthenticated:
|
||||
raise RuntimeError('Client is not authenticated')
|
||||
self._state = StateClientSentConnectRequest(request)
|
||||
return request.dumps()
|
||||
|
||||
raise RuntimeError(f'Invalid request type: {type(request)}')
|
||||
|
||||
def receive(self, data: bytes) -> Reply:
|
||||
if type(self._state) is StateClientSentAuthMethods:
|
||||
reply = AuthMethodReply.loads(data)
|
||||
reply.validate(self._state.data)
|
||||
if reply.method == AuthMethod.USERNAME_PASSWORD:
|
||||
self._state = StateServerWaitingForAuth(data=reply)
|
||||
else:
|
||||
self._state = StateClientAuthenticated()
|
||||
return reply
|
||||
|
||||
if type(self._state) is StateClientSentAuthRequest:
|
||||
reply = AuthReply.loads(data)
|
||||
self._state = StateClientAuthenticated(data=reply)
|
||||
return reply
|
||||
|
||||
if type(self._state) is StateClientSentConnectRequest:
|
||||
reply = ConnectReply.loads(data)
|
||||
self._state = StateServerConnected(data=reply)
|
||||
return reply
|
||||
|
||||
raise RuntimeError(f'Invalid connection state: {self._state}')
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._state
|
||||
7
.venv/lib/python3.9/site-packages/python_socks/_types.py
Normal file
7
.venv/lib/python3.9/site-packages/python_socks/_types.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ProxyType(Enum):
|
||||
SOCKS4 = 1
|
||||
SOCKS5 = 2
|
||||
HTTP = 3
|
||||
@@ -0,0 +1,2 @@
|
||||
__title__ = 'python-socks'
|
||||
__version__ = '2.8.1'
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._proxy_chain import ProxyChain
|
||||
|
||||
__all__ = ('ProxyChain',)
|
||||
@@ -0,0 +1,34 @@
|
||||
from typing import Iterable
|
||||
import warnings
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Iterable):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(self, dest_host, dest_port, timeout=None):
|
||||
curr_socket = None
|
||||
proxies = list(self._proxies)
|
||||
|
||||
length = len(proxies) - 1
|
||||
for i in range(length):
|
||||
curr_socket = await proxies[i].connect(
|
||||
dest_host=proxies[i + 1].proxy_host,
|
||||
dest_port=proxies[i + 1].proxy_port,
|
||||
timeout=timeout,
|
||||
_socket=curr_socket,
|
||||
)
|
||||
|
||||
curr_socket = await proxies[length].connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
timeout=timeout,
|
||||
_socket=curr_socket,
|
||||
)
|
||||
|
||||
return curr_socket
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._proxy import AnyioProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
__all__ = ('Proxy', 'ProxyChain')
|
||||
@@ -0,0 +1,42 @@
|
||||
from typing import Iterable
|
||||
import warnings
|
||||
from ._proxy import AnyioProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Iterable[AnyioProxy]):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host,
|
||||
dest_port,
|
||||
dest_ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
_stream = None
|
||||
proxies = list(self._proxies)
|
||||
|
||||
length = len(proxies) - 1
|
||||
for i in range(length):
|
||||
_stream = await proxies[i].connect(
|
||||
dest_host=proxies[i + 1].proxy_host,
|
||||
dest_port=proxies[i + 1].proxy_port,
|
||||
timeout=timeout,
|
||||
_stream=_stream,
|
||||
)
|
||||
|
||||
_stream = await proxies[length].connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
timeout=timeout,
|
||||
_stream=_stream,
|
||||
)
|
||||
|
||||
return _stream
|
||||
@@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
import anyio
|
||||
import anyio.abc
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_host: Optional[str] = None,
|
||||
) -> anyio.abc.SocketStream:
|
||||
|
||||
return await anyio.connect_tcp(
|
||||
remote_host=host,
|
||||
remote_port=port,
|
||||
local_host=local_host,
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
import anyio
|
||||
|
||||
from ..._types import ProxyType
|
||||
from ..._helpers import parse_proxy_url
|
||||
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ._resolver import Resolver
|
||||
from ._stream import AnyioSocketStream
|
||||
from ._connect import connect_tcp
|
||||
|
||||
from ..._protocols.errors import ReplyError
|
||||
from ..._connectors.factory_async import create_connector
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class AnyioProxy:
|
||||
_stream: Optional[AnyioSocketStream]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
proxy_ssl: Optional[ssl.SSLContext] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._proxy_ssl = proxy_ssl
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> AnyioSocketStream:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
_stream = kwargs.get('_stream')
|
||||
if _stream is not None:
|
||||
warnings.warn(
|
||||
"The '_stream' argument is deprecated and will be removed in the future",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
local_host = kwargs.get('local_host')
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
if _stream is None:
|
||||
try:
|
||||
_stream = AnyioSocketStream(
|
||||
await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_host=local_host,
|
||||
)
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
|
||||
stream = _stream
|
||||
|
||||
try:
|
||||
if self._proxy_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=self._proxy_host,
|
||||
ssl_context=self._proxy_ssl,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
if dest_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=dest_host,
|
||||
ssl_context=dest_ssl,
|
||||
)
|
||||
|
||||
return stream
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException:
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
except TimeoutError as e:
|
||||
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
|
||||
|
||||
@property
|
||||
def proxy_host(self):
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self):
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'AnyioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,22 @@
|
||||
import anyio
|
||||
import socket
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
|
||||
class Resolver(abc.AsyncResolver):
|
||||
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
|
||||
infos = await anyio.getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
if not infos: # pragma: no cover
|
||||
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
@@ -0,0 +1,59 @@
|
||||
import ssl
|
||||
from typing import Union
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
from anyio.streams.tls import TLSStream
|
||||
|
||||
from ..._errors import ProxyError
|
||||
from ... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
AnyioStreamType = Union[anyio.abc.SocketStream, TLSStream]
|
||||
|
||||
|
||||
class AnyioSocketStream(abc.AsyncSocketStream):
|
||||
_stream: AnyioStreamType
|
||||
|
||||
def __init__(self, stream: AnyioStreamType) -> None:
|
||||
self._stream = stream
|
||||
|
||||
async def write_all(self, data: bytes):
|
||||
await self._stream.send(item=data)
|
||||
|
||||
async def read(self, max_bytes: int = DEFAULT_RECEIVE_SIZE):
|
||||
try:
|
||||
return await self._stream.receive(max_bytes=max_bytes)
|
||||
except anyio.EndOfStream: # pragma: no cover
|
||||
return b""
|
||||
|
||||
async def read_exact(self, n: int):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self.read(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
) -> 'AnyioSocketStream':
|
||||
ssl_stream = await TLSStream.wrap(
|
||||
self._stream,
|
||||
ssl_context=ssl_context,
|
||||
hostname=hostname,
|
||||
standard_compatible=False,
|
||||
server_side=False,
|
||||
)
|
||||
return AnyioSocketStream(ssl_stream)
|
||||
|
||||
async def close(self):
|
||||
await self._stream.aclose()
|
||||
|
||||
@property
|
||||
def anyio_stream(self) -> AnyioStreamType: # pragma: no cover
|
||||
return self._stream
|
||||
@@ -0,0 +1,7 @@
|
||||
from ._proxy import AnyioProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
__all__ = (
|
||||
'Proxy',
|
||||
'ProxyChain',
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Sequence
|
||||
import warnings
|
||||
from ._proxy import AnyioProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Sequence[AnyioProxy]):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host,
|
||||
dest_port,
|
||||
dest_ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
forward = None
|
||||
for proxy in self._proxies:
|
||||
proxy._forward = forward
|
||||
forward = proxy
|
||||
|
||||
return await forward.connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
import anyio
|
||||
import anyio.abc
|
||||
from ._stream import AnyioSocketStream
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_host: Optional[str] = None,
|
||||
) -> AnyioSocketStream:
|
||||
s = await anyio.connect_tcp(
|
||||
remote_host=host,
|
||||
remote_port=port,
|
||||
local_host=local_host,
|
||||
)
|
||||
return AnyioSocketStream(s)
|
||||
@@ -0,0 +1,135 @@
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
|
||||
import anyio
|
||||
|
||||
from ._connect import connect_tcp
|
||||
from ._stream import AnyioSocketStream
|
||||
from .._resolver import Resolver
|
||||
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ...._types import ProxyType
|
||||
from ...._helpers import parse_proxy_url
|
||||
|
||||
from ...._protocols.errors import ReplyError
|
||||
from ...._connectors.factory_async import create_connector
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class AnyioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
proxy_ssl: Optional[ssl.SSLContext] = None,
|
||||
forward: Optional['AnyioProxy'] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._rdns = rdns
|
||||
|
||||
self._proxy_ssl = proxy_ssl
|
||||
self._forward = forward
|
||||
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> AnyioSocketStream:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
local_host = kwargs.get('local_host')
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
local_host=local_host,
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
local_host: Optional[str] = None,
|
||||
) -> AnyioSocketStream:
|
||||
if self._forward is None:
|
||||
try:
|
||||
stream = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_host=local_host,
|
||||
)
|
||||
except OSError as e:
|
||||
raise ProxyConnectionError(
|
||||
e.errno,
|
||||
"Couldn't connect to proxy"
|
||||
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
|
||||
) from e
|
||||
else:
|
||||
stream = await self._forward.connect(
|
||||
dest_host=self._proxy_host,
|
||||
dest_port=self._proxy_port,
|
||||
)
|
||||
|
||||
try:
|
||||
if self._proxy_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=self._proxy_host,
|
||||
ssl_context=self._proxy_ssl,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
if dest_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=dest_host,
|
||||
ssl_context=dest_ssl,
|
||||
)
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException:
|
||||
with anyio.CancelScope(shield=True):
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
return stream
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'AnyioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,59 @@
|
||||
import ssl
|
||||
from typing import Union
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
from anyio.streams.tls import TLSStream
|
||||
|
||||
from ...._errors import ProxyError
|
||||
from .... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
AnyioStreamType = Union[anyio.abc.SocketStream, TLSStream]
|
||||
|
||||
|
||||
class AnyioSocketStream(abc.AsyncSocketStream):
|
||||
_stream: AnyioStreamType
|
||||
|
||||
def __init__(self, stream: AnyioStreamType) -> None:
|
||||
self._stream = stream
|
||||
|
||||
async def write_all(self, data: bytes):
|
||||
await self._stream.send(item=data)
|
||||
|
||||
async def read(self, max_bytes: int = DEFAULT_RECEIVE_SIZE):
|
||||
try:
|
||||
return await self._stream.receive(max_bytes=max_bytes)
|
||||
except anyio.EndOfStream: # pragma: no cover
|
||||
return b""
|
||||
|
||||
async def read_exact(self, n: int):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self.read(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
) -> 'AnyioSocketStream':
|
||||
ssl_stream = await TLSStream.wrap(
|
||||
self._stream,
|
||||
ssl_context=ssl_context,
|
||||
hostname=hostname,
|
||||
standard_compatible=False,
|
||||
server_side=False,
|
||||
)
|
||||
return AnyioSocketStream(ssl_stream)
|
||||
|
||||
async def close(self):
|
||||
await self._stream.aclose()
|
||||
|
||||
@property
|
||||
def anyio_stream(self) -> AnyioStreamType: # pragma: no cover
|
||||
return self._stream
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._proxy import AsyncioProxy as Proxy
|
||||
|
||||
|
||||
__all__ = ('Proxy',)
|
||||
@@ -0,0 +1,43 @@
|
||||
import socket
|
||||
import asyncio
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ._resolver import Resolver
|
||||
from ..._helpers import is_ipv4_address, is_ipv6_address
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> socket.socket:
|
||||
|
||||
family, host = await _resolve_host(host, loop)
|
||||
|
||||
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
|
||||
sock.setblocking(False)
|
||||
if local_addr is not None: # pragma: no cover
|
||||
sock.bind(local_addr)
|
||||
|
||||
if is_ipv6_address(host):
|
||||
address = (host, port, 0, 0) # to fix OSError: [WinError 10022]
|
||||
else:
|
||||
address = (host, port) # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
await loop.sock_connect(sock=sock, address=address)
|
||||
except OSError:
|
||||
sock.close()
|
||||
raise
|
||||
return sock
|
||||
|
||||
|
||||
async def _resolve_host(host, loop):
|
||||
if is_ipv4_address(host):
|
||||
return socket.AF_INET, host
|
||||
if is_ipv6_address(host):
|
||||
return socket.AF_INET6, host
|
||||
|
||||
resolver = Resolver(loop=loop)
|
||||
return await resolver.resolve(host=host)
|
||||
@@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
import socket
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
from ..._types import ProxyType
|
||||
from ..._helpers import parse_proxy_url
|
||||
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
from ._stream import AsyncioSocketStream
|
||||
from ._resolver import Resolver
|
||||
|
||||
from ..._protocols.errors import ReplyError
|
||||
from ..._connectors.factory_async import create_connector
|
||||
|
||||
from ._connect import connect_tcp
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import asyncio as async_timeout # pylint:disable=reimported
|
||||
else:
|
||||
import async_timeout
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class AsyncioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
):
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self._loop = loop
|
||||
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._resolver = Resolver(loop=loop)
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> socket.socket:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
_socket = kwargs.get('_socket')
|
||||
if _socket is not None:
|
||||
warnings.warn(
|
||||
"The '_socket' argument is deprecated and will be removed in the future",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
async with async_timeout.timeout(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
_socket=_socket,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host,
|
||||
dest_port,
|
||||
_socket=None,
|
||||
local_addr=None,
|
||||
) -> socket.socket:
|
||||
if _socket is None:
|
||||
try:
|
||||
_socket = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
loop=self._loop,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
|
||||
stream = AsyncioSocketStream(sock=_socket, loop=self._loop)
|
||||
|
||||
try:
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
return _socket
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except (asyncio.CancelledError, Exception): # pragma: no cover
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
@property
|
||||
def proxy_host(self):
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self):
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'AsyncioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,25 @@
|
||||
import asyncio
|
||||
import socket
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
|
||||
class Resolver(abc.AsyncResolver):
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
self._loop = loop
|
||||
|
||||
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
|
||||
infos = await self._loop.getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
if not infos: # pragma: no cover
|
||||
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
@@ -0,0 +1,36 @@
|
||||
import asyncio
|
||||
import socket
|
||||
|
||||
from ..._errors import ProxyError
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
|
||||
class AsyncioSocketStream(abc.AsyncSocketStream):
|
||||
_loop: asyncio.AbstractEventLoop = None
|
||||
_socket = None
|
||||
|
||||
def __init__(self, sock: socket.socket, loop: asyncio.AbstractEventLoop):
|
||||
self._loop = loop
|
||||
self._socket = sock
|
||||
|
||||
async def write_all(self, data):
|
||||
await self._loop.sock_sendall(self._socket, data)
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._loop.sock_recv(self._socket, max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self._loop.sock_recv(self._socket, n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def close(self):
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._proxy import AsyncioProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
__all__ = ('Proxy', 'ProxyChain')
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Sequence
|
||||
import warnings
|
||||
from ._proxy import AsyncioProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Sequence[AsyncioProxy]):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
forward = None
|
||||
for proxy in self._proxies:
|
||||
proxy._forward = forward
|
||||
forward = proxy
|
||||
|
||||
return await forward.connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
import asyncio
|
||||
from typing import Optional, Tuple
|
||||
from ._stream import AsyncioSocketStream
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> AsyncioSocketStream:
|
||||
kwargs = {}
|
||||
if local_addr is not None:
|
||||
kwargs['local_addr'] = local_addr # pragma: no cover
|
||||
|
||||
reader, writer = await asyncio.open_connection(
|
||||
host=host,
|
||||
port=port,
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
return AsyncioSocketStream(
|
||||
loop=loop,
|
||||
reader=reader,
|
||||
writer=writer,
|
||||
)
|
||||
@@ -0,0 +1,157 @@
|
||||
import asyncio
|
||||
import ssl
|
||||
from typing import Any, Optional, Tuple
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
|
||||
from ...._types import ProxyType
|
||||
from ...._helpers import parse_proxy_url
|
||||
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ...._protocols.errors import ReplyError
|
||||
from ...._connectors.factory_async import create_connector
|
||||
|
||||
from .._resolver import Resolver
|
||||
from ._stream import AsyncioSocketStream
|
||||
from ._connect import connect_tcp
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import asyncio as async_timeout # pylint:disable=reimported
|
||||
else:
|
||||
import async_timeout
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class AsyncioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
proxy_ssl: Optional[ssl.SSLContext] = None,
|
||||
forward: Optional['AsyncioProxy'] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
):
|
||||
if loop is not None: # pragma: no cover
|
||||
warnings.warn(
|
||||
'The loop argument is deprecated and scheduled for removal in the future.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self._loop = loop
|
||||
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._rdns = rdns
|
||||
|
||||
self._proxy_ssl = proxy_ssl
|
||||
self._forward = forward
|
||||
|
||||
self._resolver = Resolver(loop=loop)
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncioSocketStream:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
async with async_timeout.timeout(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> AsyncioSocketStream:
|
||||
if self._forward is None:
|
||||
try:
|
||||
stream = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
loop=self._loop,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
raise ProxyConnectionError(
|
||||
e.errno,
|
||||
"Couldn't connect to proxy"
|
||||
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
|
||||
) from e
|
||||
else:
|
||||
stream = await self._forward.connect(
|
||||
dest_host=self._proxy_host,
|
||||
dest_port=self._proxy_port,
|
||||
)
|
||||
|
||||
try:
|
||||
if self._proxy_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=self._proxy_host,
|
||||
ssl_context=self._proxy_ssl,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
if dest_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=dest_host,
|
||||
ssl_context=dest_ssl,
|
||||
)
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except (asyncio.CancelledError, Exception):
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
return stream
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'AsyncioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,91 @@
|
||||
import asyncio
|
||||
import ssl
|
||||
|
||||
from .... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
|
||||
class AsyncioSocketStream(abc.AsyncSocketStream):
|
||||
_loop: asyncio.AbstractEventLoop
|
||||
_reader: asyncio.StreamReader
|
||||
_writer: asyncio.StreamWriter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
):
|
||||
self._loop = loop
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
|
||||
async def write_all(self, data):
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._reader.read(max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
return await self._reader.readexactly(n)
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
ssl_handshake_timeout=None,
|
||||
) -> 'AsyncioSocketStream':
|
||||
if hasattr(self._writer, 'start_tls'): # Python>=3.11
|
||||
await self._writer.start_tls(
|
||||
ssl_context,
|
||||
server_hostname=hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
return self
|
||||
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
|
||||
transport: asyncio.Transport = await self._loop.start_tls(
|
||||
self._writer.transport, # type: ignore
|
||||
protocol,
|
||||
ssl_context,
|
||||
server_side=False,
|
||||
server_hostname=hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
|
||||
# reader.set_transport(transport)
|
||||
|
||||
# Initialize the protocol, so it is made aware of being tied to
|
||||
# a TLS connection.
|
||||
# See: https://github.com/encode/httpx/issues/859
|
||||
protocol.connection_made(transport)
|
||||
|
||||
writer = asyncio.StreamWriter(
|
||||
transport=transport,
|
||||
protocol=protocol,
|
||||
reader=reader,
|
||||
loop=self._loop,
|
||||
)
|
||||
|
||||
stream = AsyncioSocketStream(loop=self._loop, reader=reader, writer=writer)
|
||||
# When we return a new SocketStream with new StreamReader/StreamWriter instances
|
||||
# we need to keep references to the old StreamReader/StreamWriter so that they
|
||||
# are not garbage collected and closed while we're still using them.
|
||||
stream._inner = self # type: ignore # pylint:disable=W0212,W0201
|
||||
return stream
|
||||
|
||||
async def close(self):
|
||||
self._writer.close()
|
||||
self._writer.transport.abort() # noqa
|
||||
|
||||
@property
|
||||
def reader(self):
|
||||
return self._reader # pragma: no cover
|
||||
|
||||
@property
|
||||
def writer(self):
|
||||
return self._writer # pragma: no cover
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._proxy import CurioProxy as Proxy
|
||||
|
||||
|
||||
__all__ = ('Proxy',)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import curio
|
||||
import curio.io
|
||||
import curio.socket
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> curio.io.Socket:
|
||||
return await curio.open_connection(
|
||||
host=host,
|
||||
port=port,
|
||||
source_addr=local_addr,
|
||||
)
|
||||
@@ -0,0 +1,132 @@
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
import curio
|
||||
import curio.io
|
||||
|
||||
from ..._types import ProxyType
|
||||
from ..._helpers import parse_proxy_url
|
||||
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ._stream import CurioSocketStream
|
||||
from ._resolver import Resolver
|
||||
from ._connect import connect_tcp
|
||||
|
||||
from ..._protocols.errors import ReplyError
|
||||
from ..._connectors.factory_async import create_connector
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class CurioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> curio.io.Socket:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
_socket = kwargs.get('_socket')
|
||||
if _socket is not None:
|
||||
warnings.warn(
|
||||
"The '_socket' argument is deprecated and will be removed in the future",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
return await curio.timeout_after(
|
||||
timeout,
|
||||
self._connect,
|
||||
dest_host,
|
||||
dest_port,
|
||||
_socket,
|
||||
local_addr,
|
||||
)
|
||||
except curio.TaskTimeout as e:
|
||||
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
_socket=None,
|
||||
local_addr=None,
|
||||
):
|
||||
if _socket is None:
|
||||
try:
|
||||
_socket = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
|
||||
stream = CurioSocketStream(_socket)
|
||||
|
||||
try:
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
return _socket
|
||||
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException:
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
@property
|
||||
def proxy_host(self):
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self):
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'CurioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,25 @@
|
||||
import socket
|
||||
from curio.socket import getaddrinfo
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
|
||||
class Resolver(abc.AsyncResolver):
|
||||
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
|
||||
try:
|
||||
infos = await getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
except socket.gaierror: # pragma: no cover
|
||||
infos = None
|
||||
|
||||
if not infos: # pragma: no cover
|
||||
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
@@ -0,0 +1,32 @@
|
||||
import curio.io
|
||||
import curio.socket
|
||||
|
||||
from ... import _abc as abc
|
||||
from ..._errors import ProxyError
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
|
||||
class CurioSocketStream(abc.AsyncSocketStream):
|
||||
_socket: curio.io.Socket = None
|
||||
|
||||
def __init__(self, sock: curio.io.Socket):
|
||||
self._socket = sock
|
||||
|
||||
async def write_all(self, data):
|
||||
await self._socket.sendall(data)
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._socket.recv(max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self._socket.recv(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def close(self):
|
||||
await self._socket.close()
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._proxy import TrioProxy as Proxy
|
||||
|
||||
__all__ = ('Proxy',)
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import trio
|
||||
|
||||
from ._resolver import Resolver
|
||||
from ..._helpers import is_ipv4_address, is_ipv6_address
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> trio.socket.SocketType:
|
||||
|
||||
family, host = await _resolve_host(host)
|
||||
|
||||
sock = trio.socket.socket(family=family, type=trio.socket.SOCK_STREAM)
|
||||
if local_addr is not None: # pragma: no cover
|
||||
await sock.bind(local_addr)
|
||||
|
||||
try:
|
||||
await sock.connect((host, port))
|
||||
except OSError:
|
||||
sock.close()
|
||||
raise
|
||||
return sock
|
||||
|
||||
|
||||
async def _resolve_host(host):
|
||||
if is_ipv4_address(host):
|
||||
return trio.socket.AF_INET, host
|
||||
if is_ipv6_address(host):
|
||||
return trio.socket.AF_INET6, host
|
||||
|
||||
resolver = Resolver()
|
||||
return await resolver.resolve(host=host)
|
||||
@@ -0,0 +1,131 @@
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
import trio
|
||||
|
||||
from ..._types import ProxyType
|
||||
from ..._helpers import parse_proxy_url
|
||||
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ._stream import TrioSocketStream
|
||||
from ._resolver import Resolver
|
||||
from ._connect import connect_tcp
|
||||
|
||||
from ..._protocols.errors import ReplyError
|
||||
from ..._connectors.factory_async import create_connector
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class TrioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> trio.socket.SocketType:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
_socket = kwargs.get('_socket')
|
||||
if _socket is not None:
|
||||
warnings.warn(
|
||||
"The '_socket' argument is deprecated and will be removed in the future",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
with trio.fail_after(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
_socket=_socket,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except trio.TooSlowError as e:
|
||||
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
_socket=None,
|
||||
local_addr=None,
|
||||
) -> trio.socket.SocketType:
|
||||
if _socket is None:
|
||||
try:
|
||||
_socket = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
|
||||
stream = TrioSocketStream(sock=_socket)
|
||||
|
||||
try:
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
return _socket
|
||||
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException: # trio.Cancelled...
|
||||
with trio.CancelScope(shield=True):
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
@property
|
||||
def proxy_host(self):
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self):
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'TrioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,21 @@
|
||||
import trio
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
|
||||
class Resolver(abc.AsyncResolver):
|
||||
async def resolve(self, host, port=0, family=trio.socket.AF_UNSPEC):
|
||||
infos = await trio.socket.getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=trio.socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
if not infos: # pragma: no cover
|
||||
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
@@ -0,0 +1,35 @@
|
||||
import trio
|
||||
|
||||
from ..._errors import ProxyError
|
||||
from ... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
|
||||
class TrioSocketStream(abc.AsyncSocketStream):
|
||||
def __init__(self, sock):
|
||||
self._socket = sock
|
||||
|
||||
async def write_all(self, data):
|
||||
total_sent = 0
|
||||
while total_sent < len(data):
|
||||
remaining = data[total_sent:]
|
||||
sent = await self._socket.send(remaining)
|
||||
total_sent += sent
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._socket.recv(max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self._socket.recv(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def close(self):
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
await trio.lowlevel.checkpoint()
|
||||
@@ -0,0 +1,7 @@
|
||||
from ._proxy import TrioProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
__all__ = (
|
||||
'Proxy',
|
||||
'ProxyChain',
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Sequence
|
||||
import warnings
|
||||
from ._proxy import TrioProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Sequence[TrioProxy]):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host,
|
||||
dest_port,
|
||||
dest_ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
forward = None
|
||||
for proxy in self._proxies:
|
||||
proxy._forward = forward
|
||||
forward = proxy
|
||||
|
||||
return await forward.connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
import trio
|
||||
from ._stream import TrioSocketStream
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_addr: Optional[str] = None,
|
||||
) -> TrioSocketStream:
|
||||
trio_stream = await trio.open_tcp_stream(
|
||||
host=host,
|
||||
port=port,
|
||||
local_address=local_addr,
|
||||
)
|
||||
return TrioSocketStream(trio_stream)
|
||||
@@ -0,0 +1,135 @@
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
|
||||
import trio
|
||||
|
||||
from ._connect import connect_tcp
|
||||
from ._stream import TrioSocketStream
|
||||
from .._resolver import Resolver
|
||||
|
||||
from ...._types import ProxyType
|
||||
from ...._helpers import parse_proxy_url
|
||||
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ...._protocols.errors import ReplyError
|
||||
from ...._connectors.factory_async import create_connector
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class TrioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
proxy_ssl: Optional[ssl.SSLContext] = None,
|
||||
forward: Optional['TrioProxy'] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._rdns = rdns
|
||||
|
||||
self._proxy_ssl = proxy_ssl
|
||||
self._forward = forward
|
||||
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> TrioSocketStream:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
with trio.fail_after(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except trio.TooSlowError as e:
|
||||
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
local_addr: Optional[str] = None,
|
||||
) -> TrioSocketStream:
|
||||
if self._forward is None:
|
||||
try:
|
||||
stream = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
raise ProxyConnectionError(
|
||||
e.errno,
|
||||
"Couldn't connect to proxy"
|
||||
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
|
||||
) from e
|
||||
else:
|
||||
stream = await self._forward.connect(
|
||||
dest_host=self._proxy_host,
|
||||
dest_port=self._proxy_port,
|
||||
)
|
||||
|
||||
try:
|
||||
if self._proxy_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=self._proxy_host,
|
||||
ssl_context=self._proxy_ssl,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
if dest_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=dest_host,
|
||||
ssl_context=dest_ssl,
|
||||
)
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException: # trio.Cancelled...
|
||||
with trio.CancelScope(shield=True):
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
return stream
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'TrioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,55 @@
|
||||
import ssl
|
||||
from typing import Union
|
||||
|
||||
import trio
|
||||
|
||||
from ...._errors import ProxyError
|
||||
from .... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
TrioStreamType = Union[trio.SocketStream, trio.SSLStream]
|
||||
|
||||
|
||||
class TrioSocketStream(abc.AsyncSocketStream):
|
||||
_stream: TrioStreamType
|
||||
|
||||
def __init__(self, stream: TrioStreamType):
|
||||
self._stream = stream
|
||||
|
||||
async def write_all(self, data):
|
||||
await self._stream.send_all(data)
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._stream.receive_some(max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self._stream.receive_some(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
) -> 'TrioSocketStream':
|
||||
ssl_stream = trio.SSLStream(
|
||||
self._stream,
|
||||
ssl_context=ssl_context,
|
||||
server_hostname=hostname,
|
||||
https_compatible=True,
|
||||
server_side=False,
|
||||
)
|
||||
await ssl_stream.do_handshake()
|
||||
return TrioSocketStream(ssl_stream)
|
||||
|
||||
async def close(self):
|
||||
await self._stream.aclose()
|
||||
|
||||
@property
|
||||
def trio_stream(self) -> TrioStreamType: # pragma: nocover
|
||||
return self._stream
|
||||
@@ -0,0 +1,5 @@
|
||||
from ._proxy import SyncProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
|
||||
__all__ = ('Proxy', 'ProxyChain')
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Iterable
|
||||
import warnings
|
||||
from ._proxy import SyncProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Iterable[SyncProxy]):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
def connect(self, dest_host, dest_port, timeout=None):
|
||||
curr_socket = None
|
||||
proxies = list(self._proxies)
|
||||
|
||||
length = len(proxies) - 1
|
||||
for i in range(length):
|
||||
curr_socket = proxies[i].connect(
|
||||
dest_host=proxies[i + 1].proxy_host,
|
||||
dest_port=proxies[i + 1].proxy_port,
|
||||
timeout=timeout,
|
||||
_socket=curr_socket,
|
||||
)
|
||||
|
||||
curr_socket = proxies[length].connect(
|
||||
dest_host=dest_host, dest_port=dest_port, timeout=timeout, _socket=curr_socket
|
||||
)
|
||||
|
||||
return curr_socket
|
||||
@@ -0,0 +1,16 @@
|
||||
import socket
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: Optional[float] = None,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> socket.socket:
|
||||
address = (host, port)
|
||||
return socket.create_connection(
|
||||
address,
|
||||
timeout,
|
||||
source_address=local_addr,
|
||||
)
|
||||
116
.venv/lib/python3.9/site-packages/python_socks/sync/_proxy.py
Normal file
116
.venv/lib/python3.9/site-packages/python_socks/sync/_proxy.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import socket
|
||||
from typing import Optional, Any
|
||||
import warnings
|
||||
|
||||
from .._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from .._types import ProxyType
|
||||
from .._helpers import parse_proxy_url
|
||||
from .._protocols.errors import ReplyError
|
||||
from .._connectors.factory_sync import create_connector
|
||||
|
||||
from ._stream import SyncSocketStream
|
||||
from ._resolver import SyncResolver
|
||||
from ._connect import connect_tcp
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class SyncProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._resolver = SyncResolver()
|
||||
|
||||
def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> socket.socket:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
_socket = kwargs.get('_socket')
|
||||
if _socket is not None:
|
||||
warnings.warn(
|
||||
"The '_socket' argument is deprecated and will be removed in the future",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if _socket is None:
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
_socket = connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
timeout=timeout,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
|
||||
stream = SyncSocketStream(_socket)
|
||||
|
||||
try:
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
return _socket
|
||||
except socket.timeout as e:
|
||||
stream.close()
|
||||
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
|
||||
except ReplyError as e:
|
||||
stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except Exception:
|
||||
stream.close()
|
||||
raise
|
||||
|
||||
@property
|
||||
def proxy_host(self):
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self):
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'SyncProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,16 @@
|
||||
import socket
|
||||
from .. import _abc as abc
|
||||
|
||||
|
||||
class SyncResolver(abc.SyncResolver):
|
||||
# noinspection PyMethodMayBeStatic
|
||||
def resolve(self, host, port=0, family=socket.AF_UNSPEC):
|
||||
infos = socket.getaddrinfo(host=host, port=port, family=family, type=socket.SOCK_STREAM)
|
||||
|
||||
if not infos: # pragma: no cover
|
||||
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
@@ -0,0 +1,32 @@
|
||||
import socket
|
||||
|
||||
from .._errors import ProxyError
|
||||
from .. import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
|
||||
class SyncSocketStream(abc.SyncSocketStream):
|
||||
_socket: socket.socket
|
||||
|
||||
def __init__(self, sock: socket.socket):
|
||||
self._socket = sock
|
||||
|
||||
def write_all(self, data):
|
||||
self._socket.sendall(data)
|
||||
|
||||
def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return self._socket.recv(max_bytes)
|
||||
|
||||
def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = self._socket.recv(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
def close(self):
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
@@ -0,0 +1,7 @@
|
||||
from ._proxy import SyncProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
__all__ = (
|
||||
'Proxy',
|
||||
'ProxyChain',
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
from typing import Iterable
|
||||
from ._proxy import SyncProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Iterable[SyncProxy]):
|
||||
self._proxies = proxies
|
||||
|
||||
def connect(
|
||||
self,
|
||||
dest_host,
|
||||
dest_port,
|
||||
dest_ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
forward = None
|
||||
for proxy in self._proxies:
|
||||
proxy._forward = forward
|
||||
forward = proxy
|
||||
|
||||
return forward.connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
import socket
|
||||
from typing import Optional, Tuple
|
||||
from ._stream import SyncSocketStream
|
||||
|
||||
|
||||
def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: Optional[float] = None,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> SyncSocketStream:
|
||||
address = (host, port)
|
||||
sock = socket.create_connection(
|
||||
address,
|
||||
timeout,
|
||||
source_address=local_addr,
|
||||
)
|
||||
|
||||
return SyncSocketStream(sock)
|
||||
121
.venv/lib/python3.9/site-packages/python_socks/sync/v2/_proxy.py
Normal file
121
.venv/lib/python3.9/site-packages/python_socks/sync/v2/_proxy.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import socket
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
|
||||
from ._connect import connect_tcp
|
||||
from ._stream import SyncSocketStream
|
||||
from .._resolver import SyncResolver
|
||||
from ..._types import ProxyType
|
||||
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
from ..._helpers import parse_proxy_url
|
||||
|
||||
from ..._protocols.errors import ReplyError
|
||||
from ..._connectors.factory_sync import create_connector
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class SyncProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
proxy_ssl: Optional[ssl.SSLContext] = None,
|
||||
forward: Optional['SyncProxy'] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._rdns = rdns
|
||||
self._proxy_ssl = proxy_ssl
|
||||
self._forward = forward
|
||||
|
||||
self._resolver = SyncResolver()
|
||||
|
||||
def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> SyncSocketStream:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
if self._forward is None:
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
stream = connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
timeout=timeout,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
else:
|
||||
stream = self._forward.connect(
|
||||
dest_host=self._proxy_host,
|
||||
dest_port=self._proxy_port,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
if self._proxy_ssl is not None:
|
||||
stream = stream.start_tls(
|
||||
hostname=self._proxy_host,
|
||||
ssl_context=self._proxy_ssl,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
if dest_ssl is not None:
|
||||
stream = stream.start_tls(
|
||||
hostname=dest_host,
|
||||
ssl_context=dest_ssl,
|
||||
)
|
||||
|
||||
return stream
|
||||
|
||||
except socket.timeout as e:
|
||||
stream.close()
|
||||
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
|
||||
except ReplyError as e:
|
||||
stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except Exception:
|
||||
stream.close()
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'SyncProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Copied from urllib3.util.ssltransport
|
||||
"""
|
||||
import io
|
||||
import socket
|
||||
import ssl
|
||||
|
||||
|
||||
SSL_BLOCKSIZE = 16384
|
||||
|
||||
|
||||
class SSLTransport:
|
||||
"""
|
||||
The SSLTransport wraps an existing socket and establishes an SSL connection.
|
||||
|
||||
Contrary to Python's implementation of SSLSocket, it allows you to chain
|
||||
multiple TLS connections together. It's particularly useful if you need to
|
||||
implement TLS within TLS.
|
||||
|
||||
The class supports most of the socket API operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
|
||||
):
|
||||
"""
|
||||
Create an SSLTransport around socket using the provided ssl_context.
|
||||
"""
|
||||
self.incoming = ssl.MemoryBIO()
|
||||
self.outgoing = ssl.MemoryBIO()
|
||||
|
||||
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||||
self.socket = socket
|
||||
|
||||
self.sslobj = ssl_context.wrap_bio(
|
||||
self.incoming, self.outgoing, server_hostname=server_hostname
|
||||
)
|
||||
|
||||
# Perform initial handshake.
|
||||
self._ssl_io_loop(self.sslobj.do_handshake)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
self.close()
|
||||
|
||||
def fileno(self):
|
||||
return self.socket.fileno()
|
||||
|
||||
def read(self, len=1024, buffer=None):
|
||||
return self._wrap_ssl_read(len, buffer)
|
||||
|
||||
def recv(self, len=1024, flags=0):
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to recv")
|
||||
return self._wrap_ssl_read(len)
|
||||
|
||||
def recv_into(self, buffer, nbytes=None, flags=0):
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to recv_into")
|
||||
if buffer and (nbytes is None):
|
||||
nbytes = len(buffer)
|
||||
elif nbytes is None:
|
||||
nbytes = 1024
|
||||
return self.read(nbytes, buffer)
|
||||
|
||||
def sendall(self, data, flags=0):
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to sendall")
|
||||
count = 0
|
||||
with memoryview(data) as view, view.cast("B") as byte_view:
|
||||
amount = len(byte_view)
|
||||
while count < amount:
|
||||
v = self.send(byte_view[count:])
|
||||
count += v
|
||||
|
||||
def send(self, data, flags=0):
|
||||
if flags != 0:
|
||||
raise ValueError("non-zero flags not allowed in calls to send")
|
||||
response = self._ssl_io_loop(self.sslobj.write, data)
|
||||
return response
|
||||
|
||||
def makefile(
|
||||
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
|
||||
):
|
||||
"""
|
||||
Python's httpclient uses makefile and buffered io when reading HTTP
|
||||
messages and we need to support it.
|
||||
|
||||
This is unfortunately a copy and paste of socket.py makefile with small
|
||||
changes to point to the socket directly.
|
||||
"""
|
||||
if not set(mode) <= {"r", "w", "b"}:
|
||||
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
|
||||
|
||||
writing = "w" in mode
|
||||
reading = "r" in mode or not writing
|
||||
assert reading or writing
|
||||
binary = "b" in mode
|
||||
rawmode = ""
|
||||
if reading:
|
||||
rawmode += "r"
|
||||
if writing:
|
||||
rawmode += "w"
|
||||
raw = socket.SocketIO(self, rawmode)
|
||||
self.socket._io_refs += 1
|
||||
if buffering is None:
|
||||
buffering = -1
|
||||
if buffering < 0:
|
||||
buffering = io.DEFAULT_BUFFER_SIZE
|
||||
if buffering == 0:
|
||||
if not binary:
|
||||
raise ValueError("unbuffered streams must be binary")
|
||||
return raw
|
||||
if reading and writing:
|
||||
buffer = io.BufferedRWPair(raw, raw, buffering)
|
||||
elif reading:
|
||||
buffer = io.BufferedReader(raw, buffering)
|
||||
else:
|
||||
assert writing
|
||||
buffer = io.BufferedWriter(raw, buffering)
|
||||
if binary:
|
||||
return buffer
|
||||
text = io.TextIOWrapper(buffer, encoding, errors, newline)
|
||||
text.mode = mode
|
||||
return text
|
||||
|
||||
def unwrap(self):
|
||||
self._ssl_io_loop(self.sslobj.unwrap)
|
||||
|
||||
def close(self):
|
||||
self.socket.close()
|
||||
|
||||
def getpeercert(self, binary_form=False):
|
||||
return self.sslobj.getpeercert(binary_form)
|
||||
|
||||
def version(self):
|
||||
return self.sslobj.version()
|
||||
|
||||
def cipher(self):
|
||||
return self.sslobj.cipher()
|
||||
|
||||
def selected_alpn_protocol(self):
|
||||
return self.sslobj.selected_alpn_protocol()
|
||||
|
||||
def selected_npn_protocol(self):
|
||||
return self.sslobj.selected_npn_protocol()
|
||||
|
||||
def shared_ciphers(self):
|
||||
return self.sslobj.shared_ciphers()
|
||||
|
||||
def compression(self):
|
||||
return self.sslobj.compression()
|
||||
|
||||
def settimeout(self, value):
|
||||
self.socket.settimeout(value)
|
||||
|
||||
def gettimeout(self):
|
||||
return self.socket.gettimeout()
|
||||
|
||||
def _decref_socketios(self):
|
||||
self.socket._decref_socketios()
|
||||
|
||||
def _wrap_ssl_read(self, len, buffer=None):
|
||||
try:
|
||||
return self._ssl_io_loop(self.sslobj.read, len, buffer)
|
||||
except ssl.SSLError as e:
|
||||
if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
||||
return 0 # eof, return 0.
|
||||
else:
|
||||
raise
|
||||
|
||||
def _ssl_io_loop(self, func, *args):
|
||||
"""Performs an I/O loop between incoming/outgoing and the socket."""
|
||||
should_loop = True
|
||||
ret = None
|
||||
|
||||
while should_loop:
|
||||
errno = None
|
||||
try:
|
||||
ret = func(*args)
|
||||
except ssl.SSLError as e:
|
||||
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
|
||||
# WANT_READ, and WANT_WRITE are expected, others are not.
|
||||
raise e
|
||||
errno = e.errno
|
||||
|
||||
buf = self.outgoing.read()
|
||||
self.socket.sendall(buf)
|
||||
|
||||
if errno is None:
|
||||
should_loop = False
|
||||
elif errno == ssl.SSL_ERROR_WANT_READ:
|
||||
buf = self.socket.recv(SSL_BLOCKSIZE)
|
||||
if buf:
|
||||
self.incoming.write(buf)
|
||||
else:
|
||||
self.incoming.write_eof()
|
||||
return ret
|
||||
@@ -0,0 +1,56 @@
|
||||
import socket
|
||||
import ssl
|
||||
from typing import Union
|
||||
|
||||
from ._ssl_transport import SSLTransport
|
||||
|
||||
from ..._errors import ProxyError
|
||||
from ... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
SocketType = Union[socket.socket, ssl.SSLSocket, SSLTransport]
|
||||
|
||||
|
||||
class SyncSocketStream(abc.SyncSocketStream):
|
||||
_socket: SocketType
|
||||
|
||||
def __init__(self, sock: SocketType):
|
||||
self._socket = sock
|
||||
|
||||
def write_all(self, data):
|
||||
self._socket.sendall(data)
|
||||
|
||||
def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return self._socket.recv(max_bytes)
|
||||
|
||||
def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = self._socket.recv(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
def start_tls(self, hostname: str, ssl_context: ssl.SSLContext) -> 'SyncSocketStream':
|
||||
if isinstance(self._socket, (ssl.SSLSocket, SSLTransport)):
|
||||
ssl_socket = SSLTransport(
|
||||
self._socket,
|
||||
ssl_context=ssl_context,
|
||||
server_hostname=hostname,
|
||||
)
|
||||
else: # plain socket?
|
||||
ssl_socket = ssl_context.wrap_socket(
|
||||
self._socket,
|
||||
server_hostname=hostname,
|
||||
)
|
||||
|
||||
return SyncSocketStream(ssl_socket)
|
||||
|
||||
def close(self):
|
||||
self._socket.close()
|
||||
|
||||
@property
|
||||
def socket(self) -> SocketType: # pragma: nocover
|
||||
return self._socket
|
||||
Reference in New Issue
Block a user