diff --git a/CHANGELOG.md b/CHANGELOG.md index e5510fc..b5de6b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- `IPForceAdapter` unified factory function +- `IPForceSession` unified session class +- `IPVersion` enum (`V4`, `V6`) +- `IPForceMethod` enum (`GLOBAL`, `LOCK`) - `IPv6LockAdapter` class - `IPv4LockAdapter` class - Logo diff --git a/README.md b/README.md index 62b4472..369b69e 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ - + @@ -43,7 +43,7 @@
PyPI Counter
Github Stars
- +
Code Quality CodeFactor
@@ -60,58 +60,78 @@ - `pip install ipforce==0.1` ## Usage -### Enforce IPv4 +### Enforce IPv4 Use when you need to ensure connections only use IPv4 addresses, useful for legacy systems that don't support IPv6, networks with IPv4-only infrastructure, or testing IPv4 connectivity. ```python +from ipforce import IPForceAdapter, IPVersion, IPForceMethod import requests -from ipforce import IPv4TransportAdapter -# Create a session that will only use IPv4 addresses session = requests.Session() -session.mount('http://', IPv4TransportAdapter()) -session.mount('https://', IPv4TransportAdapter()) +adapter = IPForceAdapter(IPVersion.V4, IPForceMethod.LOCK) +session.mount('http://', adapter) +session.mount('https://', adapter) -# All requests through this session will only resolve to IPv4 addresses response = session.get('https://ifconfig.co/json') ``` ### Enforce IPv6 - Use when you need to ensure connections only use IPv6 addresses, useful for modern networks with IPv6 infrastructure, testing IPv6 connectivity, or applications requiring IPv6-specific features. ```python +from ipforce import IPForceAdapter, IPVersion, IPForceMethod import requests -from ipforce import IPv6TransportAdapter -# Create a session that will only use IPv6 addresses session = requests.Session() -session.mount('http://', IPv6TransportAdapter()) -session.mount('https://', IPv6TransportAdapter()) +adapter = IPForceAdapter(IPVersion.V6, IPForceMethod.LOCK) +session.mount('http://', adapter) +session.mount('https://', adapter) -# All requests through this session will only resolve to IPv6 addresses response = session.get('https://ifconfig.co/json') ``` +### Using IPForceSession + +```python +from ipforce import IPForceSession, IPVersion + +with IPForceSession(IPVersion.V4) as session: + response = session.get('https://ifconfig.co/json') +``` + +### Available Methods + +| Method | Description | +|--------|-------------| +| `IPForceMethod.LOCK` | Thread-safe — global lock serialization (default) | +| `IPForceMethod.GLOBAL` | Non-thread-safe — temporary getaddrinfo patch | + > [!WARNING] -> `IPv4TransportAdapter` / `IPv6TransportAdapter` are NOT thread-safe. They modify the global `socket.getaddrinfo` function, which can cause race conditions in multi-threaded applications. Use the thread-safe adapters below for concurrent usage. +> `IPForceMethod.GLOBAL` is NOT thread-safe. It modifies the global `socket.getaddrinfo` function, which can cause race conditions in multi-threaded applications. Use `IPForceMethod.LOCK` (the default) for concurrent usage. -### Thread-Safe: Lock-Based Adapters +### Direct Class Usage (Deprecated) -A process-wide lock serializes access to `socket.getaddrinfo`, guaranteeing correctness under concurrent access. +The following direct class usage still works but is deprecated in favor of the unified API above: ```python -import requests from ipforce import IPv4LockAdapter, IPv6LockAdapter session = requests.Session() session.mount('http://', IPv4LockAdapter()) # or IPv6LockAdapter() session.mount('https://', IPv4LockAdapter()) # or IPv6LockAdapter() - response = session.get('https://ifconfig.co/json') ``` +### Roadmap + +| Method | Description | +|--------|-------------| +| `IPForceMethod.THREAD_LOCAL` | Per-thread dispatch (fully concurrent) | +| `IPForceMethod.CONTEXT_VAR` | ContextVar dispatch (async-safe) | +| `IPForceMethod.CONNECTION` | urllib3 connection-level (zero global state) | +| `IPForceMethod.AUTO` | Automatically select best available | + ## Issues & Bug Reports Just fill an issue and describe it. We'll check it ASAP! diff --git a/ipforce/__init__.py b/ipforce/__init__.py index 349378d..28a5ef0 100644 --- a/ipforce/__init__.py +++ b/ipforce/__init__.py @@ -1,12 +1,16 @@ # -*- coding: utf-8 -*- """ipforce modules.""" from .params import IPFORCE_VERSION +from .enums import IPVersion, IPForceMethod +from .api import IPForceAdapter, IPForceSession from .adapters import IPv4TransportAdapter, IPv6TransportAdapter from .adapters import IPv4LockAdapter, IPv6LockAdapter __version__ = IPFORCE_VERSION __all__ = [ + "IPVersion", "IPForceMethod", + "IPForceAdapter", "IPForceSession", "IPv4TransportAdapter", "IPv6TransportAdapter", "IPv4LockAdapter", "IPv6LockAdapter", ] diff --git a/ipforce/adapters.py b/ipforce/adapters.py index ff3308a..8e498be 100644 --- a/ipforce/adapters.py +++ b/ipforce/adapters.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """IPForce Adapters to force IPv4 or IPv6 for requests.""" import socket +import warnings from typing import Any, List, Tuple from requests.adapters import HTTPAdapter from threading import Lock @@ -13,6 +14,15 @@ class IPv4TransportAdapter(HTTPAdapter): """A custom HTTPAdapter that enforces the use of IPv4 for DNS resolution during HTTP(S) requests using the requests library.""" + def __init__(self, *args, **kwargs) -> None: + """Initialize the adapter and emit a deprecation warning.""" + warnings.warn( + "IPv4TransportAdapter is deprecated, use IPForceAdapter(IPVersion.V4, IPForceMethod.GLOBAL) instead", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + def send(self, *args: list, **kwargs: dict) -> Any: """ Override send method to apply the monkey patch only during the request. @@ -43,6 +53,15 @@ def ipv4_only_getaddrinfo(*gargs: list, **gkwargs: dict) -> List[Tuple]: class IPv6TransportAdapter(HTTPAdapter): """A custom HTTPAdapter that enforces the use of IPv6 for DNS resolution during HTTP(S) requests using the requests library.""" + def __init__(self, *args, **kwargs) -> None: + """Initialize the adapter and emit a deprecation warning.""" + warnings.warn( + "IPv6TransportAdapter is deprecated, use IPForceAdapter(IPVersion.V6, IPForceMethod.GLOBAL) instead", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + def send(self, *args: list, **kwargs: dict) -> Any: """ Override send method to apply the monkey patch only during the request. diff --git a/ipforce/api.py b/ipforce/api.py new file mode 100644 index 0000000..eb06fc8 --- /dev/null +++ b/ipforce/api.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +"""Unified public API for IPForce adapter and session creation.""" +import warnings + +from requests import Session +from requests.adapters import HTTPAdapter + +from .enums import IPVersion, IPForceMethod +from .adapters import ( + IPv4TransportAdapter, IPv6TransportAdapter, + IPv4LockAdapter, IPv6LockAdapter, +) + +_ADAPTER_REGISTRY = { + (IPVersion.V4, IPForceMethod.GLOBAL): IPv4TransportAdapter, + (IPVersion.V6, IPForceMethod.GLOBAL): IPv6TransportAdapter, + (IPVersion.V4, IPForceMethod.LOCK): IPv4LockAdapter, + (IPVersion.V6, IPForceMethod.LOCK): IPv6LockAdapter, +} + + +def IPForceAdapter( + ip_version: IPVersion, + method: IPForceMethod = IPForceMethod.LOCK, +) -> HTTPAdapter: + """ + Create an HTTP adapter that forces a specific IP version. + + :param ip_version: IPVersion.V4 or IPVersion.V6 + :param method: thread-safety strategy (default: LOCK) + :return: configured HTTPAdapter instance + :raises ValueError: if the (ip_version, method) combination is not registered + """ + adapter_cls = _ADAPTER_REGISTRY.get((ip_version, method)) + if adapter_cls is None: + raise ValueError("Unsupported combination: {v} + {m}".format(v=ip_version, m=method)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return adapter_cls() + + +class IPForceSession(Session): + """A requests.Session pre-configured to force a specific IP version.""" + + def __init__( + self, + ip_version: IPVersion, + method: IPForceMethod = IPForceMethod.LOCK, + ) -> None: + """ + Initialize the session with an IP-version-forced adapter. + + :param ip_version: IPVersion.V4 or IPVersion.V6 + :param method: thread-safety strategy (default: LOCK) + """ + super().__init__() + adapter = IPForceAdapter(ip_version, method) + self.mount('http://', adapter) + self.mount('https://', adapter) diff --git a/ipforce/enums.py b/ipforce/enums.py new file mode 100644 index 0000000..096d383 --- /dev/null +++ b/ipforce/enums.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +"""IPForce enumerations for IP version and resolution method selection.""" +from enum import Enum + + +class IPVersion(Enum): + """IP protocol version to enforce for DNS resolution.""" + + V4 = "ipv4" + V6 = "ipv6" + + +class IPForceMethod(Enum): + """Thread-safety strategy for address family enforcement.""" + + GLOBAL = "global" + LOCK = "lock" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..4b2938f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +filterwarnings = + ignore:.*is deprecated, use IP.*:DeprecationWarning diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..aa7b40e --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,108 @@ +"""Tests for the unified IPForceAdapter / IPForceSession API.""" +import socket +import warnings +import unittest + +from requests.adapters import HTTPAdapter + +from ipforce import ( + IPVersion, IPForceMethod, + IPForceAdapter, IPForceSession, + IPv4TransportAdapter, IPv6TransportAdapter, +) +from ipforce.adapters import _BaseLockAdapter + + +class TestIPForceAdapterFactory(unittest.TestCase): + """Test that IPForceAdapter returns correct adapter types.""" + + def test_v4_lock(self): + adapter = IPForceAdapter(IPVersion.V4, IPForceMethod.LOCK) + self.assertIsInstance(adapter, _BaseLockAdapter) + self.assertEqual(adapter._family, socket.AF_INET) + + def test_v6_lock(self): + adapter = IPForceAdapter(IPVersion.V6, IPForceMethod.LOCK) + self.assertIsInstance(adapter, _BaseLockAdapter) + self.assertEqual(adapter._family, socket.AF_INET6) + + def test_v4_global(self): + adapter = IPForceAdapter(IPVersion.V4, IPForceMethod.GLOBAL) + self.assertIsInstance(adapter, HTTPAdapter) + + def test_v6_global(self): + adapter = IPForceAdapter(IPVersion.V6, IPForceMethod.GLOBAL) + self.assertIsInstance(adapter, HTTPAdapter) + + def test_default_method_is_lock(self): + adapter = IPForceAdapter(IPVersion.V4) + self.assertIsInstance(adapter, _BaseLockAdapter) + + def test_invalid_combination_raises(self): + with self.assertRaises((ValueError, KeyError)): + IPForceAdapter(IPVersion.V4, "not_a_method") + + +class TestIPForceSession(unittest.TestCase): + """Test IPForceSession class.""" + + def test_v4_session_mounts_lock_adapter(self): + with IPForceSession(IPVersion.V4) as session: + adapter = session.get_adapter('https://example.com') + self.assertIsInstance(adapter, _BaseLockAdapter) + + def test_v6_session_mounts_lock_adapter(self): + with IPForceSession(IPVersion.V6) as session: + adapter = session.get_adapter('https://example.com') + self.assertIsInstance(adapter, _BaseLockAdapter) + self.assertEqual(adapter._family, socket.AF_INET6) + + def test_session_with_global_method(self): + with IPForceSession(IPVersion.V4, method=IPForceMethod.GLOBAL) as session: + adapter = session.get_adapter('https://example.com') + self.assertIsInstance(adapter, HTTPAdapter) + + def test_session_context_manager(self): + with IPForceSession(IPVersion.V4) as session: + self.assertIsInstance(session, IPForceSession) + + +class TestDeprecationWarnings(unittest.TestCase): + """Old v0.1 classes emit DeprecationWarning; new API does not.""" + + def test_ipv4_transport_adapter_warns(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + IPv4TransportAdapter() + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("IPForceAdapter", str(w[0].message)) + + def test_ipv6_transport_adapter_warns(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + IPv6TransportAdapter() + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + + def test_new_api_does_not_warn(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + IPForceAdapter(IPVersion.V4, IPForceMethod.LOCK) + IPForceAdapter(IPVersion.V4, IPForceMethod.GLOBAL) + session = IPForceSession(IPVersion.V4) + session.close() + dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + self.assertEqual(len(dep_warnings), 0) + + +class TestEnums(unittest.TestCase): + """Test enum values.""" + + def test_ip_version_values(self): + self.assertEqual(IPVersion.V4.value, "ipv4") + self.assertEqual(IPVersion.V6.value, "ipv6") + + def test_method_values(self): + self.assertEqual(IPForceMethod.GLOBAL.value, "global") + self.assertEqual(IPForceMethod.LOCK.value, "lock")