From a9d89ac39de2cd92aafd70258103488db0965522 Mon Sep 17 00:00:00 2001 From: mart-r Date: Wed, 18 Mar 2026 14:21:26 +0000 Subject: [PATCH 1/7] CU-869cgny1k: Add timing for tokenizer as well --- medcat-v2/medcat/pipeline/speed_utils.py | 145 ++++++++++++++++++----- 1 file changed, 117 insertions(+), 28 deletions(-) diff --git a/medcat-v2/medcat/pipeline/speed_utils.py b/medcat-v2/medcat/pipeline/speed_utils.py index bca621fce..0aa29bc12 100644 --- a/medcat-v2/medcat/pipeline/speed_utils.py +++ b/medcat-v2/medcat/pipeline/speed_utils.py @@ -1,8 +1,7 @@ -from typing import Callable, Union, Type, cast +from typing import Callable, Protocol, Union, Type, cast import time import contextlib import logging -from abc import ABC, abstractmethod from io import StringIO import cProfile from pstats import Stats @@ -11,6 +10,7 @@ from medcat.components.addons.addons import AddonComponent from medcat.components.types import BaseComponent, CoreComponent, CoreComponentType from medcat.pipeline import Pipeline +from medcat.tokenizing.tokenizers import BaseTokenizer from medcat.tokenizing.tokens import MutableDocument from medcat.cat import AddonType @@ -20,42 +20,69 @@ logger.addHandler(logging.StreamHandler()) -class BaseTimedComponent(ABC): +class BaseTimedObject: - def __init__(self, component: BaseComponent): + def __init__(self, component: Union[BaseComponent, BaseTokenizer]): self._component = component - @property - def full_name(self): - return self._component.full_name - def __getattr__(self, name: str): if name == '_component': raise AttributeError('_component not set') return getattr(self._component, name) - @abstractmethod - def __call__(self, doc: MutableDocument) -> MutableDocument: - pass + @property + def full_name(self): + if isinstance(self._component, BaseComponent): + return self._component.full_name + else: + return f"Tokenizer:{self._component.__class__.__name__}" def __repr__(self): return f"{self.__class__.__name__}({self._component!r})" -class TimedComponent(BaseTimedComponent): - """Wraps a component and logs the time spent in it.""" +class BaseTimedComponent(Protocol): def __call__(self, doc: MutableDocument) -> MutableDocument: + pass + + +class PerDocTimedObject(BaseTimedObject): + + def time_it(self, to_run: Callable[[], MutableDocument]) -> MutableDocument: start = time.perf_counter() - result = self._component(doc) + result = to_run() elapsed_ms = (time.perf_counter() - start) * 1000 logger.info("Component %s took %.3fms", self.full_name, elapsed_ms) return result -class AveragingTimedComponent(BaseTimedComponent): +class TimedComponent(PerDocTimedObject): + """Wraps a component and logs the time spent in it.""" def __init__(self, component: BaseComponent, + ) -> None: + super().__init__(component) + self._component: BaseComponent + + def __call__(self, doc: MutableDocument) -> MutableDocument: + return self.time_it(lambda: self._component(doc)) + + +class TimedTokenizer(PerDocTimedObject): + + def __init__(self, component: BaseTokenizer, + ) -> None: + super().__init__(component) + self._component: BaseTokenizer + + def __call__(self, text: str) -> MutableDocument: + return self.time_it(lambda: self._component(text)) + + +class AveragingTimedObject(BaseTimedObject): + + def __init__(self, component: Union[BaseComponent, BaseTokenizer], condition: Callable[[int, float], bool]): super().__init__(component) self._condition = condition @@ -78,8 +105,8 @@ def _show_time(self): median_elapsed = statistics.median(self._to_average) max_elapsed = max(self._to_average) time_elapsed = time.perf_counter() - self._last_show - logger.info("Component %s took (min/mean/median/average): " - "%.3fms / %.3fms / %.3fms / %.3fms" + logger.info("Component %s took (min/mean/median/max): " + "%.3fms / %.3fms / %.3fms / %.3fms " "over %d docs and a total of %.3fs", self.full_name, min_elapsed, mean_elapsed, median_elapsed, max_elapsed, @@ -92,6 +119,15 @@ def _maybe_show_time(self, elapsed_ms: float): self._show_time() self._reset() + +class AveragingTimedComponent(AveragingTimedObject): + + def __init__(self, component: BaseComponent, + condition: Callable[[int, float], bool] + ) -> None: + super().__init__(component, condition) + self._component: BaseComponent + def __call__(self, doc: MutableDocument) -> MutableDocument: start = time.perf_counter() result = self._component(doc) @@ -100,19 +136,28 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: return result -class ProfiledComponent(BaseTimedComponent): - - def __init__(self, component: BaseComponent): - super().__init__(component) - self._profiler = cProfile.Profile() +class AveragingTimedTokenizer(AveragingTimedObject): + def __init__(self, component: BaseTokenizer, + condition: Callable[[int, float], bool] + ) -> None: + super().__init__(component, condition) + self._component: BaseTokenizer - def __call__(self, doc: MutableDocument) -> MutableDocument: - self._profiler.enable() - result = self._component(doc) - self._profiler.disable() + def __call__(self, text: str) -> MutableDocument: + start = time.perf_counter() + result = self._component(text) + elapsed_ms = (time.perf_counter() - start) * 1000 + self._maybe_show_time(elapsed_ms) return result + +class ProfiledObject(BaseTimedObject): + + def __init__(self, component: Union[BaseComponent, BaseTokenizer]): + super().__init__(component) + self._profiler = cProfile.Profile() + def _show_type(self, stats_type: str, limit: int): if not self._profiler.getstats(): logger.info("Component %s has no profiling data", self.full_name) @@ -128,10 +173,39 @@ def show_stats(self, limit: int = 20): self._show_type('cumtime', limit) +class ProfiledComponent(ProfiledObject): + + def __init__(self, component: BaseComponent, + ) -> None: + super().__init__(component) + self._component: BaseComponent + + def __call__(self, doc: MutableDocument) -> MutableDocument: + self._profiler.enable() + result = self._component(doc) + self._profiler.disable() + return result + + +class ProfiledTokenizer(ProfiledObject): + + def __init__(self, component: BaseTokenizer, + ) -> None: + super().__init__(component) + self._component: BaseTokenizer + + def __call__(self, text: str) -> MutableDocument: + self._profiler.enable() + result = self._component(text) + self._profiler.disable() + return result + + @contextlib.contextmanager def pipeline_per_doc_timer( pipeline: Pipeline, - timer_init: Callable[[BaseComponent], BaseTimedComponent] = TimedComponent + timer_init: Callable[[BaseComponent], BaseTimedComponent] = TimedComponent, + tknzer_timer_init: Callable[[BaseTokenizer], TimedTokenizer] = TimedTokenizer, ): """Time the pipeline on a per document basis. @@ -139,10 +213,13 @@ def pipeline_per_doc_timer( pipeline (Pipeline): The pipeline to time. timer_init (Callable[[BaseComponent], BaseTimedComponent])): The initialiser for the timer. Defaults to TimedComponent. + tknzer_timer_init (Callable[[BaseTokenizer], TimedTokenizer): The + initialiser for the timer for the tokenizer. Defaults to TimedTokenizer. Yields: Pipeline: The same pipeline. """ + original_tokenizer = pipeline._tokenizer original_components = pipeline._components original_addons = pipeline._addons @@ -153,12 +230,15 @@ def pipeline_per_doc_timer( cast(AddonComponent, timer_init(a)) for a in original_addons] + pipeline._tokenizer = cast( + BaseTokenizer, tknzer_timer_init(original_tokenizer)) pipeline._components = updated_core_components pipeline._addons = updated_addons try: yield pipeline finally: + pipeline._tokenizer = original_tokenizer pipeline._components = original_components pipeline._addons = original_addons @@ -198,6 +278,7 @@ def pipeline_timer_averaging_docs( if show_frequency_secs == -1 and show_frequency_docs == -1: show_frequency_docs = 100 + original_tokenizer = pipeline._tokenizer original_components = pipeline._components original_addons = pipeline._addons @@ -212,16 +293,24 @@ def wrapper_condition(num_docs: int, time_spent: float) -> bool: wrapped_addons = [ AveragingTimedComponent(addon, wrapper_condition) for addon in original_addons] + wrapped_tokenizer = AveragingTimedTokenizer( + original_tokenizer, wrapper_condition) + pipeline._tokenizer = wrapped_tokenizer # type: ignore pipeline._components = wrapped_core_comps # type: ignore pipeline._addons = wrapped_addons # type: ignore try: yield pipeline finally: + pipeline._tokenizer = original_tokenizer pipeline._components = original_components pipeline._addons = original_addons - for comp in [*wrapped_core_comps, *wrapped_addons]: + timed_objects: list[AveragingTimedObject] = [ + wrapped_tokenizer, *wrapped_core_comps, *wrapped_addons + ] + + for comp in timed_objects: if comp._to_average: comp._show_time() comp._reset() From f67016de457ff1f7e1236850eb28561cb31ccd01 Mon Sep 17 00:00:00 2001 From: mart-r Date: Wed, 18 Mar 2026 14:32:43 +0000 Subject: [PATCH 2/7] CU-869cgny1k: Add profiling for tokenizer as well --- medcat-v2/medcat/pipeline/speed_utils.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/medcat-v2/medcat/pipeline/speed_utils.py b/medcat-v2/medcat/pipeline/speed_utils.py index 0aa29bc12..986b22534 100644 --- a/medcat-v2/medcat/pipeline/speed_utils.py +++ b/medcat-v2/medcat/pipeline/speed_utils.py @@ -1,4 +1,4 @@ -from typing import Callable, Protocol, Union, Type, cast +from typing import Callable, Literal, Protocol, Union, Type, cast import time import contextlib import logging @@ -319,7 +319,7 @@ def wrapper_condition(num_docs: int, time_spent: float) -> bool: @contextlib.contextmanager def profile_pipeline_component( pipeline: Pipeline, - comp_type: Union[CoreComponentType, Type[AddonType]], + comp_type: Union[CoreComponentType, Type[AddonType], Literal['tokenizer']], limit: int = 20, ): """Time a specific component of the pipeline. @@ -330,21 +330,23 @@ def profile_pipeline_component( Args: pipeline (Pipeline): The pipeline to time. - comp_type (Union[CoreComponentType, Type[AddonType]]): The type of - component to profile. This can be either a core component - or an addon component. + comp_type (Union[CoreComponentType, Type[AddonType], Literal['tokenizer']]): + The type of component to profile. This can be either a core component + or an addon component, ot the tokenizer. limit (int): The number of function calls to show in output. Defaults to 20. Yields: Pipeline: The same pipeline. """ + original_tokenizer = pipeline._tokenizer original_components = pipeline._components original_addons = pipeline._addons updated_addons: list[AddonComponent] updated_core_comps: list[CoreComponent] if isinstance(comp_type, CoreComponentType): + updated_tokenizer = original_tokenizer changed_comp = pipeline.get_component(comp_type) updated_core_comps = [ comp if comp != changed_comp else @@ -352,7 +354,12 @@ def profile_pipeline_component( for comp in original_components ] updated_addons = original_addons + elif comp_type == 'tokenizer': + updated_tokenizer = cast(BaseTokenizer, ProfiledTokenizer(original_tokenizer)) + updated_core_comps = original_components + updated_addons = original_addons else: + updated_tokenizer = original_tokenizer changed_comps = [ addon for addon in pipeline.iter_addons() if isinstance(addon, comp_type) @@ -364,16 +371,18 @@ def profile_pipeline_component( for addon in original_addons ] profiled_comps = [ - comp for comp in updated_core_comps + updated_addons - if isinstance(comp, ProfiledComponent) + comp for comp in updated_core_comps + updated_addons + [updated_tokenizer,] + if isinstance(comp, ProfiledObject) ] + pipeline._tokenizer = updated_tokenizer pipeline._components = updated_core_comps pipeline._addons = updated_addons try: yield pipeline finally: + pipeline._tokenizer = original_tokenizer pipeline._components = original_components pipeline._addons = original_addons for comp in profiled_comps: From 714590c00e8bc09737843b64f06b3897276e5538 Mon Sep 17 00:00:00 2001 From: mart-r Date: Wed, 18 Mar 2026 15:05:52 +0000 Subject: [PATCH 3/7] CU-869cgny1k: Some protocol changes for clarity --- medcat-v2/medcat/pipeline/speed_utils.py | 34 +++++++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/medcat-v2/medcat/pipeline/speed_utils.py b/medcat-v2/medcat/pipeline/speed_utils.py index 986b22534..02908c85d 100644 --- a/medcat-v2/medcat/pipeline/speed_utils.py +++ b/medcat-v2/medcat/pipeline/speed_utils.py @@ -20,6 +20,18 @@ logger.addHandler(logging.StreamHandler()) +class BaseTimedObjectProtocol(Protocol): + @property + def full_name(self) -> str: + pass + + def __getattr__(self, name: str): + pass + + def __repr__(self) -> str: + pass + + class BaseTimedObject: def __init__(self, component: Union[BaseComponent, BaseTokenizer]): @@ -47,6 +59,18 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: pass +class BaseTimedTokenizer(Protocol): + + def __call__(self, text: str) -> MutableDocument: + pass + +class TimedComponentProtocol(BaseTimedObjectProtocol, BaseTimedComponent, Protocol): + pass + +class TimedTokenizerProtocol(BaseTimedObjectProtocol, BaseTimedTokenizer, Protocol): + pass + + class PerDocTimedObject(BaseTimedObject): def time_it(self, to_run: Callable[[], MutableDocument]) -> MutableDocument: @@ -204,16 +228,18 @@ def __call__(self, text: str) -> MutableDocument: @contextlib.contextmanager def pipeline_per_doc_timer( pipeline: Pipeline, - timer_init: Callable[[BaseComponent], BaseTimedComponent] = TimedComponent, - tknzer_timer_init: Callable[[BaseTokenizer], TimedTokenizer] = TimedTokenizer, + timer_init: Callable[[BaseComponent], + TimedComponentProtocol] = TimedComponent, + tknzer_timer_init: Callable[[BaseTokenizer], + TimedTokenizerProtocol] = TimedTokenizer, ): """Time the pipeline on a per document basis. Args: pipeline (Pipeline): The pipeline to time. - timer_init (Callable[[BaseComponent], BaseTimedComponent])): The + timer_init (Callable[[BaseComponent], TimedComponentProtocol])): The initialiser for the timer. Defaults to TimedComponent. - tknzer_timer_init (Callable[[BaseTokenizer], TimedTokenizer): The + tknzer_timer_init (Callable[[BaseTokenizer], TimedTokenizerProtocol): The initialiser for the timer for the tokenizer. Defaults to TimedTokenizer. Yields: From 4ac5b75ee1d992d6a8038433e81a3dd459bfafa9 Mon Sep 17 00:00:00 2001 From: mart-r Date: Wed, 18 Mar 2026 15:09:54 +0000 Subject: [PATCH 4/7] CU-869cgny1k: Allow honoring user-specified logging --- medcat-v2/medcat/pipeline/speed_utils.py | 33 ++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/medcat-v2/medcat/pipeline/speed_utils.py b/medcat-v2/medcat/pipeline/speed_utils.py index 02908c85d..085e56846 100644 --- a/medcat-v2/medcat/pipeline/speed_utils.py +++ b/medcat-v2/medcat/pipeline/speed_utils.py @@ -16,8 +16,34 @@ logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) -logger.addHandler(logging.StreamHandler()) + +@contextlib.contextmanager +def _with_logging(): + has_stream_handler = any( + type(h) is logging.StreamHandler + for h in logger.handlers + ) + handler = None + original_level = logger.level + if not has_stream_handler: + handler = logging.StreamHandler() + logger.addHandler(handler) + logger.setLevel(logging.INFO) + try: + yield + finally: + if handler is not None: + logger.removeHandler(handler) + logger.setLevel(original_level) + + +def with_logging(func): + @contextlib.wraps(func) + @contextlib.contextmanager + def wrapper(*args, **kwargs): + with _with_logging(): + yield from func(*args, **kwargs) + return wrapper class BaseTimedObjectProtocol(Protocol): @@ -225,6 +251,7 @@ def __call__(self, text: str) -> MutableDocument: return result +@with_logging @contextlib.contextmanager def pipeline_per_doc_timer( pipeline: Pipeline, @@ -269,6 +296,7 @@ def pipeline_per_doc_timer( pipeline._addons = original_addons +@with_logging @contextlib.contextmanager def pipeline_timer_averaging_docs( pipeline: Pipeline, @@ -342,6 +370,7 @@ def wrapper_condition(num_docs: int, time_spent: float) -> bool: comp._reset() +@with_logging @contextlib.contextmanager def profile_pipeline_component( pipeline: Pipeline, From e646910d02799ff1a2e9442e8c5be37df28d075e Mon Sep 17 00:00:00 2001 From: mart-r Date: Wed, 18 Mar 2026 15:14:53 +0000 Subject: [PATCH 5/7] CU-869cgny1k: Fix issue with context manager conflicts --- medcat-v2/medcat/pipeline/speed_utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/medcat-v2/medcat/pipeline/speed_utils.py b/medcat-v2/medcat/pipeline/speed_utils.py index 085e56846..7a230c747 100644 --- a/medcat-v2/medcat/pipeline/speed_utils.py +++ b/medcat-v2/medcat/pipeline/speed_utils.py @@ -37,7 +37,7 @@ def _with_logging(): logger.setLevel(original_level) -def with_logging(func): +def context_manager_with_logging(func): @contextlib.wraps(func) @contextlib.contextmanager def wrapper(*args, **kwargs): @@ -251,8 +251,7 @@ def __call__(self, text: str) -> MutableDocument: return result -@with_logging -@contextlib.contextmanager +@context_manager_with_logging def pipeline_per_doc_timer( pipeline: Pipeline, timer_init: Callable[[BaseComponent], @@ -296,8 +295,7 @@ def pipeline_per_doc_timer( pipeline._addons = original_addons -@with_logging -@contextlib.contextmanager +@context_manager_with_logging def pipeline_timer_averaging_docs( pipeline: Pipeline, show_frequency_docs: int = -1, @@ -370,8 +368,7 @@ def wrapper_condition(num_docs: int, time_spent: float) -> bool: comp._reset() -@with_logging -@contextlib.contextmanager +@context_manager_with_logging def profile_pipeline_component( pipeline: Pipeline, comp_type: Union[CoreComponentType, Type[AddonType], Literal['tokenizer']], From 80a4378bfbe70fc8dd526095e88d493eee26eff0 Mon Sep 17 00:00:00 2001 From: mart-r Date: Wed, 18 Mar 2026 15:26:25 +0000 Subject: [PATCH 6/7] CU-869cgny1k: Fix mocks in existing tests --- medcat-v2/tests/pipeline/test_speed_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/medcat-v2/tests/pipeline/test_speed_utils.py b/medcat-v2/tests/pipeline/test_speed_utils.py index 1a9443d17..2b82a5675 100644 --- a/medcat-v2/tests/pipeline/test_speed_utils.py +++ b/medcat-v2/tests/pipeline/test_speed_utils.py @@ -34,6 +34,7 @@ def make_mock_pipeline(*component_names: str) -> MagicMock: pipeline = MagicMock(spec=Pipeline) pipeline._components = [make_mock_component(n) for n in component_names] pipeline._addons = [] + pipeline._tokenizer = None return pipeline @@ -338,6 +339,7 @@ def _make_core_pipeline(self, comp_type: CoreComponentType) -> MagicMock: comp = make_mock_component() pipeline._components = [comp] pipeline._addons = [] + pipeline._tokenizer = None pipeline.get_component.return_value = comp pipeline.iter_addons.return_value = iter([]) return pipeline @@ -349,6 +351,7 @@ def _make_addon_pipeline(self, addon_type) -> tuple[MagicMock, MagicMock]: addon.side_effect = lambda doc: doc pipeline._components = [] pipeline._addons = [addon] + pipeline._tokenizer = None pipeline.get_component.side_effect = RuntimeError("not a core comp") pipeline.iter_addons.return_value = iter([addon]) return pipeline, addon From 33bee19e896093afb3413bf264bf04331fdec0c3 Mon Sep 17 00:00:00 2001 From: mart-r Date: Wed, 18 Mar 2026 15:33:46 +0000 Subject: [PATCH 7/7] CU-869cgny1k: Add new tests to tokenizer timings and other new stuff --- medcat-v2/tests/pipeline/test_speed_utils.py | 181 ++++++++++++++++++- 1 file changed, 180 insertions(+), 1 deletion(-) diff --git a/medcat-v2/tests/pipeline/test_speed_utils.py b/medcat-v2/tests/pipeline/test_speed_utils.py index 2b82a5675..2bc05ff01 100644 --- a/medcat-v2/tests/pipeline/test_speed_utils.py +++ b/medcat-v2/tests/pipeline/test_speed_utils.py @@ -2,22 +2,26 @@ import time import unittest from unittest.mock import MagicMock, patch +from medcat.tokenizing.tokenizers import BaseTokenizer from medcat.tokenizing.tokens import MutableDocument from medcat.pipeline import Pipeline from medcat.components.types import BaseComponent - +import logging import cProfile from unittest.mock import patch, MagicMock from medcat.components.types import CoreComponentType from medcat.pipeline.speed_utils import ( + AveragingTimedTokenizer, TimedComponent, AveragingTimedComponent, + TimedTokenizer, pipeline_per_doc_timer, pipeline_timer_averaging_docs, ProfiledComponent, profile_pipeline_component, + logger, ) @@ -428,5 +432,180 @@ def test_show_stats_called_on_exit_after_exception(self, mock_logger): self.assertEqual(mock_logger.info.call_count, 2) +def make_mock_tokenizer(name: str = "test_tokenizer") -> MagicMock: + tokenizer = MagicMock(spec=BaseTokenizer) + tokenizer.side_effect = lambda text: make_mock_doc() + return tokenizer + + +def make_mock_pipeline_with_tokenizer(*component_names: str) -> MagicMock: + pipeline = make_mock_pipeline(*component_names) + pipeline._tokenizer = make_mock_tokenizer() + return pipeline + + +class TestTimedTokenizer(unittest.TestCase): + + def test_underlying_tokenizer_called(self): + tokenizer = make_mock_tokenizer() + timed = TimedTokenizer(tokenizer) + timed("some text") + tokenizer.assert_called_once_with("some text") + + def test_returns_result_of_underlying_tokenizer(self): + tokenizer = make_mock_tokenizer() + expected = make_mock_doc() + tokenizer.side_effect = lambda text: expected + timed = TimedTokenizer(tokenizer) + result = timed("some text") + self.assertIs(result, expected) + + def test_full_name_includes_class_name(self): + tokenizer = make_mock_tokenizer() + timed = TimedTokenizer(tokenizer) + self.assertIn("Tokenizer", timed.full_name) + + @patch("medcat.pipeline.speed_utils.logger") + def test_logs_once_per_call(self, mock_logger): + tokenizer = make_mock_tokenizer() + timed = TimedTokenizer(tokenizer) + timed("text one") + timed("text two") + self.assertEqual(mock_logger.info.call_count, 2) + + +class TestAveragingTimedTokenizer(unittest.TestCase): + + def _every_n(self, n: int): + return lambda num_docs, time_spent: num_docs >= n + + def test_underlying_tokenizer_called(self): + tokenizer = make_mock_tokenizer() + timed = AveragingTimedTokenizer(tokenizer, self._every_n(1)) + timed("some text") + tokenizer.assert_called_once_with("some text") + + @patch("medcat.pipeline.speed_utils.logger") + def test_logs_after_n_calls(self, mock_logger): + tokenizer = make_mock_tokenizer() + timed = AveragingTimedTokenizer(tokenizer, self._every_n(3)) + for _ in range(3): + timed("text") + mock_logger.info.assert_called_once() + + @patch("medcat.pipeline.speed_utils.logger") + def test_does_not_log_before_n_calls(self, mock_logger): + tokenizer = make_mock_tokenizer() + timed = AveragingTimedTokenizer(tokenizer, self._every_n(3)) + for _ in range(2): + timed("text") + mock_logger.info.assert_not_called() + + +class TestTokenizerInPipelinePerDocTimer(unittest.TestCase): + + def test_tokenizer_replaced_inside_context(self): + pipeline = make_mock_pipeline_with_tokenizer("comp_a") + original = pipeline._tokenizer + with pipeline_per_doc_timer(pipeline): + self.assertIsInstance(pipeline._tokenizer, TimedTokenizer) + self.assertIsNot(pipeline._tokenizer, original) + + def test_tokenizer_restored_after_context(self): + pipeline = make_mock_pipeline_with_tokenizer("comp_a") + original = pipeline._tokenizer + with pipeline_per_doc_timer(pipeline): + pass + self.assertIs(pipeline._tokenizer, original) + + def test_tokenizer_restored_after_exception(self): + pipeline = make_mock_pipeline_with_tokenizer("comp_a") + original = pipeline._tokenizer + with self.assertRaises(RuntimeError): + with pipeline_per_doc_timer(pipeline): + raise RuntimeError("boom") + self.assertIs(pipeline._tokenizer, original) + + def test_underlying_tokenizer_called(self): + pipeline = make_mock_pipeline_with_tokenizer("comp_a") + original_tokenizer = pipeline._tokenizer + with pipeline_per_doc_timer(pipeline): + pipeline._tokenizer("some text") + original_tokenizer.assert_called_once_with("some text") + + +class TestTokenizerInPipelineTimerAveragingDocs(unittest.TestCase): + + def test_tokenizer_replaced_inside_context(self): + pipeline = make_mock_pipeline_with_tokenizer("comp_a") + original = pipeline._tokenizer + with pipeline_timer_averaging_docs(pipeline, show_frequency_docs=10): + self.assertIsInstance(pipeline._tokenizer, AveragingTimedTokenizer) + self.assertIsNot(pipeline._tokenizer, original) + + def test_tokenizer_restored_after_context(self): + pipeline = make_mock_pipeline_with_tokenizer("comp_a") + original = pipeline._tokenizer + with pipeline_timer_averaging_docs(pipeline, show_frequency_docs=10): + pass + self.assertIs(pipeline._tokenizer, original) + + @patch("medcat.pipeline.speed_utils.logger") + def test_tokenizer_flushed_on_exit(self, mock_logger): + pipeline = make_mock_pipeline_with_tokenizer() + with pipeline_timer_averaging_docs(pipeline, show_frequency_docs=100): + for _ in range(5): + pipeline._tokenizer("text") + mock_logger.info.assert_called_once() + + +class TestWithLogging(unittest.TestCase): + + def test_stream_handler_added_when_none_present(self): + logger.handlers.clear() + with pipeline_per_doc_timer(make_mock_pipeline_with_tokenizer()): + self.assertTrue( + any(type(h) is logging.StreamHandler for h in logger.handlers)) + + def test_stream_handler_removed_after_context(self): + logger.handlers.clear() + with pipeline_per_doc_timer(make_mock_pipeline_with_tokenizer()): + pass + self.assertFalse( + any(type(h) is logging.StreamHandler for h in logger.handlers)) + + def test_stream_handler_not_added_if_already_present(self): + logger.handlers.clear() + existing = logging.StreamHandler() + logger.addHandler(existing) + with pipeline_per_doc_timer(make_mock_pipeline_with_tokenizer()): + stream_handlers = [ + h for h in logger.handlers + if type(h) is logging.StreamHandler + ] + self.assertEqual(len(stream_handlers), 1) + logger.removeHandler(existing) + + def test_log_level_restored_after_context(self): + original_level = logger.level + with pipeline_per_doc_timer(make_mock_pipeline_with_tokenizer()): + pass + self.assertEqual(logger.level, original_level) + + def test_log_level_restored_after_exception(self): + original_level = logger.level + with self.assertRaises(RuntimeError): + with pipeline_per_doc_timer(make_mock_pipeline_with_tokenizer()): + raise RuntimeError("boom") + self.assertEqual(logger.level, original_level) + + def test_stream_handler_removed_after_exception(self): + logger.handlers.clear() + with self.assertRaises(RuntimeError): + with pipeline_per_doc_timer(make_mock_pipeline_with_tokenizer()): + raise RuntimeError("boom") + self.assertFalse( + any(type(h) is logging.StreamHandler for h in logger.handlers)) + if __name__ == "__main__": unittest.main()