diff --git a/CHANGELOG.md b/CHANGELOG.md
index e5510fc..b5de6b8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
+- `IPForceAdapter` unified factory function
+- `IPForceSession` unified session class
+- `IPVersion` enum (`V4`, `V6`)
+- `IPForceMethod` enum (`GLOBAL`, `LOCK`)
- `IPv6LockAdapter` class
- `IPv4LockAdapter` class
- Logo
diff --git a/README.md b/README.md
index 62b4472..369b69e 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@
| PyPI Counter |
-  |
+  |
| Github Stars |
@@ -43,7 +43,7 @@
@@ -60,58 +60,78 @@
- `pip install ipforce==0.1`
## Usage
-### Enforce IPv4
+### Enforce IPv4
Use when you need to ensure connections only use IPv4 addresses, useful for legacy systems that don't support IPv6, networks with IPv4-only infrastructure, or testing IPv4 connectivity.
```python
+from ipforce import IPForceAdapter, IPVersion, IPForceMethod
import requests
-from ipforce import IPv4TransportAdapter
-# Create a session that will only use IPv4 addresses
session = requests.Session()
-session.mount('http://', IPv4TransportAdapter())
-session.mount('https://', IPv4TransportAdapter())
+adapter = IPForceAdapter(IPVersion.V4, IPForceMethod.LOCK)
+session.mount('http://', adapter)
+session.mount('https://', adapter)
-# All requests through this session will only resolve to IPv4 addresses
response = session.get('https://ifconfig.co/json')
```
### Enforce IPv6
-
Use when you need to ensure connections only use IPv6 addresses, useful for modern networks with IPv6 infrastructure, testing IPv6 connectivity, or applications requiring IPv6-specific features.
```python
+from ipforce import IPForceAdapter, IPVersion, IPForceMethod
import requests
-from ipforce import IPv6TransportAdapter
-# Create a session that will only use IPv6 addresses
session = requests.Session()
-session.mount('http://', IPv6TransportAdapter())
-session.mount('https://', IPv6TransportAdapter())
+adapter = IPForceAdapter(IPVersion.V6, IPForceMethod.LOCK)
+session.mount('http://', adapter)
+session.mount('https://', adapter)
-# All requests through this session will only resolve to IPv6 addresses
response = session.get('https://ifconfig.co/json')
```
+### Using IPForceSession
+
+```python
+from ipforce import IPForceSession, IPVersion
+
+with IPForceSession(IPVersion.V4) as session:
+ response = session.get('https://ifconfig.co/json')
+```
+
+### Available Methods
+
+| Method | Description |
+|--------|-------------|
+| `IPForceMethod.LOCK` | Thread-safe — global lock serialization (default) |
+| `IPForceMethod.GLOBAL` | Non-thread-safe — temporary getaddrinfo patch |
+
> [!WARNING]
-> `IPv4TransportAdapter` / `IPv6TransportAdapter` are NOT thread-safe. They modify the global `socket.getaddrinfo` function, which can cause race conditions in multi-threaded applications. Use the thread-safe adapters below for concurrent usage.
+> `IPForceMethod.GLOBAL` is NOT thread-safe. It modifies the global `socket.getaddrinfo` function, which can cause race conditions in multi-threaded applications. Use `IPForceMethod.LOCK` (the default) for concurrent usage.
-### Thread-Safe: Lock-Based Adapters
+### Direct Class Usage (Deprecated)
-A process-wide lock serializes access to `socket.getaddrinfo`, guaranteeing correctness under concurrent access.
+The following direct class usage still works but is deprecated in favor of the unified API above:
```python
-import requests
from ipforce import IPv4LockAdapter, IPv6LockAdapter
session = requests.Session()
session.mount('http://', IPv4LockAdapter()) # or IPv6LockAdapter()
session.mount('https://', IPv4LockAdapter()) # or IPv6LockAdapter()
-
response = session.get('https://ifconfig.co/json')
```
+### Roadmap
+
+| Method | Description |
+|--------|-------------|
+| `IPForceMethod.THREAD_LOCAL` | Per-thread dispatch (fully concurrent) |
+| `IPForceMethod.CONTEXT_VAR` | ContextVar dispatch (async-safe) |
+| `IPForceMethod.CONNECTION` | urllib3 connection-level (zero global state) |
+| `IPForceMethod.AUTO` | Automatically select best available |
+
## Issues & Bug Reports
Just fill an issue and describe it. We'll check it ASAP!
diff --git a/ipforce/__init__.py b/ipforce/__init__.py
index 349378d..28a5ef0 100644
--- a/ipforce/__init__.py
+++ b/ipforce/__init__.py
@@ -1,12 +1,16 @@
# -*- coding: utf-8 -*-
"""ipforce modules."""
from .params import IPFORCE_VERSION
+from .enums import IPVersion, IPForceMethod
+from .api import IPForceAdapter, IPForceSession
from .adapters import IPv4TransportAdapter, IPv6TransportAdapter
from .adapters import IPv4LockAdapter, IPv6LockAdapter
__version__ = IPFORCE_VERSION
__all__ = [
+ "IPVersion", "IPForceMethod",
+ "IPForceAdapter", "IPForceSession",
"IPv4TransportAdapter", "IPv6TransportAdapter",
"IPv4LockAdapter", "IPv6LockAdapter",
]
diff --git a/ipforce/adapters.py b/ipforce/adapters.py
index ff3308a..8e498be 100644
--- a/ipforce/adapters.py
+++ b/ipforce/adapters.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""IPForce Adapters to force IPv4 or IPv6 for requests."""
import socket
+import warnings
from typing import Any, List, Tuple
from requests.adapters import HTTPAdapter
from threading import Lock
@@ -13,6 +14,15 @@
class IPv4TransportAdapter(HTTPAdapter):
"""A custom HTTPAdapter that enforces the use of IPv4 for DNS resolution during HTTP(S) requests using the requests library."""
+ def __init__(self, *args, **kwargs) -> None:
+ """Initialize the adapter and emit a deprecation warning."""
+ warnings.warn(
+ "IPv4TransportAdapter is deprecated, use IPForceAdapter(IPVersion.V4, IPForceMethod.GLOBAL) instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ super().__init__(*args, **kwargs)
+
def send(self, *args: list, **kwargs: dict) -> Any:
"""
Override send method to apply the monkey patch only during the request.
@@ -43,6 +53,15 @@ def ipv4_only_getaddrinfo(*gargs: list, **gkwargs: dict) -> List[Tuple]:
class IPv6TransportAdapter(HTTPAdapter):
"""A custom HTTPAdapter that enforces the use of IPv6 for DNS resolution during HTTP(S) requests using the requests library."""
+ def __init__(self, *args, **kwargs) -> None:
+ """Initialize the adapter and emit a deprecation warning."""
+ warnings.warn(
+ "IPv6TransportAdapter is deprecated, use IPForceAdapter(IPVersion.V6, IPForceMethod.GLOBAL) instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ super().__init__(*args, **kwargs)
+
def send(self, *args: list, **kwargs: dict) -> Any:
"""
Override send method to apply the monkey patch only during the request.
diff --git a/ipforce/api.py b/ipforce/api.py
new file mode 100644
index 0000000..eb06fc8
--- /dev/null
+++ b/ipforce/api.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+"""Unified public API for IPForce adapter and session creation."""
+import warnings
+
+from requests import Session
+from requests.adapters import HTTPAdapter
+
+from .enums import IPVersion, IPForceMethod
+from .adapters import (
+ IPv4TransportAdapter, IPv6TransportAdapter,
+ IPv4LockAdapter, IPv6LockAdapter,
+)
+
+_ADAPTER_REGISTRY = {
+ (IPVersion.V4, IPForceMethod.GLOBAL): IPv4TransportAdapter,
+ (IPVersion.V6, IPForceMethod.GLOBAL): IPv6TransportAdapter,
+ (IPVersion.V4, IPForceMethod.LOCK): IPv4LockAdapter,
+ (IPVersion.V6, IPForceMethod.LOCK): IPv6LockAdapter,
+}
+
+
+def IPForceAdapter(
+ ip_version: IPVersion,
+ method: IPForceMethod = IPForceMethod.LOCK,
+) -> HTTPAdapter:
+ """
+ Create an HTTP adapter that forces a specific IP version.
+
+ :param ip_version: IPVersion.V4 or IPVersion.V6
+ :param method: thread-safety strategy (default: LOCK)
+ :return: configured HTTPAdapter instance
+ :raises ValueError: if the (ip_version, method) combination is not registered
+ """
+ adapter_cls = _ADAPTER_REGISTRY.get((ip_version, method))
+ if adapter_cls is None:
+ raise ValueError("Unsupported combination: {v} + {m}".format(v=ip_version, m=method))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", DeprecationWarning)
+ return adapter_cls()
+
+
+class IPForceSession(Session):
+ """A requests.Session pre-configured to force a specific IP version."""
+
+ def __init__(
+ self,
+ ip_version: IPVersion,
+ method: IPForceMethod = IPForceMethod.LOCK,
+ ) -> None:
+ """
+ Initialize the session with an IP-version-forced adapter.
+
+ :param ip_version: IPVersion.V4 or IPVersion.V6
+ :param method: thread-safety strategy (default: LOCK)
+ """
+ super().__init__()
+ adapter = IPForceAdapter(ip_version, method)
+ self.mount('http://', adapter)
+ self.mount('https://', adapter)
diff --git a/ipforce/enums.py b/ipforce/enums.py
new file mode 100644
index 0000000..096d383
--- /dev/null
+++ b/ipforce/enums.py
@@ -0,0 +1,17 @@
+# -*- coding: utf-8 -*-
+"""IPForce enumerations for IP version and resolution method selection."""
+from enum import Enum
+
+
+class IPVersion(Enum):
+ """IP protocol version to enforce for DNS resolution."""
+
+ V4 = "ipv4"
+ V6 = "ipv6"
+
+
+class IPForceMethod(Enum):
+ """Thread-safety strategy for address family enforcement."""
+
+ GLOBAL = "global"
+ LOCK = "lock"
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..4b2938f
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,3 @@
+[pytest]
+filterwarnings =
+ ignore:.*is deprecated, use IP.*:DeprecationWarning
diff --git a/tests/test_api.py b/tests/test_api.py
new file mode 100644
index 0000000..aa7b40e
--- /dev/null
+++ b/tests/test_api.py
@@ -0,0 +1,108 @@
+"""Tests for the unified IPForceAdapter / IPForceSession API."""
+import socket
+import warnings
+import unittest
+
+from requests.adapters import HTTPAdapter
+
+from ipforce import (
+ IPVersion, IPForceMethod,
+ IPForceAdapter, IPForceSession,
+ IPv4TransportAdapter, IPv6TransportAdapter,
+)
+from ipforce.adapters import _BaseLockAdapter
+
+
+class TestIPForceAdapterFactory(unittest.TestCase):
+ """Test that IPForceAdapter returns correct adapter types."""
+
+ def test_v4_lock(self):
+ adapter = IPForceAdapter(IPVersion.V4, IPForceMethod.LOCK)
+ self.assertIsInstance(adapter, _BaseLockAdapter)
+ self.assertEqual(adapter._family, socket.AF_INET)
+
+ def test_v6_lock(self):
+ adapter = IPForceAdapter(IPVersion.V6, IPForceMethod.LOCK)
+ self.assertIsInstance(adapter, _BaseLockAdapter)
+ self.assertEqual(adapter._family, socket.AF_INET6)
+
+ def test_v4_global(self):
+ adapter = IPForceAdapter(IPVersion.V4, IPForceMethod.GLOBAL)
+ self.assertIsInstance(adapter, HTTPAdapter)
+
+ def test_v6_global(self):
+ adapter = IPForceAdapter(IPVersion.V6, IPForceMethod.GLOBAL)
+ self.assertIsInstance(adapter, HTTPAdapter)
+
+ def test_default_method_is_lock(self):
+ adapter = IPForceAdapter(IPVersion.V4)
+ self.assertIsInstance(adapter, _BaseLockAdapter)
+
+ def test_invalid_combination_raises(self):
+ with self.assertRaises((ValueError, KeyError)):
+ IPForceAdapter(IPVersion.V4, "not_a_method")
+
+
+class TestIPForceSession(unittest.TestCase):
+ """Test IPForceSession class."""
+
+ def test_v4_session_mounts_lock_adapter(self):
+ with IPForceSession(IPVersion.V4) as session:
+ adapter = session.get_adapter('https://example.com')
+ self.assertIsInstance(adapter, _BaseLockAdapter)
+
+ def test_v6_session_mounts_lock_adapter(self):
+ with IPForceSession(IPVersion.V6) as session:
+ adapter = session.get_adapter('https://example.com')
+ self.assertIsInstance(adapter, _BaseLockAdapter)
+ self.assertEqual(adapter._family, socket.AF_INET6)
+
+ def test_session_with_global_method(self):
+ with IPForceSession(IPVersion.V4, method=IPForceMethod.GLOBAL) as session:
+ adapter = session.get_adapter('https://example.com')
+ self.assertIsInstance(adapter, HTTPAdapter)
+
+ def test_session_context_manager(self):
+ with IPForceSession(IPVersion.V4) as session:
+ self.assertIsInstance(session, IPForceSession)
+
+
+class TestDeprecationWarnings(unittest.TestCase):
+ """Old v0.1 classes emit DeprecationWarning; new API does not."""
+
+ def test_ipv4_transport_adapter_warns(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ IPv4TransportAdapter()
+ self.assertEqual(len(w), 1)
+ self.assertTrue(issubclass(w[0].category, DeprecationWarning))
+ self.assertIn("IPForceAdapter", str(w[0].message))
+
+ def test_ipv6_transport_adapter_warns(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ IPv6TransportAdapter()
+ self.assertEqual(len(w), 1)
+ self.assertTrue(issubclass(w[0].category, DeprecationWarning))
+
+ def test_new_api_does_not_warn(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ IPForceAdapter(IPVersion.V4, IPForceMethod.LOCK)
+ IPForceAdapter(IPVersion.V4, IPForceMethod.GLOBAL)
+ session = IPForceSession(IPVersion.V4)
+ session.close()
+ dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]
+ self.assertEqual(len(dep_warnings), 0)
+
+
+class TestEnums(unittest.TestCase):
+ """Test enum values."""
+
+ def test_ip_version_values(self):
+ self.assertEqual(IPVersion.V4.value, "ipv4")
+ self.assertEqual(IPVersion.V6.value, "ipv6")
+
+ def test_method_values(self):
+ self.assertEqual(IPForceMethod.GLOBAL.value, "global")
+ self.assertEqual(IPForceMethod.LOCK.value, "lock")