diff options
author | uael <uael@google.com> | 2023-09-29 18:40:06 +0000 |
---|---|---|
committer | uael <uael@google.com> | 2023-09-29 18:40:06 +0000 |
commit | 1decd46067626cdfb075435e789b0bc9946d9923 (patch) | |
tree | aa98f29f308756c4a060711d20036a9e31d819fa | |
parent | 400265218fe6cc5cfa834c8d608d879f058c90ca (diff) | |
parent | 6f2b623e3ce909be4962b042c314ca03d12e8e2d (diff) | |
download | bumble-1decd46067626cdfb075435e789b0bc9946d9923.tar.gz |
Merge remote-tracking branch 'aosp/upstream-main' into main
Change-Id: I4b924760be1b02d29cf933930830d00b9cf89663
39 files changed, 1110 insertions, 449 deletions
diff --git a/.github/workflows/code-check.yml b/.github/workflows/code-check.yml index b6cf8fd..021b1e4 100644 --- a/.github/workflows/code-check.yml +++ b/.github/workflows/code-check.yml @@ -14,6 +14,10 @@ jobs: check: name: Check Code runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + fail-fast: false steps: - name: Check out from Git diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml index c8a1031..4cc3e73 100644 --- a/.github/workflows/python-build-test.yml +++ b/.github/workflows/python-build-test.yml @@ -12,10 +12,10 @@ permissions: jobs: build: - - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} strategy: matrix: + os: ['ubuntu-latest', 'macos-latest', 'windows-latest'] python-version: ["3.8", "3.9", "3.10", "3.11"] fail-fast: false @@ -41,6 +41,7 @@ jobs: run: | inv build inv build.mkdocs + build-rust: runs-on: ubuntu-latest strategy: @@ -64,6 +65,8 @@ jobs: with: components: clippy,rustfmt toolchain: ${{ matrix.rust-version }} + - name: Check License Headers + run: cd rust && cargo run --features dev-tools --bin file-header check-all - name: Rust Build run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets # Lints after build so what clippy needs is already built diff --git a/.vscode/settings.json b/.vscode/settings.json index 864fe69..57e682a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -39,10 +39,12 @@ "libusb", "MITM", "NDIS", + "netsim", "NONBLOCK", "NONCONN", "OXIMETER", "popleft", + "protobuf", "psms", "pyee", "pyusb", diff --git a/apps/console.py b/apps/console.py index 0ea9e5b..9a529dd 100644 --- a/apps/console.py +++ b/apps/console.py @@ -1172,7 +1172,7 @@ class ScanResult: name = '' # Remove any '/P' qualifier suffix from the address string - address_str = str(self.address).replace('/P', '') + address_str = self.address.to_string(with_type_qualifier=False) # RSSI bar bar_string = rssi_bar(self.rssi) diff --git a/apps/controller_info.py b/apps/controller_info.py index 4707983..5be4f3d 100644 --- a/apps/controller_info.py +++ b/apps/controller_info.py @@ -63,7 +63,8 @@ async def get_classic_info(host): if command_succeeded(response): print() print( - color('Classic Address:', 'yellow'), response.return_parameters.bd_addr + color('Classic Address:', 'yellow'), + response.return_parameters.bd_addr.to_string(False), ) if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND): diff --git a/apps/pandora_server.py b/apps/pandora_server.py index b577f82..16bc211 100644 --- a/apps/pandora_server.py +++ b/apps/pandora_server.py @@ -3,7 +3,7 @@ import click import logging import json -from bumble.pandora import PandoraDevice, serve +from bumble.pandora import PandoraDevice, Config, serve from typing import Dict, Any BUMBLE_SERVER_GRPC_PORT = 7999 @@ -29,12 +29,14 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No transport = transport.replace('<rootcanal-port>', str(rootcanal_port)) bumble_config = retrieve_config(config) - if 'transport' not in bumble_config.keys(): - bumble_config.update({'transport': transport}) + bumble_config.setdefault('transport', transport) device = PandoraDevice(bumble_config) + server_config = Config() + server_config.load_from_dict(bumble_config.get('server', {})) + logging.basicConfig(level=logging.DEBUG) - asyncio.run(serve(device, port=grpc_port)) + asyncio.run(serve(device, config=server_config, port=grpc_port)) def retrieve_config(config: str) -> Dict[str, Any]: diff --git a/apps/speaker/speaker.py b/apps/speaker/speaker.py index 1a1eac3..e451c04 100644 --- a/apps/speaker/speaker.py +++ b/apps/speaker/speaker.py @@ -195,7 +195,7 @@ class WebSocketOutput(QueuedOutput): except HCI_StatusError: pass peer_name = '' if connection.peer_name is None else connection.peer_name - peer_address = str(connection.peer_address).replace('/P', '') + peer_address = connection.peer_address.to_string(False) await self.send_message( 'connection', peer_address=peer_address, @@ -376,7 +376,7 @@ class UiServer: if connection := self.speaker().connection: await self.send_message( 'connection', - peer_address=str(connection.peer_address).replace('/P', ''), + peer_address=connection.peer_address.to_string(False), peer_name=connection.peer_name, ) diff --git a/bumble/att.py b/bumble/att.py index 55ae8a5..db8d2ba 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -23,13 +23,14 @@ # Imports # ----------------------------------------------------------------------------- from __future__ import annotations +import enum import functools import struct from pyee import EventEmitter -from typing import Dict, Type, TYPE_CHECKING +from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING -from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError -from bumble.hci import HCI_Object, key_with_value, HCI_Constant +from bumble.core import UUID, name_or_number, ProtocolError +from bumble.hci import HCI_Object, key_with_value from bumble.colors import color if TYPE_CHECKING: @@ -182,6 +183,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731 # pylint: enable=line-too-long # pylint: disable=invalid-name + # ----------------------------------------------------------------------------- # Exceptions # ----------------------------------------------------------------------------- @@ -209,7 +211,7 @@ class ATT_PDU: pdu_classes: Dict[int, Type[ATT_PDU]] = {} op_code = 0 - name = None + name: str @staticmethod def from_bytes(pdu): @@ -720,47 +722,67 @@ class ATT_Handle_Value_Confirmation(ATT_PDU): # ----------------------------------------------------------------------------- -class Attribute(EventEmitter): - # Permission flags - READABLE = 0x01 - WRITEABLE = 0x02 - READ_REQUIRES_ENCRYPTION = 0x04 - WRITE_REQUIRES_ENCRYPTION = 0x08 - READ_REQUIRES_AUTHENTICATION = 0x10 - WRITE_REQUIRES_AUTHENTICATION = 0x20 - READ_REQUIRES_AUTHORIZATION = 0x40 - WRITE_REQUIRES_AUTHORIZATION = 0x80 - - PERMISSION_NAMES = { - READABLE: 'READABLE', - WRITEABLE: 'WRITEABLE', - READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION', - WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION', - READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION', - WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION', - READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION', - WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION', - } +class ConnectionValue(Protocol): + def read(self, connection) -> bytes: + ... - @staticmethod - def string_to_permissions(permissions_str: str): - try: - return functools.reduce( - lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y), - permissions_str.split(","), - 0, - ) - except TypeError as exc: - raise TypeError( - f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}" - ) from exc + def write(self, connection, value: bytes) -> None: + ... - def __init__(self, attribute_type, permissions, value=b''): + +# ----------------------------------------------------------------------------- +class Attribute(EventEmitter): + class Permissions(enum.IntFlag): + READABLE = 0x01 + WRITEABLE = 0x02 + READ_REQUIRES_ENCRYPTION = 0x04 + WRITE_REQUIRES_ENCRYPTION = 0x08 + READ_REQUIRES_AUTHENTICATION = 0x10 + WRITE_REQUIRES_AUTHENTICATION = 0x20 + READ_REQUIRES_AUTHORIZATION = 0x40 + WRITE_REQUIRES_AUTHORIZATION = 0x80 + + @classmethod + def from_string(cls, permissions_str: str) -> Attribute.Permissions: + try: + return functools.reduce( + lambda x, y: x | Attribute.Permissions[y], + permissions_str.replace('|', ',').split(","), + Attribute.Permissions(0), + ) + except TypeError as exc: + # The check for `p.name is not None` here is needed because for InFlag + # enums, the .name property can be None, when the enum value is 0, + # so the type hint for .name is Optional[str]. + enum_list: List[str] = [p.name for p in cls if p.name is not None] + enum_list_str = ",".join(enum_list) + raise TypeError( + f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}" + ) from exc + + # Permission flags(legacy-use only) + READABLE = Permissions.READABLE + WRITEABLE = Permissions.WRITEABLE + READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION + WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION + READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION + WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION + READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION + WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION + + value: Union[str, bytes, ConnectionValue] + + def __init__( + self, + attribute_type: Union[str, bytes, UUID], + permissions: Union[str, Attribute.Permissions], + value: Union[str, bytes, ConnectionValue] = b'', + ) -> None: EventEmitter.__init__(self) self.handle = 0 self.end_group_handle = 0 if isinstance(permissions, str): - self.permissions = self.string_to_permissions(permissions) + self.permissions = Attribute.Permissions.from_string(permissions) else: self.permissions = permissions @@ -778,22 +800,26 @@ class Attribute(EventEmitter): else: self.value = value - def encode_value(self, value): + def encode_value(self, value: Any) -> bytes: return value - def decode_value(self, value_bytes): + def decode_value(self, value_bytes: bytes) -> Any: return value_bytes - def read_value(self, connection: Connection): + def read_value(self, connection: Optional[Connection]) -> bytes: if ( - self.permissions & self.READ_REQUIRES_ENCRYPTION - ) and not connection.encryption: + (self.permissions & self.READ_REQUIRES_ENCRYPTION) + and connection is not None + and not connection.encryption + ): raise ATT_Error( error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle ) if ( - self.permissions & self.READ_REQUIRES_AUTHENTICATION - ) and not connection.authenticated: + (self.permissions & self.READ_REQUIRES_AUTHENTICATION) + and connection is not None + and not connection.authenticated + ): raise ATT_Error( error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle ) @@ -803,9 +829,9 @@ class Attribute(EventEmitter): error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle ) - if read := getattr(self.value, 'read', None): + if hasattr(self.value, 'read'): try: - value = read(connection) # pylint: disable=not-callable + value = self.value.read(connection) except ATT_Error as error: raise ATT_Error( error_code=error.error_code, att_handle=self.handle @@ -815,7 +841,7 @@ class Attribute(EventEmitter): return self.encode_value(value) - def write_value(self, connection: Connection, value_bytes): + def write_value(self, connection: Connection, value_bytes: bytes) -> None: if ( self.permissions & self.WRITE_REQUIRES_ENCRYPTION ) and not connection.encryption: @@ -836,9 +862,9 @@ class Attribute(EventEmitter): value = self.decode_value(value_bytes) - if write := getattr(self.value, 'write', None): + if hasattr(self.value, 'write'): try: - write(connection, value) # pylint: disable=not-callable + self.value.write(connection, value) # pylint: disable=not-callable except ATT_Error as error: raise ATT_Error( error_code=error.error_code, att_handle=self.handle diff --git a/bumble/core.py b/bumble/core.py index 4dff432..4a67d6e 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -80,7 +80,7 @@ class BaseError(Exception): def __init__( self, - error_code: int | None, + error_code: Optional[int], error_namespace: str = '', error_name: str = '', details: str = '', diff --git a/bumble/device.py b/bumble/device.py index 9a784e7..b01dc58 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -1186,8 +1186,8 @@ class Device(CompositeEventEmitter): def create_l2cap_registrar(self, psm): return lambda handler: self.register_l2cap_server(psm, handler) - def register_l2cap_server(self, psm, server): - self.l2cap_channel_manager.register_server(psm, server) + def register_l2cap_server(self, psm, server) -> int: + return self.l2cap_channel_manager.register_server(psm, server) def register_l2cap_channel_server( self, @@ -2758,7 +2758,9 @@ class Device(CompositeEventEmitter): self.abort_on( 'flush', self.start_advertising( - advertising_type=self.advertising_type, auto_restart=True + advertising_type=self.advertising_type, + own_address_type=self.advertising_own_address_type, + auto_restart=True, ), ) diff --git a/bumble/gatt.py b/bumble/gatt.py index 067f31d..fe3e85c 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -28,7 +28,7 @@ import enum import functools import logging import struct -from typing import Optional, Sequence, List +from typing import Optional, Sequence, Iterable, List, Union from .colors import color from .core import UUID, get_dict_key_by_value @@ -187,7 +187,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi # ----------------------------------------------------------------------------- -def show_services(services): +def show_services(services: Iterable[Service]) -> None: for service in services: print(color(str(service), 'cyan')) @@ -210,11 +210,11 @@ class Service(Attribute): def __init__( self, - uuid, + uuid: Union[str, UUID], characteristics: List[Characteristic], primary=True, included_services: List[Service] = [], - ): + ) -> None: # Convert the uuid to a UUID object if it isn't already if isinstance(uuid, str): uuid = UUID(uuid) @@ -239,7 +239,7 @@ class Service(Attribute): """ return None - def __str__(self): + def __str__(self) -> str: return ( f'Service(handle=0x{self.handle:04X}, ' f'end=0x{self.end_group_handle:04X}, ' @@ -255,9 +255,11 @@ class TemplateService(Service): to expose their UUID as a class property ''' - UUID: Optional[UUID] = None + UUID: UUID - def __init__(self, characteristics, primary=True): + def __init__( + self, characteristics: List[Characteristic], primary: bool = True + ) -> None: super().__init__(self.UUID, characteristics, primary) @@ -269,7 +271,7 @@ class IncludedServiceDeclaration(Attribute): service: Service - def __init__(self, service): + def __init__(self, service: Service) -> None: declaration_bytes = struct.pack( '<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes() ) @@ -278,7 +280,7 @@ class IncludedServiceDeclaration(Attribute): ) self.service = service - def __str__(self): + def __str__(self) -> str: return ( f'IncludedServiceDefinition(handle=0x{self.handle:04X}, ' f'group_starting_handle=0x{self.service.handle:04X}, ' @@ -326,7 +328,7 @@ class Characteristic(Attribute): f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}" ) - def __str__(self): + def __str__(self) -> str: # NOTE: we override this method to offer a consistent result between python # versions: the value returned by IntFlag.__str__() changed in version 11. return '|'.join( @@ -348,10 +350,10 @@ class Characteristic(Attribute): def __init__( self, - uuid, + uuid: Union[str, bytes, UUID], properties: Characteristic.Properties, - permissions, - value=b'', + permissions: Union[str, Attribute.Permissions], + value: Union[str, bytes, CharacteristicValue] = b'', descriptors: Sequence[Descriptor] = (), ): super().__init__(uuid, permissions, value) @@ -369,7 +371,7 @@ class Characteristic(Attribute): def has_properties(self, properties: Characteristic.Properties) -> bool: return self.properties & properties == properties - def __str__(self): + def __str__(self) -> str: return ( f'Characteristic(handle=0x{self.handle:04X}, ' f'end=0x{self.end_group_handle:04X}, ' @@ -386,7 +388,7 @@ class CharacteristicDeclaration(Attribute): characteristic: Characteristic - def __init__(self, characteristic, value_handle): + def __init__(self, characteristic: Characteristic, value_handle: int) -> None: declaration_bytes = ( struct.pack('<BH', characteristic.properties, value_handle) + characteristic.uuid.to_pdu_bytes() @@ -397,7 +399,7 @@ class CharacteristicDeclaration(Attribute): self.value_handle = value_handle self.characteristic = characteristic - def __str__(self): + def __str__(self) -> str: return ( f'CharacteristicDeclaration(handle=0x{self.handle:04X}, ' f'value_handle=0x{self.value_handle:04X}, ' @@ -520,7 +522,7 @@ class CharacteristicAdapter: return self.wrapped_characteristic.unsubscribe(subscriber) - def __str__(self): + def __str__(self) -> str: wrapped = str(self.wrapped_characteristic) return f'{self.__class__.__name__}({wrapped})' @@ -600,10 +602,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter): Adapter that converts strings to/from bytes using UTF-8 encoding ''' - def encode_value(self, value): + def encode_value(self, value: str) -> bytes: return value.encode('utf-8') - def decode_value(self, value): + def decode_value(self, value: bytes) -> str: return value.decode('utf-8') @@ -613,7 +615,7 @@ class Descriptor(Attribute): See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations ''' - def __str__(self): + def __str__(self) -> str: return ( f'Descriptor(handle=0x{self.handle:04X}, ' f'type={self.type}, ' diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index a33039e..e3b8bb2 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -28,7 +28,18 @@ import asyncio import logging import struct from datetime import datetime -from typing import List, Optional, Dict, Tuple, Callable, Union, Any +from typing import ( + List, + Optional, + Dict, + Tuple, + Callable, + Union, + Any, + Iterable, + Type, + TYPE_CHECKING, +) from pyee import EventEmitter @@ -66,8 +77,12 @@ from .gatt import ( GATT_INCLUDE_ATTRIBUTE_TYPE, Characteristic, ClientCharacteristicConfigurationBits, + TemplateService, ) +if TYPE_CHECKING: + from bumble.device import Connection + # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -78,16 +93,16 @@ logger = logging.getLogger(__name__) # Proxies # ----------------------------------------------------------------------------- class AttributeProxy(EventEmitter): - client: Client - - def __init__(self, client, handle, end_group_handle, attribute_type): + def __init__( + self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID + ) -> None: EventEmitter.__init__(self) self.client = client self.handle = handle self.end_group_handle = end_group_handle self.type = attribute_type - async def read_value(self, no_long_read=False): + async def read_value(self, no_long_read: bool = False) -> bytes: return self.decode_value( await self.client.read_value(self.handle, no_long_read) ) @@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter): self.handle, self.encode_value(value), with_response ) - def encode_value(self, value): + def encode_value(self, value: Any) -> bytes: return value - def decode_value(self, value_bytes): + def decode_value(self, value_bytes: bytes) -> Any: return value_bytes - def __str__(self): + def __str__(self) -> str: return f'Attribute(handle=0x{self.handle:04X}, type={self.type})' @@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy): def get_characteristics_by_uuid(self, uuid): return self.client.get_characteristics_by_uuid(uuid, self) - def __str__(self): + def __str__(self) -> str: return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})' class CharacteristicProxy(AttributeProxy): properties: Characteristic.Properties descriptors: List[DescriptorProxy] - subscribers: Dict[Any, Callable] + subscribers: Dict[Any, Callable[[bytes], Any]] def __init__( self, @@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy): return await self.client.discover_descriptors(self) async def subscribe( - self, subscriber: Optional[Callable] = None, prefer_notify=True + self, + subscriber: Optional[Callable[[bytes], Any]] = None, + prefer_notify: bool = True, ): if subscriber is not None: if subscriber in self.subscribers: @@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy): return await self.client.unsubscribe(self, subscriber) - def __str__(self): + def __str__(self) -> str: return ( f'Characteristic(handle=0x{self.handle:04X}, ' f'uuid={self.uuid}, ' @@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy): def __init__(self, client, handle, descriptor_type): super().__init__(client, handle, 0, descriptor_type) - def __str__(self): + def __str__(self) -> str: return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})' @@ -216,8 +233,10 @@ class ProfileServiceProxy: Base class for profile-specific service proxies ''' + SERVICE_CLASS: Type[TemplateService] + @classmethod - def from_client(cls, client): + def from_client(cls, client: Client) -> ProfileServiceProxy: return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) @@ -227,8 +246,12 @@ class ProfileServiceProxy: class Client: services: List[ServiceProxy] cached_values: Dict[int, Tuple[datetime, bytes]] + notification_subscribers: Dict[int, Callable[[bytes], Any]] + indication_subscribers: Dict[int, Callable[[bytes], Any]] + pending_response: Optional[asyncio.futures.Future[ATT_PDU]] + pending_request: Optional[ATT_PDU] - def __init__(self, connection): + def __init__(self, connection: Connection) -> None: self.connection = connection self.mtu_exchange_done = False self.request_semaphore = asyncio.Semaphore(1) @@ -241,16 +264,16 @@ class Client: self.services = [] self.cached_values = {} - def send_gatt_pdu(self, pdu): + def send_gatt_pdu(self, pdu: bytes) -> None: self.connection.send_l2cap_pdu(ATT_CID, pdu) - async def send_command(self, command): + async def send_command(self, command: ATT_PDU) -> None: logger.debug( f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' ) self.send_gatt_pdu(command.to_bytes()) - async def send_request(self, request): + async def send_request(self, request: ATT_PDU): logger.debug( f'GATT Request from client: [0x{self.connection.handle:04X}] {request}' ) @@ -279,14 +302,14 @@ class Client: return response - def send_confirmation(self, confirmation): + def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None: logger.debug( f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' f'{confirmation}' ) self.send_gatt_pdu(confirmation.to_bytes()) - async def request_mtu(self, mtu): + async def request_mtu(self, mtu: int) -> int: # Check the range if mtu < ATT_DEFAULT_MTU: raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}') @@ -313,10 +336,12 @@ class Client: return self.connection.att_mtu - def get_services_by_uuid(self, uuid): + def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]: return [service for service in self.services if service.uuid == uuid] - def get_characteristics_by_uuid(self, uuid, service=None): + def get_characteristics_by_uuid( + self, uuid: UUID, service: Optional[ServiceProxy] = None + ) -> List[CharacteristicProxy]: services = [service] if service else self.services return [ c @@ -363,7 +388,7 @@ class Client: if not already_known: self.services.append(service) - async def discover_services(self, uuids=None) -> List[ServiceProxy]: + async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]: ''' See Vol 3, Part G - 4.4.1 Discover All Primary Services ''' @@ -435,7 +460,7 @@ class Client: return services - async def discover_service(self, uuid): + async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]: ''' See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID ''' @@ -468,7 +493,7 @@ class Client: f'{HCI_Constant.error_name(response.error_code)}' ) # TODO raise appropriate exception - return + return [] break for attribute_handle, end_group_handle in response.handles_information: @@ -480,7 +505,7 @@ class Client: logger.warning( f'bogus handle values: {attribute_handle} {end_group_handle}' ) - return + return [] # Create a service proxy for this service service = ServiceProxy( @@ -721,7 +746,7 @@ class Client: return descriptors - async def discover_attributes(self): + async def discover_attributes(self) -> List[AttributeProxy]: ''' Discover all attributes, regardless of type ''' @@ -844,7 +869,9 @@ class Client: # No more subscribers left await self.write_value(cccd, b'\x00\x00', with_response=True) - async def read_value(self, attribute, no_long_read=False): + async def read_value( + self, attribute: Union[int, AttributeProxy], no_long_read: bool = False + ) -> Any: ''' See Vol 3, Part G - 4.8.1 Read Characteristic Value @@ -905,7 +932,9 @@ class Client: # Return the value as bytes return attribute_value - async def read_characteristics_by_uuid(self, uuid, service): + async def read_characteristics_by_uuid( + self, uuid: UUID, service: Optional[ServiceProxy] + ) -> List[bytes]: ''' See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID ''' @@ -960,7 +989,12 @@ class Client: return characteristics_values - async def write_value(self, attribute, value, with_response=False): + async def write_value( + self, + attribute: Union[int, AttributeProxy], + value: bytes, + with_response: bool = False, + ) -> None: ''' See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic Value @@ -990,7 +1024,7 @@ class Client: ) ) - def on_gatt_pdu(self, att_pdu): + def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None: logger.debug( f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' ) @@ -1013,6 +1047,7 @@ class Client: return # Return the response to the coroutine that is waiting for it + assert self.pending_response is not None self.pending_response.set_result(att_pdu) else: handler_name = f'on_{att_pdu.name.lower()}' @@ -1060,7 +1095,7 @@ class Client: # Confirm that we received the indication self.send_confirmation(ATT_Handle_Value_Confirmation()) - def cache_value(self, attribute_handle: int, value: bytes): + def cache_value(self, attribute_handle: int, value: bytes) -> None: self.cached_values[attribute_handle] = ( datetime.now(), value, diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index 3624905..cdf1b5e 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -23,11 +23,12 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging from collections import defaultdict import struct -from typing import List, Tuple, Optional, TypeVar, Type +from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING from pyee import EventEmitter from .colors import color @@ -42,6 +43,7 @@ from .att import ( ATT_INVALID_OFFSET_ERROR, ATT_REQUEST_NOT_SUPPORTED_ERROR, ATT_REQUESTS, + ATT_PDU, ATT_UNLIKELY_ERROR_ERROR, ATT_UNSUPPORTED_GROUP_TYPE_ERROR, ATT_Error, @@ -73,6 +75,8 @@ from .gatt import ( Service, ) +if TYPE_CHECKING: + from bumble.device import Device, Connection # ----------------------------------------------------------------------------- # Logging @@ -91,8 +95,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517 # ----------------------------------------------------------------------------- class Server(EventEmitter): attributes: List[Attribute] + services: List[Service] + attributes_by_handle: Dict[int, Attribute] + subscribers: Dict[int, Dict[int, bytes]] + indication_semaphores: defaultdict[int, asyncio.Semaphore] + pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]] - def __init__(self, device): + def __init__(self, device: Device) -> None: super().__init__() self.device = device self.services = [] @@ -107,16 +116,16 @@ class Server(EventEmitter): self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1)) self.pending_confirmations = defaultdict(lambda: None) - def __str__(self): + def __str__(self) -> str: return "\n".join(map(str, self.attributes)) - def send_gatt_pdu(self, connection_handle, pdu): + def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None: self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu) - def next_handle(self): + def next_handle(self) -> int: return 1 + len(self.attributes) - def get_advertising_service_data(self): + def get_advertising_service_data(self) -> Dict[Attribute, bytes]: return { attribute: data for attribute in self.attributes @@ -124,7 +133,7 @@ class Server(EventEmitter): and (data := attribute.get_advertising_data()) } - def get_attribute(self, handle): + def get_attribute(self, handle: int) -> Optional[Attribute]: attribute = self.attributes_by_handle.get(handle) if attribute: return attribute @@ -173,12 +182,17 @@ class Server(EventEmitter): return next( ( - (attribute, self.get_attribute(attribute.characteristic.handle)) + ( + attribute, + self.get_attribute(attribute.characteristic.handle), + ) # type: ignore for attribute in map( self.get_attribute, range(service_handle.handle, service_handle.end_group_handle + 1), ) - if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE + if attribute is not None + and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE + and isinstance(attribute, CharacteristicDeclaration) and attribute.characteristic.uuid == characteristic_uuid ), None, @@ -197,7 +211,7 @@ class Server(EventEmitter): return next( ( - attribute + attribute # type: ignore for attribute in map( self.get_attribute, range( @@ -205,12 +219,12 @@ class Server(EventEmitter): characteristic_value.end_group_handle + 1, ), ) - if attribute.type == descriptor_uuid + if attribute is not None and attribute.type == descriptor_uuid ), None, ) - def add_attribute(self, attribute): + def add_attribute(self, attribute: Attribute) -> None: # Assign a handle to this attribute attribute.handle = self.next_handle() attribute.end_group_handle = ( @@ -220,7 +234,7 @@ class Server(EventEmitter): # Add this attribute to the list self.attributes.append(attribute) - def add_service(self, service: Service): + def add_service(self, service: Service) -> None: # Add the service attribute to the DB self.add_attribute(service) @@ -285,11 +299,13 @@ class Server(EventEmitter): service.end_group_handle = self.attributes[-1].handle self.services.append(service) - def add_services(self, services): + def add_services(self, services: Iterable[Service]) -> None: for service in services: self.add_service(service) - def read_cccd(self, connection, characteristic): + def read_cccd( + self, connection: Optional[Connection], characteristic: Characteristic + ) -> bytes: if connection is None: return bytes([0, 0]) @@ -300,7 +316,12 @@ class Server(EventEmitter): return cccd or bytes([0, 0]) - def write_cccd(self, connection, characteristic, value): + def write_cccd( + self, + connection: Connection, + characteristic: Characteristic, + value: bytes, + ) -> None: logger.debug( f'Subscription update for connection=0x{connection.handle:04X}, ' f'handle=0x{characteristic.handle:04X}: {value.hex()}' @@ -327,13 +348,19 @@ class Server(EventEmitter): indicate_enabled, ) - def send_response(self, connection, response): + def send_response(self, connection: Connection, response: ATT_PDU) -> None: logger.debug( f'GATT Response from server: [0x{connection.handle:04X}] {response}' ) self.send_gatt_pdu(connection.handle, response.to_bytes()) - async def notify_subscriber(self, connection, attribute, value=None, force=False): + async def notify_subscriber( + self, + connection: Connection, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ) -> None: # Check if there's a subscriber if not force: subscribers = self.subscribers.get(connection.handle) @@ -370,7 +397,13 @@ class Server(EventEmitter): ) self.send_gatt_pdu(connection.handle, bytes(notification)) - async def indicate_subscriber(self, connection, attribute, value=None, force=False): + async def indicate_subscriber( + self, + connection: Connection, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ) -> None: # Check if there's a subscriber if not force: subscribers = self.subscribers.get(connection.handle) @@ -411,15 +444,13 @@ class Server(EventEmitter): assert self.pending_confirmations[connection.handle] is None # Create a future value to hold the eventual response - self.pending_confirmations[ + pending_confirmation = self.pending_confirmations[ connection.handle ] = asyncio.get_running_loop().create_future() try: self.send_gatt_pdu(connection.handle, indication.to_bytes()) - await asyncio.wait_for( - self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT - ) + await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT) except asyncio.TimeoutError as error: logger.warning(color('!!! GATT Indicate timeout', 'red')) raise TimeoutError(f'GATT timeout for {indication.name}') from error @@ -427,8 +458,12 @@ class Server(EventEmitter): self.pending_confirmations[connection.handle] = None async def notify_or_indicate_subscribers( - self, indicate, attribute, value=None, force=False - ): + self, + indicate: bool, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ) -> None: # Get all the connections for which there's at least one subscription connections = [ connection @@ -450,13 +485,23 @@ class Server(EventEmitter): ] ) - async def notify_subscribers(self, attribute, value=None, force=False): + async def notify_subscribers( + self, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ): return await self.notify_or_indicate_subscribers(False, attribute, value, force) - async def indicate_subscribers(self, attribute, value=None, force=False): + async def indicate_subscribers( + self, + attribute: Attribute, + value: Optional[bytes] = None, + force: bool = False, + ): return await self.notify_or_indicate_subscribers(True, attribute, value, force) - def on_disconnection(self, connection): + def on_disconnection(self, connection: Connection) -> None: if connection.handle in self.subscribers: del self.subscribers[connection.handle] if connection.handle in self.indication_semaphores: @@ -464,7 +509,7 @@ class Server(EventEmitter): if connection.handle in self.pending_confirmations: del self.pending_confirmations[connection.handle] - def on_gatt_pdu(self, connection, att_pdu): + def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None: logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}') handler_name = f'on_{att_pdu.name.lower()}' handler = getattr(self, handler_name, None) @@ -506,7 +551,7 @@ class Server(EventEmitter): ####################################################### # ATT handlers ####################################################### - def on_att_request(self, connection, pdu): + def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None: ''' Handler for requests without a more specific handler ''' @@ -679,7 +724,6 @@ class Server(EventEmitter): and attribute.handle <= request.ending_handle and pdu_space_available ): - try: attribute_value = attribute.read_value(connection) except ATT_Error as error: diff --git a/bumble/hci.py b/bumble/hci.py index b7014dd..41deed2 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -4397,7 +4397,7 @@ class HCI_Event(HCI_Packet): if len(parameters) != length: raise ValueError('invalid packet length') - cls: Type[HCI_Event | HCI_LE_Meta_Event] | None + cls: Any if event_code == HCI_LE_META_EVENT: # We do this dispatch here and not in the subclass in order to avoid call # loops diff --git a/bumble/l2cap.py b/bumble/l2cap.py index fea8a1d..cccb172 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio +import enum import logging import struct @@ -676,56 +677,35 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame): # ----------------------------------------------------------------------------- class Channel(EventEmitter): - # States - CLOSED = 0x00 - WAIT_CONNECT = 0x01 - WAIT_CONNECT_RSP = 0x02 - OPEN = 0x03 - WAIT_DISCONNECT = 0x04 - WAIT_CREATE = 0x05 - WAIT_CREATE_RSP = 0x06 - WAIT_MOVE = 0x07 - WAIT_MOVE_RSP = 0x08 - WAIT_MOVE_CONFIRM = 0x09 - WAIT_CONFIRM_RSP = 0x0A - - # CONFIG substates - WAIT_CONFIG = 0x10 - WAIT_SEND_CONFIG = 0x11 - WAIT_CONFIG_REQ_RSP = 0x12 - WAIT_CONFIG_RSP = 0x13 - WAIT_CONFIG_REQ = 0x14 - WAIT_IND_FINAL_RSP = 0x15 - WAIT_FINAL_RSP = 0x16 - WAIT_CONTROL_IND = 0x17 - - STATE_NAMES = { - CLOSED: 'CLOSED', - WAIT_CONNECT: 'WAIT_CONNECT', - WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP', - OPEN: 'OPEN', - WAIT_DISCONNECT: 'WAIT_DISCONNECT', - WAIT_CREATE: 'WAIT_CREATE', - WAIT_CREATE_RSP: 'WAIT_CREATE_RSP', - WAIT_MOVE: 'WAIT_MOVE', - WAIT_MOVE_RSP: 'WAIT_MOVE_RSP', - WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM', - WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP', - WAIT_CONFIG: 'WAIT_CONFIG', - WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG', - WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP', - WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP', - WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ', - WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP', - WAIT_FINAL_RSP: 'WAIT_FINAL_RSP', - WAIT_CONTROL_IND: 'WAIT_CONTROL_IND', - } + class State(enum.IntEnum): + # States + CLOSED = 0x00 + WAIT_CONNECT = 0x01 + WAIT_CONNECT_RSP = 0x02 + OPEN = 0x03 + WAIT_DISCONNECT = 0x04 + WAIT_CREATE = 0x05 + WAIT_CREATE_RSP = 0x06 + WAIT_MOVE = 0x07 + WAIT_MOVE_RSP = 0x08 + WAIT_MOVE_CONFIRM = 0x09 + WAIT_CONFIRM_RSP = 0x0A + + # CONFIG substates + WAIT_CONFIG = 0x10 + WAIT_SEND_CONFIG = 0x11 + WAIT_CONFIG_REQ_RSP = 0x12 + WAIT_CONFIG_RSP = 0x13 + WAIT_CONFIG_REQ = 0x14 + WAIT_IND_FINAL_RSP = 0x15 + WAIT_FINAL_RSP = 0x16 + WAIT_CONTROL_IND = 0x17 connection_result: Optional[asyncio.Future[None]] disconnection_result: Optional[asyncio.Future[None]] response: Optional[asyncio.Future[bytes]] sink: Optional[Callable[[bytes], Any]] - state: int + state: State connection: Connection def __init__( @@ -741,7 +721,7 @@ class Channel(EventEmitter): self.manager = manager self.connection = connection self.signaling_cid = signaling_cid - self.state = Channel.CLOSED + self.state = self.State.CLOSED self.mtu = mtu self.psm = psm self.source_cid = source_cid @@ -751,13 +731,11 @@ class Channel(EventEmitter): self.disconnection_result = None self.sink = None - def change_state(self, new_state: int) -> None: - logger.debug( - f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}' - ) + def _change_state(self, new_state: State) -> None: + logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}') self.state = new_state - def send_pdu(self, pdu: SupportsBytes | bytes) -> None: + def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: self.manager.send_pdu(self.connection, self.destination_cid, pdu) def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: @@ -767,7 +745,7 @@ class Channel(EventEmitter): # Check that there isn't already a request pending if self.response: raise InvalidStateError('request already pending') - if self.state != Channel.OPEN: + if self.state != self.State.OPEN: raise InvalidStateError('channel not open') self.response = asyncio.get_running_loop().create_future() @@ -787,14 +765,14 @@ class Channel(EventEmitter): ) async def connect(self) -> None: - if self.state != Channel.CLOSED: + if self.state != self.State.CLOSED: raise InvalidStateError('invalid state') # Check that we can start a new connection if self.connection_result: raise RuntimeError('connection already pending') - self.change_state(Channel.WAIT_CONNECT_RSP) + self._change_state(self.State.WAIT_CONNECT_RSP) self.send_control_frame( L2CAP_Connection_Request( identifier=self.manager.next_identifier(self.connection), @@ -814,10 +792,10 @@ class Channel(EventEmitter): self.connection_result = None async def disconnect(self) -> None: - if self.state != Channel.OPEN: + if self.state != self.State.OPEN: raise InvalidStateError('invalid state') - self.change_state(Channel.WAIT_DISCONNECT) + self._change_state(self.State.WAIT_DISCONNECT) self.send_control_frame( L2CAP_Disconnection_Request( identifier=self.manager.next_identifier(self.connection), @@ -832,8 +810,8 @@ class Channel(EventEmitter): return await self.disconnection_result def abort(self) -> None: - if self.state == self.OPEN: - self.change_state(self.CLOSED) + if self.state == self.State.OPEN: + self._change_state(self.State.CLOSED) self.emit('close') def send_configure_request(self) -> None: @@ -856,7 +834,7 @@ class Channel(EventEmitter): def on_connection_request(self, request) -> None: self.destination_cid = request.source_cid - self.change_state(Channel.WAIT_CONNECT) + self._change_state(self.State.WAIT_CONNECT) self.send_control_frame( L2CAP_Connection_Response( identifier=request.identifier, @@ -866,24 +844,24 @@ class Channel(EventEmitter): status=0x0000, ) ) - self.change_state(Channel.WAIT_CONFIG) + self._change_state(self.State.WAIT_CONFIG) self.send_configure_request() - self.change_state(Channel.WAIT_CONFIG_REQ_RSP) + self._change_state(self.State.WAIT_CONFIG_REQ_RSP) def on_connection_response(self, response): - if self.state != Channel.WAIT_CONNECT_RSP: + if self.state != self.State.WAIT_CONNECT_RSP: logger.warning(color('invalid state', 'red')) return if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL: self.destination_cid = response.destination_cid - self.change_state(Channel.WAIT_CONFIG) + self._change_state(self.State.WAIT_CONFIG) self.send_configure_request() - self.change_state(Channel.WAIT_CONFIG_REQ_RSP) + self._change_state(self.State.WAIT_CONFIG_REQ_RSP) elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING: pass else: - self.change_state(Channel.CLOSED) + self._change_state(self.State.CLOSED) self.connection_result.set_exception( ProtocolError( response.result, @@ -895,9 +873,9 @@ class Channel(EventEmitter): def on_configure_request(self, request) -> None: if self.state not in ( - Channel.WAIT_CONFIG, - Channel.WAIT_CONFIG_REQ, - Channel.WAIT_CONFIG_REQ_RSP, + self.State.WAIT_CONFIG, + self.State.WAIT_CONFIG_REQ, + self.State.WAIT_CONFIG_REQ_RSP, ): logger.warning(color('invalid state', 'red')) return @@ -918,25 +896,28 @@ class Channel(EventEmitter): options=request.options, # TODO: don't accept everything blindly ) ) - if self.state == Channel.WAIT_CONFIG: - self.change_state(Channel.WAIT_SEND_CONFIG) + if self.state == self.State.WAIT_CONFIG: + self._change_state(self.State.WAIT_SEND_CONFIG) self.send_configure_request() - self.change_state(Channel.WAIT_CONFIG_RSP) - elif self.state == Channel.WAIT_CONFIG_REQ: - self.change_state(Channel.OPEN) + self._change_state(self.State.WAIT_CONFIG_RSP) + elif self.state == self.State.WAIT_CONFIG_REQ: + self._change_state(self.State.OPEN) if self.connection_result: self.connection_result.set_result(None) self.connection_result = None self.emit('open') - elif self.state == Channel.WAIT_CONFIG_REQ_RSP: - self.change_state(Channel.WAIT_CONFIG_RSP) + elif self.state == self.State.WAIT_CONFIG_REQ_RSP: + self._change_state(self.State.WAIT_CONFIG_RSP) def on_configure_response(self, response) -> None: if response.result == L2CAP_Configure_Response.SUCCESS: - if self.state == Channel.WAIT_CONFIG_REQ_RSP: - self.change_state(Channel.WAIT_CONFIG_REQ) - elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND): - self.change_state(Channel.OPEN) + if self.state == self.State.WAIT_CONFIG_REQ_RSP: + self._change_state(self.State.WAIT_CONFIG_REQ) + elif self.state in ( + self.State.WAIT_CONFIG_RSP, + self.State.WAIT_CONTROL_IND, + ): + self._change_state(self.State.OPEN) if self.connection_result: self.connection_result.set_result(None) self.connection_result = None @@ -966,7 +947,7 @@ class Channel(EventEmitter): # TODO: decide how to fail gracefully def on_disconnection_request(self, request) -> None: - if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT): + if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT): self.send_control_frame( L2CAP_Disconnection_Response( identifier=request.identifier, @@ -974,14 +955,14 @@ class Channel(EventEmitter): source_cid=request.source_cid, ) ) - self.change_state(Channel.CLOSED) + self._change_state(self.State.CLOSED) self.emit('close') self.manager.on_channel_closed(self) else: logger.warning(color('invalid state', 'red')) def on_disconnection_response(self, response) -> None: - if self.state != Channel.WAIT_DISCONNECT: + if self.state != self.State.WAIT_DISCONNECT: logger.warning(color('invalid state', 'red')) return @@ -992,7 +973,7 @@ class Channel(EventEmitter): logger.warning('unexpected source or destination CID') return - self.change_state(Channel.CLOSED) + self._change_state(self.State.CLOSED) if self.disconnection_result: self.disconnection_result.set_result(None) self.disconnection_result = None @@ -1004,7 +985,7 @@ class Channel(EventEmitter): f'Channel({self.source_cid}->{self.destination_cid}, ' f'PSM={self.psm}, ' f'MTU={self.mtu}, ' - f'state={Channel.STATE_NAMES[self.state]})' + f'state={self.state.name})' ) @@ -1014,33 +995,21 @@ class LeConnectionOrientedChannel(EventEmitter): LE Credit-based Connection Oriented Channel """ - INIT = 0 - CONNECTED = 1 - CONNECTING = 2 - DISCONNECTING = 3 - DISCONNECTED = 4 - CONNECTION_ERROR = 5 - - STATE_NAMES = { - INIT: 'INIT', - CONNECTED: 'CONNECTED', - CONNECTING: 'CONNECTING', - DISCONNECTING: 'DISCONNECTING', - DISCONNECTED: 'DISCONNECTED', - CONNECTION_ERROR: 'CONNECTION_ERROR', - } + class State(enum.IntEnum): + INIT = 0 + CONNECTED = 1 + CONNECTING = 2 + DISCONNECTING = 3 + DISCONNECTED = 4 + CONNECTION_ERROR = 5 out_queue: Deque[bytes] connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]] disconnection_result: Optional[asyncio.Future[None]] out_sdu: Optional[bytes] - state: int + state: State connection: Connection - @staticmethod - def state_name(state: int) -> str: - return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state) - def __init__( self, manager: ChannelManager, @@ -1083,22 +1052,20 @@ class LeConnectionOrientedChannel(EventEmitter): self.drained.set() if connected: - self.state = LeConnectionOrientedChannel.CONNECTED + self.state = self.State.CONNECTED else: - self.state = LeConnectionOrientedChannel.INIT + self.state = self.State.INIT - def change_state(self, new_state: int) -> None: - logger.debug( - f'{self} state change -> {color(self.state_name(new_state), "cyan")}' - ) + def _change_state(self, new_state: State) -> None: + logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}') self.state = new_state - if new_state == self.CONNECTED: + if new_state == self.State.CONNECTED: self.emit('open') - elif new_state == self.DISCONNECTED: + elif new_state == self.State.DISCONNECTED: self.emit('close') - def send_pdu(self, pdu: SupportsBytes | bytes) -> None: + def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: self.manager.send_pdu(self.connection, self.destination_cid, pdu) def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: @@ -1106,7 +1073,7 @@ class LeConnectionOrientedChannel(EventEmitter): async def connect(self) -> LeConnectionOrientedChannel: # Check that we're in the right state - if self.state != self.INIT: + if self.state != self.State.INIT: raise InvalidStateError('not in a connectable state') # Check that we can start a new connection @@ -1114,7 +1081,7 @@ class LeConnectionOrientedChannel(EventEmitter): if identifier in self.manager.le_coc_requests: raise RuntimeError('too many concurrent connection requests') - self.change_state(self.CONNECTING) + self._change_state(self.State.CONNECTING) request = L2CAP_LE_Credit_Based_Connection_Request( identifier=identifier, le_psm=self.le_psm, @@ -1134,10 +1101,10 @@ class LeConnectionOrientedChannel(EventEmitter): async def disconnect(self) -> None: # Check that we're connected - if self.state != self.CONNECTED: + if self.state != self.State.CONNECTED: raise InvalidStateError('not connected') - self.change_state(self.DISCONNECTING) + self._change_state(self.State.DISCONNECTING) self.flush_output() self.send_control_frame( L2CAP_Disconnection_Request( @@ -1153,15 +1120,15 @@ class LeConnectionOrientedChannel(EventEmitter): return await self.disconnection_result def abort(self) -> None: - if self.state == self.CONNECTED: - self.change_state(self.DISCONNECTED) + if self.state == self.State.CONNECTED: + self._change_state(self.State.DISCONNECTED) def on_pdu(self, pdu: bytes) -> None: if self.sink is None: logger.warning('received pdu without a sink') return - if self.state != self.CONNECTED: + if self.state != self.State.CONNECTED: logger.warning('received PDU while not connected, dropping') # Manage the peer credits @@ -1240,7 +1207,7 @@ class LeConnectionOrientedChannel(EventEmitter): self.credits = response.initial_credits self.connected = True self.connection_result.set_result(self) - self.change_state(self.CONNECTED) + self._change_state(self.State.CONNECTED) else: self.connection_result.set_exception( ProtocolError( @@ -1251,7 +1218,7 @@ class LeConnectionOrientedChannel(EventEmitter): ), ) ) - self.change_state(self.CONNECTION_ERROR) + self._change_state(self.State.CONNECTION_ERROR) # Cleanup self.connection_result = None @@ -1271,11 +1238,11 @@ class LeConnectionOrientedChannel(EventEmitter): source_cid=request.source_cid, ) ) - self.change_state(self.DISCONNECTED) + self._change_state(self.State.DISCONNECTED) self.flush_output() def on_disconnection_response(self, response) -> None: - if self.state != self.DISCONNECTING: + if self.state != self.State.DISCONNECTING: logger.warning(color('invalid state', 'red')) return @@ -1286,7 +1253,7 @@ class LeConnectionOrientedChannel(EventEmitter): logger.warning('unexpected source or destination CID') return - self.change_state(self.DISCONNECTED) + self._change_state(self.State.DISCONNECTED) if self.disconnection_result: self.disconnection_result.set_result(None) self.disconnection_result = None @@ -1339,7 +1306,7 @@ class LeConnectionOrientedChannel(EventEmitter): return def write(self, data: bytes) -> None: - if self.state != self.CONNECTED: + if self.state != self.State.CONNECTED: logger.warning('not connected, dropping data') return @@ -1367,7 +1334,7 @@ class LeConnectionOrientedChannel(EventEmitter): def __str__(self) -> str: return ( f'CoC({self.source_cid}->{self.destination_cid}, ' - f'State={self.state_name(self.state)}, ' + f'State={self.state.name}, ' f'PSM={self.le_psm}, ' f'MTU={self.mtu}/{self.peer_mtu}, ' f'MPS={self.mps}/{self.peer_mps}, ' @@ -1571,7 +1538,7 @@ class ChannelManager: if connection_handle in self.identifiers: del self.identifiers[connection_handle] - def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None: + def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None: pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) logger.debug( f'{color(">>> Sending L2CAP PDU", "blue")} ' diff --git a/bumble/pandora/security.py b/bumble/pandora/security.py index 96fce85..0f31512 100644 --- a/bumble/pandora/security.py +++ b/bumble/pandora/security.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import contextlib import grpc import logging @@ -27,8 +28,8 @@ from bumble.core import ( ) from bumble.device import Connection as BumbleConnection, Device from bumble.hci import HCI_Error +from bumble.utils import EventWatcher from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate -from contextlib import suppress from google.protobuf import any_pb2 # pytype: disable=pyi-error from google.protobuf import empty_pb2 # pytype: disable=pyi-error from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error @@ -232,7 +233,11 @@ class SecurityService(SecurityServicer): sc=config.pairing_sc_enable, mitm=config.pairing_mitm_enable, bonding=config.pairing_bonding_enable, - identity_address_type=config.identity_address_type, + identity_address_type=( + PairingConfig.AddressType.PUBLIC + if connection.self_address.is_public + else config.identity_address_type + ), delegate=PairingDelegate( connection, self, @@ -294,23 +299,35 @@ class SecurityService(SecurityServicer): try: self.log.debug('Pair...') - if ( - connection.transport == BT_LE_TRANSPORT - and connection.role == BT_PERIPHERAL_ROLE - ): - wait_for_security: asyncio.Future[ - bool - ] = asyncio.get_running_loop().create_future() - connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore - connection.on("pairing_failure", wait_for_security.set_exception) + security_result = asyncio.get_running_loop().create_future() + + with contextlib.closing(EventWatcher()) as watcher: + + @watcher.on(connection, 'pairing') + def on_pairing(*_: Any) -> None: + security_result.set_result('success') - connection.request_pairing() + @watcher.on(connection, 'pairing_failure') + def on_pairing_failure(*_: Any) -> None: + security_result.set_result('pairing_failure') - await wait_for_security - else: - await connection.pair() + @watcher.on(connection, 'disconnection') + def on_disconnection(*_: Any) -> None: + security_result.set_result('connection_died') - self.log.debug('Paired') + if ( + connection.transport == BT_LE_TRANSPORT + and connection.role == BT_PERIPHERAL_ROLE + ): + connection.request_pairing() + else: + await connection.pair() + + result = await security_result + + self.log.debug(f'Pairing session complete, status={result}') + if result != 'success': + return SecureResponse(**{result: empty_pb2.Empty()}) except asyncio.CancelledError: self.log.warning("Connection died during encryption") return SecureResponse(connection_died=empty_pb2.Empty()) @@ -369,6 +386,7 @@ class SecurityService(SecurityServicer): str ] = asyncio.get_running_loop().create_future() authenticate_task: Optional[asyncio.Future[None]] = None + pair_task: Optional[asyncio.Future[None]] = None async def authenticate() -> None: assert connection @@ -415,6 +433,10 @@ class SecurityService(SecurityServicer): if authenticate_task is None: authenticate_task = asyncio.create_task(authenticate()) + def pair(*_: Any) -> None: + if self.need_pairing(connection, level): + pair_task = asyncio.create_task(connection.pair()) + listeners: Dict[str, Callable[..., None]] = { 'disconnection': set_failure('connection_died'), 'pairing_failure': set_failure('pairing_failure'), @@ -425,6 +447,7 @@ class SecurityService(SecurityServicer): 'connection_encryption_change': on_encryption_change, 'classic_pairing': try_set_success, 'classic_pairing_failure': set_failure('pairing_failure'), + 'security_request': pair, } # register event handlers @@ -452,6 +475,15 @@ class SecurityService(SecurityServicer): pass self.log.debug('Authenticated') + # wait for `pair` to finish if any + if pair_task is not None: + self.log.debug('Wait for authentication...') + try: + await pair_task # type: ignore + except: + pass + self.log.debug('paired') + return WaitSecurityResponse(**kwargs) def reached_security_level( @@ -523,7 +555,7 @@ class SecurityStorageService(SecurityStorageServicer): self.log.debug(f"DeleteBond: {address}") if self.device.keystore is not None: - with suppress(KeyError): + with contextlib.suppress(KeyError): await self.device.keystore.delete(str(address)) return empty_pb2.Empty() diff --git a/bumble/smp.py b/bumble/smp.py index 55b8359..f8bba40 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -37,6 +37,7 @@ from typing import ( Optional, Tuple, Type, + cast, ) from pyee import EventEmitter @@ -1771,7 +1772,26 @@ class Manager(EventEmitter): cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID connection.send_l2cap_pdu(cid, command.to_bytes()) + def on_smp_security_request_command( + self, connection: Connection, request: SMP_Security_Request_Command + ) -> None: + connection.emit('security_request', request.auth_req) + def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None: + # Parse the L2CAP payload into an SMP Command object + command = SMP_Command.from_bytes(pdu) + logger.debug( + f'<<< Received SMP Command on connection [0x{connection.handle:04X}] ' + f'{connection.peer_address}: {command}' + ) + + # Security request is more than just pairing, so let applications handle them + if command.code == SMP_SECURITY_REQUEST_COMMAND: + self.on_smp_security_request_command( + connection, cast(SMP_Security_Request_Command, command) + ) + return + # Look for a session with this connection, and create one if none exists if not (session := self.sessions.get(connection.handle)): if connection.role == BT_CENTRAL_ROLE: @@ -1782,13 +1802,6 @@ class Manager(EventEmitter): ) self.sessions[connection.handle] = session - # Parse the L2CAP payload into an SMP Command object - command = SMP_Command.from_bytes(pdu) - logger.debug( - f'<<< Received SMP Command on connection [0x{connection.handle:04X}] ' - f'{connection.peer_address}: {command}' - ) - # Delegate the handling of the command to the session session.on_smp_command(command) diff --git a/bumble/transport/android_emulator.py b/bumble/transport/android_emulator.py index 5ef0047..8d19a9e 100644 --- a/bumble/transport/android_emulator.py +++ b/bumble/transport/android_emulator.py @@ -18,6 +18,8 @@ import logging import grpc.aio +from typing import Optional, Union + from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport # pylint: disable=no-name-in-module @@ -33,7 +35,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_android_emulator_transport(spec: str | None) -> Transport: +async def open_android_emulator_transport(spec: Optional[str]) -> Transport: ''' Open a transport connection to an Android emulator via its gRPC interface. The parameter string has this syntax: @@ -82,7 +84,7 @@ async def open_android_emulator_transport(spec: str | None) -> Transport: logger.debug(f'connecting to gRPC server at {server_address}') channel = grpc.aio.insecure_channel(server_address) - service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub + service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub] if mode == 'host': # Connect as a host service = EmulatedBluetoothServiceStub(channel) @@ -95,10 +97,13 @@ async def open_android_emulator_transport(spec: str | None) -> Transport: raise ValueError('invalid mode') # Create the transport object - transport = PumpedTransport( - PumpedPacketSource(hci_device.read), - PumpedPacketSink(hci_device.write), - channel.close, + class EmulatorTransport(PumpedTransport): + async def close(self): + await super().close() + await channel.close() + + transport = EmulatorTransport( + PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write) ) transport.start() diff --git a/bumble/transport/android_netsim.py b/bumble/transport/android_netsim.py index 76a7385..e9d36cd 100644 --- a/bumble/transport/android_netsim.py +++ b/bumble/transport/android_netsim.py @@ -18,11 +18,12 @@ import asyncio import atexit import logging -import grpc.aio import os import pathlib import sys -from typing import Optional +from typing import Dict, Optional + +import grpc.aio from .common import ( ParserSource, @@ -33,8 +34,8 @@ from .common import ( ) # pylint: disable=no-name-in-module -from .grpc_protobuf.packet_streamer_pb2_grpc import PacketStreamerStub from .grpc_protobuf.packet_streamer_pb2_grpc import ( + PacketStreamerStub, PacketStreamerServicer, add_PacketStreamerServicer_to_server, ) @@ -43,6 +44,7 @@ from .grpc_protobuf.hci_packet_pb2 import HCIPacket from .grpc_protobuf.startup_pb2 import Chip, ChipInfo from .grpc_protobuf.common_pb2 import ChipKind + # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -74,14 +76,20 @@ def get_ini_dir() -> Optional[pathlib.Path]: # ----------------------------------------------------------------------------- -def find_grpc_port() -> int: +def ini_file_name(instance_number: int) -> str: + suffix = f'_{instance_number}' if instance_number > 0 else '' + return f'netsim{suffix}.ini' + + +# ----------------------------------------------------------------------------- +def find_grpc_port(instance_number: int) -> int: if not (ini_dir := get_ini_dir()): logger.debug('no known directory for .ini file') return 0 - ini_file = ini_dir / 'netsim.ini' + ini_file = ini_dir / ini_file_name(instance_number) + logger.debug(f'Looking for .ini file at {ini_file}') if ini_file.is_file(): - logger.debug(f'Found .ini file at {ini_file}') with open(ini_file, 'r') as ini_file_data: for line in ini_file_data.readlines(): if '=' in line: @@ -90,12 +98,14 @@ def find_grpc_port() -> int: logger.debug(f'gRPC port = {value}') return int(value) + logger.debug('no grpc.port property found in .ini file') + # Not found return 0 # ----------------------------------------------------------------------------- -def publish_grpc_port(grpc_port) -> bool: +def publish_grpc_port(grpc_port: int, instance_number: int) -> bool: if not (ini_dir := get_ini_dir()): logger.debug('no known directory for .ini file') return False @@ -104,7 +114,7 @@ def publish_grpc_port(grpc_port) -> bool: logger.debug('ini directory does not exist') return False - ini_file = ini_dir / 'netsim.ini' + ini_file = ini_dir / ini_file_name(instance_number) try: ini_file.write_text(f'grpc.port={grpc_port}\n') logger.debug(f"published gRPC port at {ini_file}") @@ -122,14 +132,15 @@ def publish_grpc_port(grpc_port) -> bool: # ----------------------------------------------------------------------------- async def open_android_netsim_controller_transport( - server_host: str | None, server_port: int + server_host: Optional[str], server_port: int, options: Dict[str, str] ) -> Transport: if not server_port: raise ValueError('invalid port') if server_host == '_' or not server_host: server_host = 'localhost' - if not publish_grpc_port(server_port): + instance_number = int(options.get('instance', "0")) + if not publish_grpc_port(server_port, instance_number): logger.warning("unable to publish gRPC port") class HciDevice: @@ -186,15 +197,12 @@ async def open_android_netsim_controller_transport( logger.debug(f'<<< PACKET: {data.hex()}') self.on_data_received(data) - def send_packet(self, data): - async def send(): - await self.context.write( - PacketResponse( - hci_packet=HCIPacket(packet_type=data[0], packet=data[1:]) - ) + async def send_packet(self, data): + return await self.context.write( + PacketResponse( + hci_packet=HCIPacket(packet_type=data[0], packet=data[1:]) ) - - self.loop.create_task(send()) + ) def terminate(self): self.task.cancel() @@ -228,17 +236,17 @@ async def open_android_netsim_controller_transport( logger.debug('gRPC server cancelled') await self.grpc_server.stop(None) - def on_packet(self, packet): + async def send_packet(self, packet): if not self.device: logger.debug('no device, dropping packet') return - self.device.send_packet(packet) + return await self.device.send_packet(packet) async def StreamPackets(self, _request_iterator, context): logger.debug('StreamPackets request') - # Check that we won't already have a device + # Check that we don't already have a device if self.device: logger.debug('busy, already serving a device') return PacketResponse(error='Busy') @@ -261,15 +269,42 @@ async def open_android_netsim_controller_transport( await server.start() asyncio.get_running_loop().create_task(server.serve()) - class GrpcServerTransport(Transport): - async def close(self): - await super().close() + sink = PumpedPacketSink(server.send_packet) + sink.start() + return Transport(server, sink) + + +# ----------------------------------------------------------------------------- +async def open_android_netsim_host_transport_with_address( + server_host: Optional[str], + server_port: int, + options: Optional[Dict[str, str]] = None, +): + if server_host == '_' or not server_host: + server_host = 'localhost' + + if not server_port: + # Look for the gRPC config in a .ini file + instance_number = 0 if options is None else int(options.get('instance', '0')) + server_port = find_grpc_port(instance_number) + if not server_port: + raise RuntimeError('gRPC server port not found') + + # Connect to the gRPC server + server_address = f'{server_host}:{server_port}' + logger.debug(f'Connecting to gRPC server at {server_address}') + channel = grpc.aio.insecure_channel(server_address) - return GrpcServerTransport(server, server) + return await open_android_netsim_host_transport_with_channel( + channel, + options, + ) # ----------------------------------------------------------------------------- -async def open_android_netsim_host_transport(server_host, server_port, options): +async def open_android_netsim_host_transport_with_channel( + channel, options: Optional[Dict[str, str]] = None +): # Wrapper for I/O operations class HciDevice: def __init__(self, name, manufacturer, hci_device): @@ -288,10 +323,12 @@ async def open_android_netsim_host_transport(server_host, server_port, options): async def read(self): response = await self.hci_device.read() response_type = response.WhichOneof('response_type') + if response_type == 'error': logger.warning(f'received error: {response.error}') raise RuntimeError(response.error) - elif response_type == 'hci_packet': + + if response_type == 'hci_packet': return ( bytes([response.hci_packet.packet_type]) + response.hci_packet.packet @@ -306,24 +343,9 @@ async def open_android_netsim_host_transport(server_host, server_port, options): ) ) - name = options.get('name', DEFAULT_NAME) + name = DEFAULT_NAME if options is None else options.get('name', DEFAULT_NAME) manufacturer = DEFAULT_MANUFACTURER - if server_host == '_' or not server_host: - server_host = 'localhost' - - if not server_port: - # Look for the gRPC config in a .ini file - server_host = 'localhost' - server_port = find_grpc_port() - if not server_port: - raise RuntimeError('gRPC server port not found') - - # Connect to the gRPC server - server_address = f'{server_host}:{server_port}' - logger.debug(f'Connecting to gRPC server at {server_address}') - channel = grpc.aio.insecure_channel(server_address) - # Connect as a host service = PacketStreamerStub(channel) hci_device = HciDevice( @@ -334,10 +356,14 @@ async def open_android_netsim_host_transport(server_host, server_port, options): await hci_device.start() # Create the transport object - transport = PumpedTransport( + class GrpcTransport(PumpedTransport): + async def close(self): + await super().close() + await channel.close() + + transport = GrpcTransport( PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write), - channel.close, ) transport.start() @@ -345,7 +371,7 @@ async def open_android_netsim_host_transport(server_host, server_port, options): # ----------------------------------------------------------------------------- -async def open_android_netsim_transport(spec): +async def open_android_netsim_transport(spec: Optional[str]) -> Transport: ''' Open a transport connection as a client or server, implementing Android's `netsim` simulator protocol over gRPC. @@ -359,6 +385,11 @@ async def open_android_netsim_transport(spec): to connect *to* a netsim server (netsim is the controller), or accept connections *as* a netsim-compatible server. + instance=<n> + Specifies an instance number, with <n> > 0. This is used to determine which + .init file to use. In `host` mode, it is ignored when the <host>:<port> + specifier is present, since in that case no .ini file is used. + In `host` mode: The <host>:<port> part is optional. When not specified, the transport looks for a netsim .ini file, from which it will read the `grpc.backend.port` @@ -387,14 +418,15 @@ async def open_android_netsim_transport(spec): params = spec.split(',') if spec else [] if params and ':' in params[0]: # Explicit <host>:<port> - host, port = params[0].split(':') + host, port_str = params[0].split(':') + port = int(port_str) params_offset = 1 else: host = None port = 0 params_offset = 0 - options = {} + options: Dict[str, str] = {} for param in params[params_offset:]: if '=' not in param: raise ValueError('invalid parameter, expected <name>=<value>') @@ -403,10 +435,12 @@ async def open_android_netsim_transport(spec): mode = options.get('mode', 'host') if mode == 'host': - return await open_android_netsim_host_transport(host, port, options) + return await open_android_netsim_host_transport_with_address( + host, port, options + ) if mode == 'controller': if host is None: raise ValueError('<host>:<port> missing') - return await open_android_netsim_controller_transport(host, port) + return await open_android_netsim_controller_transport(host, port, options) raise ValueError('invalid mode option') diff --git a/bumble/transport/common.py b/bumble/transport/common.py index c030308..2786a75 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -339,8 +339,9 @@ class PumpedPacketSource(ParserSource): try: packet = await self.receive_function() self.parser.feed_data(packet) - except asyncio.exceptions.CancelledError: + except asyncio.CancelledError: logger.debug('source pump task done') + self.terminated.set_result(None) break except Exception as error: logger.warning(f'exception while waiting for packet: {error}') @@ -370,7 +371,7 @@ class PumpedPacketSink: try: packet = await self.packet_queue.get() await self.send_function(packet) - except asyncio.exceptions.CancelledError: + except asyncio.CancelledError: logger.debug('sink pump task done') break except Exception as error: @@ -393,19 +394,13 @@ class PumpedTransport(Transport): self, source: PumpedPacketSource, sink: PumpedPacketSink, - close_function, ) -> None: super().__init__(source, sink) - self.close_function = close_function def start(self) -> None: self.source.start() self.sink.start() - async def close(self) -> None: - await super().close() - await self.close_function() - # ----------------------------------------------------------------------------- class SnoopingTransport(Transport): diff --git a/bumble/transport/hci_socket.py b/bumble/transport/hci_socket.py index 9891c5b..df9e885 100644 --- a/bumble/transport/hci_socket.py +++ b/bumble/transport/hci_socket.py @@ -23,6 +23,8 @@ import socket import ctypes import collections +from typing import Optional + from .common import Transport, ParserSource @@ -33,7 +35,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_hci_socket_transport(spec: str | None) -> Transport: +async def open_hci_socket_transport(spec: Optional[str]) -> Transport: ''' Open an HCI Socket (only available on some platforms). The parameter string is either empty (to use the first/default Bluetooth adapter) @@ -45,9 +47,9 @@ async def open_hci_socket_transport(spec: str | None) -> Transport: # Create a raw HCI socket try: hci_socket = socket.socket( - socket.AF_BLUETOOTH, - socket.SOCK_RAW | socket.SOCK_NONBLOCK, - socket.BTPROTO_HCI, # type: ignore + socket.AF_BLUETOOTH, # type: ignore[attr-defined] + socket.SOCK_RAW | socket.SOCK_NONBLOCK, # type: ignore[attr-defined] + socket.BTPROTO_HCI, # type: ignore[attr-defined] ) except AttributeError as error: # Not supported on this platform @@ -78,7 +80,7 @@ async def open_hci_socket_transport(spec: str | None) -> Transport: bind_address = struct.pack( # pylint: disable=no-member '<HHH', - socket.AF_BLUETOOTH, + socket.AF_BLUETOOTH, # type: ignore[attr-defined] adapter_index, HCI_CHANNEL_USER, ) diff --git a/bumble/transport/pty.py b/bumble/transport/pty.py index 7765b09..2f46e75 100644 --- a/bumble/transport/pty.py +++ b/bumble/transport/pty.py @@ -23,6 +23,8 @@ import atexit import os import logging +from typing import Optional + from .common import Transport, StreamPacketSource, StreamPacketSink # ----------------------------------------------------------------------------- @@ -32,7 +34,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_pty_transport(spec: str | None) -> Transport: +async def open_pty_transport(spec: Optional[str]) -> Transport: ''' Open a PTY transport. The parameter string may be empty, or a path name where a symbolic link diff --git a/bumble/transport/vhci.py b/bumble/transport/vhci.py index 5795840..2b19085 100644 --- a/bumble/transport/vhci.py +++ b/bumble/transport/vhci.py @@ -17,6 +17,8 @@ # ----------------------------------------------------------------------------- import logging +from typing import Optional + from .common import Transport from .file import open_file_transport @@ -27,7 +29,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_vhci_transport(spec: str | None) -> Transport: +async def open_vhci_transport(spec: Optional[str]) -> Transport: ''' Open a VHCI transport (only available on some platforms). The parameter string is either empty (to use the default VHCI device diff --git a/bumble/transport/ws_client.py b/bumble/transport/ws_client.py index facd1c9..902001e 100644 --- a/bumble/transport/ws_client.py +++ b/bumble/transport/ws_client.py @@ -31,19 +31,21 @@ async def open_ws_client_transport(spec: str) -> Transport: ''' Open a WebSocket client transport. The parameter string has this syntax: - <remote-host>:<remote-port> + <websocket-url> - Example: 127.0.0.1:9001 + Example: ws://localhost:7681/v1/websocket/bt ''' - remote_host, remote_port = spec.split(':') - uri = f'ws://{remote_host}:{remote_port}' - websocket = await websockets.client.connect(uri) + websocket = await websockets.client.connect(spec) - transport = PumpedTransport( + class WsTransport(PumpedTransport): + async def close(self): + await super().close() + await websocket.close() + + transport = WsTransport( PumpedPacketSource(websocket.recv), PumpedPacketSink(websocket.send), - websocket.close, ) transport.start() return transport diff --git a/bumble/utils.py b/bumble/utils.py index 8a55684..dc03725 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -15,12 +15,24 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging import traceback import collections import sys -from typing import Awaitable, Set, TypeVar +from typing import ( + Awaitable, + Set, + TypeVar, + List, + Tuple, + Callable, + Any, + Optional, + Union, + overload, +) from functools import wraps from pyee import EventEmitter @@ -65,6 +77,102 @@ def composite_listener(cls): # ----------------------------------------------------------------------------- +_Handler = TypeVar('_Handler', bound=Callable) + + +class EventWatcher: + '''A wrapper class to control the lifecycle of event handlers better. + + Usage: + ``` + watcher = EventWatcher() + + def on_foo(): + ... + watcher.on(emitter, 'foo', on_foo) + + @watcher.on(emitter, 'bar') + def on_bar(): + ... + + # Close all event handlers watching through this watcher + watcher.close() + ``` + + As context: + ``` + with contextlib.closing(EventWatcher()) as context: + @context.on(emitter, 'foo') + def on_foo(): + ... + # on_foo() has been removed here! + ``` + ''' + + handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]] + + def __init__(self) -> None: + self.handlers = [] + + @overload + def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]: + ... + + @overload + def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler: + ... + + def on( + self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None + ) -> Union[_Handler, Callable[[_Handler], _Handler]]: + '''Watch an event until the context is closed. + + Args: + emitter: EventEmitter to watch + event: Event name + handler: (Optional) Event handler. When nothing is passed, this method works as a decorator. + ''' + + def wrapper(f: _Handler) -> _Handler: + self.handlers.append((emitter, event, f)) + emitter.on(event, f) + return f + + return wrapper if handler is None else wrapper(handler) + + @overload + def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]: + ... + + @overload + def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler: + ... + + def once( + self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None + ) -> Union[_Handler, Callable[[_Handler], _Handler]]: + '''Watch an event for once. + + Args: + emitter: EventEmitter to watch + event: Event name + handler: (Optional) Event handler. When nothing passed, this method works as a decorator. + ''' + + def wrapper(f: _Handler) -> _Handler: + self.handlers.append((emitter, event, f)) + emitter.once(event, f) + return f + + return wrapper if handler is None else wrapper(handler) + + def close(self) -> None: + for emitter, event, handler in self.handlers: + if handler in emitter.listeners(event): + emitter.remove_listener(event, handler) + + +# ----------------------------------------------------------------------------- _T = TypeVar('_T') diff --git a/rust/Cargo.lock b/rust/Cargo.lock index bd168dc..c2d0cd3 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -131,6 +131,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" [[package]] +name = "bstr" +version = "1.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c2f7349907b712260e64b0afe2f84692af14a454be26187d9df565c7f69266a" +dependencies = [ + "memchr", + "serde", +] + +[[package]] name = "bumble" version = "0.1.0" dependencies = [ @@ -138,7 +148,9 @@ dependencies = [ "clap 4.4.1", "directories", "env_logger", + "file-header", "futures", + "globset", "hex", "itertools", "lazy_static", @@ -273,6 +285,73 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] +name = "crossbeam" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset 0.9.0", + "scopeguard", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + +[[package]] name = "directories" version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -349,6 +428,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764" [[package]] +name = "file-header" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5568149106e77ae33bc3a2c3ef3839cbe63ffa4a8dd4a81612a6f9dfdbc2e9f" +dependencies = [ + "crossbeam", + "lazy_static", + "license", + "thiserror", + "walkdir", +] + +[[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -485,6 +577,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" [[package]] +name = "globset" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759c97c1e17c55525b57192c06a267cda0ac5210b222d6b82189a2338fa1c13d" +dependencies = [ + "aho-corasick", + "bstr", + "fnv", + "log", + "regex", +] + +[[package]] name = "h2" version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -711,6 +816,17 @@ dependencies = [ ] [[package]] +name = "license" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66615d42e949152327c402e03cd29dab8bff91ce470381ac2ca6d380d8d9946" +dependencies = [ + "reword", + "serde", + "serde_json", +] + +[[package]] name = "linux-raw-sys" version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -757,6 +873,15 @@ dependencies = [ ] [[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] name = "mime" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1201,6 +1326,15 @@ dependencies = [ ] [[package]] +name = "reword" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe272098dce9ed76b479995953f748d1851261390b08f8a0ff619c885a1f0765" +dependencies = [ + "unicode-segmentation", +] + +[[package]] name = "rusb" version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1242,6 +1376,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" [[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] name = "schannel" version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1590,6 +1733,12 @@ dependencies = [ ] [[package]] +name = "unicode-segmentation" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" + +[[package]] name = "unindent" version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1619,6 +1768,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] +name = "walkdir" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] name = "want" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 6c38c82..a553afd 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -24,6 +24,10 @@ itertools = "0.11.0" lazy_static = "1.4.0" thiserror = "1.0.41" +# Dev tools +file-header = { version = "0.1.2", optional = true } +globset = { version = "0.4.13", optional = true } + # CLI anyhow = { version = "1.0.71", optional = true } clap = { version = "4.3.3", features = ["derive"], optional = true } @@ -53,9 +57,14 @@ env_logger = "0.10.0" rustdoc-args = ["--generate-link-to-definition"] [[bin]] +name = "file-header" +path = "tools/file_header.rs" +required-features = ["dev-tools"] + +[[bin]] name = "gen-assigned-numbers" path = "tools/gen_assigned_numbers.rs" -required-features = ["bumble-codegen"] +required-features = ["dev-tools"] [[bin]] name = "bumble" @@ -71,7 +80,7 @@ harness = false [features] anyhow = ["pyo3/anyhow"] pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"] -bumble-codegen = ["dep:anyhow"] +dev-tools = ["dep:anyhow", "dep:clap", "dep:file-header", "dep:globset"] # separate feature for CLI so that dependencies don't spend time building these bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"] default = [] diff --git a/rust/README.md b/rust/README.md index 23dec03..15a19b9 100644 --- a/rust/README.md +++ b/rust/README.md @@ -62,5 +62,5 @@ in tests at `pytests/assigned_numbers.rs`. To regenerate the assigned number tables based on the Python codebase: ``` -PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features bumble-codegen +PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features dev-tools ```
\ No newline at end of file diff --git a/rust/pytests/assigned_numbers.rs b/rust/pytests/assigned_numbers.rs index 10e7f3e..7f8f1d1 100644 --- a/rust/pytests/assigned_numbers.rs +++ b/rust/pytests/assigned_numbers.rs @@ -1,3 +1,17 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use bumble::wrapper::{self, core::Uuid16}; use pyo3::{intern, prelude::*, types::PyDict}; use std::collections; diff --git a/rust/src/adv.rs b/rust/src/adv.rs index 8a4c979..6f84cc5 100644 --- a/rust/src/adv.rs +++ b/rust/src/adv.rs @@ -1,3 +1,17 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + //! BLE advertisements. use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS}; diff --git a/rust/src/wrapper/logging.rs b/rust/src/wrapper/logging.rs index 141cc04..bd932cb 100644 --- a/rust/src/wrapper/logging.rs +++ b/rust/src/wrapper/logging.rs @@ -1,3 +1,17 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + //! Bumble & Python logging use pyo3::types::PyDict; diff --git a/rust/tools/file_header.rs b/rust/tools/file_header.rs new file mode 100644 index 0000000..fb3286d --- /dev/null +++ b/rust/tools/file_header.rs @@ -0,0 +1,78 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::anyhow; +use clap::Parser as _; +use file_header::{ + add_headers_recursively, check_headers_recursively, + license::spdx::{YearCopyrightOwnerValue, APACHE_2_0}, +}; +use globset::{Glob, GlobSet, GlobSetBuilder}; +use std::{env, path::PathBuf}; + +fn main() -> anyhow::Result<()> { + let rust_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); + let ignore_globset = ignore_globset()?; + // Note: when adding headers, there is a bug where the line spacing is off for Apache 2.0 (see https://github.com/spdx/license-list-XML/issues/2127) + let header = APACHE_2_0.build_header(YearCopyrightOwnerValue::new(2023, "Google LLC".into())); + + let cli = Cli::parse(); + + match cli.subcommand { + Subcommand::CheckAll => { + let result = + check_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header, 4)?; + if result.has_failure() { + return Err(anyhow!( + "The following files do not have headers: {result:?}" + )); + } + } + Subcommand::AddAll => { + let files_with_new_header = + add_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header)?; + files_with_new_header + .iter() + .for_each(|path| println!("Added header to: {path:?}")); + } + } + Ok(()) +} + +fn ignore_globset() -> anyhow::Result<GlobSet> { + Ok(GlobSetBuilder::new() + .add(Glob::new("**/.idea/**")?) + .add(Glob::new("**/target/**")?) + .add(Glob::new("**/.gitignore")?) + .add(Glob::new("**/CHANGELOG.md")?) + .add(Glob::new("**/Cargo.lock")?) + .add(Glob::new("**/Cargo.toml")?) + .add(Glob::new("**/README.md")?) + .add(Glob::new("*.bin")?) + .build()?) +} + +#[derive(clap::Parser)] +struct Cli { + #[clap(subcommand)] + subcommand: Subcommand, +} + +#[derive(clap::Subcommand, Debug, Clone)] +enum Subcommand { + /// Checks if a license is present in files that are not in the ignore list. + CheckAll, + /// Adds a license as needed to files that are not in the ignore list. + AddAll, +} @@ -36,6 +36,10 @@ install_requires = bt-test-interfaces >= 0.0.2; platform_system!='Emscripten' click == 8.1.3; platform_system!='Emscripten' cryptography == 39; platform_system!='Emscripten' + # Pyodide bundles a version of cryptography that is built for wasm, which may not match the + # versions available on PyPI. Relax the version requirement since it's better than being + # completely unable to import the package in case of version mismatch. + cryptography >= 39.0; platform_system=='Emscripten' grpcio == 1.57.0; platform_system!='Emscripten' humanize >= 4.6.0; platform_system!='Emscripten' libusb1 >= 2.0.1; platform_system!='Emscripten' @@ -84,7 +88,7 @@ development = black == 22.10 grpcio-tools >= 1.57.0 invoke >= 1.7.3 - mypy == 1.2.0 + mypy == 1.5.0 nox >= 2022 pylint == 2.15.8 types-appdirs >= 1.4.3 diff --git a/tests/gatt_test.py b/tests/gatt_test.py index dd0277e..d9f6d60 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -891,10 +891,10 @@ async def async_main(): # ----------------------------------------------------------------------------- -def test_attribute_string_to_permissions(): - assert Attribute.string_to_permissions('READABLE') == 1 - assert Attribute.string_to_permissions('WRITEABLE') == 2 - assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3 +def test_permissions_from_string(): + assert Attribute.Permissions.from_string('READABLE') == 1 + assert Attribute.Permissions.from_string('WRITEABLE') == 2 + assert Attribute.Permissions.from_string('READABLE,WRITEABLE') == 3 # ----------------------------------------------------------------------------- diff --git a/tests/keystore_test.py b/tests/keystore_test.py index 2e73039..2a3d48d 100644 --- a/tests/keystore_test.py +++ b/tests/keystore_test.py @@ -18,6 +18,8 @@ import asyncio import json import logging +import pathlib +import pytest import tempfile import os @@ -83,87 +85,95 @@ JSON3 = """ # ----------------------------------------------------------------------------- -async def test_basic(): - with tempfile.NamedTemporaryFile(mode="r+", encoding='utf-8') as file: - keystore = JsonKeyStore('my_namespace', file.name) +@pytest.fixture +def temporary_file(): + file = tempfile.NamedTemporaryFile(delete=False) + file.close() + yield file.name + pathlib.Path(file.name).unlink() + + +# ----------------------------------------------------------------------------- +async def test_basic(temporary_file): + with open(temporary_file, mode='w', encoding='utf-8') as file: file.write("{}") file.flush() - keys = await keystore.get_all() - assert len(keys) == 0 - - keys = PairingKeys() - await keystore.update('foo', keys) - foo = await keystore.get('foo') - assert foo is not None - assert foo.ltk is None - ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) - keys.ltk = PairingKeys.Key(ltk) - await keystore.update('foo', keys) - foo = await keystore.get('foo') - assert foo is not None - assert foo.ltk is not None - assert foo.ltk.value == ltk + keystore = JsonKeyStore('my_namespace', temporary_file) - file.flush() - with open(file.name, "r", encoding="utf-8") as json_file: - json_data = json.load(json_file) - assert 'my_namespace' in json_data - assert 'foo' in json_data['my_namespace'] - assert 'ltk' in json_data['my_namespace']['foo'] + keys = await keystore.get_all() + assert len(keys) == 0 + + keys = PairingKeys() + await keystore.update('foo', keys) + foo = await keystore.get('foo') + assert foo is not None + assert foo.ltk is None + ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) + keys.ltk = PairingKeys.Key(ltk) + await keystore.update('foo', keys) + foo = await keystore.get('foo') + assert foo is not None + assert foo.ltk is not None + assert foo.ltk.value == ltk + + with open(file.name, "r", encoding="utf-8") as json_file: + json_data = json.load(json_file) + assert 'my_namespace' in json_data + assert 'foo' in json_data['my_namespace'] + assert 'ltk' in json_data['my_namespace']['foo'] # ----------------------------------------------------------------------------- -async def test_parsing(): - with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: - keystore = JsonKeyStore('my_namespace', file.name) +async def test_parsing(temporary_file): + with open(temporary_file, mode='w', encoding='utf-8') as file: file.write(JSON1) file.flush() - foo = await keystore.get('14:7D:DA:4E:53:A8/P') - assert foo is not None - assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683') + keystore = JsonKeyStore('my_namespace', file.name) + foo = await keystore.get('14:7D:DA:4E:53:A8/P') + assert foo is not None + assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683') # ----------------------------------------------------------------------------- -async def test_default_namespace(): - with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: - keystore = JsonKeyStore(None, file.name) +async def test_default_namespace(temporary_file): + with open(temporary_file, mode='w', encoding='utf-8') as file: file.write(JSON1) file.flush() - all_keys = await keystore.get_all() - assert len(all_keys) == 1 - name, keys = all_keys[0] - assert name == '14:7D:DA:4E:53:A8/P' - assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') + keystore = JsonKeyStore(None, file.name) + all_keys = await keystore.get_all() + assert len(all_keys) == 1 + name, keys = all_keys[0] + assert name == '14:7D:DA:4E:53:A8/P' + assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') - with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: - keystore = JsonKeyStore(None, file.name) + with open(temporary_file, mode='w', encoding='utf-8') as file: file.write(JSON2) file.flush() - keys = PairingKeys() - ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) - keys.ltk = PairingKeys.Key(ltk) - await keystore.update('foo', keys) - file.flush() - with open(file.name, "r", encoding="utf-8") as json_file: - json_data = json.load(json_file) - assert '__DEFAULT__' in json_data - assert 'foo' in json_data['__DEFAULT__'] - assert 'ltk' in json_data['__DEFAULT__']['foo'] - - with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: - keystore = JsonKeyStore(None, file.name) + keystore = JsonKeyStore(None, file.name) + keys = PairingKeys() + ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) + keys.ltk = PairingKeys.Key(ltk) + await keystore.update('foo', keys) + with open(file.name, "r", encoding="utf-8") as json_file: + json_data = json.load(json_file) + assert '__DEFAULT__' in json_data + assert 'foo' in json_data['__DEFAULT__'] + assert 'ltk' in json_data['__DEFAULT__']['foo'] + + with open(temporary_file, mode='w', encoding='utf-8') as file: file.write(JSON3) file.flush() - all_keys = await keystore.get_all() - assert len(all_keys) == 1 - name, keys = all_keys[0] - assert name == '14:7D:DA:4E:53:A8/P' - assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') + keystore = JsonKeyStore(None, file.name) + all_keys = await keystore.get_all() + assert len(all_keys) == 1 + name, keys = all_keys[0] + assert name == '14:7D:DA:4E:53:A8/P' + assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') # ----------------------------------------------------------------------------- diff --git a/tests/utils_test.py b/tests/utils_test.py new file mode 100644 index 0000000..d6f5780 --- /dev/null +++ b/tests/utils_test.py @@ -0,0 +1,77 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import os + +from bumble import utils +from pyee import EventEmitter +from unittest.mock import MagicMock + + +def test_on() -> None: + emitter = EventEmitter() + with contextlib.closing(utils.EventWatcher()) as context: + mock = MagicMock() + context.on(emitter, 'event', mock) + + emitter.emit('event') + + assert not emitter.listeners('event') + assert mock.call_count == 1 + + +def test_on_decorator() -> None: + emitter = EventEmitter() + with contextlib.closing(utils.EventWatcher()) as context: + mock = MagicMock() + + @context.on(emitter, 'event') + def on_event(*_) -> None: + mock() + + emitter.emit('event') + + assert not emitter.listeners('event') + assert mock.call_count == 1 + + +def test_multiple_handlers() -> None: + emitter = EventEmitter() + with contextlib.closing(utils.EventWatcher()) as context: + mock = MagicMock() + + context.once(emitter, 'a', mock) + context.once(emitter, 'b', mock) + + emitter.emit('b', 'b') + + assert not emitter.listeners('a') + assert not emitter.listeners('b') + + mock.assert_called_once_with('b') + + +# ----------------------------------------------------------------------------- +def run_tests(): + test_on() + test_on_decorator() + test_multiple_handlers() + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + run_tests() diff --git a/web/bumble.js b/web/bumble.js index c5fd6a3..b1243a5 100644 --- a/web/bumble.js +++ b/web/bumble.js @@ -74,7 +74,6 @@ export async function loadBumble(pyodide, bumblePackage) { await pyodide.loadPackage("micropip"); await pyodide.runPythonAsync(` import micropip - await micropip.install("cryptography") await micropip.install("${bumblePackage}") package_list = micropip.list() print(package_list) diff --git a/web/scanner/scanner.py b/web/scanner/scanner.py index dd53050..c0fc456 100644 --- a/web/scanner/scanner.py +++ b/web/scanner/scanner.py @@ -23,7 +23,7 @@ from bumble.device import Device # ----------------------------------------------------------------------------- class ScanEntry: def __init__(self, advertisement): - self.address = str(advertisement.address).replace("/P", "") + self.address = advertisement.address.to_string(False) self.address_type = ('Public', 'Random', 'Public Identity', 'Random Identity')[ advertisement.address.address_type ] diff --git a/web/speaker/speaker.py b/web/speaker/speaker.py index ddc2086..d9293a4 100644 --- a/web/speaker/speaker.py +++ b/web/speaker/speaker.py @@ -171,7 +171,7 @@ class Speaker: self.connection = connection connection.on('disconnection', self.on_bluetooth_disconnection) peer_name = '' if connection.peer_name is None else connection.peer_name - peer_address = str(connection.peer_address).replace('/P', '') + peer_address = connection.peer_address.to_string(False) self.emit_event( 'connection', {'peer_name': peer_name, 'peer_address': peer_address} ) |