1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
//! An implementation of the [embedded_nal] (Network Abstradtion Layer) UDP traits based on RIOT
//! sockets

use core::convert::TryInto;
use core::mem::MaybeUninit;

use crate::error::{NegativeErrorExt, NumericError};
use crate::socket::UdpEp;

use embedded_nal::SocketAddr;

/// The operating system's network stack, used to get an implementation of
/// ``embedded_nal::UdpClientStack``.
///
/// Using this is not trivial, as RIOT needs its sockets pinned to memory for their lifetime.
/// Without a heap allocator, this is achieved by allocating all the required UDP sockets in a
/// stack object. To ensure that it is not moved, sockets on it can only be created in (and live
/// only for the duration of) a the `run` callback, which gives the actual implemtation of
/// UdpClientStack.
///
/// The number of UDP sockets allocated is configurable using the UDPCOUNT const generic.
pub struct Stack<const UDPCOUNT: usize> {
    udp_sockets: heapless::Vec<riot_sys::sock_udp_t, UDPCOUNT>,
}

impl<const UDPCOUNT: usize> core::fmt::Debug for Stack<UDPCOUNT> {
    fn fmt(&self, fmt: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
        write!(
            fmt,
            "Stack {{ {} of {} sockets used }}",
            self.udp_sockets.len(),
            UDPCOUNT
        )
    }
}

// FIXME: This should really just use Pin like socket_embedded_nal_tcp does; unfortunately, this
// doesn't align well with the .run() API, maybe that's best just to break.
#[derive(Debug)]
pub struct StackAccessor<'a, const UDPCOUNT: usize> {
    stack: &'a mut Stack<UDPCOUNT>,
}

impl<const UDPCOUNT: usize> Stack<UDPCOUNT> {
    pub fn new() -> Self {
        Self {
            udp_sockets: Default::default(),
        }
    }

    pub fn run(&mut self, runner: impl for<'a> FnOnce(StackAccessor<'a, UDPCOUNT>)) {
        let accessor = StackAccessor { stack: self };
        runner(accessor);
        // In particular, this would require tracking of whether the sockets are closed
        unimplemented!("Allocator does not have clean-up implemented");
    }
}

pub struct UdpSocket<'a> {
    // This indirection -- not having the sock_udp_t inside UdpSocket -- is necessary becasue the
    // way they are created (embedded-nal .socket()) produces owned values and needs owned values
    // later -- while what we'd prefer would be producing owned values and needing pinned ones.
    //
    // See also https://github.com/rust-embedded-community/embedded-nal/issues/61
    socket: Option<&'a mut riot_sys::sock_udp_t>,
}

impl<'a> UdpSocket<'a> {
    /// Version of socket() that gives errors compatible with Self::Error
    fn access(&mut self) -> Result<*mut riot_sys::sock_udp_t, NumericError> {
        self.socket()
            .ok_or(NumericError::from_constant(riot_sys::ENOTCONN as _))
    }

    /// Accessor to the inner socket pointer
    ///
    /// This can be used by users of the wrapper to alter properties of the socket, as long as that
    /// does not interfere with the wrapper's operation. It is not specified which parts that are;
    /// users of this beware that what the wrapper handles can be changed in subsequent versions.
    ///
    /// The method is safe on its own because all operations on the `*mut` are unsafe anyway
    /// (including the functions exported in riot-sys). It is not returning a &mut on the inner
    /// socket because that would allow swapping it out (which RIOT doesn't like at all).
    pub fn socket(&mut self) -> Option<*mut riot_sys::sock_udp_t> {
        self.socket.as_mut().map(|s| &mut **s as _)
    }

    /// If there is an actuall socket in here, close it
    fn close(&mut self) {
        if let Some(socket) = self.socket.take() {
            unsafe { riot_sys::sock_udp_close(&mut *socket) };
        }
    }
}

impl<'a, const UDPCOUNT: usize> StackAccessor<'a, UDPCOUNT> {
    /// Take one of the stack accessor's allocated slots
    fn allocate(&mut self) -> Result<*mut riot_sys::sock_udp, NumericError> {
        // This happens rarely enough that any MaybeUninit trickery is unwarranted
        self.stack
            .udp_sockets
            .push(Default::default())
            .map_err(|_| NumericError::from_constant(riot_sys::ENOMEM as _))?;

        let last = self.stack.udp_sockets.len() - 1;
        Ok(&mut self.stack.udp_sockets[last] as *mut _)
    }

