mavtables  0.2.1
MAVLink router and firewall.
UnixUDPSocket.cpp
Go to the documentation of this file.
1 // MAVLink router and firewall.
2 // Copyright (C) 2018 Michael R. Shannon <mrshannon.aerospace@gmail.com>
3 //
4 // This program is free software; you can redistribute it and/or modify
5 // it under the terms of the GNU General Public License as published by
6 // the Free Software Foundation; either version 2 of the License, or
7 // (at your option) any later version.
8 //
9 // This program is distributed in the hope that it will be useful,
10 // but WITHOUT ANY WARRANTY; without even the implied warranty of
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 // GNU General Public License for more details.
13 //
14 // You should have received a copy of the GNU General Public License
15 // along with this program. If not, see <http://www.gnu.org/licenses/>.
16 
17 
18 #include <algorithm>
19 #include <chrono>
20 #include <cstdint>
21 #include <cstring>
22 #include <iterator>
23 #include <memory>
24 #include <optional>
25 #include <stdexcept>
26 #include <system_error>
27 #include <thread>
28 #include <utility>
29 #include <vector>
30 
31 #include <arpa/inet.h>
32 #include <errno.h>
33 
34 #include "IPAddress.hpp"
35 #include "UnixSyscalls.hpp"
36 #include "UnixUDPSocket.hpp"
37 
38 
39 using namespace std::chrono_literals;
40 
41 
42 /** Construct a UDP socket.
43  *
44  * \param port The port number to listen on.
45  * \param address The address to listen on (the port portion of the address is
46  * ignored). The default is to listen on any address.
47  * \param max_bitrate The maximum number of bits per second to transmit on the
48  * UDP interface. The default is 0, which indicates no limit.
49  * \param syscalls The object to use for unix system calls. It is default
50  * constructed to the production implementation. This argument is only
51  * used for testing.
52  * \throws std::system_error if a system call produces an error.
53  */
55  unsigned int port, std::optional<IPAddress> address,
56  unsigned long max_bitrate, std::unique_ptr<UnixSyscalls> syscalls)
57  : port_(port), address_(std::move(address)), max_bitrate_(max_bitrate),
58  syscalls_(std::move(syscalls)), socket_(-1),
59  next_time_(std::chrono::steady_clock::now())
60 {
61  create_socket_();
62 }
63 
64 
65 /** The socket destructor.
66  *
67  * Closes the underlying file descriptor of the UDP socket.
68  */
69 // LCOV_EXCL_START
71 {
72  syscalls_->close(socket_);
73 }
74 // LCOV_EXCL_STOP
75 
76 
77 /** \copydoc UDPSocket::send(const std::vector<uint8_t> &, const IPAddress &)
78  *
79  * \throws PartialSendError if it fails to write all the data it is given.
80  */
82  const std::vector<uint8_t> &data, const IPAddress &address)
83 {
84  if (max_bitrate_ != 0)
85  {
86  // Implement rate limit.
87  auto now = std::chrono::steady_clock::now();
88 
89  if (now < next_time_)
90  {
91  std::this_thread::sleep_for(next_time_ - now);
92  }
93 
94  next_time_ = now + (1000 * 1000 * data.size() * 8) / max_bitrate_ * 1us;
95  }
96 
97  // Destination address structure.
98  struct sockaddr_in addr;
99  addr.sin_family = AF_INET;
100  addr.sin_port = htons(static_cast<uint16_t>(address.port()));
101  addr.sin_addr.s_addr =
102  htonl(static_cast<uint32_t>(address.address()));
103  std::memset(addr.sin_zero, '\0', sizeof(addr.sin_zero));
104  // Send the packet.
105  auto err = syscalls_->sendto(
106  socket_, data.data(), data.size(), 0,
107  reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr));
108 
109  if (err < 0)
110  {
111  throw std::system_error(std::error_code(errno, std::system_category()));
112  }
113 }
114 
115 
116 /** \copydoc UDPSocket::receive(const std::chrono::nanoseconds &)
117  *
118  * \note The timeout precision of this implementation is 1 millisecond.
119  *
120  * \throws std::system_error if a system call produces an error.
121  */
122 std::pair<std::vector<uint8_t>, IPAddress> UnixUDPSocket::receive(
123  const std::chrono::nanoseconds &timeout)
124 {
125  std::chrono::milliseconds timeout_ms =
126  std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
127  struct pollfd fds = {socket_, POLLIN, 0};
128  auto result = syscalls_->poll(
129  &fds, 1, static_cast<int>(timeout_ms.count()));
130 
131  // Poll error
132  if (result < 0)
133  {
134  throw std::system_error(std::error_code(errno, std::system_category()));
135  }
136  // Success
137  else if (result > 0)
138  {
139  // Socket error
140  if (fds.revents & POLLERR)
141  {
142  syscalls_->close(socket_);
143  create_socket_();
144  return {std::vector<uint8_t>(), IPAddress(0)};
145  }
146  // Datagram available for reading.
147  else if (fds.revents & POLLIN)
148  {
149  return receive_();
150  }
151  }
152 
153  // Timed out
154  return {std::vector<uint8_t>(), IPAddress(0)};
155 }
156 
157 
158 /** Create socket using the `port_` and `address_` member variables.
159  *
160  * \throws std::system_error if a system call produces an error.
161  */
162 void UnixUDPSocket::create_socket_()
163 {
164  socket_ = -1;
165 
166  // Create socket.
167  if ((socket_ = syscalls_->socket(AF_INET, SOCK_DGRAM, 0)) < 0)
168  {
169  throw std::system_error(std::error_code(errno, std::system_category()));
170  }
171 
172  // Bind socket to port (and optionally an IP address).
173  struct sockaddr_in addr;
174  addr.sin_family = AF_INET;
175  addr.sin_port = htons(static_cast<uint16_t>(port_));
176 
177  if (address_)
178  {
179  addr.sin_addr.s_addr =
180  htonl(static_cast<uint32_t>(address_.value().address()));
181  }
182  else
183  {
184  addr.sin_addr.s_addr = htonl(INADDR_ANY);
185  }
186 
187  std::memset(addr.sin_zero, '\0', sizeof(addr.sin_zero));
188 
189  if ((syscalls_->bind(socket_, reinterpret_cast<struct sockaddr *>(&addr),
190  sizeof(addr))) < 0)
191  {
192  throw std::system_error(std::error_code(errno, std::system_category()));
193  }
194 }
195 
196 
197 /** Read data from socket.
198  *
199  * \note There must be a packet to receive, otherwise calling this method is
200  * undefined.
201  *
202  * \returns The data read from the socket and the IP address it was sent from.
203  * \throws std::system_error if a system call produces an error.
204  */
205 std::pair<std::vector<uint8_t>, IPAddress> UnixUDPSocket::receive_()
206 {
207  // Get needed buffer size.
208  int packet_size;
209 
210  if ((syscalls_->ioctl(socket_, FIONREAD, &packet_size)) < 0)
211  {
212  throw std::system_error(std::error_code(errno, std::system_category()));
213  }
214 
215  // Read datagram.
216  std::vector<uint8_t> buffer;
217  buffer.resize(static_cast<size_t>(packet_size));
218  struct sockaddr_in addr;
219  socklen_t addrlen = sizeof(addr);
220  auto size = syscalls_->recvfrom(
221  socket_, buffer.data(), buffer.size(), 0,
222  reinterpret_cast<struct sockaddr *>(&addr), &addrlen);
223 
224  // Handle errors and extract IP address.
225  if (size < 0)
226  {
227  throw std::system_error(std::error_code(errno, std::system_category()));
228  }
229  else if (size > 0)
230  {
231  if (addrlen <= sizeof(addr) && addr.sin_family == AF_INET)
232  {
233  auto ip =
234  IPAddress(ntohl(addr.sin_addr.s_addr), ntohs(addr.sin_port));
235  return {buffer, ip};
236  }
237  }
238 
239  // Failed to read datagram.
240  return {std::vector<uint8_t>(), IPAddress(0)};
241 }
242 
243 
244 /** \copydoc UDPSocket::print_(std::ostream &os)const
245  *
246  * An example:
247  * ```
248  * udp {
249  * port 14555;
250  * address 127.0.0.1;
251  * max_bitrate 262144;
252  * }
253  * ```
254  *
255  * \param os The output stream to print to.
256  */
257 std::ostream &UnixUDPSocket::print_(std::ostream &os) const
258 {
259  os << "udp {" << std::endl;
260  os << " port " << std::to_string(port_) << ";" << std::endl;
261 
262  if (address_.has_value())
263  {
264  os << " address " << address_.value() << ";" << std::endl;
265  }
266 
267  if (max_bitrate_ != 0)
268  {
269  os << " max_bitrate " << max_bitrate_ << ";" << std::endl;
270  }
271 
272  os << "}";
273  return os;
274 }
UnixUDPSocket(unsigned int port, std::optional< IPAddress > address={}, unsigned long max_bitrate=0, std::unique_ptr< UnixSyscalls > syscalls=std::make_unique< UnixSyscalls >())
unsigned int port() const
Definition: IPAddress.cpp:215
STL namespace.
virtual std::pair< std::vector< uint8_t >, IPAddress > receive(const std::chrono::nanoseconds &timeout=std::chrono::nanoseconds::zero()) final
virtual ~UnixUDPSocket()
virtual void send(const std::vector< uint8_t > &data, const IPAddress &address) final
unsigned long address() const
Definition: IPAddress.cpp:205
std::ostream & print_(std::ostream &os) const final