diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0600355..51ba373 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -5,6 +5,10 @@ on: branches: - main - master + pull_request: + branches: + - main + - master jobs: build: diff --git a/README.md b/README.md index 778efe5..f3d6b0d 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Apparser is a Python library for automating desktop applications and interacting # Installation ```bash -# Base Apparser package +# Base Apparser package with base ocr model pip install apparser # Apparser with text recognition support @@ -31,13 +31,17 @@ pip install "apparser[all]" ``` # Examples +> Disclaimer: This example is provided for educational purposes only to demonstrate UI automation concepts. +> Do not use it with Counter-Strike 2, Steam, multiplayer games, VAC-protected games, ranking/progression systems, +> or any application/service where automation is prohibited by its terms of service. +> This project is not affiliated with, endorsed by, or sponsored by Valve, Steam, or Counter-Strike 2. + 1) Open CS2 and start a game #### Code ```python from apparser import App from apparser.instructions import OCRAlgorithm from apparser.instructions.ocr import WaitText, ClickOnText -from apparser.text_readers import ScreensController, RapidOcrReader # Text labels that the OCR algorithm will look for on the screen. play_button = "play" @@ -56,8 +60,8 @@ algorithm = OCRAlgorithm([ # Select the hostage group and start the match. WaitText(group_button), ClickOnText(group_button), - ClickOnText(start_button, min_similarity=0.5), -], text_reader=ScreensController(RapidOcrReader())) + ClickOnText(start_button, min_similarity=0.5) +]) # Launch CS2 app = App(['cmd', '/c', 'start', 'steam://rungameid/730'], timeout=20) @@ -67,7 +71,7 @@ algorithm.perform(app.ui) ``` #### Video - + # Docs Full documentation is available here
diff --git a/apparser/core/app.py b/apparser/core/app.py index 2fbc7db..eeee884 100644 --- a/apparser/core/app.py +++ b/apparser/core/app.py @@ -16,9 +16,9 @@ def __init__(self, start_command: str | list[str], """Initialize an application controller. :param start_command: App start command. - :type start_command: str + :type start_command: str | list[str] :param window_title: Title of the window to attach to. - :type window_title: str + :type window_title: str | None :param timeout: Delay before the window lookup starts. :type timeout: float :raises TypeError: If any argument has an invalid type. diff --git a/apparser/core/ui/coordinates.py b/apparser/core/ui/coordinates.py index d74a422..a9a3c30 100644 --- a/apparser/core/ui/coordinates.py +++ b/apparser/core/ui/coordinates.py @@ -22,8 +22,8 @@ def __init__( :type from_ui: BaseUi :param point_one: First point of the nested region. :type point_one: Point | RelativelyPoint - :param point_two: Second point of the nested region or region size. - :type point_two: Point | RelativelyPoint | Size + :param point_two: Second point of the nested region. + :type point_two: Point | RelativelyPoint :raises TypeError: If any argument has an invalid type. """ if not isinstance(from_ui, BaseUi): diff --git a/apparser/cv/events/__init__.py b/apparser/cv/events/__init__.py index 0f40278..c272c94 100644 --- a/apparser/cv/events/__init__.py +++ b/apparser/cv/events/__init__.py @@ -2,7 +2,7 @@ from apparser.cv.events.moved import Moved from apparser.cv.events.detected import Detected from apparser.cv.events.resized import Resized -from apparser.cv.events.undetected import UnDetected +from apparser.cv.events.undetected import Undetected __all__ = ["CvEvent", "Moved", - "Detected", "Resized", "UnDetected"] + "Detected", "Resized", "Undetected"] diff --git a/apparser/cv/events/undetected.py b/apparser/cv/events/undetected.py index f446e53..ca0d1cb 100644 --- a/apparser/cv/events/undetected.py +++ b/apparser/cv/events/undetected.py @@ -1,7 +1,7 @@ from apparser.cv.events.base import CvEvent -class UnDetected(CvEvent): +class Undetected(CvEvent): """Represent a previously tracked object that disappeared.""" def __str__(self) -> str: diff --git a/apparser/cv/handlers/default.py b/apparser/cv/handlers/default.py index 27ce1f2..6e2d62e 100644 --- a/apparser/cv/handlers/default.py +++ b/apparser/cv/handlers/default.py @@ -1,4 +1,5 @@ -from typing import Callable, Type, Optional, Any +from types import UnionType +from typing import Callable, Type, Optional, Any, Union, get_args, get_origin import inspect from apparser.core import BaseUi @@ -7,6 +8,27 @@ from apparser.cv.models import CvAllData, CvHandler, CvChangeData +def _is_annotation_matches(annotation: Any, value: Any) -> bool: + if annotation is inspect.Parameter.empty: + return False + + if annotation is Any: + return True + + annotation_origin = get_origin(annotation) + annotation_args = get_args(annotation) + if annotation_origin in (Union, UnionType): + return any(_is_annotation_matches(i, value) for i in annotation_args) + + if isinstance(annotation_origin, type): + return isinstance(value, annotation_origin) + + if isinstance(annotation, type): + return isinstance(value, annotation) + + return annotation is type(value) + + def _form_args(function: Callable, *args) -> dict[str, Any]: """Build keyword arguments matching the annotated handler signature. @@ -20,7 +42,7 @@ def _form_args(function: Callable, *args) -> dict[str, Any]: function_signature = inspect.signature(function) for arg in function_signature.parameters.values(): for a in args: - if arg.annotation is type(a): + if _is_annotation_matches(arg.annotation, a): result[arg.name] = a return result diff --git a/apparser/cv/readers/yolo.py b/apparser/cv/readers/yolo.py index 8fcb027..e13720b 100644 --- a/apparser/cv/readers/yolo.py +++ b/apparser/cv/readers/yolo.py @@ -50,12 +50,12 @@ def read(self, ui: BaseUi) -> CvAllData: track_id = int(track_id.item()) cls_name = names[class_index] x1, y1, x2, y2 = box.xyxy[0].tolist() - x = int(x1) - y = int(y1) - x2 = int(x2) - y2 = int(y2) - width = x2 - x1 - height = y2 - y1 + x = round(x1) + y = round(y1) + x2 = round(x2) + y2 = round(y2) + width = x2 - x + height = y2 - y box_ui = CoordinatesUi(ui, Point(x, y), Point(x2, y2)) boxes.append( CvBox( diff --git a/apparser/cv/utils/changes_checker.py b/apparser/cv/utils/changes_checker.py index e704b08..8eff7b6 100644 --- a/apparser/cv/utils/changes_checker.py +++ b/apparser/cv/utils/changes_checker.py @@ -1,30 +1,12 @@ from apparser.cv.models import CvAllData, CvChangeData, CvBox -from apparser.cv.events import Detected, UnDetected, Moved, Resized +from apparser.cv.events import Detected, Undetected, Moved, Resized def _is_moved(box: CvBox, old_box: CvBox) -> bool: - """Check whether a box position changed. - - :param box: Current box state. - :type box: CvBox - :param old_box: Previous box state. - :type old_box: CvBox - :return: True if the box coordinates changed. - :rtype: bool - """ return abs(box.x - old_box.x) > 0 or abs(box.y - old_box.y) > 0 def _is_resized(box: CvBox, old_box: CvBox) -> bool: - """Check whether both box dimensions changed. - - :param box: Current box state. - :type box: CvBox - :param old_box: Previous box state. - :type old_box: CvBox - :return: True if width and height both changed. - :rtype: bool - """ return abs(box.width - old_box.width) > 0 or abs(box.height - old_box.height) > 0 @@ -59,7 +41,7 @@ def __get_undetected(self, current_data: CvAllData) -> list[CvChangeData]: :rtype: list[CvChangeData] """ new_ids = [i.track_id for i in current_data.boxes if i.track_id is not None] - return [CvChangeData(UnDetected, i, i) for i in self.__old_data.boxes if + return [CvChangeData(Undetected, i, i) for i in self.__old_data.boxes if i.track_id not in new_ids and i.track_id is not None] def check(self, data: CvAllData) -> list[CvChangeData]: diff --git a/apparser/instructions/default/click.py b/apparser/instructions/default/click.py index e0bc955..e3d2e99 100644 --- a/apparser/instructions/default/click.py +++ b/apparser/instructions/default/click.py @@ -12,7 +12,7 @@ def __init__(self, click_type: BaseKeyCode = LeftClick()): :param click_type: Mouse button to click. :type click_type: BaseKeyCode - :raises TypeError: If ``click_type`` is neither :class:`BaseKeyCode`. + :raises TypeError: If ``click_type`` is not a :class:`BaseKeyCode`. """ if not isinstance(click_type, BaseKeyCode): raise TypeError('click_type must be BaseKeyCode') diff --git a/apparser/instructions/default/press.py b/apparser/instructions/default/press.py index 5235db4..025fb45 100644 --- a/apparser/instructions/default/press.py +++ b/apparser/instructions/default/press.py @@ -30,16 +30,17 @@ def perform(self, *args, **kwargs): class PressKeysCombination(BaseInstruction): """Send a keyboard shortcut as a pressed combination.""" - def __init__(self, keys: list[BaseKeyCode | str]): + def __init__(self, keys: list[BaseKeyCode | str] | str): """Initialize a key combination instruction. :param keys: Keys to press together. - :type keys: list[BaseKeyCode | str] + :type keys: list[BaseKeyCode | str] | str """ self.__keys = keys self.__validate() def __validate(self): + for key in self.__keys: if not (isinstance(key, BaseKeyCode) or isinstance(key, str)): raise TypeError('key_code must be BaseKeyCode or str') diff --git a/apparser/instructions/default/sleep.py b/apparser/instructions/default/sleep.py index 8f4ce1e..bd22f52 100644 --- a/apparser/instructions/default/sleep.py +++ b/apparser/instructions/default/sleep.py @@ -12,7 +12,11 @@ def __init__(self, sleep_time: float): :param sleep_time: Delay duration in seconds. :type sleep_time: float :raises ValueError: If ``sleep_time`` is not greater than zero. + :raises TypeError: If ``sleep_time`` is not a number. """ + if not isinstance(sleep_time, float) and not isinstance(sleep_time, int): + raise TypeError("sleep_time must be a number.") + if sleep_time <= 0: raise ValueError("sleep_time must be > 0") diff --git a/apparser/instructions/ocr/click_on_text.py b/apparser/instructions/ocr/click_on_text.py index 8b8086d..62c1a0b 100644 --- a/apparser/instructions/ocr/click_on_text.py +++ b/apparser/instructions/ocr/click_on_text.py @@ -18,7 +18,7 @@ def __init__(self, text: str, min_similarity: float = 0.8, offset: Point | RelativelyPoint = Point(0, 0), text_getter: GetText | None = None, - sleep_time_before_move: float = 0.1): + sleep_time_after_move: float = 0.1): """Initialize a text click instruction. :param text: Text to locate before clicking. @@ -31,8 +31,8 @@ def __init__(self, text: str, :type offset: Point | RelativelyPoint :param text_getter: Instruction used to extract text from the screen. If None use GetText() :type text_getter: GetText | None - :param sleep_time_before_move: Delay before the click is performed. - :type sleep_time_before_move: float + :param sleep_time_after_move: Delay before the click is performed. + :type sleep_time_after_move: float """ if text_getter is None: @@ -40,7 +40,7 @@ def __init__(self, text: str, self.__mouse_mover = MoveToText(text, min_similarity, offset, text_getter) self.__click_type = click_type - self.__sleep = Sleep(sleep_time_before_move) + self.__sleep = Sleep(sleep_time_after_move) @property def id(self) -> int: diff --git a/apparser/instructions/ocr/plot_text.py b/apparser/instructions/ocr/plot_text.py index dff78cb..21af8d0 100644 --- a/apparser/instructions/ocr/plot_text.py +++ b/apparser/instructions/ocr/plot_text.py @@ -5,7 +5,7 @@ from apparser.core import BaseUi from apparser.geometry import Point -from apparser.text_readers import BaseTextReader, TextData\ +from apparser.text_readers import BaseTextReader, TextData from apparser.instructions.ocr.base import OCRInstruction from apparser.instructions.ocr.text_getter import GetText @@ -41,7 +41,7 @@ def __paint_cords(self, data: TextData): if y < 0: y = data.coordinates.right_bottom.y - self.__text_move.y x = data.coordinates.left_top.x + self.__text_move.x - if y < 0: + if x < 0: x = data.coordinates.right_bottom.x - self.__text_move.x self.__draw.text((x, y), data.text, fill=self.__color) diff --git a/apparser/instructions/ui/algorithms/ids.py b/apparser/instructions/ui/algorithms/ids.py index df16dc3..605294e 100644 --- a/apparser/instructions/ui/algorithms/ids.py +++ b/apparser/instructions/ui/algorithms/ids.py @@ -1,5 +1,6 @@ import inspect -from typing import Any +from types import UnionType +from typing import Any, Union, get_args, get_origin from apparser.core import BaseUi from apparser.instructions import BaseInstruction @@ -22,6 +23,27 @@ def _check_instruction(instruction: tuple[int, list[Any]]) -> tuple[int, list[An return instruction_id, instruction_args +def _is_annotation_matches(annotation: Any, value: Any) -> bool: + if annotation is inspect.Parameter.empty: + return False + + if annotation is Any: + return True + + annotation_origin = get_origin(annotation) + annotation_args = get_args(annotation) + if annotation_origin in (Union, UnionType): + return any(_is_annotation_matches(i, value) for i in annotation_args) + + if isinstance(annotation_origin, type): + return isinstance(value, annotation_origin) + + if isinstance(annotation, type): + return isinstance(value, annotation) + + return annotation is type(value) + + class IdsAlgorithm(BaseAlgorithm): """Resolve and execute instructions by their numeric identifiers.""" @@ -44,7 +66,7 @@ def __init__(self, raise TypeError("attributes must be list") if not isinstance(instructions, list): - raise TypeError("attributes must be list") + raise TypeError("instructions must be list") if not isinstance(debugger, BaseDebugger) and not isinstance(debugger, bool): raise TypeError(f"debugger must be a bool or BaseDebugger") @@ -68,7 +90,7 @@ def __form_args(self, instruction: BaseInstruction, *additional_args) -> dict[st function_signature = inspect.signature(instruction.perform) for arg in function_signature.parameters.values(): for a in self.__attributes + list(additional_args): - if arg.annotation is type(a): + if _is_annotation_matches(arg.annotation, a): result[arg.name] = a return result @@ -91,7 +113,7 @@ def perform(self, ui: BaseUi, *args, **kwargs): if self.__debugger is not None: self.__debugger.try_perform(instruction, **perform_kwargs) else: - instruction.perform(ui, **perform_kwargs) + instruction.perform(**perform_kwargs) def add_instruction(self, instruction: tuple[int, list[Any]]): _check_instruction(instruction) diff --git a/apparser/instructions/ui/algorithms/names.py b/apparser/instructions/ui/algorithms/names.py index f1fa940..4bdae59 100644 --- a/apparser/instructions/ui/algorithms/names.py +++ b/apparser/instructions/ui/algorithms/names.py @@ -1,5 +1,6 @@ import inspect -from typing import Any +from types import UnionType +from typing import Any, Union, get_args, get_origin from apparser.core import BaseUi @@ -23,6 +24,27 @@ def _check_instruction(instruction: tuple[str, list[Any]]) -> tuple[str, list[An return instruction_name, instruction_args +def _is_annotation_matches(annotation: Any, value: Any) -> bool: + if annotation is inspect.Parameter.empty: + return False + + if annotation is Any: + return True + + annotation_origin = get_origin(annotation) + annotation_args = get_args(annotation) + if annotation_origin in (Union, UnionType): + return any(_is_annotation_matches(i, value) for i in annotation_args) + + if isinstance(annotation_origin, type): + return isinstance(value, annotation_origin) + + if isinstance(annotation, type): + return isinstance(value, annotation) + + return annotation is type(value) + + class NamesAlgorithm(BaseAlgorithm): """Resolve and execute instructions by their registered names.""" @@ -44,7 +66,7 @@ def __init__(self, raise TypeError("attributes must be list") if not isinstance(instructions, list): - raise TypeError("attributes must be list") + raise TypeError("instructions must be list") if not isinstance(debugger, BaseDebugger) and not isinstance(debugger, bool): raise TypeError(f"debugger must be a bool or BaseDebugger") @@ -68,7 +90,7 @@ def __form_args(self, instruction: BaseInstruction, *additional_args) -> dict[st function_signature = inspect.signature(instruction.perform) for arg in function_signature.parameters.values(): for a in self.__attributes + list(additional_args): - if arg.annotation is type(a): + if _is_annotation_matches(arg.annotation, a): result[arg.name] = a return result @@ -92,7 +114,7 @@ def perform(self, ui: BaseUi, *args, **kwargs): if self.__debugger is not None: self.__debugger.try_perform(instruction, **perform_kwargs) else: - instruction.perform(ui, **perform_kwargs) + instruction.perform(**perform_kwargs) def add_instruction(self, instruction: tuple[str, list[Any]]): _check_instruction(instruction) diff --git a/apparser/instructions/ui/algorithms/ocr.py b/apparser/instructions/ui/algorithms/ocr.py index 55d373f..c5d1bf4 100644 --- a/apparser/instructions/ui/algorithms/ocr.py +++ b/apparser/instructions/ui/algorithms/ocr.py @@ -1,6 +1,6 @@ from apparser.core import BaseUi -from apparser.text_readers import BaseTextReader, EasyOcrReader, ScreensController +from apparser.text_readers import BaseTextReader, RapidOcrReader, ScreensController from apparser.instructions.debuggers import BaseDebugger, Debugger from apparser.instructions.ui.algorithms.base import BaseAlgorithm @@ -25,7 +25,7 @@ def __init__(self, :raises TypeError: If ``text_reader`` or ``debugger`` has an invalid type. """ if text_reader is None: - text_reader = ScreensController(EasyOcrReader()) + text_reader = ScreensController(RapidOcrReader()) if not isinstance(text_reader, BaseTextReader): raise TypeError("text_reader must be BaseTextReader") diff --git a/apparser/instructions/ui/algorithms/unique.py b/apparser/instructions/ui/algorithms/unique.py index 0562925..2021a58 100644 --- a/apparser/instructions/ui/algorithms/unique.py +++ b/apparser/instructions/ui/algorithms/unique.py @@ -1,4 +1,5 @@ -from typing import Any +from types import UnionType +from typing import Any, Union, get_args, get_origin import inspect from apparser.core import BaseUi @@ -8,6 +9,27 @@ from apparser.instructions.base import BaseInstruction +def _is_annotation_matches(annotation: Any, value: Any) -> bool: + if annotation is inspect.Parameter.empty: + return False + + if annotation is Any: + return True + + annotation_origin = get_origin(annotation) + annotation_args = get_args(annotation) + if annotation_origin in (Union, UnionType): + return any(_is_annotation_matches(i, value) for i in annotation_args) + + if isinstance(annotation_origin, type): + return isinstance(value, annotation_origin) + + if isinstance(annotation, type): + return isinstance(value, annotation) + + return annotation is type(value) + + class UniqueAlgorithm(BaseAlgorithm): """Run instructions with arguments resolved from unique attribute types.""" @@ -45,7 +67,7 @@ def __form_args(self, instruction: BaseInstruction) -> dict[str, Any]: function_signature = inspect.signature(instruction.perform) for arg in function_signature.parameters.values(): for a in self.__attributes: - if arg.annotation is type(a): + if _is_annotation_matches(arg.annotation, a): result[arg.name] = a return result diff --git a/docs/api/geometry/QuadPoints.rst b/docs/api/geometry/QuadPoints.rst new file mode 100644 index 0000000..1e9f313 --- /dev/null +++ b/docs/api/geometry/QuadPoints.rst @@ -0,0 +1,10 @@ +QuadPoints +========== + +.. currentmodule:: apparser.geometry + +.. autoclass:: QuadPoints + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/docs/examples/ocr.rst b/docs/examples/ocr.rst index da23d16..4ce7c1e 100644 --- a/docs/examples/ocr.rst +++ b/docs/examples/ocr.rst @@ -51,8 +51,8 @@ Code ui = WindowByDisplayUi(app.ui.window) try: - hello_world_algorithm.perform(ui) new_tab_algorithm.perform(ui) + hello_world_algorithm.perform(ui) finally: app.stop_app() diff --git a/docs/info/instructions_ids.rst b/docs/info/instructions_ids.rst index bec6635..cf28850 100644 --- a/docs/info/instructions_ids.rst +++ b/docs/info/instructions_ids.rst @@ -27,8 +27,8 @@ non-abstract instruction classes. .. important:: - Algorithm instructions currently not included in - ``get_instruction_by_id()``. + Algorithm instructions are currently not included + in ``get_instruction_by_id()``. Current numbering layout ------------------------ diff --git a/docs/requirements.txt b/docs/requirements.txt index 3f07068..0e77686 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,9 +1,9 @@ sphinx>=9.1.0 shibuya -appwindows >= 1.3.1, -keyboard >= 0.13.2, -mouse >= 0.7.1, -numpy >= 1.20, -pillow >= 11.0, -screeninfo >= 0.8, +appwindows >= 1.3.1 +keyboard >= 0.13.2 +mouse >= 0.7.1 +numpy >= 1.20 +pillow >= 11.0 +screeninfo >= 0.8 thefuzz >= 0.20.0 diff --git a/example.gif b/example.gif index fd9ec70..34bddcd 100644 Binary files a/example.gif and b/example.gif differ diff --git a/pyproject.toml b/pyproject.toml index a0eb667..99d3528 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,16 @@ [project] dependencies = [ - "appwindows >= 1.3.0", + "appwindows >= 1.3.1", "PyAutoGui >= 0.9.0", "numpy >= 1.20", "pillow >= 11.0", "screeninfo >= 0.8", - "thefuzz >= 0.20.0" + "thefuzz >= 0.20.0", + "rapidocr>=3.0.0", + "onnxruntime>=1.20.1" ] name = "apparser" -version = "1.1.0" +version = "1.1.1" authors = [ { name = "Terochkin A.S", email = "apparser.development@gmail.com" }, ] @@ -60,8 +62,6 @@ ocr = [ "easyocr >= 1.7.2, < 2.0", "paddleocr >= 3.5.0", "paddlepaddle >= 3.3.1", - "rapidocr>=3.0.0", - "onnxruntime>=1.20.1" ] speak = [ diff --git a/requirements.txt b/requirements.txt index 002d2b4..923602f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ -appwindows >= 1.3.1, -numpy >= 1.20, -pillow >= 11.0, +appwindows >= 1.3.1 +numpy >= 1.20 +pillow >= 11.0 PyAutoGui >= 0.9.0 -screeninfo >= 0.8, +screeninfo >= 0.8 thefuzz >= 0.20.0 +rapidocr>=3.0.0 +onnxruntime>=1.20.1 \ No newline at end of file diff --git a/tests/apparser/cv/events/test_undetected.py b/tests/apparser/cv/events/test_undetected.py index 692e368..52cbb37 100644 --- a/tests/apparser/cv/events/test_undetected.py +++ b/tests/apparser/cv/events/test_undetected.py @@ -1,7 +1,7 @@ from __future__ import annotations -from apparser.cv.events.undetected import UnDetected +from apparser.cv.events.undetected import Undetected def test_undetected_string_representation() -> None: - assert str(UnDetected()) == "UnDetected" + assert str(Undetected()) == "UnDetected" diff --git a/tests/apparser/cv/utils/test_changes_checker.py b/tests/apparser/cv/utils/test_changes_checker.py index 84aa970..3c8b025 100644 --- a/tests/apparser/cv/utils/test_changes_checker.py +++ b/tests/apparser/cv/utils/test_changes_checker.py @@ -1,6 +1,6 @@ from __future__ import annotations -from apparser.cv.events import Detected, Moved, Resized, UnDetected +from apparser.cv.events import Detected, Moved, Resized, Undetected from apparser.cv.models import CvAllData, CvBox from apparser.cv.utils.changes_checker import ChangesChecker, _is_moved, _is_resized from tests.utils import FakeUi @@ -36,7 +36,7 @@ def test_changes_checker_reports_detected_moved_resized_and_undetected() -> None next_result = checker.check(CvAllData([new_box])) - assert next_result[0].event is UnDetected + assert next_result[0].event is Undetected def test_changes_checker_reports_new_and_changed_boxes_in_order() -> None: diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 99e52da..4a40815 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -5,6 +5,7 @@ install_external_stubs, pyautogui_stub, paddleocr_stub, + rapidocr_stub, reset_external_stubs, screeninfo_stub, sounddevice_stub, @@ -49,9 +50,9 @@ "create_wave_file", "easyocr_stub", "install_external_stubs", - "keyboard_stub", "pyautogui_stub", "paddleocr_stub", + "rapidocr_stub", "reset_external_stubs", "screeninfo_stub", "sounddevice_stub", diff --git a/tests/utils/external_stubs.py b/tests/utils/external_stubs.py index a12aa1a..6fdb8ac 100644 --- a/tests/utils/external_stubs.py +++ b/tests/utils/external_stubs.py @@ -8,6 +8,7 @@ from tests.utils.stubs.ml.torch_stub import TorchStub from tests.utils.stubs.text.easy_ocr_stub import EasyOcrStub from tests.utils.stubs.text.paddle_ocr_stub import PaddleOcrStub +from tests.utils.stubs.text.rapid_ocr_stub import RapidOcrStub from tests.utils.stubs.text.thefuzz_stub import TheFuzzStub from tests.utils.stubs.vision.ultralytics_stub import UltralyticsStub @@ -19,6 +20,7 @@ sounddevice_stub = SoundDeviceStub() easyocr_stub = EasyOcrStub() paddleocr_stub = PaddleOcrStub() +rapidocr_stub = RapidOcrStub() torch_stub = TorchStub() chattts_stub = ChatTTSStub() @@ -31,6 +33,7 @@ def install_external_stubs() -> None: sys.modules["sounddevice"] = sounddevice_stub sys.modules["easyocr"] = easyocr_stub sys.modules["paddleocr"] = paddleocr_stub + sys.modules["rapidocr"] = rapidocr_stub sys.modules["torch"] = torch_stub sys.modules["ChatTTS"] = chattts_stub @@ -44,6 +47,7 @@ def reset_external_stubs() -> None: sounddevice_stub, easyocr_stub, paddleocr_stub, + rapidocr_stub, torch_stub, chattts_stub, ]: diff --git a/tests/utils/stubs/text/__init__.py b/tests/utils/stubs/text/__init__.py index a725aa0..171eb6e 100644 --- a/tests/utils/stubs/text/__init__.py +++ b/tests/utils/stubs/text/__init__.py @@ -3,6 +3,7 @@ from tests.utils.stubs.text.fuzz_namespace import FuzzNamespace from tests.utils.stubs.text.paddle_ocr_reader_stub import PaddleOcrReaderStub from tests.utils.stubs.text.paddle_ocr_stub import PaddleOcrStub +from tests.utils.stubs.text.rapid_ocr_stub import RapidOcrEngineStub, RapidOcrStub from tests.utils.stubs.text.thefuzz_stub import TheFuzzStub __all__ = [ @@ -11,5 +12,7 @@ "FuzzNamespace", "PaddleOcrReaderStub", "PaddleOcrStub", + "RapidOcrEngineStub", + "RapidOcrStub", "TheFuzzStub", ] diff --git a/tests/utils/stubs/text/rapid_ocr_stub.py b/tests/utils/stubs/text/rapid_ocr_stub.py new file mode 100644 index 0000000..1aaffc2 --- /dev/null +++ b/tests/utils/stubs/text/rapid_ocr_stub.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any + +import numpy + + +class RapidOcrEngineStub: + instances: list["RapidOcrEngineStub"] = [] + + def __init__(self, **settings: Any) -> None: + self.settings = settings + self.result: Any = [] + self.calls: list[dict[str, Any]] = [] + self.__class__.instances.append(self) + + def __call__(self, image: numpy.ndarray, **settings: Any) -> Any: + self.calls.append({"image": image, "settings": settings}) + return self.result + + +class RapidOcrStub(ModuleType): + def __init__(self) -> None: + super().__init__("rapidocr") + self.reset() + + def reset(self) -> None: + RapidOcrEngineStub.instances = [] + self.RapidOCR = RapidOcrEngineStub