    /// Wrapper around sock_udp_create
    fn create(
        &mut self,
        handle: &mut UdpSocket<'a>,
        local: &UdpEp,
        remote: Option<&UdpEp>,
    ) -> Result<(), NumericError> {
        handle.close();

        let socket = self.allocate()?;

        (unsafe {
            riot_sys::sock_udp_create(
                socket,
                local.as_ref(),
                remote
                    .map(|r| {
                        let r: &riot_sys::sock_udp_ep_t = r.as_ref();
                        r as *const _
                    })
                    .unwrap_or(core::ptr::null()),
                0,
            )
        })
        .negative_to_error()?;

        // unsafe: Having an 'a mutable reference for it is OK because the StackAccessor guarantees
        // that the stack is available for 'a and won't move.
        let socket: &'a mut _ = unsafe { &mut *socket };


        handle.socket = Some(socket);

        Ok(())
    }
}

impl<'a, const UDPCOUNT: usize> embedded_nal::UdpClientStack for StackAccessor<'a, UDPCOUNT> {
    type UdpSocket = UdpSocket<'a>;
    type Error = NumericError;

    fn socket(&mut self) -> Result<UdpSocket<'a>, Self::Error> {
        Ok(UdpSocket { socket: None })
    }

    fn connect(
        &mut self,
        handle: &mut Self::UdpSocket,
        remote: SocketAddr,
    ) -> Result<(), Self::Error> {
        // unsafe: Side effect free C macros
        let local = unsafe {
            match remote {
                SocketAddr::V4(_) => riot_sys::macro_SOCK_IPV4_EP_ANY(),
                SocketAddr::V6(_) => riot_sys::macro_SOCK_IPV6_EP_ANY(),
            }
            .into()
        };

        let remote = remote.into();

        self.create(handle, &local, Some(&remote))
    }
    fn send(
        &mut self,
        socket: &mut Self::UdpSocket,
        buffer: &[u8],
    ) -> Result<(), nb::Error<Self::Error>> {
        let socket = socket.access()?;

        (unsafe {
            riot_sys::sock_udp_send(
                crate::inline_cast_mut(&mut *socket as *mut _),
                buffer.as_ptr() as _,
                buffer.len().try_into().unwrap(),
                0 as *const _,
            )
        })
        .negative_to_error()
        .map(|_| ())
        // Sending never blocks in RIOT sockets
        .map_err(|e| nb::Error::Other(e))
    }
    fn receive(
        &mut self,
        socket: &mut Self::UdpSocket,
        buffer: &mut [u8],
    ) -> Result<(usize, SocketAddr), nb::Error<Self::Error>> {
        let socket = socket.access()?;

        let mut remote = MaybeUninit::uninit();

        let read = (unsafe {
            riot_sys::sock_udp_recv(
                crate::inline_cast_mut(&mut *socket as *mut _),
                buffer.as_mut_ptr() as _,
                buffer.len().try_into().unwrap(),
                0,
                crate::inline_cast_mut(remote.as_mut_ptr() as *mut _),
            )
        })
        .negative_to_error()
        .map(|e| e as usize)
        .map_err(|e| e.again_is_wouldblock());

        // unsafe: Set by C function
        let remote = UdpEp(unsafe { remote.assume_init() });

        Ok((read?, remote.into()))
    }

    fn close(&mut self, mut socket: Self::UdpSocket) -> Result<(), Self::Error> {
        socket.close();
        Ok(())
    }
}

impl<'a, const UDPCOUNT: usize> embedded_nal::UdpFullStack for StackAccessor<'a, UDPCOUNT> {
    fn bind(&mut self, handle: &mut UdpSocket<'a>, port: u16) -> Result<(), Self::Error> {
        let local = UdpEp::ipv6_any().with_port(port);

        self.create(handle, &local, None)
    }

    fn send_to(
        &mut self,
        handle: &mut UdpSocket<'a>,
        remote: SocketAddr,
        buffer: &[u8],
    ) -> Result<(), nb::Error<Self::Error>> {
        let socket = handle.access()?;

        let remote: UdpEp = remote.into();

        (unsafe {
            riot_sys::sock_udp_send(
                crate::inline_cast_mut(&mut *socket as *mut _),
                buffer.as_ptr() as _,
                buffer.len().try_into().unwrap(),
                remote.as_ref(),
            )
        })
        .negative_to_error()
        .map(|_| ())
        // Sending never blocks in RIOT sockets
        .map_err(|e| nb::Error::Other(e))
    }
}