fix: 修复代理问题
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from ._proxy_chain import ProxyChain
|
||||
|
||||
__all__ = ('ProxyChain',)
|
||||
@@ -0,0 +1,34 @@
|
||||
from typing import Iterable
|
||||
import warnings
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Iterable):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(self, dest_host, dest_port, timeout=None):
|
||||
curr_socket = None
|
||||
proxies = list(self._proxies)
|
||||
|
||||
length = len(proxies) - 1
|
||||
for i in range(length):
|
||||
curr_socket = await proxies[i].connect(
|
||||
dest_host=proxies[i + 1].proxy_host,
|
||||
dest_port=proxies[i + 1].proxy_port,
|
||||
timeout=timeout,
|
||||
_socket=curr_socket,
|
||||
)
|
||||
|
||||
curr_socket = await proxies[length].connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
timeout=timeout,
|
||||
_socket=curr_socket,
|
||||
)
|
||||
|
||||
return curr_socket
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._proxy import AnyioProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
__all__ = ('Proxy', 'ProxyChain')
|
||||
@@ -0,0 +1,42 @@
|
||||
from typing import Iterable
|
||||
import warnings
|
||||
from ._proxy import AnyioProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Iterable[AnyioProxy]):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host,
|
||||
dest_port,
|
||||
dest_ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
_stream = None
|
||||
proxies = list(self._proxies)
|
||||
|
||||
length = len(proxies) - 1
|
||||
for i in range(length):
|
||||
_stream = await proxies[i].connect(
|
||||
dest_host=proxies[i + 1].proxy_host,
|
||||
dest_port=proxies[i + 1].proxy_port,
|
||||
timeout=timeout,
|
||||
_stream=_stream,
|
||||
)
|
||||
|
||||
_stream = await proxies[length].connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
timeout=timeout,
|
||||
_stream=_stream,
|
||||
)
|
||||
|
||||
return _stream
|
||||
@@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
import anyio
|
||||
import anyio.abc
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_host: Optional[str] = None,
|
||||
) -> anyio.abc.SocketStream:
|
||||
|
||||
return await anyio.connect_tcp(
|
||||
remote_host=host,
|
||||
remote_port=port,
|
||||
local_host=local_host,
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
import anyio
|
||||
|
||||
from ..._types import ProxyType
|
||||
from ..._helpers import parse_proxy_url
|
||||
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ._resolver import Resolver
|
||||
from ._stream import AnyioSocketStream
|
||||
from ._connect import connect_tcp
|
||||
|
||||
from ..._protocols.errors import ReplyError
|
||||
from ..._connectors.factory_async import create_connector
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class AnyioProxy:
|
||||
_stream: Optional[AnyioSocketStream]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
proxy_ssl: Optional[ssl.SSLContext] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._proxy_ssl = proxy_ssl
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> AnyioSocketStream:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
_stream = kwargs.get('_stream')
|
||||
if _stream is not None:
|
||||
warnings.warn(
|
||||
"The '_stream' argument is deprecated and will be removed in the future",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
local_host = kwargs.get('local_host')
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
if _stream is None:
|
||||
try:
|
||||
_stream = AnyioSocketStream(
|
||||
await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_host=local_host,
|
||||
)
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
|
||||
stream = _stream
|
||||
|
||||
try:
|
||||
if self._proxy_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=self._proxy_host,
|
||||
ssl_context=self._proxy_ssl,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
if dest_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=dest_host,
|
||||
ssl_context=dest_ssl,
|
||||
)
|
||||
|
||||
return stream
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException:
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
except TimeoutError as e:
|
||||
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
|
||||
|
||||
@property
|
||||
def proxy_host(self):
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self):
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'AnyioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,22 @@
|
||||
import anyio
|
||||
import socket
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
|
||||
class Resolver(abc.AsyncResolver):
|
||||
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
|
||||
infos = await anyio.getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
if not infos: # pragma: no cover
|
||||
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
@@ -0,0 +1,59 @@
|
||||
import ssl
|
||||
from typing import Union
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
from anyio.streams.tls import TLSStream
|
||||
|
||||
from ..._errors import ProxyError
|
||||
from ... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
AnyioStreamType = Union[anyio.abc.SocketStream, TLSStream]
|
||||
|
||||
|
||||
class AnyioSocketStream(abc.AsyncSocketStream):
|
||||
_stream: AnyioStreamType
|
||||
|
||||
def __init__(self, stream: AnyioStreamType) -> None:
|
||||
self._stream = stream
|
||||
|
||||
async def write_all(self, data: bytes):
|
||||
await self._stream.send(item=data)
|
||||
|
||||
async def read(self, max_bytes: int = DEFAULT_RECEIVE_SIZE):
|
||||
try:
|
||||
return await self._stream.receive(max_bytes=max_bytes)
|
||||
except anyio.EndOfStream: # pragma: no cover
|
||||
return b""
|
||||
|
||||
async def read_exact(self, n: int):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self.read(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
) -> 'AnyioSocketStream':
|
||||
ssl_stream = await TLSStream.wrap(
|
||||
self._stream,
|
||||
ssl_context=ssl_context,
|
||||
hostname=hostname,
|
||||
standard_compatible=False,
|
||||
server_side=False,
|
||||
)
|
||||
return AnyioSocketStream(ssl_stream)
|
||||
|
||||
async def close(self):
|
||||
await self._stream.aclose()
|
||||
|
||||
@property
|
||||
def anyio_stream(self) -> AnyioStreamType: # pragma: no cover
|
||||
return self._stream
|
||||
@@ -0,0 +1,7 @@
|
||||
from ._proxy import AnyioProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
__all__ = (
|
||||
'Proxy',
|
||||
'ProxyChain',
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Sequence
|
||||
import warnings
|
||||
from ._proxy import AnyioProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Sequence[AnyioProxy]):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host,
|
||||
dest_port,
|
||||
dest_ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
forward = None
|
||||
for proxy in self._proxies:
|
||||
proxy._forward = forward
|
||||
forward = proxy
|
||||
|
||||
return await forward.connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
import anyio
|
||||
import anyio.abc
|
||||
from ._stream import AnyioSocketStream
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_host: Optional[str] = None,
|
||||
) -> AnyioSocketStream:
|
||||
s = await anyio.connect_tcp(
|
||||
remote_host=host,
|
||||
remote_port=port,
|
||||
local_host=local_host,
|
||||
)
|
||||
return AnyioSocketStream(s)
|
||||
@@ -0,0 +1,135 @@
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
|
||||
import anyio
|
||||
|
||||
from ._connect import connect_tcp
|
||||
from ._stream import AnyioSocketStream
|
||||
from .._resolver import Resolver
|
||||
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ...._types import ProxyType
|
||||
from ...._helpers import parse_proxy_url
|
||||
|
||||
from ...._protocols.errors import ReplyError
|
||||
from ...._connectors.factory_async import create_connector
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class AnyioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
proxy_ssl: Optional[ssl.SSLContext] = None,
|
||||
forward: Optional['AnyioProxy'] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._rdns = rdns
|
||||
|
||||
self._proxy_ssl = proxy_ssl
|
||||
self._forward = forward
|
||||
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> AnyioSocketStream:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
local_host = kwargs.get('local_host')
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
local_host=local_host,
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
local_host: Optional[str] = None,
|
||||
) -> AnyioSocketStream:
|
||||
if self._forward is None:
|
||||
try:
|
||||
stream = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_host=local_host,
|
||||
)
|
||||
except OSError as e:
|
||||
raise ProxyConnectionError(
|
||||
e.errno,
|
||||
"Couldn't connect to proxy"
|
||||
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
|
||||
) from e
|
||||
else:
|
||||
stream = await self._forward.connect(
|
||||
dest_host=self._proxy_host,
|
||||
dest_port=self._proxy_port,
|
||||
)
|
||||
|
||||
try:
|
||||
if self._proxy_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=self._proxy_host,
|
||||
ssl_context=self._proxy_ssl,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
if dest_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=dest_host,
|
||||
ssl_context=dest_ssl,
|
||||
)
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException:
|
||||
with anyio.CancelScope(shield=True):
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
return stream
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'AnyioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,59 @@
|
||||
import ssl
|
||||
from typing import Union
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
from anyio.streams.tls import TLSStream
|
||||
|
||||
from ...._errors import ProxyError
|
||||
from .... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
AnyioStreamType = Union[anyio.abc.SocketStream, TLSStream]
|
||||
|
||||
|
||||
class AnyioSocketStream(abc.AsyncSocketStream):
|
||||
_stream: AnyioStreamType
|
||||
|
||||
def __init__(self, stream: AnyioStreamType) -> None:
|
||||
self._stream = stream
|
||||
|
||||
async def write_all(self, data: bytes):
|
||||
await self._stream.send(item=data)
|
||||
|
||||
async def read(self, max_bytes: int = DEFAULT_RECEIVE_SIZE):
|
||||
try:
|
||||
return await self._stream.receive(max_bytes=max_bytes)
|
||||
except anyio.EndOfStream: # pragma: no cover
|
||||
return b""
|
||||
|
||||
async def read_exact(self, n: int):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self.read(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
) -> 'AnyioSocketStream':
|
||||
ssl_stream = await TLSStream.wrap(
|
||||
self._stream,
|
||||
ssl_context=ssl_context,
|
||||
hostname=hostname,
|
||||
standard_compatible=False,
|
||||
server_side=False,
|
||||
)
|
||||
return AnyioSocketStream(ssl_stream)
|
||||
|
||||
async def close(self):
|
||||
await self._stream.aclose()
|
||||
|
||||
@property
|
||||
def anyio_stream(self) -> AnyioStreamType: # pragma: no cover
|
||||
return self._stream
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._proxy import AsyncioProxy as Proxy
|
||||
|
||||
|
||||
__all__ = ('Proxy',)
|
||||
@@ -0,0 +1,43 @@
|
||||
import socket
|
||||
import asyncio
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ._resolver import Resolver
|
||||
from ..._helpers import is_ipv4_address, is_ipv6_address
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> socket.socket:
|
||||
|
||||
family, host = await _resolve_host(host, loop)
|
||||
|
||||
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
|
||||
sock.setblocking(False)
|
||||
if local_addr is not None: # pragma: no cover
|
||||
sock.bind(local_addr)
|
||||
|
||||
if is_ipv6_address(host):
|
||||
address = (host, port, 0, 0) # to fix OSError: [WinError 10022]
|
||||
else:
|
||||
address = (host, port) # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
await loop.sock_connect(sock=sock, address=address)
|
||||
except OSError:
|
||||
sock.close()
|
||||
raise
|
||||
return sock
|
||||
|
||||
|
||||
async def _resolve_host(host, loop):
|
||||
if is_ipv4_address(host):
|
||||
return socket.AF_INET, host
|
||||
if is_ipv6_address(host):
|
||||
return socket.AF_INET6, host
|
||||
|
||||
resolver = Resolver(loop=loop)
|
||||
return await resolver.resolve(host=host)
|
||||
@@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
import socket
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
from ..._types import ProxyType
|
||||
from ..._helpers import parse_proxy_url
|
||||
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
from ._stream import AsyncioSocketStream
|
||||
from ._resolver import Resolver
|
||||
|
||||
from ..._protocols.errors import ReplyError
|
||||
from ..._connectors.factory_async import create_connector
|
||||
|
||||
from ._connect import connect_tcp
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import asyncio as async_timeout # pylint:disable=reimported
|
||||
else:
|
||||
import async_timeout
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class AsyncioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
):
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self._loop = loop
|
||||
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._resolver = Resolver(loop=loop)
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> socket.socket:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
_socket = kwargs.get('_socket')
|
||||
if _socket is not None:
|
||||
warnings.warn(
|
||||
"The '_socket' argument is deprecated and will be removed in the future",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
async with async_timeout.timeout(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
_socket=_socket,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host,
|
||||
dest_port,
|
||||
_socket=None,
|
||||
local_addr=None,
|
||||
) -> socket.socket:
|
||||
if _socket is None:
|
||||
try:
|
||||
_socket = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
loop=self._loop,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
|
||||
stream = AsyncioSocketStream(sock=_socket, loop=self._loop)
|
||||
|
||||
try:
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
return _socket
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except (asyncio.CancelledError, Exception): # pragma: no cover
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
@property
|
||||
def proxy_host(self):
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self):
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'AsyncioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,25 @@
|
||||
import asyncio
|
||||
import socket
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
|
||||
class Resolver(abc.AsyncResolver):
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
self._loop = loop
|
||||
|
||||
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
|
||||
infos = await self._loop.getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
if not infos: # pragma: no cover
|
||||
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
@@ -0,0 +1,36 @@
|
||||
import asyncio
|
||||
import socket
|
||||
|
||||
from ..._errors import ProxyError
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
|
||||
class AsyncioSocketStream(abc.AsyncSocketStream):
|
||||
_loop: asyncio.AbstractEventLoop = None
|
||||
_socket = None
|
||||
|
||||
def __init__(self, sock: socket.socket, loop: asyncio.AbstractEventLoop):
|
||||
self._loop = loop
|
||||
self._socket = sock
|
||||
|
||||
async def write_all(self, data):
|
||||
await self._loop.sock_sendall(self._socket, data)
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._loop.sock_recv(self._socket, max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self._loop.sock_recv(self._socket, n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def close(self):
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._proxy import AsyncioProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
__all__ = ('Proxy', 'ProxyChain')
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Sequence
|
||||
import warnings
|
||||
from ._proxy import AsyncioProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Sequence[AsyncioProxy]):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
forward = None
|
||||
for proxy in self._proxies:
|
||||
proxy._forward = forward
|
||||
forward = proxy
|
||||
|
||||
return await forward.connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
import asyncio
|
||||
from typing import Optional, Tuple
|
||||
from ._stream import AsyncioSocketStream
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> AsyncioSocketStream:
|
||||
kwargs = {}
|
||||
if local_addr is not None:
|
||||
kwargs['local_addr'] = local_addr # pragma: no cover
|
||||
|
||||
reader, writer = await asyncio.open_connection(
|
||||
host=host,
|
||||
port=port,
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
return AsyncioSocketStream(
|
||||
loop=loop,
|
||||
reader=reader,
|
||||
writer=writer,
|
||||
)
|
||||
@@ -0,0 +1,157 @@
|
||||
import asyncio
|
||||
import ssl
|
||||
from typing import Any, Optional, Tuple
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
|
||||
from ...._types import ProxyType
|
||||
from ...._helpers import parse_proxy_url
|
||||
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ...._protocols.errors import ReplyError
|
||||
from ...._connectors.factory_async import create_connector
|
||||
|
||||
from .._resolver import Resolver
|
||||
from ._stream import AsyncioSocketStream
|
||||
from ._connect import connect_tcp
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import asyncio as async_timeout # pylint:disable=reimported
|
||||
else:
|
||||
import async_timeout
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class AsyncioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
proxy_ssl: Optional[ssl.SSLContext] = None,
|
||||
forward: Optional['AsyncioProxy'] = None,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
):
|
||||
if loop is not None: # pragma: no cover
|
||||
warnings.warn(
|
||||
'The loop argument is deprecated and scheduled for removal in the future.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self._loop = loop
|
||||
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._rdns = rdns
|
||||
|
||||
self._proxy_ssl = proxy_ssl
|
||||
self._forward = forward
|
||||
|
||||
self._resolver = Resolver(loop=loop)
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncioSocketStream:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
async with async_timeout.timeout(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> AsyncioSocketStream:
|
||||
if self._forward is None:
|
||||
try:
|
||||
stream = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
loop=self._loop,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
raise ProxyConnectionError(
|
||||
e.errno,
|
||||
"Couldn't connect to proxy"
|
||||
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
|
||||
) from e
|
||||
else:
|
||||
stream = await self._forward.connect(
|
||||
dest_host=self._proxy_host,
|
||||
dest_port=self._proxy_port,
|
||||
)
|
||||
|
||||
try:
|
||||
if self._proxy_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=self._proxy_host,
|
||||
ssl_context=self._proxy_ssl,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
if dest_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=dest_host,
|
||||
ssl_context=dest_ssl,
|
||||
)
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except (asyncio.CancelledError, Exception):
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
return stream
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'AsyncioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,91 @@
|
||||
import asyncio
|
||||
import ssl
|
||||
|
||||
from .... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
|
||||
class AsyncioSocketStream(abc.AsyncSocketStream):
|
||||
_loop: asyncio.AbstractEventLoop
|
||||
_reader: asyncio.StreamReader
|
||||
_writer: asyncio.StreamWriter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
):
|
||||
self._loop = loop
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
|
||||
async def write_all(self, data):
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._reader.read(max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
return await self._reader.readexactly(n)
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
ssl_handshake_timeout=None,
|
||||
) -> 'AsyncioSocketStream':
|
||||
if hasattr(self._writer, 'start_tls'): # Python>=3.11
|
||||
await self._writer.start_tls(
|
||||
ssl_context,
|
||||
server_hostname=hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
return self
|
||||
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
|
||||
transport: asyncio.Transport = await self._loop.start_tls(
|
||||
self._writer.transport, # type: ignore
|
||||
protocol,
|
||||
ssl_context,
|
||||
server_side=False,
|
||||
server_hostname=hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
|
||||
# reader.set_transport(transport)
|
||||
|
||||
# Initialize the protocol, so it is made aware of being tied to
|
||||
# a TLS connection.
|
||||
# See: https://github.com/encode/httpx/issues/859
|
||||
protocol.connection_made(transport)
|
||||
|
||||
writer = asyncio.StreamWriter(
|
||||
transport=transport,
|
||||
protocol=protocol,
|
||||
reader=reader,
|
||||
loop=self._loop,
|
||||
)
|
||||
|
||||
stream = AsyncioSocketStream(loop=self._loop, reader=reader, writer=writer)
|
||||
# When we return a new SocketStream with new StreamReader/StreamWriter instances
|
||||
# we need to keep references to the old StreamReader/StreamWriter so that they
|
||||
# are not garbage collected and closed while we're still using them.
|
||||
stream._inner = self # type: ignore # pylint:disable=W0212,W0201
|
||||
return stream
|
||||
|
||||
async def close(self):
|
||||
self._writer.close()
|
||||
self._writer.transport.abort() # noqa
|
||||
|
||||
@property
|
||||
def reader(self):
|
||||
return self._reader # pragma: no cover
|
||||
|
||||
@property
|
||||
def writer(self):
|
||||
return self._writer # pragma: no cover
|
||||
@@ -0,0 +1,4 @@
|
||||
from ._proxy import CurioProxy as Proxy
|
||||
|
||||
|
||||
__all__ = ('Proxy',)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import curio
|
||||
import curio.io
|
||||
import curio.socket
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> curio.io.Socket:
|
||||
return await curio.open_connection(
|
||||
host=host,
|
||||
port=port,
|
||||
source_addr=local_addr,
|
||||
)
|
||||
@@ -0,0 +1,132 @@
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
import curio
|
||||
import curio.io
|
||||
|
||||
from ..._types import ProxyType
|
||||
from ..._helpers import parse_proxy_url
|
||||
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ._stream import CurioSocketStream
|
||||
from ._resolver import Resolver
|
||||
from ._connect import connect_tcp
|
||||
|
||||
from ..._protocols.errors import ReplyError
|
||||
from ..._connectors.factory_async import create_connector
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class CurioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> curio.io.Socket:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
_socket = kwargs.get('_socket')
|
||||
if _socket is not None:
|
||||
warnings.warn(
|
||||
"The '_socket' argument is deprecated and will be removed in the future",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
return await curio.timeout_after(
|
||||
timeout,
|
||||
self._connect,
|
||||
dest_host,
|
||||
dest_port,
|
||||
_socket,
|
||||
local_addr,
|
||||
)
|
||||
except curio.TaskTimeout as e:
|
||||
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
_socket=None,
|
||||
local_addr=None,
|
||||
):
|
||||
if _socket is None:
|
||||
try:
|
||||
_socket = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
|
||||
stream = CurioSocketStream(_socket)
|
||||
|
||||
try:
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
return _socket
|
||||
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException:
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
@property
|
||||
def proxy_host(self):
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self):
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'CurioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,25 @@
|
||||
import socket
|
||||
from curio.socket import getaddrinfo
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
|
||||
class Resolver(abc.AsyncResolver):
|
||||
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
|
||||
try:
|
||||
infos = await getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
except socket.gaierror: # pragma: no cover
|
||||
infos = None
|
||||
|
||||
if not infos: # pragma: no cover
|
||||
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
@@ -0,0 +1,32 @@
|
||||
import curio.io
|
||||
import curio.socket
|
||||
|
||||
from ... import _abc as abc
|
||||
from ..._errors import ProxyError
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
|
||||
class CurioSocketStream(abc.AsyncSocketStream):
|
||||
_socket: curio.io.Socket = None
|
||||
|
||||
def __init__(self, sock: curio.io.Socket):
|
||||
self._socket = sock
|
||||
|
||||
async def write_all(self, data):
|
||||
await self._socket.sendall(data)
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._socket.recv(max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self._socket.recv(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def close(self):
|
||||
await self._socket.close()
|
||||
@@ -0,0 +1,3 @@
|
||||
from ._proxy import TrioProxy as Proxy
|
||||
|
||||
__all__ = ('Proxy',)
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import trio
|
||||
|
||||
from ._resolver import Resolver
|
||||
from ..._helpers import is_ipv4_address, is_ipv6_address
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_addr: Optional[Tuple[str, int]] = None,
|
||||
) -> trio.socket.SocketType:
|
||||
|
||||
family, host = await _resolve_host(host)
|
||||
|
||||
sock = trio.socket.socket(family=family, type=trio.socket.SOCK_STREAM)
|
||||
if local_addr is not None: # pragma: no cover
|
||||
await sock.bind(local_addr)
|
||||
|
||||
try:
|
||||
await sock.connect((host, port))
|
||||
except OSError:
|
||||
sock.close()
|
||||
raise
|
||||
return sock
|
||||
|
||||
|
||||
async def _resolve_host(host):
|
||||
if is_ipv4_address(host):
|
||||
return trio.socket.AF_INET, host
|
||||
if is_ipv6_address(host):
|
||||
return trio.socket.AF_INET6, host
|
||||
|
||||
resolver = Resolver()
|
||||
return await resolver.resolve(host=host)
|
||||
@@ -0,0 +1,131 @@
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
import trio
|
||||
|
||||
from ..._types import ProxyType
|
||||
from ..._helpers import parse_proxy_url
|
||||
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ._stream import TrioSocketStream
|
||||
from ._resolver import Resolver
|
||||
from ._connect import connect_tcp
|
||||
|
||||
from ..._protocols.errors import ReplyError
|
||||
from ..._connectors.factory_async import create_connector
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class TrioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> trio.socket.SocketType:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
_socket = kwargs.get('_socket')
|
||||
if _socket is not None:
|
||||
warnings.warn(
|
||||
"The '_socket' argument is deprecated and will be removed in the future",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
with trio.fail_after(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
_socket=_socket,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except trio.TooSlowError as e:
|
||||
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
_socket=None,
|
||||
local_addr=None,
|
||||
) -> trio.socket.SocketType:
|
||||
if _socket is None:
|
||||
try:
|
||||
_socket = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
msg = 'Could not connect to proxy {}:{} [{}]'.format(
|
||||
self._proxy_host,
|
||||
self._proxy_port,
|
||||
e.strerror,
|
||||
)
|
||||
raise ProxyConnectionError(e.errno, msg) from e
|
||||
|
||||
stream = TrioSocketStream(sock=_socket)
|
||||
|
||||
try:
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
return _socket
|
||||
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException: # trio.Cancelled...
|
||||
with trio.CancelScope(shield=True):
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
@property
|
||||
def proxy_host(self):
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self):
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'TrioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,21 @@
|
||||
import trio
|
||||
|
||||
from ... import _abc as abc
|
||||
|
||||
|
||||
class Resolver(abc.AsyncResolver):
|
||||
async def resolve(self, host, port=0, family=trio.socket.AF_UNSPEC):
|
||||
infos = await trio.socket.getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=trio.socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
if not infos: # pragma: no cover
|
||||
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
@@ -0,0 +1,35 @@
|
||||
import trio
|
||||
|
||||
from ..._errors import ProxyError
|
||||
from ... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
|
||||
class TrioSocketStream(abc.AsyncSocketStream):
|
||||
def __init__(self, sock):
|
||||
self._socket = sock
|
||||
|
||||
async def write_all(self, data):
|
||||
total_sent = 0
|
||||
while total_sent < len(data):
|
||||
remaining = data[total_sent:]
|
||||
sent = await self._socket.send(remaining)
|
||||
total_sent += sent
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._socket.recv(max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self._socket.recv(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def close(self):
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
await trio.lowlevel.checkpoint()
|
||||
@@ -0,0 +1,7 @@
|
||||
from ._proxy import TrioProxy as Proxy
|
||||
from ._chain import ProxyChain
|
||||
|
||||
__all__ = (
|
||||
'Proxy',
|
||||
'ProxyChain',
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Sequence
|
||||
import warnings
|
||||
from ._proxy import TrioProxy
|
||||
|
||||
|
||||
class ProxyChain:
|
||||
def __init__(self, proxies: Sequence[TrioProxy]):
|
||||
warnings.warn(
|
||||
'This implementation of ProxyChain is deprecated and will be removed in the future',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._proxies = proxies
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host,
|
||||
dest_port,
|
||||
dest_ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
forward = None
|
||||
for proxy in self._proxies:
|
||||
proxy._forward = forward
|
||||
forward = proxy
|
||||
|
||||
return await forward.connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
timeout=timeout,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
import trio
|
||||
from ._stream import TrioSocketStream
|
||||
|
||||
|
||||
async def connect_tcp(
|
||||
host: str,
|
||||
port: int,
|
||||
local_addr: Optional[str] = None,
|
||||
) -> TrioSocketStream:
|
||||
trio_stream = await trio.open_tcp_stream(
|
||||
host=host,
|
||||
port=port,
|
||||
local_address=local_addr,
|
||||
)
|
||||
return TrioSocketStream(trio_stream)
|
||||
@@ -0,0 +1,135 @@
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
|
||||
import trio
|
||||
|
||||
from ._connect import connect_tcp
|
||||
from ._stream import TrioSocketStream
|
||||
from .._resolver import Resolver
|
||||
|
||||
from ...._types import ProxyType
|
||||
from ...._helpers import parse_proxy_url
|
||||
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
|
||||
|
||||
from ...._protocols.errors import ReplyError
|
||||
from ...._connectors.factory_async import create_connector
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class TrioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
rdns: Optional[bool] = None,
|
||||
proxy_ssl: Optional[ssl.SSLContext] = None,
|
||||
forward: Optional['TrioProxy'] = None,
|
||||
):
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._rdns = rdns
|
||||
|
||||
self._proxy_ssl = proxy_ssl
|
||||
self._forward = forward
|
||||
|
||||
self._resolver = Resolver()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> TrioSocketStream:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
|
||||
local_addr = kwargs.get('local_addr')
|
||||
try:
|
||||
with trio.fail_after(timeout):
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
dest_ssl=dest_ssl,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except trio.TooSlowError as e:
|
||||
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
dest_ssl: Optional[ssl.SSLContext] = None,
|
||||
local_addr: Optional[str] = None,
|
||||
) -> TrioSocketStream:
|
||||
if self._forward is None:
|
||||
try:
|
||||
stream = await connect_tcp(
|
||||
host=self._proxy_host,
|
||||
port=self._proxy_port,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
except OSError as e:
|
||||
raise ProxyConnectionError(
|
||||
e.errno,
|
||||
"Couldn't connect to proxy"
|
||||
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
|
||||
) from e
|
||||
else:
|
||||
stream = await self._forward.connect(
|
||||
dest_host=self._proxy_host,
|
||||
dest_port=self._proxy_port,
|
||||
)
|
||||
|
||||
try:
|
||||
if self._proxy_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=self._proxy_host,
|
||||
ssl_context=self._proxy_ssl,
|
||||
)
|
||||
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=stream,
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
if dest_ssl is not None:
|
||||
stream = await stream.start_tls(
|
||||
hostname=dest_host,
|
||||
ssl_context=dest_ssl,
|
||||
)
|
||||
except ReplyError as e:
|
||||
await stream.close()
|
||||
raise ProxyError(e, error_code=e.error_code)
|
||||
except BaseException: # trio.Cancelled...
|
||||
with trio.CancelScope(shield=True):
|
||||
await stream.close()
|
||||
raise
|
||||
|
||||
return stream
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs): # for backward compatibility
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> 'TrioProxy':
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,55 @@
|
||||
import ssl
|
||||
from typing import Union
|
||||
|
||||
import trio
|
||||
|
||||
from ...._errors import ProxyError
|
||||
from .... import _abc as abc
|
||||
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
TrioStreamType = Union[trio.SocketStream, trio.SSLStream]
|
||||
|
||||
|
||||
class TrioSocketStream(abc.AsyncSocketStream):
|
||||
_stream: TrioStreamType
|
||||
|
||||
def __init__(self, stream: TrioStreamType):
|
||||
self._stream = stream
|
||||
|
||||
async def write_all(self, data):
|
||||
await self._stream.send_all(data)
|
||||
|
||||
async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
|
||||
return await self._stream.receive_some(max_bytes)
|
||||
|
||||
async def read_exact(self, n):
|
||||
data = bytearray()
|
||||
while len(data) < n:
|
||||
packet = await self._stream.receive_some(n - len(data))
|
||||
if not packet: # pragma: no cover
|
||||
raise ProxyError('Connection closed unexpectedly')
|
||||
data += packet
|
||||
return data
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
) -> 'TrioSocketStream':
|
||||
ssl_stream = trio.SSLStream(
|
||||
self._stream,
|
||||
ssl_context=ssl_context,
|
||||
server_hostname=hostname,
|
||||
https_compatible=True,
|
||||
server_side=False,
|
||||
)
|
||||
await ssl_stream.do_handshake()
|
||||
return TrioSocketStream(ssl_stream)
|
||||
|
||||
async def close(self):
|
||||
await self._stream.aclose()
|
||||
|
||||
@property
|
||||
def trio_stream(self) -> TrioStreamType: # pragma: nocover
|
||||
return self._stream
|
||||
Reference in New Issue
Block a user