diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2024-01-31 20:18:06 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2024-01-31 20:18:06 +0000 |
commit | 0a335d54e5e4f8fe6275534505641ddeca649c65 (patch) | |
tree | 31bab7dd66656c56a0ceed823ef1181dbdd63460 | |
parent | cd957ddea920508489e6afc6805b363596874835 (diff) | |
parent | 86c3148e3b162f9a46f9d86bba84f186eede0d20 (diff) | |
download | pica-0a335d54e5e4f8fe6275534505641ddeca649c65.tar.gz |
Snap for 11386054 from 86c3148e3b162f9a46f9d86bba84f186eede0d20 to build-tools-release
Change-Id: I81d15b75fd77b516eac95c585c9c5374fa838b0a
-rw-r--r-- | .github/workflows/build_and_test.yml (renamed from .github/workflows/build.yml) | 22 | ||||
-rw-r--r-- | .github/workflows/python_format.yml | 23 | ||||
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | Cargo.lock | 232 | ||||
-rw-r--r-- | Cargo.toml | 8 | ||||
-rw-r--r-- | OWNERS | 2 | ||||
-rw-r--r-- | README.md | 20 | ||||
-rwxr-xr-x | py/pica/console.py | 531 | ||||
-rw-r--r-- | py/pica/pica/__init__.py (renamed from scripts/pica/__init__.py) | 22 | ||||
-rw-r--r-- | py/pica/pica/packets/__init__.py (renamed from scripts/pica/packets/__init__.py) | 0 | ||||
-rw-r--r-- | py/pica/pica/packets/uci.py (renamed from scripts/pica/packets/uci.py) | 239 | ||||
-rw-r--r-- | py/pica/pyproject.toml | 5 | ||||
-rwxr-xr-x | scripts/console.py | 412 | ||||
-rwxr-xr-x | scripts/ranging_example.py | 297 | ||||
-rw-r--r-- | src/bin/server/mod.rs | 8 | ||||
-rw-r--r-- | src/bin/server/web.rs | 6 | ||||
-rw-r--r-- | src/device.rs | 39 | ||||
-rw-r--r-- | src/lib.rs | 239 | ||||
-rw-r--r-- | src/packets.rs | 163 | ||||
-rw-r--r-- | src/pcapng.rs | 30 | ||||
-rw-r--r-- | src/session.rs | 250 | ||||
-rw-r--r-- | src/uci_packets.pdl | 65 | ||||
-rw-r--r-- | tests/__init__.py | 0 | ||||
-rwxr-xr-x | tests/data_transfer.py | 203 | ||||
-rw-r--r-- | tests/helper.py | 30 | ||||
-rwxr-xr-x | tests/ranging.py | 275 | ||||
-rw-r--r-- | tests/test_runner.py | 66 |
27 files changed, 2145 insertions, 1045 deletions
diff --git a/.github/workflows/build.yml b/.github/workflows/build_and_test.yml index 6e763e5..7a19021 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build_and_test.yml @@ -1,5 +1,4 @@ -# Build, test and check the code against the linter and clippy -name: Build, Test, Format and Clippy +name: Build, Check, Test on: push: @@ -9,13 +8,12 @@ on: env: CARGO_TERM_COLOR: always + PY_COLORS: "1" jobs: build_and_test: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [macos-latest, ubuntu-latest, windows-latest] + name: Build, Check, Test + runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Install Rust 1.67.1 @@ -24,6 +22,16 @@ jobs: toolchain: 1.67.1 override: true components: rustfmt, clippy + - name: Set Up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: 3.11 + - name: Install + run: | + pip install --upgrade pip + pip install ./py/pica/ + pip install pytest=="7.4.4" + pip install pytest_asyncio=="0.23.3" - name: Build run: cargo build - name: Test @@ -32,3 +40,5 @@ jobs: run: cargo fmt --check --quiet - name: Clippy run: cargo clippy --no-deps -- --deny warnings + - name: Run Python tests suite + run: pytest --log-cli-level=DEBUG -v diff --git a/.github/workflows/python_format.yml b/.github/workflows/python_format.yml new file mode 100644 index 0000000..1e551db --- /dev/null +++ b/.github/workflows/python_format.yml @@ -0,0 +1,23 @@ +name: Python Format + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + format: + name: Check format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set Up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: 3.11 + - name: Install + run: | + pip install --upgrade pip + pip install black=="23.12.1" + - run: black --check tests/ py/pica --exclude py/pica/pica/packets @@ -1,2 +1,5 @@ target/ __pycache__ +build/ +*.egg-info/ +artifacts @@ -3,6 +3,21 @@ version = 3 [[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] name = "anyhow" version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -46,6 +61,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -67,6 +97,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] + +[[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -201,6 +240,12 @@ dependencies = [ ] [[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + +[[package]] name = "glam" version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -277,7 +322,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio", "tower-service", "tracing", @@ -298,9 +343,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.139" +version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" [[package]] name = "log" @@ -318,26 +363,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + +[[package]] name = "mio" -version = "0.8.6" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", - "log", "wasi", - "windows-sys 0.45.0", + "windows-sys", ] [[package]] name = "num-derive" -version = "0.4.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfb77679af88f8b125209d354a202862602672222e7f2313fdd6dc349bad4712" +checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 1.0.89", ] [[package]] @@ -360,6 +413,15 @@ dependencies = [ ] [[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + +[[package]] name = "once_cell" version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -447,7 +509,7 @@ dependencies = [ [[package]] name = "pica" -version = "0.1.3" +version = "0.1.7" dependencies = [ "anyhow", "bytes", @@ -468,9 +530,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.8" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e280fbe77cc62c91527259e9442153f4688736748d24660126286329742b4c6c" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -531,6 +593,12 @@ dependencies = [ ] [[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + +[[package]] name = "ryu" version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -589,6 +657,16 @@ dependencies = [ ] [[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] name = "syn" version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -641,26 +719,26 @@ dependencies = [ [[package]] name = "tokio" -version = "1.28.0" +version = "1.35.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3c786bf8134e5a3a166db9b29ab8f48134739014a3eca7bc6bfa95d673b136f" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" dependencies = [ - "autocfg", + "backtrace", "bytes", "libc", "mio", "num_cpus", "pin-project-lite", - "socket2", + "socket2 0.5.5", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", @@ -810,132 +888,66 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-sys" -version = "0.45.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" -dependencies = [ - "windows-targets 0.42.1", -] - -[[package]] -name = "windows-sys" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.48.5", + "windows-targets", ] [[package]] name = "windows-targets" -version = "0.42.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" -dependencies = [ - "windows_aarch64_gnullvm 0.42.1", - "windows_aarch64_msvc 0.42.1", - "windows_i686_gnu 0.42.1", - "windows_i686_msvc 0.42.1", - "windows_x86_64_gnu 0.42.1", - "windows_x86_64_gnullvm 0.42.1", - "windows_x86_64_msvc 0.42.1", -] - -[[package]] -name = "windows-targets" -version = "0.48.5" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.42.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" [[package]] name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_i686_gnu" -version = "0.42.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" [[package]] name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_msvc" -version = "0.42.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" [[package]] name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.42.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" [[package]] name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.42.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.42.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" [[package]] name = "windows_x86_64_msvc" -version = "0.48.5" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" @@ -1,6 +1,6 @@ [package] name = "pica" -version = "0.1.3" +version = "0.1.7" edition = "2021" description = "Pica is a virtual UWB Controller implementing the FiRa UCI specification." repository = "https://github.com/google/pica" @@ -39,13 +39,13 @@ web = ["hyper", "tokio/rt-multi-thread"] pdl-compiler = "0.2.3" [dependencies] -tokio = { version = "1.25.0", features = [ "fs", "io-util", "macros", "net", "rt" ] } +tokio = { version = "1.32.0", features = [ "fs", "io-util", "macros", "net", "rt" ] } tokio-stream = { version = "0.1.8", features = ["sync"] } bytes = "1" anyhow = "1.0.56" -num-derive = "0.4.1" +num-derive = "0.3.3" num-traits = "0.2.17" -pdl-runtime = "0.2.3" +pdl-runtime = "0.2.2" thiserror = "1.0.49" glam = "0.23.0" hyper = { version = "0.14", features = ["server", "stream", "http1", "tcp"], optional = true } @@ -1,4 +1,4 @@ -adrienl@google.com charliebout@google.com henrichataing@google.com licorne@google.com +ziyiw@google.com @@ -117,3 +117,23 @@ $> --> pica_create_anchor 00:01 # Create another one Pica also implements HTTP commands, the documentation is available at `http://0.0.0.0:3000/openapi`. The set of HTTP commands let the user interact with Pica amd modify its scene. + +# Tests + +Setup your python env: + +```bash +python3 -m venv venv +source venv/bin/activate +pip install pytest +pip install pytest_asyncio +pip install -e py/pica/ +``` + +Then run the tests + +```bash +pytest --log-cli-level=DEBUG -v +``` + +The tests are located in `./tests/` diff --git a/py/pica/console.py b/py/pica/console.py new file mode 100755 index 0000000..425db53 --- /dev/null +++ b/py/pica/console.py @@ -0,0 +1,531 @@ +#!/usr/bin/env python3 + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import inspect +import json +import random +import readline +import socket +import sys +import time +import requests +import struct +import asyncio +from concurrent.futures import ThreadPoolExecutor + +from pica import Host +from pica.packets import uci + +MAX_DATA_PACKET_PAYLOAD_SIZE = 1024 + + +def encode_short_mac_address(mac_address: str) -> bytes: + return int(mac_address).to_bytes(2, byteorder="little") + + +def encode_mac_address(mac_address: str) -> bytes: + return int(mac_address).to_bytes(8, byteorder="little") + + +def parse_mac_address(mac_address: str) -> bytes: + bs = mac_address.split(":") + return bytes(int(b, 16) for b in bs) + + +class Device: + def __init__(self, reader, writer, http_address): + self.host = Host(reader, writer, bytes([0, 1])) + self.http_address = http_address + + def pica_get_state(self, **kargs): + """List the UCI devices and anchors currently existing within the Pica + virtual environment""" + r = requests.get(f"{self.http_address}/get-state") + print(f"{r.status_code}:\n{json.dumps(r.json(), indent=2)}") + + def pica_init_uci_device( + self, + mac_address: str = "00:00", + x: str = "0", + y: str = "0", + z: str = "0", + yaw: str = "0", + pitch: str = "0", + roll: str = "0", + **kargs, + ): + """Init Pica device""" + r = requests.post( + f"{self.http_address}/init-uci-device/{mac_address}", + data=json.dumps( + { + "x": int(x), + "y": int(y), + "z": int(z), + "yaw": int(yaw), + "pitch": int(pitch), + "roll": int(roll), + } + ), + ) + print(f"{r.status_code}: {r.text}") + + def pica_create_anchor( + self, + mac_address: str = "00:00", + x: str = "0", + y: str = "0", + z: str = "0", + yaw: str = "0", + pitch: str = "0", + roll: str = "0", + **kargs, + ): + """Create a Pica anchor""" + r = requests.post( + f"{self.http_address}/create-anchor/{mac_address}", + data=json.dumps( + { + "x": int(x), + "y": int(y), + "z": int(z), + "yaw": int(yaw), + "pitch": int(pitch), + "roll": int(roll), + } + ), + ) + print(f"{r.status_code}: {r.text}") + + def pica_destroy_anchor(self, mac_address: str = "00:00", **kargs): + """Destroy a Pica anchor""" + r = requests.post(f"{self.http_address}/destroy-anchor/{mac_address}") + print(f"{r.status_code}: {r.text}") + + def pica_set_position( + self, + mac_address: str = "00:00", + x: str = "0", + y: str = "0", + z: str = "0", + yaw: str = "0", + pitch: str = "0", + roll: str = "0", + **kargs, + ): + """Set Pica UCI device or anchor position""" + r = requests.post( + f"{self.http_address}/set-position/{mac_address}", + data=json.dumps( + { + "x": int(x), + "y": int(y), + "z": int(z), + "yaw": int(yaw), + "pitch": int(pitch), + "roll": int(roll), + } + ), + ) + print(f"{r.status_code}: {r.text}") + + def device_reset(self, **kargs): + """Reset the UWBS.""" + self.host.send_control( + uci.DeviceResetCmd(reset_config=uci.ResetConfig.UWBS_RESET) + ) + + def get_device_info(self, **kargs): + """Retrieve the device information like (UCI version and other vendor specific info).""" + self.host.send_control(uci.GetDeviceInfoCmd()) + + def get_caps_info(self, **kargs): + """Get the capability of the UWBS.""" + self.host.send_control(uci.GetCapsInfoCmd()) + + def set_config(self, low_power_mode: str = "0", **kargs): + """Set the configuration parameters on the UWBS.""" + self.host.send_control( + uci.SetConfigCmd( + tlvs=[ + uci.DeviceConfigTlv( + cfg_id=uci.DeviceConfigId.LOW_POWER_MODE, + v=bytes([int(low_power_mode)]), + ), + ] + ) + ) + + def get_config(self, **kargs): + """Retrieve the current configuration parameter(s) of the UWBS.""" + self.host.send_control( + uci.GetConfigCmd( + cfg_id=[ + uci.DeviceConfigId.LOW_POWER_MODE, + uci.DeviceConfigId.DEVICE_STATE, + ] + ) + ) + + def session_init(self, session_id: str = "0", **kargs): + """Initialize the session""" + self.host.send_control( + uci.SessionInitCmd( + session_id=int(session_id), + session_type=uci.SessionType.FIRA_RANGING_AND_IN_BAND_DATA_SESSION, + ) + ) + + def session_deinit(self, session_id: str = "0", **kargs): + """Deinitialize the session""" + self.host.send_control(uci.SessionDeinitCmd(session_token=int(session_id))) + + def session_set_app_config( + self, + session_id: str = "0", + ranging_interval: str = "200", + dst_mac_addresses: str = "", + **kargs, + ): + """set APP Configuration Parameters for the requested UWB session.""" + dst_mac_addresses = [ + parse_mac_address(a) for a in dst_mac_addresses.split(",") if a + ] + if any(len(a) > 2 for a in dst_mac_addresses): + mac_address_mode = 0x2 + mac_address_len = 8 + else: + mac_address_mode = 0x0 + mac_address_len = 2 + + encoded_dst_mac_addresses = bytes() + for mac_address in dst_mac_addresses: + encoded_dst_mac_addresses += mac_address + encoded_dst_mac_addresses += b"\0" * (mac_address_len - len(mac_address)) + + self.host.send_control( + uci.SessionSetAppConfigCmd( + session_token=int(session_id), + tlvs=[ + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.MAC_ADDRESS_MODE, + v=bytes([mac_address_mode]), + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.RANGING_DURATION, + v=int(ranging_interval).to_bytes(4, byteorder="little"), + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.NO_OF_CONTROLEE, + v=bytes([len(dst_mac_addresses)]), + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DST_MAC_ADDRESS, + v=encoded_dst_mac_addresses, + ), + ], + ) + ) + + def session_get_app_config(self, session_id: str = "0", **kargs): + """retrieve the current APP Configuration Parameters of the requested UWB session.""" + self.host.send_control( + uci.SessionGetAppConfigCmd(session_token=int(session_id), app_cfg=[0x9]) + ) + + def session_get_count(self, **kargs): + """Retrieve number of UWB sessions in the UWBS.""" + self.host.send_control(uci.SessionGetCountCmd()) + + def session_get_state(self, session_id: str = "0", **kargs): + """Query the current state of the UWB session.""" + self.host.send_control(uci.SessionGetStateCmd(session_token=int(session_id))) + + def session_update_controller_multicast_list( + self, + session_id: str = "0", + action: str = "add", + mac_address: str = "0", + subsession_id: str = "0", + **kargs, + ): + """Update the controller multicast list.""" + + if action == "add": + encoded_action = uci.UpdateMulticastListAction.ADD_CONTROLEE + elif action == "remove": + encoded_action = uci.UpdateMulticastListAction.REMOVE_CONTROLEE + else: + print(f"Unexpected action: '{action}', expected add or remove") + return + + self.host.send_control( + uci.SessionUpdateControllerMulticastListCmd( + session_token=int(session_id), + action=encoded_action, + payload=uci.SessionUpdateControllerMulticastListCmdPayload( + controlees=[ + uci.Controlee( + short_address=encode_short_mac_address(mac_address), + subsession_id=int(subsession_id), + ) + ], + ).serialize(), + ) + ) + + def range_start(self, session_id: str = "0", **kargs): + """start a UWB session.""" + self.host.send_control(uci.SessionStartCmd(session_id=int(session_id))) + + def range_stop(self, session_id: str = "0", **kargs): + """Stop a UWB session.""" + self.host.send_control(uci.SessionStopCmd(session_id=int(session_id))) + + def get_ranging_count(self, session_id: str = "0", **kargs): + """Get the number of times ranging has been attempted during the ranging session..""" + self.host.send_control( + uci.SessionGetRangingCountCmd(session_id=int(session_id)) + ) + + def data_transfer( + self, + dst_mac_address, + file_name, + session_id: str = "0", + ): + """Initiates data transfer by sending (possibly segmented) UCI data packet(s).""" + + # Does not have flow control, i.e. waiting for data credit notifications in between sending packets + try: + with open(file_name, "rb") as f: + b = f.read() + seq_num = 0 + dst_mac_address = parse_mac_address(dst_mac_address) + + if len(b) > MAX_DATA_PACKET_PAYLOAD_SIZE: + for i in range(0, len(b), MAX_DATA_PACKET_PAYLOAD_SIZE): + section = b[i : i + MAX_DATA_PACKET_PAYLOAD_SIZE] + + if i + MAX_DATA_PACKET_PAYLOAD_SIZE >= len(b): + self.host.send_data( + uci.DataMessageSnd( + session_handle=int(session_id), + destination_address=int.from_bytes(dst_mac_address), + data_sequence_number=seq_num, + application_data=section, + ) + ) + else: + self.host.send_data( + uci.DataMessageSnd( + session_handle=int(session_id), + pbf=uci.PacketBoundaryFlag.NOT_COMPLETE, + destination_address=int.from_bytes(dst_mac_address), + data_sequence_number=seq_num, + application_data=section, + ) + ) + + seq_num += 1 + if seq_num >= 65535: + seq_num = 0 + else: + self.host.send_data( + uci.DataMessageSnd( + session_handle=int(session_id), + destination_address=int.from_bytes(dst_mac_address), + data_sequence_number=seq_num, + application_data=b, + ) + ) + + except Exception as e: + print(e) + + async def read_responses_and_notifications(self): + def chunks(l, n): + for i in range(0, len(l), n): + yield l[i : i + n] + + while True: + packet = await self.host._recv_control() + + # Format and print raw response data + txt = "\n ".join( + [ + " ".join(["{:02x}".format(b) for b in shard]) + for shard in chunks(packet, 16) + ] + ) + + command_buffer = readline.get_line_buffer() + print("\r", end="") + print(f"Received UCI packet [{len(packet)}]:") + print(f" {txt}") + + try: + uci_packet = uci.ControlPacket.parse_all(packet) + uci_packet.show() + except Exception as exn: + pass + + print(f"--> {command_buffer}", end="", flush=True) + + +async def ainput(prompt: str = ""): + with ThreadPoolExecutor(1, "ainput") as executor: + return ( + await asyncio.get_event_loop().run_in_executor(executor, input, prompt) + ).rstrip() + + +async def get_stream_reader(pipe) -> asyncio.StreamReader: + loop = asyncio.get_event_loop() + reader = asyncio.StreamReader(loop=loop) + protocol = asyncio.StreamReaderProtocol(reader) + await loop.connect_read_pipe(lambda: protocol, pipe) + return reader + + +async def command_line(device: Device): + commands = { + "pica_get_state": device.pica_get_state, + "pica_init_uci_device": device.pica_init_uci_device, + "pica_create_anchor": device.pica_create_anchor, + "pica_destroy_anchor": device.pica_destroy_anchor, + "pica_set_position": device.pica_set_position, + "device_reset": device.device_reset, + "get_device_info": device.get_device_info, + "get_config": device.get_config, + "set_config": device.set_config, + "get_caps_info": device.get_caps_info, + "session_init": device.session_init, + "session_deinit": device.session_deinit, + "session_set_app_config": device.session_set_app_config, + "session_get_app_config": device.session_get_app_config, + "session_get_count": device.session_get_count, + "session_get_state": device.session_get_state, + "session_update_controller_multicast_list": device.session_update_controller_multicast_list, + "range_start": device.range_start, + "range_stop": device.range_stop, + "data_transfer": device.data_transfer, + "get_ranging_count": device.get_ranging_count, + } + + def usage(): + for cmd, func in commands.items(): + print(f" {cmd.ljust(32)}{func.__doc__}") + + def complete(text, state): + tokens = readline.get_line_buffer().split() + if not tokens or readline.get_line_buffer()[-1] == " ": + tokens.append("") + + # Writing a command name, complete to ' ' + if len(tokens) == 1: + results = [cmd + " " for cmd in commands.keys() if cmd.startswith(text)] + + # Writing a keyword argument, no completion + elif "=" in tokens[-1]: + results = [] + + # Writing a keyword name, but unknown command, no completion + elif tokens[0] not in commands: + results = [] + + # Writing a keyword name, complete to '=' + else: + sig = inspect.signature(commands[tokens[0]]) + names = [ + name + for (name, p) in sig.parameters.items() + if ( + p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + or p.kind == inspect.Parameter.KEYWORD_ONLY + ) + ] + results = [name + "=" for name in names if name.startswith(tokens[-1])] + + results += [None] + return results[state] + + # Configure readline + readline.parse_and_bind("tab: complete") + readline.set_completer(complete) + + while True: + cmd = await ainput("--> ") + [cmd, *params] = cmd.split(" ") + args = [] + kargs = dict() + for param in params: + if len(param) == 0: + continue + elif "=" in param: + [key, value] = param.split("=") + kargs[key] = value + else: + args.append(param) + + if cmd in ["quit", "q"]: + break + if cmd not in commands: + print(f"Undefined command {cmd}") + usage() + continue + commands[cmd](*args, **kargs) + + +async def run(address: str, uci_port: int, http_port: int): + try: + # Connect to Pica + reader, writer = await asyncio.open_connection(address, uci_port) + except Exception as exn: + print( + f"Failed to connect to Pica server at address {address}:{uci_port}\n" + + "Make sure the server is running" + ) + exit(1) + + # Start input and receive loops + device = Device(reader, writer, f"http://{address}:{http_port}") + loop = asyncio.get_event_loop() + loop.create_task(device.read_responses_and_notifications()) + await command_line(device) + + +def main(): + """Start a Pica interactive console.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--address", + type=str, + default="127.0.0.1", + help="Select the pica server address", + ) + parser.add_argument( + "--uci-port", type=int, default=7000, help="Select the pica TCP UCI port" + ) + parser.add_argument( + "--http-port", type=int, default=3000, help="Select the pica HTTP port" + ) + asyncio.run(run(**vars(parser.parse_args()))) + + +if __name__ == "__main__": + main() diff --git a/scripts/pica/__init__.py b/py/pica/pica/__init__.py index c4d1cea..ba2b155 100644 --- a/scripts/pica/__init__.py +++ b/py/pica/pica/__init__.py @@ -1,8 +1,8 @@ - import asyncio from typing import Union from .packets import uci + class Host: def __init__(self, reader, writer, mac_address: bytes): self.reader = reader @@ -14,17 +14,17 @@ class Host: loop = asyncio.get_event_loop() self.reader_task = loop.create_task(self._read_packets()) - @staticmethod - async def connect(address: str, port: int, mac_address: bytes) -> 'Host': + async def connect(address: str, port: int, mac_address: bytes) -> "Host": reader, writer = await asyncio.open_connection(address, port) return Host(reader, writer, mac_address) def disconnect(self): self.writer.close() + self.reader_task.cancel() async def _read_exact(self, expected_len: int) -> bytes: - """ Read an exact number of bytes from the socket. + """Read an exact number of bytes from the socket. Raises an exception if the socket gets disconnected.""" received = bytes() @@ -34,7 +34,7 @@ class Host: return received async def _read_packet(self) -> bytes: - """ Read a single UCI packet from the socket. + """Read a single UCI packet from the socket. The packet is automatically re-assembled if segmented on the UCI transport.""" @@ -49,7 +49,7 @@ class Host: while True: # Read the common packet header. header_bytes = await self._read_exact(4) - header = uci.PacketHeader.parse_all(header_bytes) + header = uci.ControlPacketHeader.parse_all(header_bytes) # Read the packet payload. payload_bytes = await self._read_exact(header.payload_length) @@ -63,7 +63,7 @@ class Host: pass async def _read_packets(self): - """ Loop reading UCI packets from the socket. + """Loop reading UCI packets from the socket. Receiving packets are added to the control queue.""" try: while True: @@ -81,6 +81,14 @@ class Host: packet[3] = len(packet) - 4 self.writer.write(packet) + def send_data(self, packet: uci.DataPacket): + packet = bytearray(packet.serialize()) + size = len(packet) - 4 + size_bytes = size.to_bytes(2, byteorder="little") + packet[2] = size_bytes[0] + packet[3] = size_bytes[1] + self.writer.write(packet) + async def expect_control( self, expected: Union[type, uci.ControlPacket], timeout: float = 1.0 ) -> uci.ControlPacket: diff --git a/scripts/pica/packets/__init__.py b/py/pica/pica/packets/__init__.py index e69de29..e69de29 100644 --- a/scripts/pica/packets/__init__.py +++ b/py/pica/pica/packets/__init__.py diff --git a/scripts/pica/packets/uci.py b/py/pica/pica/packets/uci.py index 5e0e8a8..8bba6ce 100644 --- a/scripts/pica/packets/uci.py +++ b/py/pica/pica/packets/uci.py @@ -226,6 +226,7 @@ class DataTransferNtfStatusCode(enum.IntEnum): UCI_DATA_TRANSFER_STATUS_ERROR_REJECTED = 0x4 UCI_DATA_TRANSFER_STATUS_SESSION_TYPE_NOT_SUPPORTED = 0x5 UCI_DATA_TRANSFER_STATUS_ERROR_DATA_TRANSFER_IS_ONGOING = 0x6 + UCI_DATA_TRANSFER_STATUS_INVALID_FORMAT = 0x7 class ResetConfig(enum.IntEnum): UWBS_RESET = 0x0 @@ -451,7 +452,6 @@ class MessageType(enum.IntEnum): class PacketHeader(Packet): pbf: PacketBoundaryFlag = field(kw_only=True, default=PacketBoundaryFlag.COMPLETE) mt: MessageType = field(kw_only=True, default=MessageType.DATA) - payload_length: int = field(kw_only=True, default=0) def __post_init__(self): pass @@ -459,6 +459,38 @@ class PacketHeader(Packet): @staticmethod def parse(span: bytes) -> Tuple['PacketHeader', bytes]: fields = {'payload': None} + if len(span) < 1: + raise Exception('Invalid packet size') + fields['pbf'] = PacketBoundaryFlag((span[0] >> 4) & 0x1) + fields['mt'] = MessageType((span[0] >> 5) & 0x7) + span = span[1:] + return PacketHeader(**fields), span + + def serialize(self, payload: bytes = None) -> bytes: + _span = bytearray() + _value = ( + (self.pbf << 4) | + (self.mt << 5) + ) + _span.append(_value) + return bytes(_span) + + @property + def size(self) -> int: + return 1 + +@dataclass +class ControlPacketHeader(Packet): + pbf: PacketBoundaryFlag = field(kw_only=True, default=PacketBoundaryFlag.COMPLETE) + mt: MessageType = field(kw_only=True, default=MessageType.DATA) + payload_length: int = field(kw_only=True, default=0) + + def __post_init__(self): + pass + + @staticmethod + def parse(span: bytes) -> Tuple['ControlPacketHeader', bytes]: + fields = {'payload': None} if len(span) < 4: raise Exception('Invalid packet size') fields['pbf'] = PacketBoundaryFlag((span[0] >> 4) & 0x1) @@ -466,7 +498,7 @@ class PacketHeader(Packet): value_ = int.from_bytes(span[1:3], byteorder='little') fields['payload_length'] = span[3] span = span[4:] - return PacketHeader(**fields), span + return ControlPacketHeader(**fields), span def serialize(self, payload: bytes = None) -> bytes: _span = bytearray() @@ -477,7 +509,7 @@ class PacketHeader(Packet): _span.append(_value) _span.extend([0] * 2) if self.payload_length > 255: - print(f"Invalid value for field PacketHeader::payload_length: {self.payload_length} > 255; the value will be truncated") + print(f"Invalid value for field ControlPacketHeader::payload_length: {self.payload_length} > 255; the value will be truncated") self.payload_length &= 255 _span.append((self.payload_length << 0)) return bytes(_span) @@ -487,6 +519,45 @@ class PacketHeader(Packet): return 4 @dataclass +class DataPacketHeader(Packet): + pbf: PacketBoundaryFlag = field(kw_only=True, default=PacketBoundaryFlag.COMPLETE) + mt: MessageType = field(kw_only=True, default=MessageType.DATA) + payload_length: int = field(kw_only=True, default=0) + + def __post_init__(self): + pass + + @staticmethod + def parse(span: bytes) -> Tuple['DataPacketHeader', bytes]: + fields = {'payload': None} + if len(span) < 4: + raise Exception('Invalid packet size') + fields['pbf'] = PacketBoundaryFlag((span[0] >> 4) & 0x1) + fields['mt'] = MessageType((span[0] >> 5) & 0x7) + value_ = int.from_bytes(span[2:4], byteorder='little') + fields['payload_length'] = value_ + span = span[4:] + return DataPacketHeader(**fields), span + + def serialize(self, payload: bytes = None) -> bytes: + _span = bytearray() + _value = ( + (self.pbf << 4) | + (self.mt << 5) + ) + _span.append(_value) + _span.extend([0] * 1) + if self.payload_length > 65535: + print(f"Invalid value for field DataPacketHeader::payload_length: {self.payload_length} > 65535; the value will be truncated") + self.payload_length &= 65535 + _span.extend(int.to_bytes((self.payload_length << 0), length=2, byteorder='little')) + return bytes(_span) + + @property + def size(self) -> int: + return 4 + +@dataclass class ControlPacket(Packet): gid: GroupId = field(kw_only=True, default=GroupId.CORE) mt: MessageType = field(kw_only=True, default=MessageType.DATA) @@ -726,8 +797,168 @@ class ControlPacket(Packet): return len(self.payload) + 4 @dataclass -class UciCommand(ControlPacket): +class DataPacket(Packet): + dpf: DataPacketFormat = field(kw_only=True, default=DataPacketFormat.DATA_SND) + pbf: PacketBoundaryFlag = field(kw_only=True, default=PacketBoundaryFlag.COMPLETE) + mt: MessageType = field(kw_only=True, default=MessageType.DATA) + def __post_init__(self): + pass + + @staticmethod + def parse(span: bytes) -> Tuple['DataPacket', bytes]: + fields = {'payload': None} + if len(span) < 4: + raise Exception('Invalid packet size') + fields['dpf'] = DataPacketFormat((span[0] >> 0) & 0xf) + fields['pbf'] = PacketBoundaryFlag((span[0] >> 4) & 0x1) + fields['mt'] = MessageType((span[0] >> 5) & 0x7) + value_ = int.from_bytes(span[2:4], byteorder='little') + span = span[4:] + payload = span + span = bytes([]) + fields['payload'] = payload + try: + return DataMessageSnd.parse(fields.copy(), payload) + except Exception as exn: + pass + try: + return DataMessageRcv.parse(fields.copy(), payload) + except Exception as exn: + pass + return DataPacket(**fields), span + + def serialize(self, payload: bytes = None) -> bytes: + _span = bytearray() + _value = ( + (self.dpf << 0) | + (self.pbf << 4) | + (self.mt << 5) + ) + _span.append(_value) + _span.extend([0] * 1) + _span.extend([0] * 2) + _span.extend(payload or self.payload or []) + return bytes(_span) + + @property + def size(self) -> int: + return len(self.payload) + 4 + +@dataclass +class DataMessageSnd(DataPacket): + session_handle: int = field(kw_only=True, default=0) + destination_address: int = field(kw_only=True, default=0) + data_sequence_number: int = field(kw_only=True, default=0) + application_data: bytearray = field(kw_only=True, default_factory=bytearray) + + def __post_init__(self): + self.dpf = DataPacketFormat.DATA_SND + self.mt = MessageType.DATA + + @staticmethod + def parse(fields: dict, span: bytes) -> Tuple['DataMessageSnd', bytes]: + if fields['dpf'] != DataPacketFormat.DATA_SND or fields['mt'] != MessageType.DATA: + raise Exception("Invalid constraint field values") + if len(span) < 16: + raise Exception('Invalid packet size') + value_ = int.from_bytes(span[0:4], byteorder='little') + fields['session_handle'] = value_ + value_ = int.from_bytes(span[4:12], byteorder='little') + fields['destination_address'] = value_ + value_ = int.from_bytes(span[12:14], byteorder='little') + fields['data_sequence_number'] = value_ + value_ = int.from_bytes(span[14:16], byteorder='little') + application_data_size = value_ + span = span[16:] + if len(span) < application_data_size: + raise Exception('Invalid packet size') + fields['application_data'] = list(span[:application_data_size]) + span = span[application_data_size:] + return DataMessageSnd(**fields), span + + def serialize(self, payload: bytes = None) -> bytes: + _span = bytearray() + if self.session_handle > 4294967295: + print(f"Invalid value for field DataMessageSnd::session_handle: {self.session_handle} > 4294967295; the value will be truncated") + self.session_handle &= 4294967295 + _span.extend(int.to_bytes((self.session_handle << 0), length=4, byteorder='little')) + if self.destination_address > 18446744073709551615: + print(f"Invalid value for field DataMessageSnd::destination_address: {self.destination_address} > 18446744073709551615; the value will be truncated") + self.destination_address &= 18446744073709551615 + _span.extend(int.to_bytes((self.destination_address << 0), length=8, byteorder='little')) + if self.data_sequence_number > 65535: + print(f"Invalid value for field DataMessageSnd::data_sequence_number: {self.data_sequence_number} > 65535; the value will be truncated") + self.data_sequence_number &= 65535 + _span.extend(int.to_bytes((self.data_sequence_number << 0), length=2, byteorder='little')) + _span.extend(int.to_bytes(((len(self.application_data) * 1) << 0), length=2, byteorder='little')) + _span.extend(self.application_data) + return DataPacket.serialize(self, payload = bytes(_span)) + + @property + def size(self) -> int: + return len(self.application_data) * 1 + 16 + +@dataclass +class DataMessageRcv(DataPacket): + session_handle: int = field(kw_only=True, default=0) + status: StatusCode = field(kw_only=True, default=StatusCode.UCI_STATUS_OK) + source_address: int = field(kw_only=True, default=0) + data_sequence_number: int = field(kw_only=True, default=0) + application_data: bytearray = field(kw_only=True, default_factory=bytearray) + + def __post_init__(self): + self.dpf = DataPacketFormat.DATA_RCV + self.mt = MessageType.DATA + + @staticmethod + def parse(fields: dict, span: bytes) -> Tuple['DataMessageRcv', bytes]: + if fields['dpf'] != DataPacketFormat.DATA_RCV or fields['mt'] != MessageType.DATA: + raise Exception("Invalid constraint field values") + if len(span) < 17: + raise Exception('Invalid packet size') + value_ = int.from_bytes(span[0:4], byteorder='little') + fields['session_handle'] = value_ + fields['status'] = StatusCode(span[4]) + value_ = int.from_bytes(span[5:13], byteorder='little') + fields['source_address'] = value_ + value_ = int.from_bytes(span[13:15], byteorder='little') + fields['data_sequence_number'] = value_ + value_ = int.from_bytes(span[15:17], byteorder='little') + application_data_size = value_ + span = span[17:] + if len(span) < application_data_size: + raise Exception('Invalid packet size') + fields['application_data'] = list(span[:application_data_size]) + span = span[application_data_size:] + return DataMessageRcv(**fields), span + + def serialize(self, payload: bytes = None) -> bytes: + _span = bytearray() + if self.session_handle > 4294967295: + print(f"Invalid value for field DataMessageRcv::session_handle: {self.session_handle} > 4294967295; the value will be truncated") + self.session_handle &= 4294967295 + _span.extend(int.to_bytes((self.session_handle << 0), length=4, byteorder='little')) + _span.append((self.status << 0)) + if self.source_address > 18446744073709551615: + print(f"Invalid value for field DataMessageRcv::source_address: {self.source_address} > 18446744073709551615; the value will be truncated") + self.source_address &= 18446744073709551615 + _span.extend(int.to_bytes((self.source_address << 0), length=8, byteorder='little')) + if self.data_sequence_number > 65535: + print(f"Invalid value for field DataMessageRcv::data_sequence_number: {self.data_sequence_number} > 65535; the value will be truncated") + self.data_sequence_number &= 65535 + _span.extend(int.to_bytes((self.data_sequence_number << 0), length=2, byteorder='little')) + _span.extend(int.to_bytes(((len(self.application_data) * 1) << 0), length=2, byteorder='little')) + _span.extend(self.application_data) + return DataPacket.serialize(self, payload = bytes(_span)) + + @property + def size(self) -> int: + return len(self.application_data) * 1 + 17 + +@dataclass +class UciCommand(ControlPacket): + def __post_init__(self): self.mt = MessageType.COMMAND diff --git a/py/pica/pyproject.toml b/py/pica/pyproject.toml new file mode 100644 index 0000000..3f7acac --- /dev/null +++ b/py/pica/pyproject.toml @@ -0,0 +1,5 @@ +[project] +name = "pica" +dynamic = ["version"] +description = "UCI host helpers" +requires-python = ">=3.10" diff --git a/scripts/console.py b/scripts/console.py deleted file mode 100755 index 2e21ae3..0000000 --- a/scripts/console.py +++ /dev/null @@ -1,412 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import inspect -import json -import random -import readline -import socket -import sys -import time -import requests -import struct -import asyncio -from concurrent.futures import ThreadPoolExecutor - -from pica import Host -from pica.packets import uci - - -def encode_short_mac_address(mac_address: str) -> bytes: - return int(mac_address).to_bytes(2, byteorder='little') - - -def encode_mac_address(mac_address: str) -> bytes: - return int(mac_address).to_bytes(8, byteorder='little') - - -def parse_mac_address(mac_address: str) -> bytes: - bs = mac_address.split(':') - return bytes(int(b, 16) for b in bs) - - -class Device: - def __init__(self, reader, writer, http_address): - self.host = Host(reader, writer, bytes([0, 1])) - self.http_address = http_address - - def pica_get_state( - self, - **kargs): - """List the UCI devices and anchors currently existing within the Pica - virtual environment""" - r = requests.get(f'{self.http_address}/get-state') - print(f'{r.status_code}:\n{json.dumps(r.json(), indent=2)}') - - def pica_init_uci_device( - self, - mac_address: str = "00:00", - x: str = "0", - y: str = "0", - z: str = "0", - yaw: str = "0", - pitch: str = "0", - roll: str = "0", - **kargs): - """Init Pica device""" - r = requests.post(f'{self.http_address}/init-uci-device/{mac_address}', - data=json.dumps({ - 'x': int(x), 'y': int(y), 'z': int(z), - 'yaw': int(yaw), 'pitch': int(pitch), 'roll': int(roll) - })) - print(f'{r.status_code}: {r.text}') - - def pica_create_anchor( - self, - mac_address: str = "00:00", - x: str = "0", - y: str = "0", - z: str = "0", - yaw: str = "0", - pitch: str = "0", - roll: str = "0", - **kargs): - """Create a Pica anchor""" - r = requests.post(f'{self.http_address}/create-anchor/{mac_address}', - data=json.dumps({ - 'x': int(x), 'y': int(y), 'z': int(z), - 'yaw': int(yaw), 'pitch': int(pitch), 'roll': int(roll) - })) - print(f'{r.status_code}: {r.text}') - - def pica_destroy_anchor( - self, - mac_address: str = "00:00", - **kargs): - """Destroy a Pica anchor""" - r = requests.post(f'{self.http_address}/destroy-anchor/{mac_address}') - print(f'{r.status_code}: {r.text}') - - def pica_set_position( - self, - mac_address: str = "00:00", - x: str = "0", - y: str = "0", - z: str = "0", - yaw: str = "0", - pitch: str = "0", - roll: str = "0", - **kargs): - """Set Pica UCI device or anchor position""" - r = requests.post(f'{self.http_address}/set-position/{mac_address}', - data=json.dumps({ - 'x': int(x), 'y': int(y), 'z': int(z), - 'yaw': int(yaw), 'pitch': int(pitch), 'roll': int(roll) - })) - print(f'{r.status_code}: {r.text}') - - def device_reset(self, **kargs): - """Reset the UWBS.""" - self.host.send_control(uci.DeviceResetCmd(reset_config=uci.ResetConfig.UWBS_RESET)) - - def get_device_info(self, **kargs): - """Retrieve the device information like (UCI version and other vendor specific info).""" - self.host.send_control(uci.GetDeviceInfoCmd()) - - def get_caps_info(self, **kargs): - """Get the capability of the UWBS.""" - self.host.send_control(uci.GetCapsInfoCmd()) - - def set_config(self, low_power_mode: str = '0', **kargs): - """Set the configuration parameters on the UWBS.""" - self.host.send_control(uci.SetConfigCmd(tlvs=[ - uci.DeviceConfigTlv( - cfg_id=uci.DeviceConfigId.LOW_POWER_MODE, - v=bytes([int(low_power_mode)])), - ])) - - def get_config(self, **kargs): - """Retrieve the current configuration parameter(s) of the UWBS.""" - self.host.send_control(uci.GetConfigCmd(cfg_id=[ - uci.DeviceConfigId.LOW_POWER_MODE, - uci.DeviceConfigId.DEVICE_STATE, - ])) - - def session_init(self, session_id: str = '0', **kargs): - """Initialize the session""" - self.host.send_control(uci.SessionInitCmd( - session_id=int(session_id), - session_type=uci.SessionType.FIRA_RANGING_SESSION)) - - def session_deinit(self, session_id: str = '0', **kargs): - """Deinitialize the session""" - self.host.send_control(uci.SessionDeinitCmd( - session_token=int(session_id))) - - def session_set_app_config( - self, - session_id: str = '0', - ranging_interval: str = '200', - dst_mac_addresses: str = '', - **kargs): - """set APP Configuration Parameters for the requested UWB session.""" - dst_mac_addresses = [parse_mac_address(a) for a in dst_mac_addresses.split(',') if a] - if any(len(a) > 2 for a in dst_mac_addresses): - mac_address_mode = 0x2 - mac_address_len = 8 - else: - mac_address_mode = 0x0 - mac_address_len = 2 - - encoded_dst_mac_addresses = bytes() - for mac_address in dst_mac_addresses: - encoded_dst_mac_addresses += mac_address - encoded_dst_mac_addresses += b'\0' * (mac_address_len - len(mac_address)) - - self.host.send_control(uci.SessionSetAppConfigCmd( - session_token=int(session_id), - tlvs=[ - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.MAC_ADDRESS_MODE, - v=bytes([mac_address_mode])), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.RANGING_DURATION, - v=int(ranging_interval).to_bytes(4, byteorder='little')), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.NO_OF_CONTROLEE, - v=bytes([len(dst_mac_addresses)])), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.DST_MAC_ADDRESS, - v=encoded_dst_mac_addresses), - ])) - - def session_get_app_config(self, session_id: str = '0', **kargs): - """retrieve the current APP Configuration Parameters of the requested UWB session.""" - self.host.send_control(uci.SessionGetAppConfigCmd( - session_token=int(session_id), app_cfg=[0x9])) - - def session_get_count(self, **kargs): - """Retrieve number of UWB sessions in the UWBS.""" - self.host.send_control(uci.SessionGetCountCmd()) - - def session_get_state(self, session_id: str = '0', **kargs): - """Query the current state of the UWB session.""" - self.host.send_control(uci.SessionGetStateCmd(session_token=int(session_id))) - - def session_update_controller_multicast_list( - self, - session_id: str = '0', - action: str = 'add', - mac_address: str = '0', - subsession_id: str = '0', - **kargs): - """Update the controller multicast list.""" - - if action == 'add': - encoded_action = uci.UpdateMulticastListAction.ADD_CONTROLEE - elif action == 'remove': - encoded_action = uci.UpdateMulticastListAction.REMOVE_CONTROLEE - else: - print(f"Unexpected action: '{action}', expected add or remove") - return - - self.host.send_control(uci.SessionUpdateControllerMulticastListCmd( - session_token=int(session_id), - action=encoded_action, - payload=uci.SessionUpdateControllerMulticastListCmdPayload( - controlees=[ - uci.Controlee( - short_address=encode_short_mac_address(mac_address), - subsession_id=int(subsession_id), - ) - ], - ).serialize())) - - def range_start(self, session_id: str = '0', **kargs): - """start a UWB session.""" - self.host.send_control(uci.SessionStartCmd(session_id=int(session_id))) - - def range_stop(self, session_id: str = '0', **kargs): - """Stop a UWB session.""" - self.host.send_control(uci.SessionStopCmd(session_id=int(session_id))) - - def get_ranging_count(self, session_id: str = '0', **kargs): - """Get the number of times ranging has been attempted during the ranging session..""" - self.host.send_control(uci.SessionGetRangingCountCmd(session_id=int(session_id))) - - async def read_responses_and_notifications(self): - def chunks(l, n): - for i in range(0, len(l), n): - yield l[i:i + n] - - while True: - packet = await self.host._recv_control() - - # Format and print raw response data - txt = '\n '.join([ - ' '.join(['{:02x}'.format(b) for b in shard]) for - shard in chunks(packet, 16)]) - - command_buffer = readline.get_line_buffer() - print('\r', end='') - print(f'Received UCI packet [{len(packet)}]:') - print(f' {txt}') - - try: - uci_packet = uci.ControlPacket.parse_all(packet) - uci_packet.show() - except Exception as exn: - pass - - print(f'--> {command_buffer}', end='', flush=True) - - -async def ainput(prompt: str = ''): - with ThreadPoolExecutor(1, 'ainput') as executor: - return (await asyncio.get_event_loop().run_in_executor(executor, input, prompt)).rstrip() - - -async def get_stream_reader(pipe) -> asyncio.StreamReader: - loop = asyncio.get_event_loop() - reader = asyncio.StreamReader(loop=loop) - protocol = asyncio.StreamReaderProtocol(reader) - await loop.connect_read_pipe(lambda: protocol, pipe) - return reader - - -async def command_line(device: Device): - commands = { - 'pica_get_state': device.pica_get_state, - 'pica_init_uci_device': device.pica_init_uci_device, - 'pica_create_anchor': device.pica_create_anchor, - 'pica_destroy_anchor': device.pica_destroy_anchor, - 'pica_set_position': device.pica_set_position, - 'device_reset': device.device_reset, - 'get_device_info': device.get_device_info, - 'get_config': device.get_config, - 'set_config': device.set_config, - 'get_caps_info': device.get_caps_info, - 'session_init': device.session_init, - 'session_deinit': device.session_deinit, - 'session_set_app_config': device.session_set_app_config, - 'session_get_app_config': device.session_get_app_config, - 'session_get_count': device.session_get_count, - 'session_get_state': device.session_get_state, - 'session_update_controller_multicast_list': device.session_update_controller_multicast_list, - 'range_start': device.range_start, - 'range_stop': device.range_stop, - 'get_ranging_count': device.get_ranging_count, - } - - def usage(): - for (cmd, func) in commands.items(): - print(f' {cmd.ljust(32)}{func.__doc__}') - - def complete(text, state): - tokens = readline.get_line_buffer().split() - if not tokens or readline.get_line_buffer()[-1] == ' ': - tokens.append('') - - # Writing a command name, complete to ' ' - if len(tokens) == 1: - results = [cmd + ' ' for cmd in commands.keys() if - cmd.startswith(text)] - - # Writing a keyword argument, no completion - elif '=' in tokens[-1]: - results = [] - - # Writing a keyword name, but unknown command, no completion - elif tokens[0] not in commands: - results = [] - - # Writing a keyword name, complete to '=' - else: - sig = inspect.signature(commands[tokens[0]]) - names = [name for (name, p) in sig.parameters.items() - if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD or - p.kind == inspect.Parameter.KEYWORD_ONLY)] - results = [ - name + '=' for name in names if name.startswith(tokens[-1])] - - results += [None] - return results[state] - - # Configure readline - readline.parse_and_bind("tab: complete") - readline.set_completer(complete) - - while True: - cmd = await ainput('--> ') - [cmd, *params] = cmd.split(' ') - args = [] - kargs = dict() - for param in params: - if len(param) == 0: - continue - elif '=' in param: - [key, value] = param.split('=') - kargs[key] = value - else: - args.append(param) - - if cmd in ['quit', 'q']: - break - if cmd not in commands: - print(f'Undefined command {cmd}') - usage() - continue - commands[cmd](*args, **kargs) - - -async def run(address: str, uci_port: int, http_port: int): - try: - # Connect to Pica - reader, writer = await asyncio.open_connection(address, uci_port) - except Exception as exn: - print( - f'Failed to connect to Pica server at address {address}:{uci_port}\n' + - 'Make sure the server is running') - exit(1) - - # Start input and receive loops - device = Device(reader, writer, f'http://{address}:{http_port}') - loop = asyncio.get_event_loop() - loop.create_task(device.read_responses_and_notifications()) - await command_line(device) - - -def main(): - """Start a Pica interactive console.""" - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument('--address', - type=str, - default='127.0.0.1', - help='Select the pica server address') - parser.add_argument('--uci-port', - type=int, - default=7000, - help='Select the pica TCP UCI port') - parser.add_argument('--http-port', - type=int, - default=3000, - help='Select the pica HTTP port') - asyncio.run(run(**vars(parser.parse_args()))) - - -if __name__ == '__main__': - main() diff --git a/scripts/ranging_example.py b/scripts/ranging_example.py deleted file mode 100755 index 497b2e9..0000000 --- a/scripts/ranging_example.py +++ /dev/null @@ -1,297 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import argparse -from pica import Host -from pica.packets import uci - -async def controller(host: Host, peer: Host): - await host.expect_control( - uci.DeviceStatusNtf(device_state=uci.DeviceState.DEVICE_STATE_READY)) - - host.send_control( - uci.DeviceResetCmd(reset_config=uci.ResetConfig.UWBS_RESET)) - - await host.expect_control( - uci.DeviceResetRsp(status=uci.StatusCode.UCI_STATUS_OK)) - - host.send_control( - uci.SessionInitCmd( - session_id=0, - session_type=uci.SessionType.FIRA_RANGING_SESSION)) - - await host.expect_control( - uci.SessionInitRsp( - status=uci.StatusCode.UCI_STATUS_OK)) - - await host.expect_control( - uci.SessionStatusNtf( - session_token=0, - session_state=uci.SessionState.SESSION_STATE_INIT, - reason_code=0)) - - mac_address_mode = 0x0 - ranging_duration = int(1000).to_bytes(4, byteorder='little') - device_role_initiator = bytes([0]) - device_type_controller = bytes([1]) - host.send_control( - uci.SessionSetAppConfigCmd( - session_token=0, - tlvs=[ - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.DEVICE_ROLE, - v=device_role_initiator), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.DEVICE_TYPE, - v=device_type_controller), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.DEVICE_MAC_ADDRESS, - v=host.mac_address), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.MAC_ADDRESS_MODE, - v=bytes([mac_address_mode])), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.RANGING_DURATION, - v=ranging_duration), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.NO_OF_CONTROLEE, - v=bytes([1])), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.DST_MAC_ADDRESS, - v=peer.mac_address), - ])) - - await host.expect_control( - uci.SessionSetAppConfigRsp( - status=uci.StatusCode.UCI_STATUS_OK, - cfg_status=[])) - - await host.expect_control( - uci.SessionStatusNtf( - session_token=0, - session_state=uci.SessionState.SESSION_STATE_IDLE, - reason_code=0)) - - host.send_control( - uci.SessionStartCmd( - session_id=0)) - - await host.expect_control( - uci.SessionStartRsp( - status=uci.StatusCode.UCI_STATUS_OK)) - - await host.expect_control( - uci.SessionStatusNtf( - session_token=0, - session_state=uci.SessionState.SESSION_STATE_ACTIVE, - reason_code=0)) - - await host.expect_control( - uci.DeviceStatusNtf( - device_state=uci.DeviceState.DEVICE_STATE_ACTIVE)) - - for n in range(1, 3): - event = await host.expect_control( - uci.ShortMacTwoWaySessionInfoNtf, - timeout=2.0) - event.show() - - host.send_control( - uci.SessionStopCmd( - session_id=0)) - - await host.expect_control( - uci.SessionStopRsp( - status=uci.StatusCode.UCI_STATUS_OK)) - - await host.expect_control( - uci.SessionStatusNtf( - session_token=0, - session_state=uci.SessionState.SESSION_STATE_IDLE, - reason_code=0)) - - await host.expect_control( - uci.DeviceStatusNtf( - device_state=uci.DeviceState.DEVICE_STATE_READY)) - - host.send_control( - uci.SessionDeinitCmd( - session_token=0)) - - await host.expect_control( - uci.SessionDeinitRsp( - status=uci.StatusCode.UCI_STATUS_OK)) - - -async def controlee(host: Host, peer: Host): - await host.expect_control( - uci.DeviceStatusNtf(device_state=uci.DeviceState.DEVICE_STATE_READY)) - - host.send_control( - uci.DeviceResetCmd(reset_config=uci.ResetConfig.UWBS_RESET)) - - await host.expect_control( - uci.DeviceResetRsp(status=uci.StatusCode.UCI_STATUS_OK)) - - host.send_control( - uci.SessionInitCmd( - session_id=0, - session_type=uci.SessionType.FIRA_RANGING_SESSION)) - - await host.expect_control( - uci.SessionInitRsp( - status=uci.StatusCode.UCI_STATUS_OK)) - - await host.expect_control( - uci.SessionStatusNtf( - session_token=0, - session_state=uci.SessionState.SESSION_STATE_INIT, - reason_code=0)) - - mac_address_mode = 0x0 - ranging_duration = int(1000).to_bytes(4, byteorder='little') - device_role_responder = bytes([1]) - device_type_controlee = bytes([0]) - host.send_control( - uci.SessionSetAppConfigCmd( - session_token=0, - tlvs=[ - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.DEVICE_ROLE, - v=device_role_responder), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.DEVICE_TYPE, - v=device_type_controlee), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.DEVICE_MAC_ADDRESS, - v=host.mac_address), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.MAC_ADDRESS_MODE, - v=bytes([mac_address_mode])), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.RANGING_DURATION, - v=ranging_duration), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.NO_OF_CONTROLEE, - v=bytes([1])), - uci.AppConfigTlv( - cfg_id=uci.AppConfigTlvType.DST_MAC_ADDRESS, - v=peer.mac_address), - ])) - - await host.expect_control( - uci.SessionSetAppConfigRsp( - status=uci.StatusCode.UCI_STATUS_OK, - cfg_status=[])) - - await host.expect_control( - uci.SessionStatusNtf( - session_token=0, - session_state=uci.SessionState.SESSION_STATE_IDLE, - reason_code=0)) - - host.send_control( - uci.SessionStartCmd( - session_id=0)) - - await host.expect_control( - uci.SessionStartRsp( - status=uci.StatusCode.UCI_STATUS_OK)) - - await host.expect_control( - uci.SessionStatusNtf( - session_token=0, - session_state=uci.SessionState.SESSION_STATE_ACTIVE, - reason_code=0)) - - await host.expect_control( - uci.DeviceStatusNtf( - device_state=uci.DeviceState.DEVICE_STATE_ACTIVE)) - - for n in range(1, 3): - event = await host.expect_control( - uci.ShortMacTwoWaySessionInfoNtf, - timeout=2.0) - event.show() - - host.send_control( - uci.SessionStopCmd( - session_id=0)) - - await host.expect_control( - uci.SessionStopRsp( - status=uci.StatusCode.UCI_STATUS_OK)) - - await host.expect_control( - uci.SessionStatusNtf( - session_token=0, - session_state=uci.SessionState.SESSION_STATE_IDLE, - reason_code=0)) - - await host.expect_control( - uci.DeviceStatusNtf( - device_state=uci.DeviceState.DEVICE_STATE_READY)) - - host.send_control( - uci.SessionDeinitCmd( - session_token=0)) - - await host.expect_control( - uci.SessionDeinitRsp( - status=uci.StatusCode.UCI_STATUS_OK)) - - -async def run(address: str, uci_port: int, http_port: int): - try: - host0 = await Host.connect(address, uci_port, bytes([0, 1])) - host1 = await Host.connect(address, uci_port, bytes([0, 2])) - except Exception as exn: - print( - f'Failed to connect to Pica server at address {address}:{uci_port}\n' + - 'Make sure the server is running') - exit(1) - - async with asyncio.TaskGroup() as tg: - task0 = tg.create_task(controller(host0, host1)) - task1 = tg.create_task(controlee(host1, host0)) - - host0.disconnect() - host1.disconnect() - - print('Ranging test completed') - - -def main(): - """Start a Pica interactive console.""" - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument('--address', - type=str, - default='127.0.0.1', - help='Select the pica server address') - parser.add_argument('--uci-port', - type=int, - default=7000, - help='Select the pica TCP UCI port') - parser.add_argument('--http-port', - type=int, - default=3000, - help='Select the pica HTTP port') - asyncio.run(run(**vars(parser.parse_args()))) - - -if __name__ == '__main__': - main() diff --git a/src/bin/server/mod.rs b/src/bin/server/mod.rs index bfa0bdb..616ef5b 100644 --- a/src/bin/server/mod.rs +++ b/src/bin/server/mod.rs @@ -26,7 +26,7 @@ use pica::{Pica, PicaCommand}; use std::net::{Ipv4Addr, SocketAddrV4}; use std::path::PathBuf; use tokio::net::TcpListener; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::mpsc; use tokio::try_join; const DEFAULT_UCI_PORT: u16 = 7000; @@ -67,16 +67,16 @@ async fn main() -> Result<()> { args.uci_port, args.web_port, "UCI port and Web port shall be different." ); - let (event_tx, _) = broadcast::channel(16); - let mut pica = Pica::new(event_tx.clone(), args.pcapng_dir); + let mut pica = Pica::new(args.pcapng_dir); let pica_tx = pica.tx(); + let pica_events = pica.events(); #[cfg(feature = "web")] try_join!( accept_incoming(pica_tx.clone(), args.uci_port), pica.run(), - web::serve(pica_tx, event_tx, args.web_port) + web::serve(pica_tx, pica_events, args.web_port) )?; #[cfg(not(feature = "web"))] diff --git a/src/bin/server/web.rs b/src/bin/server/web.rs index 0af876f..1481e54 100644 --- a/src/bin/server/web.rs +++ b/src/bin/server/web.rs @@ -120,7 +120,7 @@ fn event_name(event: &PicaEvent) -> &'static str { async fn handle( mut req: Request<Body>, tx: mpsc::Sender<PicaCommand>, - events: broadcast::Sender<PicaEvent>, + events: broadcast::Receiver<PicaEvent>, ) -> Result<Response<Body>, Infallible> { let static_file = STATIC_FILES .iter() @@ -168,7 +168,7 @@ async fn handle( .collect::<Vec<_>>()[..] { ["events"] => { - let stream = BroadcastStream::new(events.subscribe()).map(|result| { + let stream = BroadcastStream::new(events).map(|result| { result.map(|event| { format!( "event: {}\ndata: {}\n\n", @@ -256,7 +256,7 @@ pub async fn serve( let events = events.clone(); async move { Ok::<_, Infallible>(service_fn(move |req| { - handle(req, tx.clone(), events.clone()) + handle(req, tx.clone(), events.subscribe()) })) } }); diff --git a/src/device.rs b/src/device.rs index b09a493..e1cbfcb 100644 --- a/src/device.rs +++ b/src/device.rs @@ -27,7 +27,7 @@ use tokio::time; use super::session::{Session, MAX_SESSION}; pub const MAX_DEVICE: usize = 4; -const UCI_VERSION: u16 = 0x1001; // Version 1.1.0 +const UCI_VERSION: u16 = 0x0002; // Version 2.0 const MAC_VERSION: u16 = 0x3001; // Version 1.3.0 const PHY_VERSION: u16 = 0x3001; // Version 1.3.0 const TEST_VERSION: u16 = 0x1001; // Version 1.1 @@ -42,7 +42,7 @@ pub const DEFAULT_CAPS_INFO: &[(CapTlvType, &[u8])] = &[ (CapTlvType::SupportedFiraMacVersionRange, &[1, 1, 1, 3]), // 1.1 - 1.3 (CapTlvType::SupportedDeviceRoles, &[0x3]), // INTIATOR | RESPONDER (CapTlvType::SupportedRangingMethod, &[0x1f]), // DS_TWR_NON_DEFERRED | SS_TWR_NON_DEFERRED | DS_TWR_DEFERRED | SS_TWR_DEFERRED | OWR - (CapTlvType::SupportedStsConfig, &[0x7]), // STATIC_STS | DYNAMIC_STS | DYNAMIC_STS_RESPONDER_SPECIFIC_SUBSESSION_KEY + (CapTlvType::SupportedStsConfig, &[0x1f]), // STATIC_STS | DYNAMIC_STS | DYNAMIC_STS_RESPONDER_SPECIFIC_SUBSESSION_KEY | PROVISIONED_STS | PROVISIONED_STS_RESPONDER_SPECIFIC_SUBSESSION_KEY (CapTlvType::SupportedMultiNodeModes, &[0xff]), (CapTlvType::SupportedRangingTimeStruct, &[0x01]), // Block Based Scheduling (default) (CapTlvType::SupportedScheduledMode, &[0x01]), // Time scheduled ranging (default) @@ -358,6 +358,41 @@ impl Device { .build() } + pub fn data_message_snd(&mut self, data: DataPacket) -> SessionControlNotification { + match data.specialize() { + DataPacketChild::DataMessageSnd(data_msg_snd) => { + let session_token = data_msg_snd.get_session_handle(); + if let Some(session) = self.get_session_mut(session_token) { + session.data_message_snd(data_msg_snd) + } else { + DataTransferStatusNtfBuilder { + session_token, + status: DataTransferNtfStatusCode::UciDataTransferStatusErrorRejected, + tx_count: 1, // TODO: support for retries? + uci_sequence_number: 0, + } + .build() + .into() + } + } + DataPacketChild::DataMessageRcv(data_msg_rcv) => { + // This function should not be passed anything besides DataMessageSnd + let session_token = data_msg_rcv.get_session_handle(); + DataTransferStatusNtfBuilder { + session_token, + status: DataTransferNtfStatusCode::UciDataTransferStatusInvalidFormat, + tx_count: 1, // TODO: support for retries? + uci_sequence_number: 0, + } + .build() + .into() + } + _ => { + unimplemented!() + } + } + } + pub fn command(&mut self, cmd: UciCommand) -> UciResponse { match cmd.specialize() { // Handle commands for this device @@ -20,7 +20,6 @@ use std::collections::HashMap; use std::fmt::Display; use std::path::PathBuf; use thiserror::Error; -use tokio::io::AsyncReadExt; use tokio::net::TcpStream; use tokio::sync::{broadcast, mpsc, oneshot}; @@ -45,104 +44,6 @@ pub use mac_address::MacAddress; use crate::session::RangeDataNtfConfig; -/// Size of UCI packet headers. -const HEADER_SIZE: usize = 4; -/// Maximum size of an UCI packet payload. -const MAX_PAYLOAD_SIZE: usize = 255; - -struct Connection { - socket: TcpStream, - pcapng_file: Option<pcapng::File>, -} - -impl Connection { - fn new(socket: TcpStream, pcapng_file: Option<pcapng::File>) -> Self { - Connection { - socket, - pcapng_file, - } - } - - /// Read a single UCI packet from the socket. The packet is automatically - /// re-assembled if segmented on the UCI transport. - async fn read(&mut self) -> Result<Vec<u8>> { - let mut complete_packet = vec![0; HEADER_SIZE]; - - // Note on reassembly: - // For each segment of a Control Message, the - // header of the Control Packet SHALL contain the same MT, GID and OID - // values. - // It is correct to keep only the last header of the segmented packet. - loop { - // Read the common packet header. - self.socket - .read_exact(&mut complete_packet[0..HEADER_SIZE]) - .await?; - let header = PacketHeader::parse(&complete_packet[0..HEADER_SIZE])?; - - // Read the packet payload. - let payload_length = header.get_payload_length() as usize; - let mut payload_bytes = vec![0; payload_length]; - self.socket.read_exact(&mut payload_bytes).await?; - complete_packet.extend(&payload_bytes); - - if let Some(ref mut pcapng_file) = self.pcapng_file { - let mut packet_bytes = vec![]; - packet_bytes.extend(&complete_packet[0..HEADER_SIZE]); - packet_bytes.extend(&payload_bytes); - pcapng_file - .write(&packet_bytes, pcapng::Direction::Tx) - .await?; - } - - // Check the Packet Boundary Flag. - match header.get_pbf() { - PacketBoundaryFlag::Complete => return Ok(complete_packet), - PacketBoundaryFlag::NotComplete => (), - } - } - } - - /// Write a single UCI packet to the writer. The packet is automatically - /// segmented if the payload exceeds the maximum size limit. - async fn write(&mut self, mut packet: &[u8]) -> Result<()> { - let mut header_bytes = [packet[0], packet[1], packet[2], 0]; - packet = &packet[HEADER_SIZE..]; - - loop { - // Update header with framing information. - let chunk_length = std::cmp::min(MAX_PAYLOAD_SIZE, packet.len()); - let pbf = if chunk_length < packet.len() { - PacketBoundaryFlag::NotComplete - } else { - PacketBoundaryFlag::Complete - }; - const PBF_MASK: u8 = 0x10; - header_bytes[0] &= !PBF_MASK; - header_bytes[0] |= (pbf as u8) << 4; - header_bytes[3] = chunk_length as u8; - - if let Some(ref mut pcapng_file) = self.pcapng_file { - let mut packet_bytes = vec![]; - packet_bytes.extend(&header_bytes); - packet_bytes.extend(&packet[..chunk_length]); - pcapng_file - .write(&packet_bytes, pcapng::Direction::Rx) - .await? - } - - // Write the header and payload segment bytes. - self.socket.try_write(&header_bytes)?; - self.socket.try_write(&packet[..chunk_length])?; - packet = &packet[chunk_length..]; - - if packet.is_empty() { - return Ok(()); - } - } - } -} - pub type PicaCommandStatus = Result<(), PicaCommandError>; #[derive(Error, Debug, Clone, PartialEq, Eq)] @@ -163,8 +64,10 @@ pub enum PicaCommand { Ranging(usize, u32), // Send an in-band request to stop ranging to a peer controlee identified by address and session id. StopRanging(MacAddress, u32), + // Execute data message send for selected device and data. + UciData(usize, DataPacket), // Execute UCI command received for selected device. - Command(usize, UciCommand), + UciCommand(usize, UciCommand), // Init Uci Device InitUciDevice(MacAddress, Position, oneshot::Sender<PicaCommandStatus>), // Set Position @@ -184,7 +87,8 @@ impl Display for PicaCommand { PicaCommand::Disconnect(_) => "Disconnect", PicaCommand::Ranging(_, _) => "Ranging", PicaCommand::StopRanging(_, _) => "StopRanging", - PicaCommand::Command(_, _) => "Command", + PicaCommand::UciData(_, _) => "UciData", + PicaCommand::UciCommand(_, _) => "UciCommand", PicaCommand::InitUciDevice(_, _, _) => "InitUciDevice", PicaCommand::SetPosition(_, _, _) => "SetPosition", PicaCommand::CreateAnchor(_, _, _) => "CreateAnchor", @@ -252,7 +156,8 @@ pub struct Pica { /// Result of UCI packet parsing. enum UciParseResult { - Ok(UciCommand), + UciCommand(UciCommand), + UciData(DataPacket), Err(Bytes), Skip, } @@ -260,46 +165,51 @@ enum UciParseResult { /// Parse incoming UCI packets. /// Handle parsing errors by crafting a suitable error response packet. fn parse_uci_packet(bytes: &[u8]) -> UciParseResult { - match ControlPacket::parse(bytes) { - // Parsing error. Determine what error response should be - // returned to the host: - // - response and notifications are ignored, no response - // - if the group id is not known, STATUS_UNKNOWN_GID, - // - otherwise, and to simplify the code, STATUS_UNKNOWN_OID is - // always returned. That means that malformed commands - // get the same status code, instead of - // STATUS_SYNTAX_ERROR. - Err(_) => { - let message_type = (bytes[0] >> 5) & 0x7; - let group_id = bytes[0] & 0xf; - let opcode_id = bytes[1] & 0x3f; - - let status = match ( - MessageType::try_from(message_type), - GroupId::try_from(group_id), - ) { - (Ok(MessageType::Command), Ok(_)) => UciStatusCode::UciStatusUnknownOid, - (Ok(MessageType::Command), Err(_)) => UciStatusCode::UciStatusUnknownGid, - _ => return UciParseResult::Skip, - }; - // The PDL generated code cannot be used to generate - // responses with invalid group identifiers. - let response = vec![ - (u8::from(MessageType::Response) << 5) | group_id, - opcode_id, - 0, - 1, - status.into(), - ]; - UciParseResult::Err(response.into()) - } + let message_type = parse_message_type(bytes[0]); + match message_type { + MessageType::Data => match DataPacket::parse(bytes) { + Ok(packet) => UciParseResult::UciData(packet), + Err(_) => UciParseResult::Skip, + }, + _ => { + match ControlPacket::parse(bytes) { + // Parsing error. Determine what error response should be + // returned to the host: + // - response and notifications are ignored, no response + // - if the group id is not known, STATUS_UNKNOWN_GID, + // - otherwise, and to simplify the code, STATUS_UNKNOWN_OID is + // always returned. That means that malformed commands + // get the same status code, instead of + // STATUS_SYNTAX_ERROR. + Err(_) => { + let group_id = bytes[0] & 0xf; + let opcode_id = bytes[1] & 0x3f; + + let status = match (message_type, GroupId::try_from(group_id)) { + (MessageType::Command, Ok(_)) => UciStatusCode::UciStatusUnknownOid, + (MessageType::Command, Err(_)) => UciStatusCode::UciStatusUnknownGid, + _ => return UciParseResult::Skip, + }; + // The PDL generated code cannot be used to generate + // responses with invalid group identifiers. + let response = vec![ + (u8::from(MessageType::Response) << 5) | group_id, + opcode_id, + 0, + 1, + status.into(), + ]; + UciParseResult::Err(response.into()) + } - // Parsing success, ignore non command packets. - Ok(packet) => { - if let Ok(cmd) = packet.try_into() { - UciParseResult::Ok(cmd) - } else { - UciParseResult::Skip + // Parsing success, ignore non command packets. + Ok(packet) => { + if let Ok(cmd) = packet.try_into() { + UciParseResult::UciCommand(cmd) + } else { + UciParseResult::Skip + } + } } } } @@ -333,8 +243,9 @@ fn make_measurement( } impl Pica { - pub fn new(event_tx: broadcast::Sender<PicaEvent>, pcapng_dir: Option<PathBuf>) -> Self { + pub fn new(pcapng_dir: Option<PathBuf>) -> Self { let (tx, rx) = mpsc::channel(MAX_SESSION * MAX_DEVICE); + let (event_tx, _) = broadcast::channel(16); Pica { devices: HashMap::new(), anchors: HashMap::new(), @@ -346,6 +257,10 @@ impl Pica { } } + pub fn events(&self) -> broadcast::Sender<PicaEvent> { + self.event_tx.clone() + } + pub fn tx(&self) -> mpsc::Sender<PicaCommand> { self.tx.clone() } @@ -439,8 +354,8 @@ impl Pica { // Spawn and detach the connection handling task. // The task notifies pica when exiting to let it clean // the state. - tokio::spawn(async move { - let pcapng_file: Option<pcapng::File> = if let Some(dir) = pcapng_dir { + tokio::task::spawn(async move { + let mut pcapng_file = if let Some(dir) = pcapng_dir { let full_path = dir.join(format!("device-{}.pcapng", device_handle)); println!("Recording pcapng to file {}", full_path.as_path().display()); Some(pcapng::File::create(full_path).await.unwrap()) @@ -448,19 +363,26 @@ impl Pica { None }; - let mut connection = Connection::new(stream, pcapng_file); + let (uci_rx, uci_tx) = stream.into_split(); + let mut uci_reader = packets::uci::Reader::new(uci_rx); + let mut uci_writer = packets::uci::Writer::new(uci_tx); + 'outer: loop { tokio::select! { // Read command packet sent from connected UWB host. // Run associated command. - result = connection.read() => + result = uci_reader.read(&mut pcapng_file) => match result { Ok(packet) => match parse_uci_packet(&packet) { - UciParseResult::Ok(cmd) => - pica_tx.send(PicaCommand::Command(device_handle, cmd)).await.unwrap(), + UciParseResult::UciCommand(cmd) => { + pica_tx.send(PicaCommand::UciCommand(device_handle, cmd)).await.unwrap() + }, + UciParseResult::UciData(data) => { + pica_tx.send(PicaCommand::UciData(device_handle, data)).await.unwrap() + }, UciParseResult::Err(response) => - connection.write(&response).await.unwrap(), + uci_writer.write(&response, &mut pcapng_file).await.unwrap(), UciParseResult::Skip => (), }, Err(_) => break 'outer @@ -468,7 +390,7 @@ impl Pica { // Send response packets to the connected UWB host. Some(packet) = packet_rx.recv() => - if connection.write(&packet.to_bytes()).await.is_err() { + if uci_writer.write(&packet.to_bytes(), &mut pcapng_file).await.is_err() { break 'outer } } @@ -562,6 +484,20 @@ impl Pica { } } + async fn uci_data(&mut self, device_handle: usize, data: DataPacket) { + match self + .get_device_mut(device_handle) + .ok_or_else(|| PicaCommandError::DeviceNotFound(device_handle.into())) + { + Ok(device) => { + let response: SessionControlNotification = device.data_message_snd(data); + device.tx.send(response.into()).await.unwrap_or_else(|err| { + println!("Failed to send UCI data packet response: {}", err) + }); + } + Err(err) => println!("{}", err), + } + } async fn command(&mut self, device_handle: usize, cmd: UciCommand) { match self .get_device_mut(device_handle) @@ -593,7 +529,8 @@ impl Pica { Some(StopRanging(mac_address, session_id)) => { self.stop_controlee_ranging(&mac_address, session_id).await; } - Some(Command(device_handle, cmd)) => self.command(device_handle, cmd).await, + Some(UciData(device_handle, data)) => self.uci_data(device_handle, data).await, + Some(UciCommand(device_handle, cmd)) => self.command(device_handle, cmd).await, Some(SetPosition(mac_address, position, pica_cmd_rsp_tx)) => { self.set_position(mac_address, position, pica_cmd_rsp_tx) } diff --git a/src/packets.rs b/src/packets.rs index 118570d..f0c10b3 100644 --- a/src/packets.rs +++ b/src/packets.rs @@ -24,4 +24,167 @@ pub mod uci { #![allow(missing_docs)] include!(concat!(env!("OUT_DIR"), "/uci_packets.rs")); + + /// Size of common UCI packet header. + pub const COMMON_HEADER_SIZE: usize = 1; + /// Size of UCI packet headers. + pub const HEADER_SIZE: usize = 4; + /// Maximum size of an UCI control packet payload. + pub const MAX_CTRL_PACKET_PAYLOAD_SIZE: usize = 255; + /// Maximum size of an UCI data packet payload. + pub const MAX_DATA_PACKET_PAYLOAD_SIZE: usize = 1024; + + // Extract the message type from the first 3 bits of the passed (header) byte + pub fn parse_message_type(byte: u8) -> MessageType { + MessageType::try_from((byte >> 5) & 0x7).unwrap_or(MessageType::Command) + } + + use crate::pcapng; + use std::pin::Pin; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::sync::Mutex; + + /// Read UCI Control and Data packets received on the UCI transport. + /// Performs recombination of the segmented packets. + pub struct Reader { + socket: Pin<Box<dyn AsyncRead + Send>>, + } + + /// Write UCI Control and Data packets received to the UCI transport. + /// Performs segmentation of the packets. + pub struct Writer { + socket: Pin<Box<dyn AsyncWrite + Send>>, + } + + impl Reader { + /// Create an UCI reader from an UCI transport. + pub fn new<T: AsyncRead + Send + 'static>(rx: T) -> Self { + Reader { + socket: Box::pin(rx), + } + } + + /// Read a single UCI packet from the reader. The packet is automatically + /// re-assembled if segmented on the UCI transport. Data segments + /// are _not_ re-assembled but returned immediatly for credit + /// acknowledgment. + pub async fn read(&mut self, pcapng: &mut Option<pcapng::File>) -> anyhow::Result<Vec<u8>> { + use tokio::io::AsyncReadExt; + + let mut complete_packet = vec![0; HEADER_SIZE]; + + // Note on reassembly: + // For each segment of a Control Message, the + // header of the Control Packet SHALL contain the same MT, GID and OID + // values. It is correct to keep only the last header of the segmented packet. + loop { + // Read the common packet header. + self.socket + .read_exact(&mut complete_packet[0..HEADER_SIZE]) + .await?; + let common_packet_header = + PacketHeader::parse(&complete_packet[0..COMMON_HEADER_SIZE])?; + + // Read the packet payload. + let payload_length = match common_packet_header.get_mt() { + MessageType::Data => { + let data_packet_header = + DataPacketHeader::parse(&complete_packet[0..HEADER_SIZE])?; + data_packet_header.get_payload_length() as usize + } + _ => { + let control_packet_header = + ControlPacketHeader::parse(&complete_packet[0..HEADER_SIZE])?; + control_packet_header.get_payload_length() as usize + } + }; + let mut payload_bytes = vec![0; payload_length]; + self.socket.read_exact(&mut payload_bytes).await?; + complete_packet.extend(&payload_bytes); + + if let Some(ref mut pcapng) = pcapng { + let mut packet_bytes = vec![]; + packet_bytes.extend(&complete_packet[0..HEADER_SIZE]); + packet_bytes.extend(&payload_bytes); + pcapng.write(&packet_bytes, pcapng::Direction::Tx).await?; + } + + if common_packet_header.get_mt() == MessageType::Data { + return Ok(complete_packet); + } + + // Check the Packet Boundary Flag. + match common_packet_header.get_pbf() { + PacketBoundaryFlag::Complete => return Ok(complete_packet), + PacketBoundaryFlag::NotComplete => (), + } + } + } + } + + impl Writer { + /// Create an UCI writer from an UCI transport. + pub fn new<T: AsyncWrite + Send + 'static>(rx: T) -> Self { + Writer { + socket: Box::pin(rx), + } + } + + /// Write a single UCI packet to the writer. The packet is automatically + /// segmented if the payload exceeds the maximum size limit. + pub async fn write( + &mut self, + mut packet: &[u8], + pcapng: &mut Option<pcapng::File>, + ) -> anyhow::Result<()> { + use tokio::io::AsyncWriteExt; + + let mut header_bytes = [packet[0], packet[1], packet[2], 0]; + packet = &packet[HEADER_SIZE..]; + + loop { + let message_type = parse_message_type(header_bytes[0]); + let chunk_length = std::cmp::min( + packet.len(), + match message_type { + MessageType::Data => MAX_DATA_PACKET_PAYLOAD_SIZE, + _ => MAX_CTRL_PACKET_PAYLOAD_SIZE, + }, + ); + // Update header with framing information. + let pbf = if chunk_length < packet.len() { + PacketBoundaryFlag::NotComplete + } else { + PacketBoundaryFlag::Complete + }; + const PBF_MASK: u8 = 0x10; + header_bytes[0] &= !PBF_MASK; + header_bytes[0] |= (pbf as u8) << 4; + + match message_type { + MessageType::Data => { + let chunk_le_bytes = (chunk_length as u16).to_le_bytes(); + header_bytes[2..4].copy_from_slice(&chunk_le_bytes); + } + _ => header_bytes[3] = chunk_length as u8, + } + + if let Some(ref mut pcapng) = pcapng { + let mut packet_bytes = vec![]; + packet_bytes.extend(&header_bytes); + packet_bytes.extend(&packet[..chunk_length]); + pcapng.write(&packet_bytes, pcapng::Direction::Rx).await? + } + + // Write the header and payload segment bytes. + self.socket.write_all(&header_bytes).await?; + self.socket.write_all(&packet[..chunk_length]).await?; + packet = &packet[chunk_length..]; + + if packet.is_empty() { + return Ok(()); + } + } + } + } } diff --git a/src/pcapng.rs b/src/pcapng.rs index fc2d882..3119bb1 100644 --- a/src/pcapng.rs +++ b/src/pcapng.rs @@ -60,28 +60,20 @@ impl File { let packet_data_padding: usize = 4 - packet.len() % 4; let block_total_length: u32 = packet.len() as u32 + packet_data_padding as u32 + 32; let timestamp = self.start_time.elapsed().as_micros(); + let file = &mut self.file; // Wrap the packet inside an Enhanced Packet Block. - self.file.write(&u32::to_le_bytes(0x00000006)).await?; // Block Type - self.file - .write(&u32::to_le_bytes(block_total_length)) - .await?; - self.file.write(&u32::to_le_bytes(0)).await?; // Interface ID - self.file - .write(&u32::to_le_bytes((timestamp >> 32) as u32)) + file.write(&u32::to_le_bytes(0x00000006)).await?; // Block Type + file.write(&u32::to_le_bytes(block_total_length)).await?; + file.write(&u32::to_le_bytes(0)).await?; // Interface ID + file.write(&u32::to_le_bytes((timestamp >> 32) as u32)) .await?; // Timestamp (High) - self.file.write(&u32::to_le_bytes(timestamp as u32)).await?; // Timestamp (Low) - self.file - .write(&u32::to_le_bytes(packet.len() as u32)) - .await?; // Captured Packet Length - self.file - .write(&u32::to_le_bytes(packet.len() as u32)) - .await?; // Original Packet Length - self.file.write(packet).await?; - self.file.write(&vec![0; packet_data_padding]).await?; - self.file - .write(&u32::to_le_bytes(block_total_length)) - .await?; // Block Total Length + file.write(&u32::to_le_bytes(timestamp as u32)).await?; // Timestamp (Low) + file.write(&u32::to_le_bytes(packet.len() as u32)).await?; // Captured Packet Length + file.write(&u32::to_le_bytes(packet.len() as u32)).await?; // Original Packet Length + file.write(packet).await?; + file.write(&vec![0; packet_data_padding]).await?; + file.write(&u32::to_le_bytes(block_total_length)).await?; // Block Total Length Ok(()) } } diff --git a/src/session.rs b/src/session.rs index c5a05d6..8425b1b 100644 --- a/src/session.rs +++ b/src/session.rs @@ -16,7 +16,7 @@ //! - [MAC] FiRa Consortium UWB MAC Technical Requirements //! - [UCI] FiRa Consortium UWB Command Interface Generic Technical specification -use crate::packets::uci::*; +use crate::packets::uci::{self, *}; use crate::{MacAddress, PicaCommand}; use std::collections::HashMap; use std::time::Duration; @@ -32,6 +32,8 @@ pub const DEFAULT_RANGING_INTERVAL: Duration = time::Duration::from_millis(200); pub const DEFAULT_SLOT_DURATION: u16 = 2400; // RTSU unit /// cf. [UCI] 8.3 Table 29 pub const MAX_NUMBER_OF_CONTROLEES: usize = 8; +pub const FIRA_1_1_INITIATION_TIME_SIZE: usize = 4; +pub const FIRA_2_0_INITIATION_TIME_SIZE: usize = 8; #[derive(Copy, Clone, FromPrimitive, PartialEq, Eq)] pub enum DeviceType { @@ -93,6 +95,8 @@ enum MultiNodeMode { enum UpdateMulticastListAction { Add = 0x00, Delete = 0x01, + AddWithShortSubSessionKey = 0x02, + AddwithExtendedSubSessionKey = 0x03, } #[derive(Copy, Clone, FromPrimitive, ToPrimitive, PartialEq)] @@ -233,6 +237,27 @@ pub enum RangeDataNtfConfig { EnableAoaEdgeTrig = 0x06, EnableProximityAoaEdgeTrig = 0x07, } + +#[derive(Copy, Clone, FromPrimitive, ToPrimitive, PartialEq)] +#[repr(u8)] +pub enum LinkLayerMode { + Bypass = 0x00, + Assigned = 0x01, +} + +#[derive(Copy, Clone, FromPrimitive, ToPrimitive, PartialEq)] +#[repr(u8)] +pub enum DataRepetitionCount { + NoRepetition = 0x00, + Infinite = 0xFF, +} + +#[derive(Copy, Clone, FromPrimitive, ToPrimitive, PartialEq)] +#[repr(u8)] +pub enum SessionDataTransferStatusNtfConfig { + Disable = 0x00, + Enable = 0x01, +} /// cf. [UCI] 8.3 Table 29 #[derive(Clone)] pub struct AppConfig { @@ -280,9 +305,16 @@ pub struct AppConfig { bprf_phr_data_rate: BprfPhrDataRate, max_number_of_measurements: u8, sts_length: StsLength, - uwb_initiation_time: u32, + uwb_initiation_time: u64, vendor_id: Option<Vec<u8>>, static_sts_iv: Option<Vec<u8>>, + session_key: Option<Vec<u8>>, + sub_session_key: Option<Vec<u8>>, + sub_session_id: u32, + link_layer_mode: LinkLayerMode, + data_repetition_count: DataRepetitionCount, + session_data_transfer_status_ntf_config: SessionDataTransferStatusNtfConfig, + application_data_endpoint: u8, } impl Default for AppConfig { @@ -333,6 +365,13 @@ impl Default for AppConfig { uwb_initiation_time: 0, vendor_id: None, static_sts_iv: None, + session_key: None, + sub_session_key: None, + sub_session_id: 0, + link_layer_mode: LinkLayerMode::Bypass, + data_repetition_count: DataRepetitionCount::NoRepetition, + session_data_transfer_status_ntf_config: SessionDataTransferStatusNtfConfig::Disable, + application_data_endpoint: 0, } } } @@ -537,7 +576,16 @@ impl AppConfig { self.max_rr_retry = u16::from_le_bytes(value[..].try_into().unwrap()) } AppConfigTlvType::UwbInitiationTime => { - self.uwb_initiation_time = u32::from_le_bytes(value[..].try_into().unwrap()) + self.uwb_initiation_time = match value.len() { + // Backward compatible with Fira 1.1 Version UCI host. + FIRA_1_1_INITIATION_TIME_SIZE => { + u32::from_le_bytes(value[..].try_into().unwrap()) as u64 + } + FIRA_2_0_INITIATION_TIME_SIZE => { + u64::from_le_bytes(value[..].try_into().unwrap()) + } + _ => panic!("Invalid initiation time!"), + } } AppConfigTlvType::HoppingMode => { self.hopping_mode = HoppingMode::from_u8(value[0]).unwrap() @@ -558,6 +606,22 @@ impl AppConfig { AppConfigTlvType::InBandTerminationAttemptCount => { self.in_band_termination_attempt_count = value[0] } + AppConfigTlvType::SessionKey => self.session_key = Some(value.to_vec()), + AppConfigTlvType::SubSessionId => { + self.sub_session_id = u32::from_le_bytes(value[..].try_into().unwrap()) + } + AppConfigTlvType::SubsessionKey => self.sub_session_key = Some(value.to_vec()), + AppConfigTlvType::LinkLayerMode => { + self.link_layer_mode = LinkLayerMode::from_u8(value[0]).unwrap() + } + AppConfigTlvType::DataRepetitionCount => { + self.data_repetition_count = DataRepetitionCount::from_u8(value[0]).unwrap() + } + AppConfigTlvType::SessionDataTransferStatusNtfConfig => { + self.session_data_transfer_status_ntf_config = + SessionDataTransferStatusNtfConfig::from_u8(value[0]).unwrap() + } + AppConfigTlvType::ApplicationDataEndpoint => self.application_data_endpoint = value[0], id => { println!("Ignored AppConfig parameter {:?}", id); return Err(StatusCode::UciStatusInvalidParam); @@ -605,6 +669,48 @@ impl AppConfig { } } +enum SubSessionKey { + None, + Short([u8; 16]), + Extended([u8; 32]), +} +struct Controlee { + short_address: MacAddress, + sub_session_id: u32, + #[allow(dead_code)] + session_key: SubSessionKey, +} + +impl From<&uci::Controlee> for Controlee { + fn from(value: &uci::Controlee) -> Self { + Controlee { + short_address: MacAddress::Short(value.short_address), + sub_session_id: value.subsession_id, + session_key: SubSessionKey::None, + } + } +} + +impl From<&uci::Controlee_V2_0_16_Byte_Version> for Controlee { + fn from(value: &uci::Controlee_V2_0_16_Byte_Version) -> Self { + Controlee { + short_address: MacAddress::Short(value.short_address), + sub_session_id: value.subsession_id, + session_key: SubSessionKey::Short(value.subsession_key), + } + } +} + +impl From<&uci::Controlee_V2_0_32_Byte_Version> for Controlee { + fn from(value: &uci::Controlee_V2_0_32_Byte_Version) -> Self { + Controlee { + short_address: MacAddress::Short(value.short_address), + sub_session_id: value.subsession_id, + session_key: SubSessionKey::Extended(value.subsession_key), + } + } +} + pub struct Session { /// cf. [UCI] 7.1 pub state: SessionState, @@ -693,7 +799,12 @@ impl Session { self.device_handle, self.id ); assert_eq!(self.id, cmd.get_session_token()); - assert_eq!(self.session_type, SessionType::FiraRangingSession); + assert!( + self.session_type.eq(&SessionType::FiraRangingSession) + || self + .session_type + .eq(&SessionType::FiraRangingAndInBandDataSession) + ); if self.state == SessionState::SessionStateActive { const IMMUTABLE_PARAMETERS: &[AppConfigTlvType] = &[AppConfigTlvType::AoaResultReq]; @@ -810,29 +921,97 @@ impl Session { } let action = UpdateMulticastListAction::from_u8(cmd.get_action().into()).unwrap(); let mut dst_addresses = self.app_config.dst_mac_addresses.clone(); - let packet = - SessionUpdateControllerMulticastListCmdPayload::parse(cmd.get_payload()).unwrap(); - let new_controlees = packet.controlees; + let new_controlees: Vec<Controlee> = match action { + UpdateMulticastListAction::Add | UpdateMulticastListAction::Delete => { + if let Ok(packet) = + SessionUpdateControllerMulticastListCmdPayload::parse(cmd.get_payload()) + { + packet + .controlees + .iter() + .map(|controlee| controlee.into()) + .collect() + } else { + return SessionUpdateControllerMulticastListRspBuilder { + status: StatusCode::UciStatusSyntaxError, + } + .build(); + } + } + UpdateMulticastListAction::AddWithShortSubSessionKey => { + if let Ok(packet) = + SessionUpdateControllerMulticastListCmd_2_0_16_Byte_Payload::parse( + cmd.get_payload(), + ) + { + packet + .controlees + .iter() + .map(|controlee| controlee.into()) + .collect() + } else { + return SessionUpdateControllerMulticastListRspBuilder { + status: StatusCode::UciStatusSyntaxError, + } + .build(); + } + } + UpdateMulticastListAction::AddwithExtendedSubSessionKey => { + if let Ok(packet) = + SessionUpdateControllerMulticastListCmd_2_0_32_Byte_Payload::parse( + cmd.get_payload(), + ) + { + packet + .controlees + .iter() + .map(|controlee| controlee.into()) + .collect() + } else { + return SessionUpdateControllerMulticastListRspBuilder { + status: StatusCode::UciStatusSyntaxError, + } + .build(); + } + } + }; let mut controlee_status = Vec::new(); let session_id = self.id; let mut status = StatusCode::UciStatusOk; match action { - UpdateMulticastListAction::Add => { + UpdateMulticastListAction::Add + | UpdateMulticastListAction::AddWithShortSubSessionKey + | UpdateMulticastListAction::AddwithExtendedSubSessionKey => { new_controlees.iter().for_each(|controlee| { let mut update_status = MulticastUpdateStatusCode::StatusOkMulticastListUpdate; - if !dst_addresses.contains(&MacAddress::Short(controlee.short_address)) { + if !dst_addresses.contains(&controlee.short_address) { if dst_addresses.len() == MAX_NUMBER_OF_CONTROLEES { status = StatusCode::UciStatusMulticastListFull; update_status = MulticastUpdateStatusCode::StatusErrorMulticastListFull; + } else if (action == UpdateMulticastListAction::AddWithShortSubSessionKey + || action == UpdateMulticastListAction::AddwithExtendedSubSessionKey) + && self.app_config.sts_config + != StsConfig::ProvisionedForControleeIndividualKey + { + // If Action is 0x02 or 0x03 for STS_CONFIG values other than + // 0x04, the UWBS shall return SESSION_UPDATE_CONTROLLER_MULTICAST_LIST_NTF + // with Status set to STATUS_ERROR_SUB_SESSION_KEY_NOT_APPLICABLE for each + // Controlee in the Controlee List. + status = StatusCode::UciStatusFailed; + update_status = + MulticastUpdateStatusCode::StatusErrorSubSessionKeyNotApplicable; } else { - dst_addresses.push(MacAddress::Short(controlee.short_address)); + dst_addresses.push(controlee.short_address); }; } controlee_status.push(ControleeStatus { - mac_address: controlee.short_address, - subsession_id: controlee.subsession_id, + mac_address: match controlee.short_address { + MacAddress::Short(address) => address, + MacAddress::Extend(_) => panic!("Extended address is not supported!"), + }, + subsession_id: controlee.sub_session_id, status: update_status, }); }); @@ -843,11 +1022,11 @@ impl Session { let address = controlee.short_address; let attempt_count = self.app_config.in_band_termination_attempt_count; let mut update_status = MulticastUpdateStatusCode::StatusOkMulticastListUpdate; - if !dst_addresses.contains(&MacAddress::Short(address)) { + if !dst_addresses.contains(&address) { status = StatusCode::UciStatusAddressNotFound; update_status = MulticastUpdateStatusCode::StatusErrorKeyFetchFail; } else { - dst_addresses.retain(|value| *value != MacAddress::Short(address)); + dst_addresses.retain(|value| *value != address); // If IN_BAND_TERMINATION_ATTEMPT_COUNT is not equal to 0x00, then the // UWBS shall transmit the RCM with the “Stop Ranging” bit set to ‘1’ // for IN_BAND_TERMINATION_ATTEMPT_COUNT times to the corresponding @@ -856,10 +1035,7 @@ impl Session { tokio::spawn(async move { for _ in 0..attempt_count { pica_tx - .send(PicaCommand::StopRanging( - MacAddress::Short(address), - session_id, - )) + .send(PicaCommand::StopRanging(address, session_id)) .await .unwrap() } @@ -867,8 +1043,11 @@ impl Session { } } controlee_status.push(ControleeStatus { - mac_address: address, - subsession_id: controlee.subsession_id, + mac_address: match address { + MacAddress::Short(addr) => addr, + MacAddress::Extend(_) => panic!("Extended address is not supported!"), + }, + subsession_id: controlee.sub_session_id, status: update_status, }); }); @@ -1003,6 +1182,37 @@ impl Session { _ => panic!("Unsupported ranging command"), } } + + pub fn data_message_snd(&mut self, data: DataMessageSnd) -> SessionControlNotification { + let session_token = data.get_session_handle(); + let uci_sequence_number = data.get_data_sequence_number() as u8; + + if self.session_type != SessionType::FiraRangingAndInBandDataSession { + return DataTransferStatusNtfBuilder { + session_token, + status: DataTransferNtfStatusCode::UciDataTransferStatusSessionTypeNotSupported, + tx_count: 1, // TODO: support for retries? + uci_sequence_number, + } + .build() + .into(); + } + + assert_eq!(self.id, session_token); + + // TODO: perform actual data transfer across devices + println!( + "Data packet received, payload bytes: {:?}", + data.get_application_data() + ); + + DataCreditNtfBuilder { + credit_availability: CreditAvailability::CreditAvailable, + session_token, + } + .build() + .into() + } } impl Drop for Session { diff --git a/src/uci_packets.pdl b/src/uci_packets.pdl index 61a5002..6111a87 100644 --- a/src/uci_packets.pdl +++ b/src/uci_packets.pdl @@ -187,6 +187,7 @@ enum DataTransferNtfStatusCode : 8 { UCI_DATA_TRANSFER_STATUS_ERROR_REJECTED = 0x04, UCI_DATA_TRANSFER_STATUS_SESSION_TYPE_NOT_SUPPORTED = 0x05, UCI_DATA_TRANSFER_STATUS_ERROR_DATA_TRANSFER_IS_ONGOING = 0x06, + UCI_DATA_TRANSFER_STATUS_INVALID_FORMAT = 0x07, } enum ResetConfig : 8 { @@ -233,6 +234,7 @@ enum AppConfigTlvType : 8 { RESPONDER_SLOT_INDEX = 0x1E, PRF_MODE = 0x1F, CAP_SIZE_RANGE = 0x20, + TX_JITTER_WINDOW_SIZE = 0x21, SCHEDULED_MODE = 0x22, KEY_ROTATION = 0x23, KEY_ROTATION_RATE = 0x24, @@ -260,11 +262,24 @@ enum AppConfigTlvType : 8 { MIN_FRAMES_PER_RR = 0x3A, MTU_SIZE = 0x3B, INTER_FRAME_INTERVAL = 0x3C, - RFU_APP_CFG_TLV_TYPE_RANGE_1 = 0x3D..0x44, + DL_TDOA_RANGING_METHOD = 0x3D, + DL_TDOA_TX_TIMESTAMP_CONF = 0x3E, + DL_TDOA_HOP_COUNT = 0x3F, + DL_TDOA_ANCHOR_CFO = 0x40, + DL_TDOA_ANCHOR_LOCATION = 0x41, + DL_TDOA_TX_ACTIVE_RANGING_ROUNDS = 0x42, + DL_TDOA_BLOCK_STRIDING = 0x43, + DL_TDOA_TIME_REFERENCE_ANCHOR = 0x44, SESSION_KEY = 0x45, SUBSESSION_KEY = 0x46, SESSION_DATA_TRANSFER_STATUS_NTF_CONFIG = 0x47, - RFU_APP_CFG_TLV_TYPE_RANGE_2 = 0x48..0x9F, + SESSION_TIME_BASE = 0x48, + DL_TDOA_RESPONDER_TOF = 0x49, + SECURE_RANGING_NEFA_LEVEL = 0x4A, + SECURE_RANGING_CSW_LENGTH = 0x4B, + APPLICATION_DATA_ENDPOINT = 0x4C, + OWR_AOA_MEASUREMENT_NTF_PERIOD = 0x4D, + RFU_APP_CFG_TLV_TYPE_RANGE = 0x4E..0x9F, VENDOR_SPECIFIC_APP_CFG_TLV_TYPE_RANGE_1 = 0xA0..0xDF { // CCC specific @@ -462,17 +477,31 @@ enum MessageType: 3 { RESERVED_FOR_TESTING_2 = 0x05, } -// Generic format for UCI packet headers. -// No data packets are defined, the header fields are taken from the -// Control packets. +// Used to parse message type packet PacketHeader { _reserved_ : 4, pbf : PacketBoundaryFlag, mt : MessageType, +} + +// Used to parse control packet length +packet ControlPacketHeader { + _reserved_ : 4, + pbf : PacketBoundaryFlag, + mt : MessageType, _reserved_ : 16, payload_length : 8, } +// Used to parse data packet length +packet DataPacketHeader { + _reserved_ : 4, + pbf : PacketBoundaryFlag, + mt : MessageType, + _reserved_ : 8, + payload_length : 16, +} + // Unframed UCI control packet. The framing information is masked // including the payload length. The user must handle segmentation and // reassembly on the raw bytes before attempting to parse the packet. @@ -486,6 +515,32 @@ packet ControlPacket { _payload_, } +packet DataPacket { + dpf : DataPacketFormat, + pbf : PacketBoundaryFlag, + mt: MessageType, + _reserved_ : 8, + _reserved_ : 16, + _payload_, +} + +packet DataMessageSnd : DataPacket (dpf = DATA_SND, mt = DATA) { + session_handle: 32, + destination_address: 64, + data_sequence_number: 16, + _size_(application_data): 16, + application_data: 8[] +} + +packet DataMessageRcv : DataPacket (dpf = DATA_RCV, mt = DATA) { + session_handle: 32, + status : StatusCode, + source_address: 64, + data_sequence_number: 16, + _size_(application_data): 16, + application_data: 8[] +} + // TODO(b/202760099): Handle fragmentation of packets if the size exceed max allowed. packet UciCommand : ControlPacket (mt = COMMAND) { _payload_, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/data_transfer.py b/tests/data_transfer.py new file mode 100755 index 0000000..f7d731c --- /dev/null +++ b/tests/data_transfer.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import argparse +from pica import Host +from pica.packets import uci +from .helper import init +from pathlib import Path + +MAX_DATA_PACKET_PAYLOAD_SIZE = 1024 + + +async def data_message_send(host: Host, peer: Host, file: Path): + await init(host) + + host.send_control( + uci.SessionInitCmd( + session_id=0, + session_type=uci.SessionType.FIRA_RANGING_AND_IN_BAND_DATA_SESSION, + ) + ) + + await host.expect_control(uci.SessionInitRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_INIT, + reason_code=0, + ) + ) + + mac_address_mode = 0x0 + ranging_duration = int(1000).to_bytes(4, byteorder="little") + device_role_initiator = bytes([0]) + device_type_controller = bytes([1]) + host.send_control( + uci.SessionSetAppConfigCmd( + session_token=0, + tlvs=[ + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DEVICE_ROLE, v=device_role_initiator + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DEVICE_TYPE, v=device_type_controller + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DEVICE_MAC_ADDRESS, v=host.mac_address + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.MAC_ADDRESS_MODE, + v=bytes([mac_address_mode]), + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.RANGING_DURATION, v=ranging_duration + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.NO_OF_CONTROLEE, v=bytes([1]) + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DST_MAC_ADDRESS, v=peer.mac_address + ), + ], + ) + ) + + await host.expect_control( + uci.SessionSetAppConfigRsp(status=uci.StatusCode.UCI_STATUS_OK, cfg_status=[]) + ) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_IDLE, + reason_code=0, + ) + ) + + await data_transfer(host, peer.mac_address, file, 0) + + host.send_control(uci.SessionDeinitCmd(session_token=0)) + + await host.expect_control(uci.SessionDeinitRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + +async def data_transfer( + host: Host, dst_mac_address: bytes, file: Path, session_id: int +): + try: + with open(file, "rb") as f: + b = f.read() + seq_num = 0 + + if len(b) > MAX_DATA_PACKET_PAYLOAD_SIZE: + for i in range(0, len(b), MAX_DATA_PACKET_PAYLOAD_SIZE): + chunk = b[i : i + MAX_DATA_PACKET_PAYLOAD_SIZE] + + if i + MAX_DATA_PACKET_PAYLOAD_SIZE >= len(b): + host.send_data( + uci.DataMessageSnd( + session_handle=int(session_id), + destination_address=int.from_bytes(dst_mac_address), + data_sequence_number=seq_num, + application_data=chunk, + ) + ) + else: + host.send_data( + uci.DataMessageSnd( + session_handle=int(session_id), + pbf=uci.PacketBoundaryFlag.NOT_COMPLETE, + destination_address=int.from_bytes(dst_mac_address), + data_sequence_number=seq_num, + application_data=chunk, + ) + ) + + seq_num += 1 + if seq_num >= 65535: + seq_num = 0 + + event = await host.expect_control( + uci.DataCreditNtf( + session_token=int(session_id), + credit_availability=uci.CreditAvailability.CREDIT_AVAILABLE, + ) + ) + event.show() + else: + host.send_data( + uci.DataMessageSnd( + session_handle=int(session_id), + destination_address=int.from_bytes(dst_mac_address), + data_sequence_number=seq_num, + application_data=b, + ) + ) + event = await host.expect_control( + uci.DataCreditNtf( + session_token=int(session_id), + credit_availability=uci.CreditAvailability.CREDIT_AVAILABLE, + ) + ) + event.show() + + except Exception as e: + print(e) + + +async def run(address: str, uci_port: int, file: Path): + try: + host0 = await Host.connect(address, uci_port, bytes([0, 1])) + host1 = await Host.connect(address, uci_port, bytes([0, 2])) + except Exception: + print( + f"Failed to connect to Pica server at address {address}:{uci_port}\n" + + "Make sure the server is running" + ) + exit(1) + + async with asyncio.TaskGroup() as tg: + tg.create_task(data_message_send(host0, host1, file)) + + host0.disconnect() + host1.disconnect() + + print("Data transfer test completed") + + +def main(): + """Start a Pica interactive console.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--address", + type=str, + default="127.0.0.1", + help="Select the pica server address", + ) + parser.add_argument( + "--uci-port", type=int, default=7000, help="Select the pica TCP UCI port" + ) + parser.add_argument( + "--file", type=Path, required=True, help="Select the file to transfer" + ) + asyncio.run(run(**vars(parser.parse_args()))) + + +if __name__ == "__main__": + main() diff --git a/tests/helper.py b/tests/helper.py new file mode 100644 index 0000000..bbce7ad --- /dev/null +++ b/tests/helper.py @@ -0,0 +1,30 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pica import Host +from pica.packets import uci + + +async def init(host: Host): + await host.expect_control( + uci.DeviceStatusNtf(device_state=uci.DeviceState.DEVICE_STATE_READY) + ) + + host.send_control(uci.DeviceResetCmd(reset_config=uci.ResetConfig.UWBS_RESET)) + + await host.expect_control(uci.DeviceResetRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + await host.expect_control( + uci.DeviceStatusNtf(device_state=uci.DeviceState.DEVICE_STATE_READY) + ) diff --git a/tests/ranging.py b/tests/ranging.py new file mode 100755 index 0000000..7af8224 --- /dev/null +++ b/tests/ranging.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import argparse +import logging + +from pica import Host +from pica.packets import uci +from .helper import init + + +async def controller(host: Host, peer: Host): + await init(host) + + host.send_control( + uci.SessionInitCmd( + session_id=0, session_type=uci.SessionType.FIRA_RANGING_SESSION + ) + ) + + await host.expect_control(uci.SessionInitRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_INIT, + reason_code=0, + ) + ) + + mac_address_mode = 0x0 + ranging_duration = int(1000).to_bytes(4, byteorder="little") + device_role_initiator = bytes([0]) + device_type_controller = bytes([1]) + host.send_control( + uci.SessionSetAppConfigCmd( + session_token=0, + tlvs=[ + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DEVICE_ROLE, v=device_role_initiator + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DEVICE_TYPE, v=device_type_controller + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DEVICE_MAC_ADDRESS, v=host.mac_address + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.MAC_ADDRESS_MODE, + v=bytes([mac_address_mode]), + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.RANGING_DURATION, v=ranging_duration + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.NO_OF_CONTROLEE, v=bytes([1]) + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DST_MAC_ADDRESS, v=peer.mac_address + ), + ], + ) + ) + + await host.expect_control( + uci.SessionSetAppConfigRsp(status=uci.StatusCode.UCI_STATUS_OK, cfg_status=[]) + ) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_IDLE, + reason_code=0, + ) + ) + + host.send_control(uci.SessionStartCmd(session_id=0)) + + await host.expect_control(uci.SessionStartRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_ACTIVE, + reason_code=0, + ) + ) + + await host.expect_control( + uci.DeviceStatusNtf(device_state=uci.DeviceState.DEVICE_STATE_ACTIVE) + ) + + for _ in range(1, 3): + event = await host.expect_control(uci.ShortMacTwoWaySessionInfoNtf, timeout=2.0) + event.show() + + host.send_control(uci.SessionStopCmd(session_id=0)) + + await host.expect_control(uci.SessionStopRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_IDLE, + reason_code=0, + ) + ) + + await host.expect_control( + uci.DeviceStatusNtf(device_state=uci.DeviceState.DEVICE_STATE_READY) + ) + + host.send_control(uci.SessionDeinitCmd(session_token=0)) + + await host.expect_control(uci.SessionDeinitRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + +async def controlee(host: Host, peer: Host): + await init(host) + + host.send_control( + uci.SessionInitCmd( + session_id=0, session_type=uci.SessionType.FIRA_RANGING_SESSION + ) + ) + + await host.expect_control(uci.SessionInitRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_INIT, + reason_code=0, + ) + ) + + mac_address_mode = 0x0 + ranging_duration = int(1000).to_bytes(4, byteorder="little") + device_role_responder = bytes([1]) + device_type_controlee = bytes([0]) + host.send_control( + uci.SessionSetAppConfigCmd( + session_token=0, + tlvs=[ + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DEVICE_ROLE, v=device_role_responder + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DEVICE_TYPE, v=device_type_controlee + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DEVICE_MAC_ADDRESS, v=host.mac_address + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.MAC_ADDRESS_MODE, + v=bytes([mac_address_mode]), + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.RANGING_DURATION, v=ranging_duration + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.NO_OF_CONTROLEE, v=bytes([1]) + ), + uci.AppConfigTlv( + cfg_id=uci.AppConfigTlvType.DST_MAC_ADDRESS, v=peer.mac_address + ), + ], + ) + ) + + await host.expect_control( + uci.SessionSetAppConfigRsp(status=uci.StatusCode.UCI_STATUS_OK, cfg_status=[]) + ) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_IDLE, + reason_code=0, + ) + ) + + host.send_control(uci.SessionStartCmd(session_id=0)) + + await host.expect_control(uci.SessionStartRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_ACTIVE, + reason_code=0, + ) + ) + + await host.expect_control( + uci.DeviceStatusNtf(device_state=uci.DeviceState.DEVICE_STATE_ACTIVE) + ) + + for _ in range(1, 3): + event = await host.expect_control(uci.ShortMacTwoWaySessionInfoNtf, timeout=2.0) + event.show() + + host.send_control(uci.SessionStopCmd(session_id=0)) + + await host.expect_control(uci.SessionStopRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + await host.expect_control( + uci.SessionStatusNtf( + session_token=0, + session_state=uci.SessionState.SESSION_STATE_IDLE, + reason_code=0, + ) + ) + + await host.expect_control( + uci.DeviceStatusNtf(device_state=uci.DeviceState.DEVICE_STATE_READY) + ) + + host.send_control(uci.SessionDeinitCmd(session_token=0)) + + await host.expect_control(uci.SessionDeinitRsp(status=uci.StatusCode.UCI_STATUS_OK)) + + +async def run(address: str, uci_port: int): + try: + host0 = await Host.connect(address, uci_port, bytes([0, 1])) + host1 = await Host.connect(address, uci_port, bytes([0, 2])) + except Exception: + logging.debug( + f"Failed to connect to Pica server at address {address}:{uci_port}\n" + + "Make sure the server is running" + ) + exit(1) + + async with asyncio.TaskGroup() as tg: + tg.create_task(controller(host0, host1)) + tg.create_task(controlee(host1, host0)) + + host0.disconnect() + host1.disconnect() + + logging.debug("Ranging test completed") + + +def main(): + """Start a Pica interactive console.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--address", + type=str, + default="127.0.0.1", + help="Select the pica server address", + ) + parser.add_argument( + "--uci-port", type=int, default=7000, help="Select the pica TCP UCI port" + ) + asyncio.run(run(**vars(parser.parse_args()))) + + +if __name__ == "__main__": + main() diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 0000000..aa7a709 --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,66 @@ +import asyncio +from asyncio.subprocess import Process +import pytest +import pytest_asyncio +import logging +import os + +from datetime import datetime +from pathlib import Path +from typing import Tuple + +from . import ranging, data_transfer + +PICA_BIN = Path("./target/debug/pica-server") +DATA_FILE = Path("README.md") +PICA_LOCALHOST = "127.0.0.1" + +logging.basicConfig(level=os.environ.get("PICA_LOGLEVEL", "DEBUG").upper()) + + +def setup_artifacts(test_name: str) -> Tuple[Path, Path]: + artifacts = Path("./artifacts") + artifacts.mkdir(parents=True, exist_ok=True) + + current_dt = datetime.now() + formatted_date = current_dt.strftime("%Y-%m-%d_%H-%M-%S-%f")[:-3] + + f1 = artifacts / f"{formatted_date}_pica_{test_name}_stdout.txt" + f1.touch(exist_ok=True) + + f2 = artifacts / f"{formatted_date}_pica_{test_name}_stderr.txt" + f2.touch(exist_ok=True) + + return (f1, f2) + + +@pytest_asyncio.fixture +async def pica_port(request, unused_tcp_port): + (stdout, stderr) = setup_artifacts(request.node.name) + if not PICA_BIN.exists(): + raise FileNotFoundError(f"{PICA_BIN} not found") + + with stdout.open("w") as fstdout, stderr.open("w") as fstderr: + process = await asyncio.create_subprocess_exec( + PICA_BIN, + "--uci-port", + str(unused_tcp_port), + stdout=fstdout, + stderr=fstderr, + ) + await asyncio.sleep(100 / 1000) # Let pica boot up + + yield unused_tcp_port + + process.terminate() + await process.wait() + + +@pytest.mark.asyncio +async def test_ranging(pica_port): + await ranging.run(PICA_LOCALHOST, pica_port) + + +@pytest.mark.asyncio +async def test_data_transfer(pica_port): + await data_transfer.run(PICA_LOCALHOST, pica_port, DATA_FILE) |