aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoruael <uael@google.com>2023-09-29 18:40:06 +0000
committeruael <uael@google.com>2023-09-29 18:40:06 +0000
commit1decd46067626cdfb075435e789b0bc9946d9923 (patch)
treeaa98f29f308756c4a060711d20036a9e31d819fa
parent400265218fe6cc5cfa834c8d608d879f058c90ca (diff)
parent6f2b623e3ce909be4962b042c314ca03d12e8e2d (diff)
downloadbumble-1decd46067626cdfb075435e789b0bc9946d9923.tar.gz
Merge remote-tracking branch 'aosp/upstream-main' into main
Change-Id: I4b924760be1b02d29cf933930830d00b9cf89663
-rw-r--r--.github/workflows/code-check.yml4
-rw-r--r--.github/workflows/python-build-test.yml7
-rw-r--r--.vscode/settings.json2
-rw-r--r--apps/console.py2
-rw-r--r--apps/controller_info.py3
-rw-r--r--apps/pandora_server.py10
-rw-r--r--apps/speaker/speaker.py4
-rw-r--r--bumble/att.py128
-rw-r--r--bumble/core.py2
-rw-r--r--bumble/device.py8
-rw-r--r--bumble/gatt.py42
-rw-r--r--bumble/gatt_client.py99
-rw-r--r--bumble/gatt_server.py106
-rw-r--r--bumble/hci.py2
-rw-r--r--bumble/l2cap.py223
-rw-r--r--bumble/pandora/security.py66
-rw-r--r--bumble/smp.py27
-rw-r--r--bumble/transport/android_emulator.py17
-rw-r--r--bumble/transport/android_netsim.py134
-rw-r--r--bumble/transport/common.py11
-rw-r--r--bumble/transport/hci_socket.py12
-rw-r--r--bumble/transport/pty.py4
-rw-r--r--bumble/transport/vhci.py4
-rw-r--r--bumble/transport/ws_client.py16
-rw-r--r--bumble/utils.py110
-rw-r--r--rust/Cargo.lock159
-rw-r--r--rust/Cargo.toml13
-rw-r--r--rust/README.md2
-rw-r--r--rust/pytests/assigned_numbers.rs14
-rw-r--r--rust/src/adv.rs14
-rw-r--r--rust/src/wrapper/logging.rs14
-rw-r--r--rust/tools/file_header.rs78
-rw-r--r--setup.cfg6
-rw-r--r--tests/gatt_test.py8
-rw-r--r--tests/keystore_test.py126
-rw-r--r--tests/utils_test.py77
-rw-r--r--web/bumble.js1
-rw-r--r--web/scanner/scanner.py2
-rw-r--r--web/speaker/speaker.py2
39 files changed, 1110 insertions, 449 deletions
diff --git a/.github/workflows/code-check.yml b/.github/workflows/code-check.yml
index b6cf8fd..021b1e4 100644
--- a/.github/workflows/code-check.yml
+++ b/.github/workflows/code-check.yml
@@ -14,6 +14,10 @@ jobs:
check:
name: Check Code
runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
+ fail-fast: false
steps:
- name: Check out from Git
diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml
index c8a1031..4cc3e73 100644
--- a/.github/workflows/python-build-test.yml
+++ b/.github/workflows/python-build-test.yml
@@ -12,10 +12,10 @@ permissions:
jobs:
build:
-
- runs-on: ubuntu-latest
+ runs-on: ${{ matrix.os }}
strategy:
matrix:
+ os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false
@@ -41,6 +41,7 @@ jobs:
run: |
inv build
inv build.mkdocs
+
build-rust:
runs-on: ubuntu-latest
strategy:
@@ -64,6 +65,8 @@ jobs:
with:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
+ - name: Check License Headers
+ run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build
run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
# Lints after build so what clippy needs is already built
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 864fe69..57e682a 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -39,10 +39,12 @@
"libusb",
"MITM",
"NDIS",
+ "netsim",
"NONBLOCK",
"NONCONN",
"OXIMETER",
"popleft",
+ "protobuf",
"psms",
"pyee",
"pyusb",
diff --git a/apps/console.py b/apps/console.py
index 0ea9e5b..9a529dd 100644
--- a/apps/console.py
+++ b/apps/console.py
@@ -1172,7 +1172,7 @@ class ScanResult:
name = ''
# Remove any '/P' qualifier suffix from the address string
- address_str = str(self.address).replace('/P', '')
+ address_str = self.address.to_string(with_type_qualifier=False)
# RSSI bar
bar_string = rssi_bar(self.rssi)
diff --git a/apps/controller_info.py b/apps/controller_info.py
index 4707983..5be4f3d 100644
--- a/apps/controller_info.py
+++ b/apps/controller_info.py
@@ -63,7 +63,8 @@ async def get_classic_info(host):
if command_succeeded(response):
print()
print(
- color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
+ color('Classic Address:', 'yellow'),
+ response.return_parameters.bd_addr.to_string(False),
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
diff --git a/apps/pandora_server.py b/apps/pandora_server.py
index b577f82..16bc211 100644
--- a/apps/pandora_server.py
+++ b/apps/pandora_server.py
@@ -3,7 +3,7 @@ import click
import logging
import json
-from bumble.pandora import PandoraDevice, serve
+from bumble.pandora import PandoraDevice, Config, serve
from typing import Dict, Any
BUMBLE_SERVER_GRPC_PORT = 7999
@@ -29,12 +29,14 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
bumble_config = retrieve_config(config)
- if 'transport' not in bumble_config.keys():
- bumble_config.update({'transport': transport})
+ bumble_config.setdefault('transport', transport)
device = PandoraDevice(bumble_config)
+ server_config = Config()
+ server_config.load_from_dict(bumble_config.get('server', {}))
+
logging.basicConfig(level=logging.DEBUG)
- asyncio.run(serve(device, port=grpc_port))
+ asyncio.run(serve(device, config=server_config, port=grpc_port))
def retrieve_config(config: str) -> Dict[str, Any]:
diff --git a/apps/speaker/speaker.py b/apps/speaker/speaker.py
index 1a1eac3..e451c04 100644
--- a/apps/speaker/speaker.py
+++ b/apps/speaker/speaker.py
@@ -195,7 +195,7 @@ class WebSocketOutput(QueuedOutput):
except HCI_StatusError:
pass
peer_name = '' if connection.peer_name is None else connection.peer_name
- peer_address = str(connection.peer_address).replace('/P', '')
+ peer_address = connection.peer_address.to_string(False)
await self.send_message(
'connection',
peer_address=peer_address,
@@ -376,7 +376,7 @@ class UiServer:
if connection := self.speaker().connection:
await self.send_message(
'connection',
- peer_address=str(connection.peer_address).replace('/P', ''),
+ peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name,
)
diff --git a/bumble/att.py b/bumble/att.py
index 55ae8a5..db8d2ba 100644
--- a/bumble/att.py
+++ b/bumble/att.py
@@ -23,13 +23,14 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
+import enum
import functools
import struct
from pyee import EventEmitter
-from typing import Dict, Type, TYPE_CHECKING
+from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
-from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
-from bumble.hci import HCI_Object, key_with_value, HCI_Constant
+from bumble.core import UUID, name_or_number, ProtocolError
+from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color
if TYPE_CHECKING:
@@ -182,6 +183,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# pylint: enable=line-too-long
# pylint: disable=invalid-name
+
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
@@ -209,7 +211,7 @@ class ATT_PDU:
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
op_code = 0
- name = None
+ name: str
@staticmethod
def from_bytes(pdu):
@@ -720,47 +722,67 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# -----------------------------------------------------------------------------
-class Attribute(EventEmitter):
- # Permission flags
- READABLE = 0x01
- WRITEABLE = 0x02
- READ_REQUIRES_ENCRYPTION = 0x04
- WRITE_REQUIRES_ENCRYPTION = 0x08
- READ_REQUIRES_AUTHENTICATION = 0x10
- WRITE_REQUIRES_AUTHENTICATION = 0x20
- READ_REQUIRES_AUTHORIZATION = 0x40
- WRITE_REQUIRES_AUTHORIZATION = 0x80
-
- PERMISSION_NAMES = {
- READABLE: 'READABLE',
- WRITEABLE: 'WRITEABLE',
- READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
- WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
- READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
- WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
- READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
- WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
- }
+class ConnectionValue(Protocol):
+ def read(self, connection) -> bytes:
+ ...
- @staticmethod
- def string_to_permissions(permissions_str: str):
- try:
- return functools.reduce(
- lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
- permissions_str.split(","),
- 0,
- )
- except TypeError as exc:
- raise TypeError(
- f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
- ) from exc
+ def write(self, connection, value: bytes) -> None:
+ ...
- def __init__(self, attribute_type, permissions, value=b''):
+
+# -----------------------------------------------------------------------------
+class Attribute(EventEmitter):
+ class Permissions(enum.IntFlag):
+ READABLE = 0x01
+ WRITEABLE = 0x02
+ READ_REQUIRES_ENCRYPTION = 0x04
+ WRITE_REQUIRES_ENCRYPTION = 0x08
+ READ_REQUIRES_AUTHENTICATION = 0x10
+ WRITE_REQUIRES_AUTHENTICATION = 0x20
+ READ_REQUIRES_AUTHORIZATION = 0x40
+ WRITE_REQUIRES_AUTHORIZATION = 0x80
+
+ @classmethod
+ def from_string(cls, permissions_str: str) -> Attribute.Permissions:
+ try:
+ return functools.reduce(
+ lambda x, y: x | Attribute.Permissions[y],
+ permissions_str.replace('|', ',').split(","),
+ Attribute.Permissions(0),
+ )
+ except TypeError as exc:
+ # The check for `p.name is not None` here is needed because for InFlag
+ # enums, the .name property can be None, when the enum value is 0,
+ # so the type hint for .name is Optional[str].
+ enum_list: List[str] = [p.name for p in cls if p.name is not None]
+ enum_list_str = ",".join(enum_list)
+ raise TypeError(
+ f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
+ ) from exc
+
+ # Permission flags(legacy-use only)
+ READABLE = Permissions.READABLE
+ WRITEABLE = Permissions.WRITEABLE
+ READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
+ WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
+ READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
+ WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
+ READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
+ WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
+
+ value: Union[str, bytes, ConnectionValue]
+
+ def __init__(
+ self,
+ attribute_type: Union[str, bytes, UUID],
+ permissions: Union[str, Attribute.Permissions],
+ value: Union[str, bytes, ConnectionValue] = b'',
+ ) -> None:
EventEmitter.__init__(self)
self.handle = 0
self.end_group_handle = 0
if isinstance(permissions, str):
- self.permissions = self.string_to_permissions(permissions)
+ self.permissions = Attribute.Permissions.from_string(permissions)
else:
self.permissions = permissions
@@ -778,22 +800,26 @@ class Attribute(EventEmitter):
else:
self.value = value
- def encode_value(self, value):
+ def encode_value(self, value: Any) -> bytes:
return value
- def decode_value(self, value_bytes):
+ def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
- def read_value(self, connection: Connection):
+ def read_value(self, connection: Optional[Connection]) -> bytes:
if (
- self.permissions & self.READ_REQUIRES_ENCRYPTION
- ) and not connection.encryption:
+ (self.permissions & self.READ_REQUIRES_ENCRYPTION)
+ and connection is not None
+ and not connection.encryption
+ ):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
- self.permissions & self.READ_REQUIRES_AUTHENTICATION
- ) and not connection.authenticated:
+ (self.permissions & self.READ_REQUIRES_AUTHENTICATION)
+ and connection is not None
+ and not connection.authenticated
+ ):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
@@ -803,9 +829,9 @@ class Attribute(EventEmitter):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
- if read := getattr(self.value, 'read', None):
+ if hasattr(self.value, 'read'):
try:
- value = read(connection) # pylint: disable=not-callable
+ value = self.value.read(connection)
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
@@ -815,7 +841,7 @@ class Attribute(EventEmitter):
return self.encode_value(value)
- def write_value(self, connection: Connection, value_bytes):
+ def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
@@ -836,9 +862,9 @@ class Attribute(EventEmitter):
value = self.decode_value(value_bytes)
- if write := getattr(self.value, 'write', None):
+ if hasattr(self.value, 'write'):
try:
- write(connection, value) # pylint: disable=not-callable
+ self.value.write(connection, value) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
diff --git a/bumble/core.py b/bumble/core.py
index 4dff432..4a67d6e 100644
--- a/bumble/core.py
+++ b/bumble/core.py
@@ -80,7 +80,7 @@ class BaseError(Exception):
def __init__(
self,
- error_code: int | None,
+ error_code: Optional[int],
error_namespace: str = '',
error_name: str = '',
details: str = '',
diff --git a/bumble/device.py b/bumble/device.py
index 9a784e7..b01dc58 100644
--- a/bumble/device.py
+++ b/bumble/device.py
@@ -1186,8 +1186,8 @@ class Device(CompositeEventEmitter):
def create_l2cap_registrar(self, psm):
return lambda handler: self.register_l2cap_server(psm, handler)
- def register_l2cap_server(self, psm, server):
- self.l2cap_channel_manager.register_server(psm, server)
+ def register_l2cap_server(self, psm, server) -> int:
+ return self.l2cap_channel_manager.register_server(psm, server)
def register_l2cap_channel_server(
self,
@@ -2758,7 +2758,9 @@ class Device(CompositeEventEmitter):
self.abort_on(
'flush',
self.start_advertising(
- advertising_type=self.advertising_type, auto_restart=True
+ advertising_type=self.advertising_type,
+ own_address_type=self.advertising_own_address_type,
+ auto_restart=True,
),
)
diff --git a/bumble/gatt.py b/bumble/gatt.py
index 067f31d..fe3e85c 100644
--- a/bumble/gatt.py
+++ b/bumble/gatt.py
@@ -28,7 +28,7 @@ import enum
import functools
import logging
import struct
-from typing import Optional, Sequence, List
+from typing import Optional, Sequence, Iterable, List, Union
from .colors import color
from .core import UUID, get_dict_key_by_value
@@ -187,7 +187,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
# -----------------------------------------------------------------------------
-def show_services(services):
+def show_services(services: Iterable[Service]) -> None:
for service in services:
print(color(str(service), 'cyan'))
@@ -210,11 +210,11 @@ class Service(Attribute):
def __init__(
self,
- uuid,
+ uuid: Union[str, UUID],
characteristics: List[Characteristic],
primary=True,
included_services: List[Service] = [],
- ):
+ ) -> None:
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
uuid = UUID(uuid)
@@ -239,7 +239,7 @@ class Service(Attribute):
"""
return None
- def __str__(self):
+ def __str__(self) -> str:
return (
f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
@@ -255,9 +255,11 @@ class TemplateService(Service):
to expose their UUID as a class property
'''
- UUID: Optional[UUID] = None
+ UUID: UUID
- def __init__(self, characteristics, primary=True):
+ def __init__(
+ self, characteristics: List[Characteristic], primary: bool = True
+ ) -> None:
super().__init__(self.UUID, characteristics, primary)
@@ -269,7 +271,7 @@ class IncludedServiceDeclaration(Attribute):
service: Service
- def __init__(self, service):
+ def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
)
@@ -278,7 +280,7 @@ class IncludedServiceDeclaration(Attribute):
)
self.service = service
- def __str__(self):
+ def __str__(self) -> str:
return (
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
f'group_starting_handle=0x{self.service.handle:04X}, '
@@ -326,7 +328,7 @@ class Characteristic(Attribute):
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
)
- def __str__(self):
+ def __str__(self) -> str:
# NOTE: we override this method to offer a consistent result between python
# versions: the value returned by IntFlag.__str__() changed in version 11.
return '|'.join(
@@ -348,10 +350,10 @@ class Characteristic(Attribute):
def __init__(
self,
- uuid,
+ uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties,
- permissions,
- value=b'',
+ permissions: Union[str, Attribute.Permissions],
+ value: Union[str, bytes, CharacteristicValue] = b'',
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value)
@@ -369,7 +371,7 @@ class Characteristic(Attribute):
def has_properties(self, properties: Characteristic.Properties) -> bool:
return self.properties & properties == properties
- def __str__(self):
+ def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
@@ -386,7 +388,7 @@ class CharacteristicDeclaration(Attribute):
characteristic: Characteristic
- def __init__(self, characteristic, value_handle):
+ def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
@@ -397,7 +399,7 @@ class CharacteristicDeclaration(Attribute):
self.value_handle = value_handle
self.characteristic = characteristic
- def __str__(self):
+ def __str__(self) -> str:
return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
@@ -520,7 +522,7 @@ class CharacteristicAdapter:
return self.wrapped_characteristic.unsubscribe(subscriber)
- def __str__(self):
+ def __str__(self) -> str:
wrapped = str(self.wrapped_characteristic)
return f'{self.__class__.__name__}({wrapped})'
@@ -600,10 +602,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
- def encode_value(self, value):
+ def encode_value(self, value: str) -> bytes:
return value.encode('utf-8')
- def decode_value(self, value):
+ def decode_value(self, value: bytes) -> str:
return value.decode('utf-8')
@@ -613,7 +615,7 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
'''
- def __str__(self):
+ def __str__(self) -> str:
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py
index a33039e..e3b8bb2 100644
--- a/bumble/gatt_client.py
+++ b/bumble/gatt_client.py
@@ -28,7 +28,18 @@ import asyncio
import logging
import struct
from datetime import datetime
-from typing import List, Optional, Dict, Tuple, Callable, Union, Any
+from typing import (
+ List,
+ Optional,
+ Dict,
+ Tuple,
+ Callable,
+ Union,
+ Any,
+ Iterable,
+ Type,
+ TYPE_CHECKING,
+)
from pyee import EventEmitter
@@ -66,8 +77,12 @@ from .gatt import (
GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
+ TemplateService,
)
+if TYPE_CHECKING:
+ from bumble.device import Connection
+
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -78,16 +93,16 @@ logger = logging.getLogger(__name__)
# Proxies
# -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter):
- client: Client
-
- def __init__(self, client, handle, end_group_handle, attribute_type):
+ def __init__(
+ self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
+ ) -> None:
EventEmitter.__init__(self)
self.client = client
self.handle = handle
self.end_group_handle = end_group_handle
self.type = attribute_type
- async def read_value(self, no_long_read=False):
+ async def read_value(self, no_long_read: bool = False) -> bytes:
return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
@@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter):
self.handle, self.encode_value(value), with_response
)
- def encode_value(self, value):
+ def encode_value(self, value: Any) -> bytes:
return value
- def decode_value(self, value_bytes):
+ def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
- def __str__(self):
+ def __str__(self) -> str:
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
@@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy):
def get_characteristics_by_uuid(self, uuid):
return self.client.get_characteristics_by_uuid(uuid, self)
- def __str__(self):
+ def __str__(self) -> str:
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
class CharacteristicProxy(AttributeProxy):
properties: Characteristic.Properties
descriptors: List[DescriptorProxy]
- subscribers: Dict[Any, Callable]
+ subscribers: Dict[Any, Callable[[bytes], Any]]
def __init__(
self,
@@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.discover_descriptors(self)
async def subscribe(
- self, subscriber: Optional[Callable] = None, prefer_notify=True
+ self,
+ subscriber: Optional[Callable[[bytes], Any]] = None,
+ prefer_notify: bool = True,
):
if subscriber is not None:
if subscriber in self.subscribers:
@@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.unsubscribe(self, subscriber)
- def __str__(self):
+ def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, '
@@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy):
def __init__(self, client, handle, descriptor_type):
super().__init__(client, handle, 0, descriptor_type)
- def __str__(self):
+ def __str__(self) -> str:
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
@@ -216,8 +233,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies
'''
+ SERVICE_CLASS: Type[TemplateService]
+
@classmethod
- def from_client(cls, client):
+ def from_client(cls, client: Client) -> ProfileServiceProxy:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -227,8 +246,12 @@ class ProfileServiceProxy:
class Client:
services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]]
+ notification_subscribers: Dict[int, Callable[[bytes], Any]]
+ indication_subscribers: Dict[int, Callable[[bytes], Any]]
+ pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
+ pending_request: Optional[ATT_PDU]
- def __init__(self, connection):
+ def __init__(self, connection: Connection) -> None:
self.connection = connection
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
@@ -241,16 +264,16 @@ class Client:
self.services = []
self.cached_values = {}
- def send_gatt_pdu(self, pdu):
+ def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu)
- async def send_command(self, command):
+ async def send_command(self, command: ATT_PDU) -> None:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes())
- async def send_request(self, request):
+ async def send_request(self, request: ATT_PDU):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
@@ -279,14 +302,14 @@ class Client:
return response
- def send_confirmation(self, confirmation):
+ def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes())
- async def request_mtu(self, mtu):
+ async def request_mtu(self, mtu: int) -> int:
# Check the range
if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
@@ -313,10 +336,12 @@ class Client:
return self.connection.att_mtu
- def get_services_by_uuid(self, uuid):
+ def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
- def get_characteristics_by_uuid(self, uuid, service=None):
+ def get_characteristics_by_uuid(
+ self, uuid: UUID, service: Optional[ServiceProxy] = None
+ ) -> List[CharacteristicProxy]:
services = [service] if service else self.services
return [
c
@@ -363,7 +388,7 @@ class Client:
if not already_known:
self.services.append(service)
- async def discover_services(self, uuids=None) -> List[ServiceProxy]:
+ async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -435,7 +460,7 @@ class Client:
return services
- async def discover_service(self, uuid):
+ async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
'''
@@ -468,7 +493,7 @@ class Client:
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
- return
+ return []
break
for attribute_handle, end_group_handle in response.handles_information:
@@ -480,7 +505,7 @@ class Client:
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
- return
+ return []
# Create a service proxy for this service
service = ServiceProxy(
@@ -721,7 +746,7 @@ class Client:
return descriptors
- async def discover_attributes(self):
+ async def discover_attributes(self) -> List[AttributeProxy]:
'''
Discover all attributes, regardless of type
'''
@@ -844,7 +869,9 @@ class Client:
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
- async def read_value(self, attribute, no_long_read=False):
+ async def read_value(
+ self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
+ ) -> Any:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -905,7 +932,9 @@ class Client:
# Return the value as bytes
return attribute_value
- async def read_characteristics_by_uuid(self, uuid, service):
+ async def read_characteristics_by_uuid(
+ self, uuid: UUID, service: Optional[ServiceProxy]
+ ) -> List[bytes]:
'''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
'''
@@ -960,7 +989,12 @@ class Client:
return characteristics_values
- async def write_value(self, attribute, value, with_response=False):
+ async def write_value(
+ self,
+ attribute: Union[int, AttributeProxy],
+ value: bytes,
+ with_response: bool = False,
+ ) -> None:
'''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value
@@ -990,7 +1024,7 @@ class Client:
)
)
- def on_gatt_pdu(self, att_pdu):
+ def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
@@ -1013,6 +1047,7 @@ class Client:
return
# Return the response to the coroutine that is waiting for it
+ assert self.pending_response is not None
self.pending_response.set_result(att_pdu)
else:
handler_name = f'on_{att_pdu.name.lower()}'
@@ -1060,7 +1095,7 @@ class Client:
# Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation())
- def cache_value(self, attribute_handle: int, value: bytes):
+ def cache_value(self, attribute_handle: int, value: bytes) -> None:
self.cached_values[attribute_handle] = (
datetime.now(),
value,
diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py
index 3624905..cdf1b5e 100644
--- a/bumble/gatt_server.py
+++ b/bumble/gatt_server.py
@@ -23,11 +23,12 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
+from __future__ import annotations
import asyncio
import logging
from collections import defaultdict
import struct
-from typing import List, Tuple, Optional, TypeVar, Type
+from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter
from .colors import color
@@ -42,6 +43,7 @@ from .att import (
ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS,
+ ATT_PDU,
ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error,
@@ -73,6 +75,8 @@ from .gatt import (
Service,
)
+if TYPE_CHECKING:
+ from bumble.device import Device, Connection
# -----------------------------------------------------------------------------
# Logging
@@ -91,8 +95,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
class Server(EventEmitter):
attributes: List[Attribute]
+ services: List[Service]
+ attributes_by_handle: Dict[int, Attribute]
+ subscribers: Dict[int, Dict[int, bytes]]
+ indication_semaphores: defaultdict[int, asyncio.Semaphore]
+ pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
- def __init__(self, device):
+ def __init__(self, device: Device) -> None:
super().__init__()
self.device = device
self.services = []
@@ -107,16 +116,16 @@ class Server(EventEmitter):
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
- def __str__(self):
+ def __str__(self) -> str:
return "\n".join(map(str, self.attributes))
- def send_gatt_pdu(self, connection_handle, pdu):
+ def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
- def next_handle(self):
+ def next_handle(self) -> int:
return 1 + len(self.attributes)
- def get_advertising_service_data(self):
+ def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
return {
attribute: data
for attribute in self.attributes
@@ -124,7 +133,7 @@ class Server(EventEmitter):
and (data := attribute.get_advertising_data())
}
- def get_attribute(self, handle):
+ def get_attribute(self, handle: int) -> Optional[Attribute]:
attribute = self.attributes_by_handle.get(handle)
if attribute:
return attribute
@@ -173,12 +182,17 @@ class Server(EventEmitter):
return next(
(
- (attribute, self.get_attribute(attribute.characteristic.handle))
+ (
+ attribute,
+ self.get_attribute(attribute.characteristic.handle),
+ ) # type: ignore
for attribute in map(
self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1),
)
- if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
+ if attribute is not None
+ and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
+ and isinstance(attribute, CharacteristicDeclaration)
and attribute.characteristic.uuid == characteristic_uuid
),
None,
@@ -197,7 +211,7 @@ class Server(EventEmitter):
return next(
(
- attribute
+ attribute # type: ignore
for attribute in map(
self.get_attribute,
range(
@@ -205,12 +219,12 @@ class Server(EventEmitter):
characteristic_value.end_group_handle + 1,
),
)
- if attribute.type == descriptor_uuid
+ if attribute is not None and attribute.type == descriptor_uuid
),
None,
)
- def add_attribute(self, attribute):
+ def add_attribute(self, attribute: Attribute) -> None:
# Assign a handle to this attribute
attribute.handle = self.next_handle()
attribute.end_group_handle = (
@@ -220,7 +234,7 @@ class Server(EventEmitter):
# Add this attribute to the list
self.attributes.append(attribute)
- def add_service(self, service: Service):
+ def add_service(self, service: Service) -> None:
# Add the service attribute to the DB
self.add_attribute(service)
@@ -285,11 +299,13 @@ class Server(EventEmitter):
service.end_group_handle = self.attributes[-1].handle
self.services.append(service)
- def add_services(self, services):
+ def add_services(self, services: Iterable[Service]) -> None:
for service in services:
self.add_service(service)
- def read_cccd(self, connection, characteristic):
+ def read_cccd(
+ self, connection: Optional[Connection], characteristic: Characteristic
+ ) -> bytes:
if connection is None:
return bytes([0, 0])
@@ -300,7 +316,12 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
- def write_cccd(self, connection, characteristic, value):
+ def write_cccd(
+ self,
+ connection: Connection,
+ characteristic: Characteristic,
+ value: bytes,
+ ) -> None:
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
@@ -327,13 +348,19 @@ class Server(EventEmitter):
indicate_enabled,
)
- def send_response(self, connection, response):
+ def send_response(self, connection: Connection, response: ATT_PDU) -> None:
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, response.to_bytes())
- async def notify_subscriber(self, connection, attribute, value=None, force=False):
+ async def notify_subscriber(
+ self,
+ connection: Connection,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -370,7 +397,13 @@ class Server(EventEmitter):
)
self.send_gatt_pdu(connection.handle, bytes(notification))
- async def indicate_subscriber(self, connection, attribute, value=None, force=False):
+ async def indicate_subscriber(
+ self,
+ connection: Connection,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -411,15 +444,13 @@ class Server(EventEmitter):
assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response
- self.pending_confirmations[
+ pending_confirmation = self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
try:
self.send_gatt_pdu(connection.handle, indication.to_bytes())
- await asyncio.wait_for(
- self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
- )
+ await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error
@@ -427,8 +458,12 @@ class Server(EventEmitter):
self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers(
- self, indicate, attribute, value=None, force=False
- ):
+ self,
+ indicate: bool,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ) -> None:
# Get all the connections for which there's at least one subscription
connections = [
connection
@@ -450,13 +485,23 @@ class Server(EventEmitter):
]
)
- async def notify_subscribers(self, attribute, value=None, force=False):
+ async def notify_subscribers(
+ self,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
- async def indicate_subscribers(self, attribute, value=None, force=False):
+ async def indicate_subscribers(
+ self,
+ attribute: Attribute,
+ value: Optional[bytes] = None,
+ force: bool = False,
+ ):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
- def on_disconnection(self, connection):
+ def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
@@ -464,7 +509,7 @@ class Server(EventEmitter):
if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle]
- def on_gatt_pdu(self, connection, att_pdu):
+ def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None)
@@ -506,7 +551,7 @@ class Server(EventEmitter):
#######################################################
# ATT handlers
#######################################################
- def on_att_request(self, connection, pdu):
+ def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
'''
Handler for requests without a more specific handler
'''
@@ -679,7 +724,6 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle
and pdu_space_available
):
-
try:
attribute_value = attribute.read_value(connection)
except ATT_Error as error:
diff --git a/bumble/hci.py b/bumble/hci.py
index b7014dd..41deed2 100644
--- a/bumble/hci.py
+++ b/bumble/hci.py
@@ -4397,7 +4397,7 @@ class HCI_Event(HCI_Packet):
if len(parameters) != length:
raise ValueError('invalid packet length')
- cls: Type[HCI_Event | HCI_LE_Meta_Event] | None
+ cls: Any
if event_code == HCI_LE_META_EVENT:
# We do this dispatch here and not in the subclass in order to avoid call
# loops
diff --git a/bumble/l2cap.py b/bumble/l2cap.py
index fea8a1d..cccb172 100644
--- a/bumble/l2cap.py
+++ b/bumble/l2cap.py
@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
+import enum
import logging
import struct
@@ -676,56 +677,35 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
class Channel(EventEmitter):
- # States
- CLOSED = 0x00
- WAIT_CONNECT = 0x01
- WAIT_CONNECT_RSP = 0x02
- OPEN = 0x03
- WAIT_DISCONNECT = 0x04
- WAIT_CREATE = 0x05
- WAIT_CREATE_RSP = 0x06
- WAIT_MOVE = 0x07
- WAIT_MOVE_RSP = 0x08
- WAIT_MOVE_CONFIRM = 0x09
- WAIT_CONFIRM_RSP = 0x0A
-
- # CONFIG substates
- WAIT_CONFIG = 0x10
- WAIT_SEND_CONFIG = 0x11
- WAIT_CONFIG_REQ_RSP = 0x12
- WAIT_CONFIG_RSP = 0x13
- WAIT_CONFIG_REQ = 0x14
- WAIT_IND_FINAL_RSP = 0x15
- WAIT_FINAL_RSP = 0x16
- WAIT_CONTROL_IND = 0x17
-
- STATE_NAMES = {
- CLOSED: 'CLOSED',
- WAIT_CONNECT: 'WAIT_CONNECT',
- WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP',
- OPEN: 'OPEN',
- WAIT_DISCONNECT: 'WAIT_DISCONNECT',
- WAIT_CREATE: 'WAIT_CREATE',
- WAIT_CREATE_RSP: 'WAIT_CREATE_RSP',
- WAIT_MOVE: 'WAIT_MOVE',
- WAIT_MOVE_RSP: 'WAIT_MOVE_RSP',
- WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM',
- WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP',
- WAIT_CONFIG: 'WAIT_CONFIG',
- WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG',
- WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP',
- WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP',
- WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ',
- WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP',
- WAIT_FINAL_RSP: 'WAIT_FINAL_RSP',
- WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
- }
+ class State(enum.IntEnum):
+ # States
+ CLOSED = 0x00
+ WAIT_CONNECT = 0x01
+ WAIT_CONNECT_RSP = 0x02
+ OPEN = 0x03
+ WAIT_DISCONNECT = 0x04
+ WAIT_CREATE = 0x05
+ WAIT_CREATE_RSP = 0x06
+ WAIT_MOVE = 0x07
+ WAIT_MOVE_RSP = 0x08
+ WAIT_MOVE_CONFIRM = 0x09
+ WAIT_CONFIRM_RSP = 0x0A
+
+ # CONFIG substates
+ WAIT_CONFIG = 0x10
+ WAIT_SEND_CONFIG = 0x11
+ WAIT_CONFIG_REQ_RSP = 0x12
+ WAIT_CONFIG_RSP = 0x13
+ WAIT_CONFIG_REQ = 0x14
+ WAIT_IND_FINAL_RSP = 0x15
+ WAIT_FINAL_RSP = 0x16
+ WAIT_CONTROL_IND = 0x17
connection_result: Optional[asyncio.Future[None]]
disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]]
- state: int
+ state: State
connection: Connection
def __init__(
@@ -741,7 +721,7 @@ class Channel(EventEmitter):
self.manager = manager
self.connection = connection
self.signaling_cid = signaling_cid
- self.state = Channel.CLOSED
+ self.state = self.State.CLOSED
self.mtu = mtu
self.psm = psm
self.source_cid = source_cid
@@ -751,13 +731,11 @@ class Channel(EventEmitter):
self.disconnection_result = None
self.sink = None
- def change_state(self, new_state: int) -> None:
- logger.debug(
- f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
- )
+ def _change_state(self, new_state: State) -> None:
+ logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state
- def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
+ def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
@@ -767,7 +745,7 @@ class Channel(EventEmitter):
# Check that there isn't already a request pending
if self.response:
raise InvalidStateError('request already pending')
- if self.state != Channel.OPEN:
+ if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.response = asyncio.get_running_loop().create_future()
@@ -787,14 +765,14 @@ class Channel(EventEmitter):
)
async def connect(self) -> None:
- if self.state != Channel.CLOSED:
+ if self.state != self.State.CLOSED:
raise InvalidStateError('invalid state')
# Check that we can start a new connection
if self.connection_result:
raise RuntimeError('connection already pending')
- self.change_state(Channel.WAIT_CONNECT_RSP)
+ self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame(
L2CAP_Connection_Request(
identifier=self.manager.next_identifier(self.connection),
@@ -814,10 +792,10 @@ class Channel(EventEmitter):
self.connection_result = None
async def disconnect(self) -> None:
- if self.state != Channel.OPEN:
+ if self.state != self.State.OPEN:
raise InvalidStateError('invalid state')
- self.change_state(Channel.WAIT_DISCONNECT)
+ self._change_state(self.State.WAIT_DISCONNECT)
self.send_control_frame(
L2CAP_Disconnection_Request(
identifier=self.manager.next_identifier(self.connection),
@@ -832,8 +810,8 @@ class Channel(EventEmitter):
return await self.disconnection_result
def abort(self) -> None:
- if self.state == self.OPEN:
- self.change_state(self.CLOSED)
+ if self.state == self.State.OPEN:
+ self._change_state(self.State.CLOSED)
self.emit('close')
def send_configure_request(self) -> None:
@@ -856,7 +834,7 @@ class Channel(EventEmitter):
def on_connection_request(self, request) -> None:
self.destination_cid = request.source_cid
- self.change_state(Channel.WAIT_CONNECT)
+ self._change_state(self.State.WAIT_CONNECT)
self.send_control_frame(
L2CAP_Connection_Response(
identifier=request.identifier,
@@ -866,24 +844,24 @@ class Channel(EventEmitter):
status=0x0000,
)
)
- self.change_state(Channel.WAIT_CONFIG)
+ self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request()
- self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
+ self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
def on_connection_response(self, response):
- if self.state != Channel.WAIT_CONNECT_RSP:
+ if self.state != self.State.WAIT_CONNECT_RSP:
logger.warning(color('invalid state', 'red'))
return
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
self.destination_cid = response.destination_cid
- self.change_state(Channel.WAIT_CONFIG)
+ self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request()
- self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
+ self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
pass
else:
- self.change_state(Channel.CLOSED)
+ self._change_state(self.State.CLOSED)
self.connection_result.set_exception(
ProtocolError(
response.result,
@@ -895,9 +873,9 @@ class Channel(EventEmitter):
def on_configure_request(self, request) -> None:
if self.state not in (
- Channel.WAIT_CONFIG,
- Channel.WAIT_CONFIG_REQ,
- Channel.WAIT_CONFIG_REQ_RSP,
+ self.State.WAIT_CONFIG,
+ self.State.WAIT_CONFIG_REQ,
+ self.State.WAIT_CONFIG_REQ_RSP,
):
logger.warning(color('invalid state', 'red'))
return
@@ -918,25 +896,28 @@ class Channel(EventEmitter):
options=request.options, # TODO: don't accept everything blindly
)
)
- if self.state == Channel.WAIT_CONFIG:
- self.change_state(Channel.WAIT_SEND_CONFIG)
+ if self.state == self.State.WAIT_CONFIG:
+ self._change_state(self.State.WAIT_SEND_CONFIG)
self.send_configure_request()
- self.change_state(Channel.WAIT_CONFIG_RSP)
- elif self.state == Channel.WAIT_CONFIG_REQ:
- self.change_state(Channel.OPEN)
+ self._change_state(self.State.WAIT_CONFIG_RSP)
+ elif self.state == self.State.WAIT_CONFIG_REQ:
+ self._change_state(self.State.OPEN)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
self.emit('open')
- elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
- self.change_state(Channel.WAIT_CONFIG_RSP)
+ elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
+ self._change_state(self.State.WAIT_CONFIG_RSP)
def on_configure_response(self, response) -> None:
if response.result == L2CAP_Configure_Response.SUCCESS:
- if self.state == Channel.WAIT_CONFIG_REQ_RSP:
- self.change_state(Channel.WAIT_CONFIG_REQ)
- elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND):
- self.change_state(Channel.OPEN)
+ if self.state == self.State.WAIT_CONFIG_REQ_RSP:
+ self._change_state(self.State.WAIT_CONFIG_REQ)
+ elif self.state in (
+ self.State.WAIT_CONFIG_RSP,
+ self.State.WAIT_CONTROL_IND,
+ ):
+ self._change_state(self.State.OPEN)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
@@ -966,7 +947,7 @@ class Channel(EventEmitter):
# TODO: decide how to fail gracefully
def on_disconnection_request(self, request) -> None:
- if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
+ if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT):
self.send_control_frame(
L2CAP_Disconnection_Response(
identifier=request.identifier,
@@ -974,14 +955,14 @@ class Channel(EventEmitter):
source_cid=request.source_cid,
)
)
- self.change_state(Channel.CLOSED)
+ self._change_state(self.State.CLOSED)
self.emit('close')
self.manager.on_channel_closed(self)
else:
logger.warning(color('invalid state', 'red'))
def on_disconnection_response(self, response) -> None:
- if self.state != Channel.WAIT_DISCONNECT:
+ if self.state != self.State.WAIT_DISCONNECT:
logger.warning(color('invalid state', 'red'))
return
@@ -992,7 +973,7 @@ class Channel(EventEmitter):
logger.warning('unexpected source or destination CID')
return
- self.change_state(Channel.CLOSED)
+ self._change_state(self.State.CLOSED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -1004,7 +985,7 @@ class Channel(EventEmitter):
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
f'MTU={self.mtu}, '
- f'state={Channel.STATE_NAMES[self.state]})'
+ f'state={self.state.name})'
)
@@ -1014,33 +995,21 @@ class LeConnectionOrientedChannel(EventEmitter):
LE Credit-based Connection Oriented Channel
"""
- INIT = 0
- CONNECTED = 1
- CONNECTING = 2
- DISCONNECTING = 3
- DISCONNECTED = 4
- CONNECTION_ERROR = 5
-
- STATE_NAMES = {
- INIT: 'INIT',
- CONNECTED: 'CONNECTED',
- CONNECTING: 'CONNECTING',
- DISCONNECTING: 'DISCONNECTING',
- DISCONNECTED: 'DISCONNECTED',
- CONNECTION_ERROR: 'CONNECTION_ERROR',
- }
+ class State(enum.IntEnum):
+ INIT = 0
+ CONNECTED = 1
+ CONNECTING = 2
+ DISCONNECTING = 3
+ DISCONNECTED = 4
+ CONNECTION_ERROR = 5
out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes]
- state: int
+ state: State
connection: Connection
- @staticmethod
- def state_name(state: int) -> str:
- return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
-
def __init__(
self,
manager: ChannelManager,
@@ -1083,22 +1052,20 @@ class LeConnectionOrientedChannel(EventEmitter):
self.drained.set()
if connected:
- self.state = LeConnectionOrientedChannel.CONNECTED
+ self.state = self.State.CONNECTED
else:
- self.state = LeConnectionOrientedChannel.INIT
+ self.state = self.State.INIT
- def change_state(self, new_state: int) -> None:
- logger.debug(
- f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
- )
+ def _change_state(self, new_state: State) -> None:
+ logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state
- if new_state == self.CONNECTED:
+ if new_state == self.State.CONNECTED:
self.emit('open')
- elif new_state == self.DISCONNECTED:
+ elif new_state == self.State.DISCONNECTED:
self.emit('close')
- def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
+ def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
@@ -1106,7 +1073,7 @@ class LeConnectionOrientedChannel(EventEmitter):
async def connect(self) -> LeConnectionOrientedChannel:
# Check that we're in the right state
- if self.state != self.INIT:
+ if self.state != self.State.INIT:
raise InvalidStateError('not in a connectable state')
# Check that we can start a new connection
@@ -1114,7 +1081,7 @@ class LeConnectionOrientedChannel(EventEmitter):
if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests')
- self.change_state(self.CONNECTING)
+ self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request(
identifier=identifier,
le_psm=self.le_psm,
@@ -1134,10 +1101,10 @@ class LeConnectionOrientedChannel(EventEmitter):
async def disconnect(self) -> None:
# Check that we're connected
- if self.state != self.CONNECTED:
+ if self.state != self.State.CONNECTED:
raise InvalidStateError('not connected')
- self.change_state(self.DISCONNECTING)
+ self._change_state(self.State.DISCONNECTING)
self.flush_output()
self.send_control_frame(
L2CAP_Disconnection_Request(
@@ -1153,15 +1120,15 @@ class LeConnectionOrientedChannel(EventEmitter):
return await self.disconnection_result
def abort(self) -> None:
- if self.state == self.CONNECTED:
- self.change_state(self.DISCONNECTED)
+ if self.state == self.State.CONNECTED:
+ self._change_state(self.State.DISCONNECTED)
def on_pdu(self, pdu: bytes) -> None:
if self.sink is None:
logger.warning('received pdu without a sink')
return
- if self.state != self.CONNECTED:
+ if self.state != self.State.CONNECTED:
logger.warning('received PDU while not connected, dropping')
# Manage the peer credits
@@ -1240,7 +1207,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.credits = response.initial_credits
self.connected = True
self.connection_result.set_result(self)
- self.change_state(self.CONNECTED)
+ self._change_state(self.State.CONNECTED)
else:
self.connection_result.set_exception(
ProtocolError(
@@ -1251,7 +1218,7 @@ class LeConnectionOrientedChannel(EventEmitter):
),
)
)
- self.change_state(self.CONNECTION_ERROR)
+ self._change_state(self.State.CONNECTION_ERROR)
# Cleanup
self.connection_result = None
@@ -1271,11 +1238,11 @@ class LeConnectionOrientedChannel(EventEmitter):
source_cid=request.source_cid,
)
)
- self.change_state(self.DISCONNECTED)
+ self._change_state(self.State.DISCONNECTED)
self.flush_output()
def on_disconnection_response(self, response) -> None:
- if self.state != self.DISCONNECTING:
+ if self.state != self.State.DISCONNECTING:
logger.warning(color('invalid state', 'red'))
return
@@ -1286,7 +1253,7 @@ class LeConnectionOrientedChannel(EventEmitter):
logger.warning('unexpected source or destination CID')
return
- self.change_state(self.DISCONNECTED)
+ self._change_state(self.State.DISCONNECTED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -1339,7 +1306,7 @@ class LeConnectionOrientedChannel(EventEmitter):
return
def write(self, data: bytes) -> None:
- if self.state != self.CONNECTED:
+ if self.state != self.State.CONNECTED:
logger.warning('not connected, dropping data')
return
@@ -1367,7 +1334,7 @@ class LeConnectionOrientedChannel(EventEmitter):
def __str__(self) -> str:
return (
f'CoC({self.source_cid}->{self.destination_cid}, '
- f'State={self.state_name(self.state)}, '
+ f'State={self.state.name}, '
f'PSM={self.le_psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, '
f'MPS={self.mps}/{self.peer_mps}, '
@@ -1571,7 +1538,7 @@ class ChannelManager:
if connection_handle in self.identifiers:
del self.identifiers[connection_handle]
- def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None:
+ def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} '
diff --git a/bumble/pandora/security.py b/bumble/pandora/security.py
index 96fce85..0f31512 100644
--- a/bumble/pandora/security.py
+++ b/bumble/pandora/security.py
@@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
+import contextlib
import grpc
import logging
@@ -27,8 +28,8 @@ from bumble.core import (
)
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
+from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
-from contextlib import suppress
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
@@ -232,7 +233,11 @@ class SecurityService(SecurityServicer):
sc=config.pairing_sc_enable,
mitm=config.pairing_mitm_enable,
bonding=config.pairing_bonding_enable,
- identity_address_type=config.identity_address_type,
+ identity_address_type=(
+ PairingConfig.AddressType.PUBLIC
+ if connection.self_address.is_public
+ else config.identity_address_type
+ ),
delegate=PairingDelegate(
connection,
self,
@@ -294,23 +299,35 @@ class SecurityService(SecurityServicer):
try:
self.log.debug('Pair...')
- if (
- connection.transport == BT_LE_TRANSPORT
- and connection.role == BT_PERIPHERAL_ROLE
- ):
- wait_for_security: asyncio.Future[
- bool
- ] = asyncio.get_running_loop().create_future()
- connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
- connection.on("pairing_failure", wait_for_security.set_exception)
+ security_result = asyncio.get_running_loop().create_future()
+
+ with contextlib.closing(EventWatcher()) as watcher:
+
+ @watcher.on(connection, 'pairing')
+ def on_pairing(*_: Any) -> None:
+ security_result.set_result('success')
- connection.request_pairing()
+ @watcher.on(connection, 'pairing_failure')
+ def on_pairing_failure(*_: Any) -> None:
+ security_result.set_result('pairing_failure')
- await wait_for_security
- else:
- await connection.pair()
+ @watcher.on(connection, 'disconnection')
+ def on_disconnection(*_: Any) -> None:
+ security_result.set_result('connection_died')
- self.log.debug('Paired')
+ if (
+ connection.transport == BT_LE_TRANSPORT
+ and connection.role == BT_PERIPHERAL_ROLE
+ ):
+ connection.request_pairing()
+ else:
+ await connection.pair()
+
+ result = await security_result
+
+ self.log.debug(f'Pairing session complete, status={result}')
+ if result != 'success':
+ return SecureResponse(**{result: empty_pb2.Empty()})
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -369,6 +386,7 @@ class SecurityService(SecurityServicer):
str
] = asyncio.get_running_loop().create_future()
authenticate_task: Optional[asyncio.Future[None]] = None
+ pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None:
assert connection
@@ -415,6 +433,10 @@ class SecurityService(SecurityServicer):
if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate())
+ def pair(*_: Any) -> None:
+ if self.need_pairing(connection, level):
+ pair_task = asyncio.create_task(connection.pair())
+
listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
@@ -425,6 +447,7 @@ class SecurityService(SecurityServicer):
'connection_encryption_change': on_encryption_change,
'classic_pairing': try_set_success,
'classic_pairing_failure': set_failure('pairing_failure'),
+ 'security_request': pair,
}
# register event handlers
@@ -452,6 +475,15 @@ class SecurityService(SecurityServicer):
pass
self.log.debug('Authenticated')
+ # wait for `pair` to finish if any
+ if pair_task is not None:
+ self.log.debug('Wait for authentication...')
+ try:
+ await pair_task # type: ignore
+ except:
+ pass
+ self.log.debug('paired')
+
return WaitSecurityResponse(**kwargs)
def reached_security_level(
@@ -523,7 +555,7 @@ class SecurityStorageService(SecurityStorageServicer):
self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None:
- with suppress(KeyError):
+ with contextlib.suppress(KeyError):
await self.device.keystore.delete(str(address))
return empty_pb2.Empty()
diff --git a/bumble/smp.py b/bumble/smp.py
index 55b8359..f8bba40 100644
--- a/bumble/smp.py
+++ b/bumble/smp.py
@@ -37,6 +37,7 @@ from typing import (
Optional,
Tuple,
Type,
+ cast,
)
from pyee import EventEmitter
@@ -1771,7 +1772,26 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes())
+ def on_smp_security_request_command(
+ self, connection: Connection, request: SMP_Security_Request_Command
+ ) -> None:
+ connection.emit('security_request', request.auth_req)
+
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
+ # Parse the L2CAP payload into an SMP Command object
+ command = SMP_Command.from_bytes(pdu)
+ logger.debug(
+ f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
+ f'{connection.peer_address}: {command}'
+ )
+
+ # Security request is more than just pairing, so let applications handle them
+ if command.code == SMP_SECURITY_REQUEST_COMMAND:
+ self.on_smp_security_request_command(
+ connection, cast(SMP_Security_Request_Command, command)
+ )
+ return
+
# Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE:
@@ -1782,13 +1802,6 @@ class Manager(EventEmitter):
)
self.sessions[connection.handle] = session
- # Parse the L2CAP payload into an SMP Command object
- command = SMP_Command.from_bytes(pdu)
- logger.debug(
- f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
- f'{connection.peer_address}: {command}'
- )
-
# Delegate the handling of the command to the session
session.on_smp_command(command)
diff --git a/bumble/transport/android_emulator.py b/bumble/transport/android_emulator.py
index 5ef0047..8d19a9e 100644
--- a/bumble/transport/android_emulator.py
+++ b/bumble/transport/android_emulator.py
@@ -18,6 +18,8 @@
import logging
import grpc.aio
+from typing import Optional, Union
+
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
# pylint: disable=no-name-in-module
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
-async def open_android_emulator_transport(spec: str | None) -> Transport:
+async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
'''
Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax:
@@ -82,7 +84,7 @@ async def open_android_emulator_transport(spec: str | None) -> Transport:
logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address)
- service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
+ service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
if mode == 'host':
# Connect as a host
service = EmulatedBluetoothServiceStub(channel)
@@ -95,10 +97,13 @@ async def open_android_emulator_transport(spec: str | None) -> Transport:
raise ValueError('invalid mode')
# Create the transport object
- transport = PumpedTransport(
- PumpedPacketSource(hci_device.read),
- PumpedPacketSink(hci_device.write),
- channel.close,
+ class EmulatorTransport(PumpedTransport):
+ async def close(self):
+ await super().close()
+ await channel.close()
+
+ transport = EmulatorTransport(
+ PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write)
)
transport.start()
diff --git a/bumble/transport/android_netsim.py b/bumble/transport/android_netsim.py
index 76a7385..e9d36cd 100644
--- a/bumble/transport/android_netsim.py
+++ b/bumble/transport/android_netsim.py
@@ -18,11 +18,12 @@
import asyncio
import atexit
import logging
-import grpc.aio
import os
import pathlib
import sys
-from typing import Optional
+from typing import Dict, Optional
+
+import grpc.aio
from .common import (
ParserSource,
@@ -33,8 +34,8 @@ from .common import (
)
# pylint: disable=no-name-in-module
-from .grpc_protobuf.packet_streamer_pb2_grpc import PacketStreamerStub
from .grpc_protobuf.packet_streamer_pb2_grpc import (
+ PacketStreamerStub,
PacketStreamerServicer,
add_PacketStreamerServicer_to_server,
)
@@ -43,6 +44,7 @@ from .grpc_protobuf.hci_packet_pb2 import HCIPacket
from .grpc_protobuf.startup_pb2 import Chip, ChipInfo
from .grpc_protobuf.common_pb2 import ChipKind
+
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -74,14 +76,20 @@ def get_ini_dir() -> Optional[pathlib.Path]:
# -----------------------------------------------------------------------------
-def find_grpc_port() -> int:
+def ini_file_name(instance_number: int) -> str:
+ suffix = f'_{instance_number}' if instance_number > 0 else ''
+ return f'netsim{suffix}.ini'
+
+
+# -----------------------------------------------------------------------------
+def find_grpc_port(instance_number: int) -> int:
if not (ini_dir := get_ini_dir()):
logger.debug('no known directory for .ini file')
return 0
- ini_file = ini_dir / 'netsim.ini'
+ ini_file = ini_dir / ini_file_name(instance_number)
+ logger.debug(f'Looking for .ini file at {ini_file}')
if ini_file.is_file():
- logger.debug(f'Found .ini file at {ini_file}')
with open(ini_file, 'r') as ini_file_data:
for line in ini_file_data.readlines():
if '=' in line:
@@ -90,12 +98,14 @@ def find_grpc_port() -> int:
logger.debug(f'gRPC port = {value}')
return int(value)
+ logger.debug('no grpc.port property found in .ini file')
+
# Not found
return 0
# -----------------------------------------------------------------------------
-def publish_grpc_port(grpc_port) -> bool:
+def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
if not (ini_dir := get_ini_dir()):
logger.debug('no known directory for .ini file')
return False
@@ -104,7 +114,7 @@ def publish_grpc_port(grpc_port) -> bool:
logger.debug('ini directory does not exist')
return False
- ini_file = ini_dir / 'netsim.ini'
+ ini_file = ini_dir / ini_file_name(instance_number)
try:
ini_file.write_text(f'grpc.port={grpc_port}\n')
logger.debug(f"published gRPC port at {ini_file}")
@@ -122,14 +132,15 @@ def publish_grpc_port(grpc_port) -> bool:
# -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(
- server_host: str | None, server_port: int
+ server_host: Optional[str], server_port: int, options: Dict[str, str]
) -> Transport:
if not server_port:
raise ValueError('invalid port')
if server_host == '_' or not server_host:
server_host = 'localhost'
- if not publish_grpc_port(server_port):
+ instance_number = int(options.get('instance', "0"))
+ if not publish_grpc_port(server_port, instance_number):
logger.warning("unable to publish gRPC port")
class HciDevice:
@@ -186,15 +197,12 @@ async def open_android_netsim_controller_transport(
logger.debug(f'<<< PACKET: {data.hex()}')
self.on_data_received(data)
- def send_packet(self, data):
- async def send():
- await self.context.write(
- PacketResponse(
- hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
- )
+ async def send_packet(self, data):
+ return await self.context.write(
+ PacketResponse(
+ hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
)
-
- self.loop.create_task(send())
+ )
def terminate(self):
self.task.cancel()
@@ -228,17 +236,17 @@ async def open_android_netsim_controller_transport(
logger.debug('gRPC server cancelled')
await self.grpc_server.stop(None)
- def on_packet(self, packet):
+ async def send_packet(self, packet):
if not self.device:
logger.debug('no device, dropping packet')
return
- self.device.send_packet(packet)
+ return await self.device.send_packet(packet)
async def StreamPackets(self, _request_iterator, context):
logger.debug('StreamPackets request')
- # Check that we won't already have a device
+ # Check that we don't already have a device
if self.device:
logger.debug('busy, already serving a device')
return PacketResponse(error='Busy')
@@ -261,15 +269,42 @@ async def open_android_netsim_controller_transport(
await server.start()
asyncio.get_running_loop().create_task(server.serve())
- class GrpcServerTransport(Transport):
- async def close(self):
- await super().close()
+ sink = PumpedPacketSink(server.send_packet)
+ sink.start()
+ return Transport(server, sink)
+
+
+# -----------------------------------------------------------------------------
+async def open_android_netsim_host_transport_with_address(
+ server_host: Optional[str],
+ server_port: int,
+ options: Optional[Dict[str, str]] = None,
+):
+ if server_host == '_' or not server_host:
+ server_host = 'localhost'
+
+ if not server_port:
+ # Look for the gRPC config in a .ini file
+ instance_number = 0 if options is None else int(options.get('instance', '0'))
+ server_port = find_grpc_port(instance_number)
+ if not server_port:
+ raise RuntimeError('gRPC server port not found')
+
+ # Connect to the gRPC server
+ server_address = f'{server_host}:{server_port}'
+ logger.debug(f'Connecting to gRPC server at {server_address}')
+ channel = grpc.aio.insecure_channel(server_address)
- return GrpcServerTransport(server, server)
+ return await open_android_netsim_host_transport_with_channel(
+ channel,
+ options,
+ )
# -----------------------------------------------------------------------------
-async def open_android_netsim_host_transport(server_host, server_port, options):
+async def open_android_netsim_host_transport_with_channel(
+ channel, options: Optional[Dict[str, str]] = None
+):
# Wrapper for I/O operations
class HciDevice:
def __init__(self, name, manufacturer, hci_device):
@@ -288,10 +323,12 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
async def read(self):
response = await self.hci_device.read()
response_type = response.WhichOneof('response_type')
+
if response_type == 'error':
logger.warning(f'received error: {response.error}')
raise RuntimeError(response.error)
- elif response_type == 'hci_packet':
+
+ if response_type == 'hci_packet':
return (
bytes([response.hci_packet.packet_type])
+ response.hci_packet.packet
@@ -306,24 +343,9 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
)
)
- name = options.get('name', DEFAULT_NAME)
+ name = DEFAULT_NAME if options is None else options.get('name', DEFAULT_NAME)
manufacturer = DEFAULT_MANUFACTURER
- if server_host == '_' or not server_host:
- server_host = 'localhost'
-
- if not server_port:
- # Look for the gRPC config in a .ini file
- server_host = 'localhost'
- server_port = find_grpc_port()
- if not server_port:
- raise RuntimeError('gRPC server port not found')
-
- # Connect to the gRPC server
- server_address = f'{server_host}:{server_port}'
- logger.debug(f'Connecting to gRPC server at {server_address}')
- channel = grpc.aio.insecure_channel(server_address)
-
# Connect as a host
service = PacketStreamerStub(channel)
hci_device = HciDevice(
@@ -334,10 +356,14 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
await hci_device.start()
# Create the transport object
- transport = PumpedTransport(
+ class GrpcTransport(PumpedTransport):
+ async def close(self):
+ await super().close()
+ await channel.close()
+
+ transport = GrpcTransport(
PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write),
- channel.close,
)
transport.start()
@@ -345,7 +371,7 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
# -----------------------------------------------------------------------------
-async def open_android_netsim_transport(spec):
+async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
'''
Open a transport connection as a client or server, implementing Android's `netsim`
simulator protocol over gRPC.
@@ -359,6 +385,11 @@ async def open_android_netsim_transport(spec):
to connect *to* a netsim server (netsim is the controller), or accept
connections *as* a netsim-compatible server.
+ instance=<n>
+ Specifies an instance number, with <n> > 0. This is used to determine which
+ .init file to use. In `host` mode, it is ignored when the <host>:<port>
+ specifier is present, since in that case no .ini file is used.
+
In `host` mode:
The <host>:<port> part is optional. When not specified, the transport
looks for a netsim .ini file, from which it will read the `grpc.backend.port`
@@ -387,14 +418,15 @@ async def open_android_netsim_transport(spec):
params = spec.split(',') if spec else []
if params and ':' in params[0]:
# Explicit <host>:<port>
- host, port = params[0].split(':')
+ host, port_str = params[0].split(':')
+ port = int(port_str)
params_offset = 1
else:
host = None
port = 0
params_offset = 0
- options = {}
+ options: Dict[str, str] = {}
for param in params[params_offset:]:
if '=' not in param:
raise ValueError('invalid parameter, expected <name>=<value>')
@@ -403,10 +435,12 @@ async def open_android_netsim_transport(spec):
mode = options.get('mode', 'host')
if mode == 'host':
- return await open_android_netsim_host_transport(host, port, options)
+ return await open_android_netsim_host_transport_with_address(
+ host, port, options
+ )
if mode == 'controller':
if host is None:
raise ValueError('<host>:<port> missing')
- return await open_android_netsim_controller_transport(host, port)
+ return await open_android_netsim_controller_transport(host, port, options)
raise ValueError('invalid mode option')
diff --git a/bumble/transport/common.py b/bumble/transport/common.py
index c030308..2786a75 100644
--- a/bumble/transport/common.py
+++ b/bumble/transport/common.py
@@ -339,8 +339,9 @@ class PumpedPacketSource(ParserSource):
try:
packet = await self.receive_function()
self.parser.feed_data(packet)
- except asyncio.exceptions.CancelledError:
+ except asyncio.CancelledError:
logger.debug('source pump task done')
+ self.terminated.set_result(None)
break
except Exception as error:
logger.warning(f'exception while waiting for packet: {error}')
@@ -370,7 +371,7 @@ class PumpedPacketSink:
try:
packet = await self.packet_queue.get()
await self.send_function(packet)
- except asyncio.exceptions.CancelledError:
+ except asyncio.CancelledError:
logger.debug('sink pump task done')
break
except Exception as error:
@@ -393,19 +394,13 @@ class PumpedTransport(Transport):
self,
source: PumpedPacketSource,
sink: PumpedPacketSink,
- close_function,
) -> None:
super().__init__(source, sink)
- self.close_function = close_function
def start(self) -> None:
self.source.start()
self.sink.start()
- async def close(self) -> None:
- await super().close()
- await self.close_function()
-
# -----------------------------------------------------------------------------
class SnoopingTransport(Transport):
diff --git a/bumble/transport/hci_socket.py b/bumble/transport/hci_socket.py
index 9891c5b..df9e885 100644
--- a/bumble/transport/hci_socket.py
+++ b/bumble/transport/hci_socket.py
@@ -23,6 +23,8 @@ import socket
import ctypes
import collections
+from typing import Optional
+
from .common import Transport, ParserSource
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
-async def open_hci_socket_transport(spec: str | None) -> Transport:
+async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
'''
Open an HCI Socket (only available on some platforms).
The parameter string is either empty (to use the first/default Bluetooth adapter)
@@ -45,9 +47,9 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
# Create a raw HCI socket
try:
hci_socket = socket.socket(
- socket.AF_BLUETOOTH,
- socket.SOCK_RAW | socket.SOCK_NONBLOCK,
- socket.BTPROTO_HCI, # type: ignore
+ socket.AF_BLUETOOTH, # type: ignore[attr-defined]
+ socket.SOCK_RAW | socket.SOCK_NONBLOCK, # type: ignore[attr-defined]
+ socket.BTPROTO_HCI, # type: ignore[attr-defined]
)
except AttributeError as error:
# Not supported on this platform
@@ -78,7 +80,7 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
bind_address = struct.pack(
# pylint: disable=no-member
'<HHH',
- socket.AF_BLUETOOTH,
+ socket.AF_BLUETOOTH, # type: ignore[attr-defined]
adapter_index,
HCI_CHANNEL_USER,
)
diff --git a/bumble/transport/pty.py b/bumble/transport/pty.py
index 7765b09..2f46e75 100644
--- a/bumble/transport/pty.py
+++ b/bumble/transport/pty.py
@@ -23,6 +23,8 @@ import atexit
import os
import logging
+from typing import Optional
+
from .common import Transport, StreamPacketSource, StreamPacketSink
# -----------------------------------------------------------------------------
@@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
-async def open_pty_transport(spec: str | None) -> Transport:
+async def open_pty_transport(spec: Optional[str]) -> Transport:
'''
Open a PTY transport.
The parameter string may be empty, or a path name where a symbolic link
diff --git a/bumble/transport/vhci.py b/bumble/transport/vhci.py
index 5795840..2b19085 100644
--- a/bumble/transport/vhci.py
+++ b/bumble/transport/vhci.py
@@ -17,6 +17,8 @@
# -----------------------------------------------------------------------------
import logging
+from typing import Optional
+
from .common import Transport
from .file import open_file_transport
@@ -27,7 +29,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
-async def open_vhci_transport(spec: str | None) -> Transport:
+async def open_vhci_transport(spec: Optional[str]) -> Transport:
'''
Open a VHCI transport (only available on some platforms).
The parameter string is either empty (to use the default VHCI device
diff --git a/bumble/transport/ws_client.py b/bumble/transport/ws_client.py
index facd1c9..902001e 100644
--- a/bumble/transport/ws_client.py
+++ b/bumble/transport/ws_client.py
@@ -31,19 +31,21 @@ async def open_ws_client_transport(spec: str) -> Transport:
'''
Open a WebSocket client transport.
The parameter string has this syntax:
- <remote-host>:<remote-port>
+ <websocket-url>
- Example: 127.0.0.1:9001
+ Example: ws://localhost:7681/v1/websocket/bt
'''
- remote_host, remote_port = spec.split(':')
- uri = f'ws://{remote_host}:{remote_port}'
- websocket = await websockets.client.connect(uri)
+ websocket = await websockets.client.connect(spec)
- transport = PumpedTransport(
+ class WsTransport(PumpedTransport):
+ async def close(self):
+ await super().close()
+ await websocket.close()
+
+ transport = WsTransport(
PumpedPacketSource(websocket.recv),
PumpedPacketSink(websocket.send),
- websocket.close,
)
transport.start()
return transport
diff --git a/bumble/utils.py b/bumble/utils.py
index 8a55684..dc03725 100644
--- a/bumble/utils.py
+++ b/bumble/utils.py
@@ -15,12 +15,24 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
+from __future__ import annotations
import asyncio
import logging
import traceback
import collections
import sys
-from typing import Awaitable, Set, TypeVar
+from typing import (
+ Awaitable,
+ Set,
+ TypeVar,
+ List,
+ Tuple,
+ Callable,
+ Any,
+ Optional,
+ Union,
+ overload,
+)
from functools import wraps
from pyee import EventEmitter
@@ -65,6 +77,102 @@ def composite_listener(cls):
# -----------------------------------------------------------------------------
+_Handler = TypeVar('_Handler', bound=Callable)
+
+
+class EventWatcher:
+ '''A wrapper class to control the lifecycle of event handlers better.
+
+ Usage:
+ ```
+ watcher = EventWatcher()
+
+ def on_foo():
+ ...
+ watcher.on(emitter, 'foo', on_foo)
+
+ @watcher.on(emitter, 'bar')
+ def on_bar():
+ ...
+
+ # Close all event handlers watching through this watcher
+ watcher.close()
+ ```
+
+ As context:
+ ```
+ with contextlib.closing(EventWatcher()) as context:
+ @context.on(emitter, 'foo')
+ def on_foo():
+ ...
+ # on_foo() has been removed here!
+ ```
+ '''
+
+ handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
+
+ def __init__(self) -> None:
+ self.handlers = []
+
+ @overload
+ def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
+ ...
+
+ @overload
+ def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
+ ...
+
+ def on(
+ self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
+ ) -> Union[_Handler, Callable[[_Handler], _Handler]]:
+ '''Watch an event until the context is closed.
+
+ Args:
+ emitter: EventEmitter to watch
+ event: Event name
+ handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
+ '''
+
+ def wrapper(f: _Handler) -> _Handler:
+ self.handlers.append((emitter, event, f))
+ emitter.on(event, f)
+ return f
+
+ return wrapper if handler is None else wrapper(handler)
+
+ @overload
+ def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
+ ...
+
+ @overload
+ def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
+ ...
+
+ def once(
+ self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
+ ) -> Union[_Handler, Callable[[_Handler], _Handler]]:
+ '''Watch an event for once.
+
+ Args:
+ emitter: EventEmitter to watch
+ event: Event name
+ handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
+ '''
+
+ def wrapper(f: _Handler) -> _Handler:
+ self.handlers.append((emitter, event, f))
+ emitter.once(event, f)
+ return f
+
+ return wrapper if handler is None else wrapper(handler)
+
+ def close(self) -> None:
+ for emitter, event, handler in self.handlers:
+ if handler in emitter.listeners(event):
+ emitter.remove_listener(event, handler)
+
+
+# -----------------------------------------------------------------------------
_T = TypeVar('_T')
diff --git a/rust/Cargo.lock b/rust/Cargo.lock
index bd168dc..c2d0cd3 100644
--- a/rust/Cargo.lock
+++ b/rust/Cargo.lock
@@ -131,6 +131,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635"
[[package]]
+name = "bstr"
+version = "1.6.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4c2f7349907b712260e64b0afe2f84692af14a454be26187d9df565c7f69266a"
+dependencies = [
+ "memchr",
+ "serde",
+]
+
+[[package]]
name = "bumble"
version = "0.1.0"
dependencies = [
@@ -138,7 +148,9 @@ dependencies = [
"clap 4.4.1",
"directories",
"env_logger",
+ "file-header",
"futures",
+ "globset",
"hex",
"itertools",
"lazy_static",
@@ -273,6 +285,73 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
+name = "crossbeam"
+version = "0.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
+dependencies = [
+ "cfg-if",
+ "crossbeam-channel",
+ "crossbeam-deque",
+ "crossbeam-epoch",
+ "crossbeam-queue",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-channel"
+version = "0.5.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
+dependencies = [
+ "cfg-if",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-deque"
+version = "0.8.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
+dependencies = [
+ "cfg-if",
+ "crossbeam-epoch",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-epoch"
+version = "0.9.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7"
+dependencies = [
+ "autocfg",
+ "cfg-if",
+ "crossbeam-utils",
+ "memoffset 0.9.0",
+ "scopeguard",
+]
+
+[[package]]
+name = "crossbeam-queue"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
+dependencies = [
+ "cfg-if",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-utils"
+version = "0.8.16"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
+dependencies = [
+ "cfg-if",
+]
+
+[[package]]
name = "directories"
version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -349,6 +428,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764"
[[package]]
+name = "file-header"
+version = "0.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b5568149106e77ae33bc3a2c3ef3839cbe63ffa4a8dd4a81612a6f9dfdbc2e9f"
+dependencies = [
+ "crossbeam",
+ "lazy_static",
+ "license",
+ "thiserror",
+ "walkdir",
+]
+
+[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -485,6 +577,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
[[package]]
+name = "globset"
+version = "0.4.13"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "759c97c1e17c55525b57192c06a267cda0ac5210b222d6b82189a2338fa1c13d"
+dependencies = [
+ "aho-corasick",
+ "bstr",
+ "fnv",
+ "log",
+ "regex",
+]
+
+[[package]]
name = "h2"
version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -711,6 +816,17 @@ dependencies = [
]
[[package]]
+name = "license"
+version = "3.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b66615d42e949152327c402e03cd29dab8bff91ce470381ac2ca6d380d8d9946"
+dependencies = [
+ "reword",
+ "serde",
+ "serde_json",
+]
+
+[[package]]
name = "linux-raw-sys"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -757,6 +873,15 @@ dependencies = [
]
[[package]]
+name = "memoffset"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
name = "mime"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1201,6 +1326,15 @@ dependencies = [
]
[[package]]
+name = "reword"
+version = "7.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fe272098dce9ed76b479995953f748d1851261390b08f8a0ff619c885a1f0765"
+dependencies = [
+ "unicode-segmentation",
+]
+
+[[package]]
name = "rusb"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1242,6 +1376,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741"
[[package]]
+name = "same-file"
+version = "1.0.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
+dependencies = [
+ "winapi-util",
+]
+
+[[package]]
name = "schannel"
version = "0.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1590,6 +1733,12 @@ dependencies = [
]
[[package]]
+name = "unicode-segmentation"
+version = "1.10.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
+
+[[package]]
name = "unindent"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1619,6 +1768,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
+name = "walkdir"
+version = "2.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee"
+dependencies = [
+ "same-file",
+ "winapi-util",
+]
+
+[[package]]
name = "want"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 6c38c82..a553afd 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -24,6 +24,10 @@ itertools = "0.11.0"
lazy_static = "1.4.0"
thiserror = "1.0.41"
+# Dev tools
+file-header = { version = "0.1.2", optional = true }
+globset = { version = "0.4.13", optional = true }
+
# CLI
anyhow = { version = "1.0.71", optional = true }
clap = { version = "4.3.3", features = ["derive"], optional = true }
@@ -53,9 +57,14 @@ env_logger = "0.10.0"
rustdoc-args = ["--generate-link-to-definition"]
[[bin]]
+name = "file-header"
+path = "tools/file_header.rs"
+required-features = ["dev-tools"]
+
+[[bin]]
name = "gen-assigned-numbers"
path = "tools/gen_assigned_numbers.rs"
-required-features = ["bumble-codegen"]
+required-features = ["dev-tools"]
[[bin]]
name = "bumble"
@@ -71,7 +80,7 @@ harness = false
[features]
anyhow = ["pyo3/anyhow"]
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
-bumble-codegen = ["dep:anyhow"]
+dev-tools = ["dep:anyhow", "dep:clap", "dep:file-header", "dep:globset"]
# separate feature for CLI so that dependencies don't spend time building these
bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"]
default = []
diff --git a/rust/README.md b/rust/README.md
index 23dec03..15a19b9 100644
--- a/rust/README.md
+++ b/rust/README.md
@@ -62,5 +62,5 @@ in tests at `pytests/assigned_numbers.rs`.
To regenerate the assigned number tables based on the Python codebase:
```
-PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features bumble-codegen
+PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features dev-tools
``` \ No newline at end of file
diff --git a/rust/pytests/assigned_numbers.rs b/rust/pytests/assigned_numbers.rs
index 10e7f3e..7f8f1d1 100644
--- a/rust/pytests/assigned_numbers.rs
+++ b/rust/pytests/assigned_numbers.rs
@@ -1,3 +1,17 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
use bumble::wrapper::{self, core::Uuid16};
use pyo3::{intern, prelude::*, types::PyDict};
use std::collections;
diff --git a/rust/src/adv.rs b/rust/src/adv.rs
index 8a4c979..6f84cc5 100644
--- a/rust/src/adv.rs
+++ b/rust/src/adv.rs
@@ -1,3 +1,17 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
//! BLE advertisements.
use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS};
diff --git a/rust/src/wrapper/logging.rs b/rust/src/wrapper/logging.rs
index 141cc04..bd932cb 100644
--- a/rust/src/wrapper/logging.rs
+++ b/rust/src/wrapper/logging.rs
@@ -1,3 +1,17 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
//! Bumble & Python logging
use pyo3::types::PyDict;
diff --git a/rust/tools/file_header.rs b/rust/tools/file_header.rs
new file mode 100644
index 0000000..fb3286d
--- /dev/null
+++ b/rust/tools/file_header.rs
@@ -0,0 +1,78 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use anyhow::anyhow;
+use clap::Parser as _;
+use file_header::{
+ add_headers_recursively, check_headers_recursively,
+ license::spdx::{YearCopyrightOwnerValue, APACHE_2_0},
+};
+use globset::{Glob, GlobSet, GlobSetBuilder};
+use std::{env, path::PathBuf};
+
+fn main() -> anyhow::Result<()> {
+ let rust_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
+ let ignore_globset = ignore_globset()?;
+ // Note: when adding headers, there is a bug where the line spacing is off for Apache 2.0 (see https://github.com/spdx/license-list-XML/issues/2127)
+ let header = APACHE_2_0.build_header(YearCopyrightOwnerValue::new(2023, "Google LLC".into()));
+
+ let cli = Cli::parse();
+
+ match cli.subcommand {
+ Subcommand::CheckAll => {
+ let result =
+ check_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header, 4)?;
+ if result.has_failure() {
+ return Err(anyhow!(
+ "The following files do not have headers: {result:?}"
+ ));
+ }
+ }
+ Subcommand::AddAll => {
+ let files_with_new_header =
+ add_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header)?;
+ files_with_new_header
+ .iter()
+ .for_each(|path| println!("Added header to: {path:?}"));
+ }
+ }
+ Ok(())
+}
+
+fn ignore_globset() -> anyhow::Result<GlobSet> {
+ Ok(GlobSetBuilder::new()
+ .add(Glob::new("**/.idea/**")?)
+ .add(Glob::new("**/target/**")?)
+ .add(Glob::new("**/.gitignore")?)
+ .add(Glob::new("**/CHANGELOG.md")?)
+ .add(Glob::new("**/Cargo.lock")?)
+ .add(Glob::new("**/Cargo.toml")?)
+ .add(Glob::new("**/README.md")?)
+ .add(Glob::new("*.bin")?)
+ .build()?)
+}
+
+#[derive(clap::Parser)]
+struct Cli {
+ #[clap(subcommand)]
+ subcommand: Subcommand,
+}
+
+#[derive(clap::Subcommand, Debug, Clone)]
+enum Subcommand {
+ /// Checks if a license is present in files that are not in the ignore list.
+ CheckAll,
+ /// Adds a license as needed to files that are not in the ignore list.
+ AddAll,
+}
diff --git a/setup.cfg b/setup.cfg
index 74a90e0..1ca73c7 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -36,6 +36,10 @@ install_requires =
bt-test-interfaces >= 0.0.2; platform_system!='Emscripten'
click == 8.1.3; platform_system!='Emscripten'
cryptography == 39; platform_system!='Emscripten'
+ # Pyodide bundles a version of cryptography that is built for wasm, which may not match the
+ # versions available on PyPI. Relax the version requirement since it's better than being
+ # completely unable to import the package in case of version mismatch.
+ cryptography >= 39.0; platform_system=='Emscripten'
grpcio == 1.57.0; platform_system!='Emscripten'
humanize >= 4.6.0; platform_system!='Emscripten'
libusb1 >= 2.0.1; platform_system!='Emscripten'
@@ -84,7 +88,7 @@ development =
black == 22.10
grpcio-tools >= 1.57.0
invoke >= 1.7.3
- mypy == 1.2.0
+ mypy == 1.5.0
nox >= 2022
pylint == 2.15.8
types-appdirs >= 1.4.3
diff --git a/tests/gatt_test.py b/tests/gatt_test.py
index dd0277e..d9f6d60 100644
--- a/tests/gatt_test.py
+++ b/tests/gatt_test.py
@@ -891,10 +891,10 @@ async def async_main():
# -----------------------------------------------------------------------------
-def test_attribute_string_to_permissions():
- assert Attribute.string_to_permissions('READABLE') == 1
- assert Attribute.string_to_permissions('WRITEABLE') == 2
- assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3
+def test_permissions_from_string():
+ assert Attribute.Permissions.from_string('READABLE') == 1
+ assert Attribute.Permissions.from_string('WRITEABLE') == 2
+ assert Attribute.Permissions.from_string('READABLE,WRITEABLE') == 3
# -----------------------------------------------------------------------------
diff --git a/tests/keystore_test.py b/tests/keystore_test.py
index 2e73039..2a3d48d 100644
--- a/tests/keystore_test.py
+++ b/tests/keystore_test.py
@@ -18,6 +18,8 @@
import asyncio
import json
import logging
+import pathlib
+import pytest
import tempfile
import os
@@ -83,87 +85,95 @@ JSON3 = """
# -----------------------------------------------------------------------------
-async def test_basic():
- with tempfile.NamedTemporaryFile(mode="r+", encoding='utf-8') as file:
- keystore = JsonKeyStore('my_namespace', file.name)
+@pytest.fixture
+def temporary_file():
+ file = tempfile.NamedTemporaryFile(delete=False)
+ file.close()
+ yield file.name
+ pathlib.Path(file.name).unlink()
+
+
+# -----------------------------------------------------------------------------
+async def test_basic(temporary_file):
+ with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write("{}")
file.flush()
- keys = await keystore.get_all()
- assert len(keys) == 0
-
- keys = PairingKeys()
- await keystore.update('foo', keys)
- foo = await keystore.get('foo')
- assert foo is not None
- assert foo.ltk is None
- ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
- keys.ltk = PairingKeys.Key(ltk)
- await keystore.update('foo', keys)
- foo = await keystore.get('foo')
- assert foo is not None
- assert foo.ltk is not None
- assert foo.ltk.value == ltk
+ keystore = JsonKeyStore('my_namespace', temporary_file)
- file.flush()
- with open(file.name, "r", encoding="utf-8") as json_file:
- json_data = json.load(json_file)
- assert 'my_namespace' in json_data
- assert 'foo' in json_data['my_namespace']
- assert 'ltk' in json_data['my_namespace']['foo']
+ keys = await keystore.get_all()
+ assert len(keys) == 0
+
+ keys = PairingKeys()
+ await keystore.update('foo', keys)
+ foo = await keystore.get('foo')
+ assert foo is not None
+ assert foo.ltk is None
+ ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
+ keys.ltk = PairingKeys.Key(ltk)
+ await keystore.update('foo', keys)
+ foo = await keystore.get('foo')
+ assert foo is not None
+ assert foo.ltk is not None
+ assert foo.ltk.value == ltk
+
+ with open(file.name, "r", encoding="utf-8") as json_file:
+ json_data = json.load(json_file)
+ assert 'my_namespace' in json_data
+ assert 'foo' in json_data['my_namespace']
+ assert 'ltk' in json_data['my_namespace']['foo']
# -----------------------------------------------------------------------------
-async def test_parsing():
- with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
- keystore = JsonKeyStore('my_namespace', file.name)
+async def test_parsing(temporary_file):
+ with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write(JSON1)
file.flush()
- foo = await keystore.get('14:7D:DA:4E:53:A8/P')
- assert foo is not None
- assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683')
+ keystore = JsonKeyStore('my_namespace', file.name)
+ foo = await keystore.get('14:7D:DA:4E:53:A8/P')
+ assert foo is not None
+ assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683')
# -----------------------------------------------------------------------------
-async def test_default_namespace():
- with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
- keystore = JsonKeyStore(None, file.name)
+async def test_default_namespace(temporary_file):
+ with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write(JSON1)
file.flush()
- all_keys = await keystore.get_all()
- assert len(all_keys) == 1
- name, keys = all_keys[0]
- assert name == '14:7D:DA:4E:53:A8/P'
- assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
+ keystore = JsonKeyStore(None, file.name)
+ all_keys = await keystore.get_all()
+ assert len(all_keys) == 1
+ name, keys = all_keys[0]
+ assert name == '14:7D:DA:4E:53:A8/P'
+ assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
- with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
- keystore = JsonKeyStore(None, file.name)
+ with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write(JSON2)
file.flush()
- keys = PairingKeys()
- ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
- keys.ltk = PairingKeys.Key(ltk)
- await keystore.update('foo', keys)
- file.flush()
- with open(file.name, "r", encoding="utf-8") as json_file:
- json_data = json.load(json_file)
- assert '__DEFAULT__' in json_data
- assert 'foo' in json_data['__DEFAULT__']
- assert 'ltk' in json_data['__DEFAULT__']['foo']
-
- with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
- keystore = JsonKeyStore(None, file.name)
+ keystore = JsonKeyStore(None, file.name)
+ keys = PairingKeys()
+ ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
+ keys.ltk = PairingKeys.Key(ltk)
+ await keystore.update('foo', keys)
+ with open(file.name, "r", encoding="utf-8") as json_file:
+ json_data = json.load(json_file)
+ assert '__DEFAULT__' in json_data
+ assert 'foo' in json_data['__DEFAULT__']
+ assert 'ltk' in json_data['__DEFAULT__']['foo']
+
+ with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write(JSON3)
file.flush()
- all_keys = await keystore.get_all()
- assert len(all_keys) == 1
- name, keys = all_keys[0]
- assert name == '14:7D:DA:4E:53:A8/P'
- assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
+ keystore = JsonKeyStore(None, file.name)
+ all_keys = await keystore.get_all()
+ assert len(all_keys) == 1
+ name, keys = all_keys[0]
+ assert name == '14:7D:DA:4E:53:A8/P'
+ assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
# -----------------------------------------------------------------------------
diff --git a/tests/utils_test.py b/tests/utils_test.py
new file mode 100644
index 0000000..d6f5780
--- /dev/null
+++ b/tests/utils_test.py
@@ -0,0 +1,77 @@
+# Copyright 2021-2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import logging
+import os
+
+from bumble import utils
+from pyee import EventEmitter
+from unittest.mock import MagicMock
+
+
+def test_on() -> None:
+ emitter = EventEmitter()
+ with contextlib.closing(utils.EventWatcher()) as context:
+ mock = MagicMock()
+ context.on(emitter, 'event', mock)
+
+ emitter.emit('event')
+
+ assert not emitter.listeners('event')
+ assert mock.call_count == 1
+
+
+def test_on_decorator() -> None:
+ emitter = EventEmitter()
+ with contextlib.closing(utils.EventWatcher()) as context:
+ mock = MagicMock()
+
+ @context.on(emitter, 'event')
+ def on_event(*_) -> None:
+ mock()
+
+ emitter.emit('event')
+
+ assert not emitter.listeners('event')
+ assert mock.call_count == 1
+
+
+def test_multiple_handlers() -> None:
+ emitter = EventEmitter()
+ with contextlib.closing(utils.EventWatcher()) as context:
+ mock = MagicMock()
+
+ context.once(emitter, 'a', mock)
+ context.once(emitter, 'b', mock)
+
+ emitter.emit('b', 'b')
+
+ assert not emitter.listeners('a')
+ assert not emitter.listeners('b')
+
+ mock.assert_called_once_with('b')
+
+
+# -----------------------------------------------------------------------------
+def run_tests():
+ test_on()
+ test_on_decorator()
+ test_multiple_handlers()
+
+
+# -----------------------------------------------------------------------------
+if __name__ == '__main__':
+ logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
+ run_tests()
diff --git a/web/bumble.js b/web/bumble.js
index c5fd6a3..b1243a5 100644
--- a/web/bumble.js
+++ b/web/bumble.js
@@ -74,7 +74,6 @@ export async function loadBumble(pyodide, bumblePackage) {
await pyodide.loadPackage("micropip");
await pyodide.runPythonAsync(`
import micropip
- await micropip.install("cryptography")
await micropip.install("${bumblePackage}")
package_list = micropip.list()
print(package_list)
diff --git a/web/scanner/scanner.py b/web/scanner/scanner.py
index dd53050..c0fc456 100644
--- a/web/scanner/scanner.py
+++ b/web/scanner/scanner.py
@@ -23,7 +23,7 @@ from bumble.device import Device
# -----------------------------------------------------------------------------
class ScanEntry:
def __init__(self, advertisement):
- self.address = str(advertisement.address).replace("/P", "")
+ self.address = advertisement.address.to_string(False)
self.address_type = ('Public', 'Random', 'Public Identity', 'Random Identity')[
advertisement.address.address_type
]
diff --git a/web/speaker/speaker.py b/web/speaker/speaker.py
index ddc2086..d9293a4 100644
--- a/web/speaker/speaker.py
+++ b/web/speaker/speaker.py
@@ -171,7 +171,7 @@ class Speaker:
self.connection = connection
connection.on('disconnection', self.on_bluetooth_disconnection)
peer_name = '' if connection.peer_name is None else connection.peer_name
- peer_address = str(connection.peer_address).replace('/P', '')
+ peer_address = connection.peer_address.to_string(False)
self.emit_event(
'connection', {'peer_name': peer_name, 'peer_address': peer_address}
)