diff --git a/python/lib/sift_client/_internal/grpc/__init__.py b/python/lib/sift_client/_internal/grpc/__init__.py new file mode 100644 index 000000000..738259dc8 --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/__init__.py @@ -0,0 +1,15 @@ +""" +This module is primarily concerned with configuring and initializing gRPC connections to the Sift API. + +Example of establishing a connection to Sift's gRPC APi: + +```python +from sift_client._internal.grpc.transport import SiftChannelConfig, use_sift_channel + +# Be sure not to include the url scheme i.e. 'https://' in the uri. +sift_channel_config = SiftChannelConfig(uri=SIFT_BASE_URI, apikey=SIFT_API_KEY) + +with use_sift_channel(sift_channel_config) as channel: + # Connect to Sift +``` +""" diff --git a/python/lib/sift_client/_internal/grpc/_async_interceptors/__init__.py b/python/lib/sift_client/_internal/grpc/_async_interceptors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/_internal/grpc/_async_interceptors/base.py b/python/lib/sift_client/_internal/grpc/_async_interceptors/base.py new file mode 100644 index 000000000..f2b8ce908 --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/_async_interceptors/base.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, AsyncIterable, Callable, Iterable, TypeVar + +from grpc import aio as grpc_aio + +CallType = TypeVar("CallType", bound=grpc_aio.Call) +Continuation = Callable[[grpc_aio.ClientCallDetails, Any], CallType] + + +class ClientAsyncInterceptor( + grpc_aio.UnaryUnaryClientInterceptor, + grpc_aio.UnaryStreamClientInterceptor, + grpc_aio.StreamUnaryClientInterceptor, + grpc_aio.StreamStreamClientInterceptor, +): + @abstractmethod + async def intercept( + self, + method: Callable, + request_or_iterator: Any, + client_call_details: grpc_aio.ClientCallDetails, + ) -> Any: + pass + + async def intercept_unary_unary( + self, + continuation: Continuation[grpc_aio.UnaryUnaryCall], + client_call_details: grpc_aio.ClientCallDetails, + request: Any, + ): + return await self.intercept(_async_swap_args(continuation), request, client_call_details) + + async def intercept_unary_stream( + self, + continuation: Continuation[grpc_aio.UnaryStreamCall], + client_call_details: grpc_aio.ClientCallDetails, + request: Any, + ): + return await self.intercept(_async_swap_args(continuation), request, client_call_details) + + async def intercept_stream_unary( + self, + continuation: Continuation[grpc_aio.StreamUnaryCall], + client_call_details: grpc_aio.ClientCallDetails, + request_iterator: Iterable[Any] | AsyncIterable[Any], + ): + return await self.intercept( + _async_swap_args(continuation), request_iterator, client_call_details + ) + + async def intercept_stream_stream( + self, + continuation: Continuation[grpc_aio.StreamStreamCall], + client_call_details: grpc_aio.ClientCallDetails, + request_iterator: Iterable[Any] | AsyncIterable[Any], + ): + return await self.intercept( + _async_swap_args(continuation), request_iterator, client_call_details + ) + + +def _async_swap_args(fn: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]: + """ + Continuations are typed in such a way that details are the first argument, and the request second. + Code generated from protobuf however takes in the request first, then the details. Weird grpc library + quirk. This utility just flips the arguments. + """ + + async def new_fn(x, y): + return await fn(y, x) + + return new_fn diff --git a/python/lib/sift_client/_internal/grpc/_async_interceptors/metadata.py b/python/lib/sift_client/_internal/grpc/_async_interceptors/metadata.py new file mode 100644 index 000000000..95cc5a925 --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/_async_interceptors/metadata.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Any, Callable, List, Tuple, cast + +from grpc import aio as grpc_aio + +from sift_client._internal.grpc._async_interceptors.base import ClientAsyncInterceptor + +Metadata = List[Tuple[str, str]] + + +class MetadataAsyncInterceptor(ClientAsyncInterceptor): + metadata: Metadata + + """ + Interceptor to add metadata to all async unary and streaming RPCs + """ + + def __init__(self, metadata: Metadata): + self.metadata = metadata + + async def intercept( + self, + method: Callable, + request_or_iterator: Any, + client_call_details: grpc_aio.ClientCallDetails, + ): + call_details = cast("grpc_aio.ClientCallDetails", client_call_details) + new_details = grpc_aio.ClientCallDetails( + call_details.method, + call_details.timeout, + self.metadata, + call_details.credentials, + call_details.wait_for_ready, + ) + return await method(request_or_iterator, new_details) diff --git a/python/lib/sift_client/_internal/grpc/_interceptors/__init__.py b/python/lib/sift_client/_internal/grpc/_interceptors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/_internal/grpc/_interceptors/base.py b/python/lib/sift_client/_internal/grpc/_interceptors/base.py new file mode 100644 index 000000000..e0a86f8f8 --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/_interceptors/base.py @@ -0,0 +1,61 @@ +from abc import abstractmethod +from typing import Any, Callable, Iterator + +import grpc + +Continuation = Callable[[grpc.ClientCallDetails, Any], Any] + + +class ClientInterceptor( + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.UnaryUnaryClientInterceptor, +): + @abstractmethod + def intercept( + self, + method: Continuation, + request_or_iterator: Any, + client_call_details: grpc.ClientCallDetails, + ): + pass + + def intercept_unary_unary( + self, + continuation: Continuation, + client_call_details: grpc.ClientCallDetails, + request: Any, + ): + return self.intercept(_swap_args(continuation), request, client_call_details) + + def intercept_stream_unary( + self, + continuation: Continuation, + client_call_details: grpc.ClientCallDetails, + request_iterator: Iterator[Any], + ): + return self.intercept(_swap_args(continuation), request_iterator, client_call_details) + + def intercept_unary_stream( + self, + continuation: Continuation, + client_call_details: grpc.ClientCallDetails, + request: Any, + ): + return self.intercept(_swap_args(continuation), request, client_call_details) + + def intercept_stream_stream( + self, + continuation: Continuation, + client_call_details: grpc.ClientCallDetails, + request_iterator: Iterator[Any], + ): + return self.intercept(_swap_args(continuation), request_iterator, client_call_details) + + +def _swap_args(fn: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]: + def new_fn(x, y): + return fn(y, x) + + return new_fn diff --git a/python/lib/sift_client/_internal/grpc/_interceptors/context.py b/python/lib/sift_client/_internal/grpc/_interceptors/context.py new file mode 100644 index 000000000..a45c6d8ab --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/_interceptors/context.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Sequence + +import grpc + + +class ClientCallDetails(grpc.ClientCallDetails): + method: str + timeout: float | None + metadata: Sequence[tuple[str, str | bytes]] | None + credentials: grpc.CallCredentials | None + wait_for_ready: bool | None + + def __init__( + self, + method: str, + timeout: float | None, + metadata: Sequence[tuple[str, str | bytes]] | None, + credentials: grpc.CallCredentials | None, + wait_for_ready: bool | None, + ): + self.method = method + self.timeout = timeout + self.metadata = metadata + self.credentials = credentials + self.wait_for_ready = wait_for_ready diff --git a/python/lib/sift_client/_internal/grpc/_interceptors/metadata.py b/python/lib/sift_client/_internal/grpc/_interceptors/metadata.py new file mode 100644 index 000000000..afb5da50c --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/_interceptors/metadata.py @@ -0,0 +1,33 @@ +from typing import Any, List, Tuple, cast + +import grpc + +from sift_client._internal.grpc._interceptors.base import ClientInterceptor, Continuation +from sift_client._internal.grpc._interceptors.context import ClientCallDetails + +Metadata = List[Tuple[str, str]] + + +class MetadataInterceptor(ClientInterceptor): + metadata: Metadata + + def __init__(self, metadata: Metadata): + self.metadata = metadata + + def intercept( + self, + method: Continuation, + request_or_iterator: Any, + client_call_details: grpc.ClientCallDetails, + ): + details = cast("ClientCallDetails", client_call_details) + + new_details = ClientCallDetails( + method=details.method, + timeout=details.timeout, + credentials=details.credentials, + wait_for_ready=details.wait_for_ready, + metadata=self.metadata, + ) + + return method(request_or_iterator, new_details) diff --git a/python/lib/sift_client/_internal/grpc/_retry.py b/python/lib/sift_client/_internal/grpc/_retry.py new file mode 100644 index 000000000..78eca2324 --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/_retry.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import json +from typing import ClassVar, TypedDict + +from grpc import StatusCode +from typing_extensions import Self + + +class RetryPolicy: + """ + Retry policy meant to be used for `sift_py.grpc.transport.SiftChannel`. Users may have the ability to configure their own + custom retry policy in the future, but for now this is primarily intended for internal use. + + - [Retry policy schema](https://github.com/grpc/grpc-proto/blob/ec30f589e2519d595688b9a42f88a91bdd6b733f/grpc/service_config/service_config.proto#L136) + - [Enable gRPC retry option](https://github.com/grpc/grpc/blob/9a5fdfc3d3a7fc575a394360be4532ee09a85620/include/grpc/impl/channel_arg_names.h#L311) + - [Service config option](https://github.com/grpc/grpc/blob/9a5fdfc3d3a7fc575a394360be4532ee09a85620/include/grpc/impl/channel_arg_names.h#L207) + """ + + config: RetryConfig + + DEFAULT_POLICY: ClassVar[RetryConfig] = { + "methodConfig": [ + { + # We can configure this on a per-service and RPC basis but for now we'll + # apply this across all services and RPCs. + "name": [{}], + "retryPolicy": { + # gRPC does not allow more than 5 attempts + "maxAttempts": 5, + "initialBackoff": "0.05s", + "maxBackoff": "5s", + "backoffMultiplier": 4, + "retryableStatusCodes": [ + StatusCode.INTERNAL.name, + StatusCode.UNKNOWN.name, + StatusCode.UNAVAILABLE.name, + StatusCode.ABORTED.name, + StatusCode.DEADLINE_EXCEEDED.name, + ], + }, + } + ] + } + + def __init__(self, config: RetryConfig): + self.config = config + + def as_json(self) -> str: + return json.dumps(self.config) + + @classmethod + def default(cls) -> Self: + return cls(config=cls.DEFAULT_POLICY) + + +class RetryConfig(TypedDict): + methodConfig: list[MethodConfigDict] + + +class MethodConfigDict(TypedDict): + name: list[dict[str, str]] + retryPolicy: RetryConfigDict + + +class RetryConfigDict(TypedDict): + maxAttempts: int + initialBackoff: str + maxBackoff: str + backoffMultiplier: int + retryableStatusCodes: list[str] diff --git a/python/lib/sift_client/_internal/grpc/keepalive.py b/python/lib/sift_client/_internal/grpc/keepalive.py new file mode 100644 index 000000000..e2997e153 --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/keepalive.py @@ -0,0 +1,34 @@ +from typing import TypedDict + +DEFAULT_KEEPALIVE_TIME_MS = 20_000 +"""Interval with which to send keepalive pings""" + +DEFAULT_KEEPALIVE_TIMEOUT_MS = 20_000 +"""Timeout while waiting for server to acknowledge keepalive ping""" + +DEFAULT_KEEPALIVE_PERMIT_WITHOUT_CALLS = 1 +"""Allows connection without any active RPCs""" + +DEFAULT_MAX_PINGS_WITHOUT_DATA = 0 +"""Disabled""" + + +# https://github.com/grpc/grpc/blob/master/doc/keepalive.md +class KeepaliveConfig(TypedDict): + """ + Make make this public in the future to allow folks to configure their own keepalive settings + if there is demand for it. + """ + + keepalive_time_ms: int + keepalive_timeout_ms: int + keepalive_permit_without_calls: int + max_pings_without_data: int + + +DEFAULT_KEEPALIVE_CONFIG: KeepaliveConfig = { + "keepalive_time_ms": DEFAULT_KEEPALIVE_TIME_MS, + "keepalive_timeout_ms": DEFAULT_KEEPALIVE_TIMEOUT_MS, + "keepalive_permit_without_calls": DEFAULT_KEEPALIVE_PERMIT_WITHOUT_CALLS, + "max_pings_without_data": DEFAULT_MAX_PINGS_WITHOUT_DATA, +} diff --git a/python/lib/sift_client/_internal/grpc/server_interceptors/__init__.py b/python/lib/sift_client/_internal/grpc/server_interceptors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/_internal/grpc/server_interceptors/server_interceptor.py b/python/lib/sift_client/_internal/grpc/server_interceptors/server_interceptor.py new file mode 100644 index 000000000..9715ed0d9 --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/server_interceptors/server_interceptor.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import abc +from typing import Any, Callable, cast + +import grpc + + +class ServerInterceptor(grpc.ServerInterceptor, metaclass=abc.ABCMeta): + @abc.abstractmethod + def intercept( + self, + method: Callable, + request_or_iterator: Any, + context: grpc.ServicerContext, + method_name: str, + ) -> Any: + return method(request_or_iterator, context) + + def intercept_service(self, continuation, handler_call_details): + next_handler = continuation(handler_call_details) + if next_handler is None: + return + + handler_factory, next_handler_method = _get_factory_and_method(next_handler) + + def invoke_intercept_method(request_or_iterator, context): + method_name = handler_call_details.method + return self.intercept( + next_handler_method, + request_or_iterator, + context, + method_name, + ) + + return handler_factory( + invoke_intercept_method, + request_deserializer=next_handler.request_deserializer, + response_serializer=next_handler.response_serializer, + ) + + +class _RpcHandler(grpc.RpcMethodHandler): + unary_unary: Callable | None + unary_stream: Callable | None + stream_unary: Callable | None + stream_stream: Callable | None + + +def _get_factory_and_method( + rpc_handler: grpc.RpcMethodHandler, +) -> tuple[Callable, Callable]: + handler = cast("_RpcHandler", rpc_handler) + + if handler.unary_unary: + return grpc.unary_unary_rpc_method_handler, handler.unary_unary + elif handler.unary_stream: + return grpc.unary_stream_rpc_method_handler, handler.unary_stream + elif handler.stream_unary: + return grpc.stream_unary_rpc_method_handler, handler.stream_unary + elif handler.stream_stream: + return grpc.stream_stream_rpc_method_handler, handler.stream_stream + else: + raise Exception("Unreachable") diff --git a/python/lib/sift_client/_internal/grpc/transport.py b/python/lib/sift_client/_internal/grpc/transport.py new file mode 100644 index 000000000..1043245a8 --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/transport.py @@ -0,0 +1,252 @@ +""" +This module is concerned with creating a gRPC transport channel specifically for +interacting with Sift's gRPC API. the `use_sift_channel` method creates said channel +and should generally be used within a with-block for correct resource management. +""" + +from __future__ import annotations + +from importlib.metadata import PackageNotFoundError, version +from typing import TYPE_CHECKING, Any, TypedDict, cast +from urllib.parse import ParseResult, urlparse + +import grpc +import grpc.aio as grpc_aio +from typing_extensions import NotRequired, TypeAlias + +from sift_client._internal.grpc._async_interceptors.metadata import MetadataAsyncInterceptor +from sift_client._internal.grpc._interceptors.metadata import Metadata, MetadataInterceptor + +if TYPE_CHECKING: + from sift_client._internal.grpc._async_interceptors.base import ClientAsyncInterceptor + from sift_client._internal.grpc._interceptors.base import ClientInterceptor +from sift_client._internal.grpc._retry import RetryPolicy +from sift_client._internal.grpc.keepalive import DEFAULT_KEEPALIVE_CONFIG, KeepaliveConfig + +SiftChannel: TypeAlias = grpc.Channel +SiftAsyncChannel: TypeAlias = grpc_aio.Channel + + +def get_ssl_credentials(cert_via_openssl: bool) -> grpc.ChannelCredentials: + """ + Returns SSL credentials for use with gRPC. + Workaround for this issue: https://github.com/grpc/grpc/issues/29682 + """ + if not cert_via_openssl: + return grpc.ssl_channel_credentials() + + try: + import ssl + + from OpenSSL import crypto + + ssl_context = ssl.create_default_context() + certs_der = ssl_context.get_ca_certs(binary_form=True) + certs_x509 = [crypto.load_certificate(crypto.FILETYPE_ASN1, x) for x in certs_der] + certs_pem = [crypto.dump_certificate(crypto.FILETYPE_PEM, x) for x in certs_x509] + certs_bytes = b"".join(certs_pem) + + return grpc.ssl_channel_credentials(certs_bytes) + except ImportError as e: + raise Exception( + "Missing required dependencies for cert_via_openssl. Run `pip install sift-stack-py[openssl]` to install the required dependencies." + ) from e + + +def use_sift_channel( + config: SiftChannelConfig, metadata: dict[str, Any] | None = None +) -> SiftChannel: + """ + Returns an intercepted channel that is meant to be used across all services that + make RPCs to Sift's API. It is highly encouraged to use this within a with-block + for correct resource clean-up. + + Should an RPC fail for a reason that isn't explicitly controlled by Sift, `SiftChannel` + will automatically leverage gRPC's retry mechanism to try and recover until the max-attempts + are exceeded, after which the underlying exception will be raised. + """ + use_ssl = config.get("use_ssl", True) + cert_via_openssl = config.get("cert_via_openssl", False) + + if not use_ssl: + return _use_insecure_sift_channel(config, metadata) + + credentials = get_ssl_credentials(cert_via_openssl) + options = _compute_channel_options(config) + api_uri = _clean_uri(config["uri"], use_ssl) + channel = grpc.secure_channel(api_uri, credentials, options) + interceptors = _compute_sift_interceptors(config, metadata) + return grpc.intercept_channel(channel, *interceptors) + + +def use_sift_async_channel( + config: SiftChannelConfig, metadata: dict[str, Any] | None = None +) -> SiftAsyncChannel: + """ + Like `use_sift_channel` but returns a channel meant to be used within the context + of an async runtime when asynchonous I/O is required. + """ + use_ssl = config.get("use_ssl", True) + cert_via_openssl = config.get("cert_via_openssl", False) + + if not use_ssl: + return _use_insecure_sift_async_channel(config, metadata) + + return grpc_aio.secure_channel( + target=_clean_uri(config["uri"], use_ssl), + credentials=get_ssl_credentials(cert_via_openssl), + options=_compute_channel_options(config), + interceptors=_compute_sift_async_interceptors(config, metadata), + ) + + +def _use_insecure_sift_channel( + config: SiftChannelConfig, metadata: dict[str, Any] | None = None +) -> SiftChannel: + """ + FOR DEVELOPMENT PURPOSES ONLY + """ + options = _compute_channel_options(config) + api_uri = _clean_uri(config["uri"], False) + channel = grpc.insecure_channel(api_uri, options) + interceptors = _compute_sift_interceptors(config, metadata) + return grpc.intercept_channel(channel, *interceptors) + + +def _use_insecure_sift_async_channel( + config: SiftChannelConfig, metadata: dict[str, Any] | None = None +) -> SiftAsyncChannel: + """ + FOR DEVELOPMENT PURPOSES ONLY + """ + return grpc_aio.insecure_channel( + target=_clean_uri(config["uri"], False), + options=_compute_channel_options(config), + interceptors=_compute_sift_async_interceptors(config, metadata), + ) + + +def _compute_sift_interceptors( + config: SiftChannelConfig, metadata: dict[str, Any] | None = None +) -> list[ClientInterceptor]: + """ + Initialized all interceptors here. + """ + return [ + _metadata_interceptor(config, metadata), + ] + + +def _compute_sift_async_interceptors( + config: SiftChannelConfig, metadata: dict[str, Any] | None = None +) -> list[grpc_aio.ClientInterceptor]: + return [ + _metadata_async_interceptor(config, metadata), + ] + + +def _compute_channel_options(opts: SiftChannelConfig) -> list[tuple[str, Any]]: + """ + Initialize all [channel options](https://github.com/grpc/grpc/blob/v1.64.x/include/grpc/impl/channel_arg_names.h) here. + """ + + options = [ + ("grpc.enable_retries", 1), + ("grpc.service_config", RetryPolicy.default().as_json()), + # Primary cannot be overriden: + # https://github.com/grpc/grpc/blob/0498194240f55d7f4b12633ad01339fb690621bf/src/core/ext/filters/http/client/http_client_filter.cc#L97 + ("grpc.secondary_user_agent", _compute_user_agent()), + ] + + enable_keepalive = opts.get("enable_keepalive", True) + if isinstance(enable_keepalive, dict): + config = cast("KeepaliveConfig", enable_keepalive) + options.extend(_compute_keep_alive_channel_opts(config)) + elif enable_keepalive: + options.extend(_compute_keep_alive_channel_opts(DEFAULT_KEEPALIVE_CONFIG)) + + return options + + +def _metadata_interceptor( + config: SiftChannelConfig, metadata: dict[str, Any] | None = None +) -> ClientInterceptor: + """ + Any new metadata goes here. + """ + apikey = config["apikey"] + md: Metadata = [("authorization", f"Bearer {apikey}")] + + if metadata: + for key, val in metadata.items(): + md.append((key, val)) + + return MetadataInterceptor(md) + + +def _metadata_async_interceptor( + config: SiftChannelConfig, metadata: dict[str, Any] | None = None +) -> ClientAsyncInterceptor: + """ + Any new metadata goes here for unary-unary calls. + """ + apikey = config["apikey"] + md: Metadata = [("authorization", f"Bearer {apikey}")] + + if metadata: + for key, val in metadata.items(): + md.append((key, val)) + + return MetadataAsyncInterceptor(md) + + +def _clean_uri(uri: str, use_ssl: bool) -> str: + """ + This will automatically transform the URI to an acceptable form regardless of whether or not + users included the scheme in the URL or included trailing slashes. + """ + + if "http://" in uri or "https://" in uri: + parsed: ParseResult = urlparse(uri) + return parsed.netloc + + full_uri = f"https://{uri}" if use_ssl else f"http://{uri}" + parsed_res: ParseResult = urlparse(full_uri) + return parsed_res.netloc + + +def _compute_user_agent() -> str: + try: + return f"sift_stack_py/{version('sift_stack_py')}" + except PackageNotFoundError: + return "sift-stack-py" + + +def _compute_keep_alive_channel_opts(config: KeepaliveConfig) -> list[tuple[str, int]]: + return [ + ("grpc.keepalive_time_ms", config["keepalive_time_ms"]), + ("grpc.keepalive_timeout_ms", config["keepalive_timeout_ms"]), + ("grpc.http2.max_pings_without_data", config["max_pings_without_data"]), + ("grpc.keepalive_permit_without_calls", config["keepalive_permit_without_calls"]), + ] + + +class SiftChannelConfig(TypedDict): + """ + Config class used to instantiate a `SiftChannel` via `use_sift_channel`. + - `uri`: The URI of Sift's gRPC API. The scheme portion of the URI i.e. `https://` should be ommitted. + - `apikey`: User-generated API key generated via the Sift application. + - `enable_keepalive`: Enabled by default, but can be disabled by passing in `False`. HTTP/2 keep-alive prevents connections from + being terminated during idle periods. A custom `sift_py.grpc.keepalive.KeepaliveConfig` may also be provided. + - `use_ssl`: INTERNAL USE. Meant to be used for local development. + - `cert_via_openssl`: Enable this if you want to use OpenSSL to load the certificates. + Run `pip install sift-stack-py[openssl]` to install the dependencies required to use this option. + This works around this issue with grpc loading SSL certificates: https://github.com/grpc/grpc/issues/29682. + Default is False. + """ + + uri: str + apikey: str + enable_keepalive: NotRequired[bool | KeepaliveConfig] + use_ssl: NotRequired[bool] + cert_via_openssl: NotRequired[bool] diff --git a/python/lib/sift_client/_internal/grpc/transport_test.py b/python/lib/sift_client/_internal/grpc/transport_test.py new file mode 100644 index 000000000..efccb6b4e --- /dev/null +++ b/python/lib/sift_client/_internal/grpc/transport_test.py @@ -0,0 +1,216 @@ +# ruff: noqa: N802 + +import re +from concurrent import futures +from contextlib import contextmanager +from typing import Any, Callable, Iterator, cast + +import grpc +import pytest +from pytest_mock import MockFixture, MockType +from sift.data.v2.data_pb2 import GetDataRequest, GetDataResponse +from sift.data.v2.data_pb2_grpc import ( + DataServiceServicer, + DataServiceStub, + add_DataServiceServicer_to_server, +) + +from sift_client._internal.grpc.server_interceptors.server_interceptor import ServerInterceptor +from sift_client._internal.grpc.transport import SiftChannelConfig, use_sift_channel + + +class DataService(DataServiceServicer): + def GetData(self, request: GetDataRequest, context: grpc.ServicerContext): + return GetDataResponse(next_page_token="next-page-token") + + +class AuthInterceptor(ServerInterceptor): + AUTH_REGEX = re.compile(r"^Bearer (.+)$") + + def intercept( + self, + method: Callable, + request_or_iterator: Any, + context: grpc.ServicerContext, + method_name: str, + ) -> Any: + authenticated = False + for metadata in context.invocation_metadata(): + if metadata.key == "authorization": + auth = self.__class__.AUTH_REGEX.match(metadata.value) + + if auth is not None and len(auth.group(1)) > 0: + authenticated = True + + break + + if authenticated: + return method(request_or_iterator, context) + else: + context.set_code(grpc.StatusCode.UNAUTHENTICATED) + context.set_details("Invalid or missing API key") + raise + + +class ForceFailInterceptor(ServerInterceptor): + """ + Force RPC to fail a few times before letting it pass. + + `failed_attempts`: Count of how many times failed + `expected_num_fails`: How many times you want call to fail + """ + + failed_attempts: int + expected_num_fails: int + failure_code: grpc.StatusCode + + def __init__( + self, expected_num_fails: int, failure_code: grpc.StatusCode = grpc.StatusCode.UNKNOWN + ): + self.expected_num_fails = expected_num_fails + self.failed_attempts = 0 + self.failure_code = failure_code + super().__init__() + + def intercept( + self, + method: Callable, + request_or_iterator: Any, + context: grpc.ServicerContext, + method_name: str, + ) -> Any: + if self.failed_attempts < self.expected_num_fails: + self.failed_attempts += 1 + context.set_code(self.failure_code) + context.set_details("something unknown happened") + raise + + return method(request_or_iterator, context) + + +@contextmanager +def server_spy(mocker: MockFixture, *interceptors: ServerInterceptor) -> Iterator[MockType]: + server = grpc.server( + thread_pool=futures.ThreadPoolExecutor(max_workers=1), interceptors=list(interceptors) + ) + + data_service = DataService() + spy = mocker.spy(data_service, "GetData") + + add_DataServiceServicer_to_server(data_service, server) + server.add_insecure_port("[::]:50052") + server.start() + try: + yield spy + finally: + server.stop(None) + server.wait_for_termination() + + +def test_sift_channel(mocker: MockFixture): + with server_spy(mocker, AuthInterceptor()) as get_data_spy: + sift_channel_config_a: SiftChannelConfig = { + "uri": "localhost:50052", + "apikey": "", + "use_ssl": False, + } + + with use_sift_channel(sift_channel_config_a) as channel: + stub = DataServiceStub(channel) + with pytest.raises(grpc.RpcError, match="UNAUTHENTICATED"): + _ = cast("GetDataResponse", stub.GetData(GetDataRequest())) + + get_data_spy.assert_not_called() + + sift_channel_config_b: SiftChannelConfig = { + "uri": "localhost:50052", + "apikey": "some-token", + "use_ssl": False, + } + + with use_sift_channel(sift_channel_config_b) as channel: + stub = DataServiceStub(channel) + res = cast("GetDataResponse", stub.GetData(GetDataRequest())) + assert res.next_page_token == "next-page-token" + get_data_spy.assert_called_once() + + force_fail_interceptor = ForceFailInterceptor(4) + with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy: + sift_channel_config_c: SiftChannelConfig = { + "uri": "localhost:50052", + "apikey": "some-token", + "use_ssl": False, + } + + with use_sift_channel(sift_channel_config_c) as channel: + stub = DataServiceStub(channel) + # This will attempt 5 times: fail 4 times, succeed on 5th + res = cast("GetDataResponse", stub.GetData(GetDataRequest())) + assert res.next_page_token == "next-page-token" + get_data_spy.assert_called_once() + + # fail 4 times, pass the 5th attempt + assert force_fail_interceptor.failed_attempts == 4 + + # Now we're going to fail beyond the max retry attempts + + force_fail_interceptor_max = ForceFailInterceptor(7) + with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy: + sift_channel_config_d: SiftChannelConfig = { + "uri": "localhost:50052", + "apikey": "some-token", + "use_ssl": False, + } + + with use_sift_channel(sift_channel_config_d) as channel: + stub = DataServiceStub(channel) + + # This will go beyond the max number of attempts + with pytest.raises(grpc.RpcError): + stub.GetData(GetDataRequest()) + + get_data_spy.assert_not_called() + + # All attempts failed + assert force_fail_interceptor_max.failed_attempts == 5 + + +def test_internal_error_retry(mocker: MockFixture): + force_fail_interceptor = ForceFailInterceptor(4, failure_code=grpc.StatusCode.INTERNAL) + with server_spy(mocker, AuthInterceptor(), force_fail_interceptor) as get_data_spy: + sift_channel_config_c: SiftChannelConfig = { + "uri": "localhost:50052", + "apikey": "some-token", + "use_ssl": False, + } + + with use_sift_channel(sift_channel_config_c) as channel: + stub = DataServiceStub(channel) + # This will attempt 5 times: fail 4 times, succeed on 5th + res = cast("GetDataResponse", stub.GetData(GetDataRequest())) + assert res.next_page_token == "next-page-token" + get_data_spy.assert_called_once() + + # fail 4 times, pass the 5th attempt + assert force_fail_interceptor.failed_attempts == 4 + + # Now we're going to fail beyond the max retry attempts + force_fail_interceptor_max = ForceFailInterceptor(7) + with server_spy(mocker, AuthInterceptor(), force_fail_interceptor_max) as get_data_spy: + sift_channel_config_d: SiftChannelConfig = { + "uri": "localhost:50052", + "apikey": "some-token", + "use_ssl": False, + } + + with use_sift_channel(sift_channel_config_d) as channel: + stub = DataServiceStub(channel) + + # This will go beyond the max number of attempts + with pytest.raises(grpc.RpcError): + stub.GetData(GetDataRequest()) + + get_data_spy.assert_not_called() + + # All attempts failed + assert force_fail_interceptor_max.failed_attempts == 5 diff --git a/python/lib/sift_client/_internal/low_level_wrappers/data.py b/python/lib/sift_client/_internal/low_level_wrappers/data.py index af469ba71..57b24e398 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/data.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/data.py @@ -15,9 +15,9 @@ Query, ) from sift.data.v2.data_pb2_grpc import DataServiceStub -from sift_py._internal.time import to_timestamp_nanos from sift_client._internal.low_level_wrappers.base import LowLevelClientBase +from sift_client._internal.time import to_timestamp_nanos from sift_client.sift_types.channel import Channel, ChannelDataType from sift_client.transport import WithGrpcClient diff --git a/python/lib/sift_client/_internal/rest.py b/python/lib/sift_client/_internal/rest.py new file mode 100644 index 000000000..5f5c954c3 --- /dev/null +++ b/python/lib/sift_client/_internal/rest.py @@ -0,0 +1,82 @@ +from abc import ABC +from typing import TypedDict + +import requests +from requests.adapters import HTTPAdapter +from typing_extensions import NotRequired +from urllib3.util import Retry + +from sift_client._internal.grpc.transport import _clean_uri + +_DEFAULT_REST_RETRY = Retry(total=3, status_forcelist=[500, 502, 503, 504], backoff_factor=1) + + +class SiftRestConfig(TypedDict): + """ + Config class used to to interact with services that use Sift's REST API.`. + - `uri`: The URI of Sift's REST API. The scheme portion of the URI i.e. `https://` should be ommitted. + - `apikey`: User-generated API key generated via the Sift application. + - `retry`: Urllib3 Retry configuration. If not provided, a default of 3 retries is used. + - `use_ssl`: INTERNAL USE. Meant to be used for local development. + - `cert_via_openssl`: Enable this if you want to use OpenSSL to load the certificates. + Run `pip install sift-stack-py[openssl]` to install the dependencies required to use this option. + Default is False. + """ + + uri: str + apikey: str + retry: NotRequired[Retry] + use_ssl: NotRequired[bool] + cert_via_openssl: NotRequired[bool] + + +def compute_uri(restconf: SiftRestConfig) -> str: + uri = restconf["uri"] + use_ssl = restconf.get("use_ssl", True) + clean_uri = _clean_uri(uri, use_ssl) + + if use_ssl: + return f"https://{clean_uri}" + + return f"http://{clean_uri}" + + +class _SiftHTTPAdapter(HTTPAdapter): + """Sift specific HTTP adapter.""" + + def __init__(self, rest_conf: SiftRestConfig, *args, **kwargs): + self._rest_conf = rest_conf + kwargs["max_retries"] = rest_conf.get("retry", _DEFAULT_REST_RETRY) + super().__init__(*args, **kwargs) + + def init_poolmanager(self, *args, **kwargs): + if self._rest_conf.get("cert_via_openssl", False): + try: + import ssl + + context = ssl.create_default_context() + context.load_default_certs() + kwargs["ssl_context"] = context + except ImportError as e: + raise Exception( + "Missing required dependencies for cert_via_openssl. Run `pip install sift-stack-py[openssl]` to install the required dependencies." + ) from e + return super().init_poolmanager(*args, **kwargs) + + +class _RestService(ABC): + """ + Abstract service that implements a REST session. + """ + + def __init__(self, rest_conf: SiftRestConfig): + self._rest_conf = rest_conf + self._base_uri = compute_uri(rest_conf) + self._apikey = rest_conf["apikey"] + + self._session = requests.Session() + self._session.headers = {"Authorization": f"Bearer {self._apikey}"} + + adapter = _SiftHTTPAdapter(rest_conf) + self._session.mount("https://", adapter) + self._session.mount("http://", adapter) diff --git a/python/lib/sift_client/_internal/time.py b/python/lib/sift_client/_internal/time.py new file mode 100644 index 000000000..9787996a6 --- /dev/null +++ b/python/lib/sift_client/_internal/time.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import cast + +import pandas as pd +from google.protobuf.timestamp_pb2 import Timestamp as TimestampPb + + +def to_timestamp_nanos(arg: TimestampPb | pd.Timestamp | datetime | str | int) -> pd.Timestamp: + """ + Converts a variety of time-types to a pandas timestamp which supports nano-second precision. + """ + + if isinstance(arg, pd.Timestamp): + return arg + elif isinstance(arg, TimestampPb): + seconds = arg.seconds + nanos = arg.nanos + + dt = datetime.fromtimestamp(seconds, tz=timezone.utc) + ts = pd.Timestamp(dt) + + return cast("pd.Timestamp", ts + pd.Timedelta(nanos, unit="ns")) + + elif isinstance(arg, int): + dt = datetime.fromtimestamp(arg, tz=timezone.utc) + return cast("pd.Timestamp", pd.Timestamp(dt)) + + else: + return cast("pd.Timestamp", pd.Timestamp(arg)) + + +def to_timestamp_pb(arg: datetime | str | int | float) -> TimestampPb: + """ + Mainly used for testing at the moment. If using this for non-testing purposes + should probably make this more robust and support nano-second precision. + """ + + ts = TimestampPb() + + if isinstance(arg, datetime): + ts.FromDatetime(arg) + return ts + elif isinstance(arg, (int, float)): + ts.FromDatetime(datetime.fromtimestamp(arg, tz=timezone.utc)) + return ts + else: + ts.FromDatetime(datetime.fromisoformat(arg)) + return ts diff --git a/python/lib/sift_client/resources/calculated_channels.py b/python/lib/sift_client/resources/calculated_channels.py index 789d52be1..c34083b30 100644 --- a/python/lib/sift_client/resources/calculated_channels.py +++ b/python/lib/sift_client/resources/calculated_channels.py @@ -243,7 +243,7 @@ async def update( ( updated_calculated_channel, - inapplicable_assets, + _inapplicable_assets, ) = await self._low_level_client.update_calculated_channel( update=update, user_notes=user_notes ) diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index dedbba934..95817a010 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -13,7 +13,7 @@ from typing import Any from urllib.parse import urlparse -from sift_py.grpc.transport import ( +from sift_client._internal.grpc.transport import ( SiftChannelConfig, use_sift_async_channel, ) diff --git a/python/lib/sift_client/transport/rest_transport.py b/python/lib/sift_client/transport/rest_transport.py index 40a83582b..e3b50a603 100644 --- a/python/lib/sift_client/transport/rest_transport.py +++ b/python/lib/sift_client/transport/rest_transport.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING from urllib.parse import urljoin -from sift_py.rest import _DEFAULT_REST_RETRY, SiftRestConfig, _RestService +from sift_client._internal.rest import _DEFAULT_REST_RETRY, SiftRestConfig, _RestService if TYPE_CHECKING: import requests