mptcp_inq.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. // SPDX-License-Identifier: GPL-2.0
  2. #define _GNU_SOURCE
  3. #include <assert.h>
  4. #include <errno.h>
  5. #include <fcntl.h>
  6. #include <limits.h>
  7. #include <string.h>
  8. #include <stdarg.h>
  9. #include <stdbool.h>
  10. #include <stdint.h>
  11. #include <inttypes.h>
  12. #include <stdio.h>
  13. #include <stdlib.h>
  14. #include <strings.h>
  15. #include <unistd.h>
  16. #include <time.h>
  17. #include <sys/ioctl.h>
  18. #include <sys/random.h>
  19. #include <sys/socket.h>
  20. #include <sys/types.h>
  21. #include <sys/wait.h>
  22. #include <netdb.h>
  23. #include <netinet/in.h>
  24. #include <linux/tcp.h>
  25. #include <linux/sockios.h>
  26. #include <linux/compiler.h>
  27. #ifndef IPPROTO_MPTCP
  28. #define IPPROTO_MPTCP 262
  29. #endif
  30. #ifndef SOL_MPTCP
  31. #define SOL_MPTCP 284
  32. #endif
  33. static int pf = AF_INET;
  34. static int proto_tx = IPPROTO_MPTCP;
  35. static int proto_rx = IPPROTO_MPTCP;
  36. static void __noreturn die_perror(const char *msg)
  37. {
  38. perror(msg);
  39. exit(1);
  40. }
  41. static void die_usage(int r)
  42. {
  43. fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
  44. exit(r);
  45. }
  46. static void __noreturn xerror(const char *fmt, ...)
  47. {
  48. va_list ap;
  49. va_start(ap, fmt);
  50. vfprintf(stderr, fmt, ap);
  51. va_end(ap);
  52. fputc('\n', stderr);
  53. exit(1);
  54. }
  55. static const char *getxinfo_strerr(int err)
  56. {
  57. if (err == EAI_SYSTEM)
  58. return strerror(errno);
  59. return gai_strerror(err);
  60. }
  61. static void xgetaddrinfo(const char *node, const char *service,
  62. struct addrinfo *hints,
  63. struct addrinfo **res)
  64. {
  65. int err;
  66. again:
  67. err = getaddrinfo(node, service, hints, res);
  68. if (err) {
  69. const char *errstr;
  70. if (err == EAI_SOCKTYPE) {
  71. hints->ai_protocol = IPPROTO_TCP;
  72. goto again;
  73. }
  74. errstr = getxinfo_strerr(err);
  75. fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
  76. node ? node : "", service ? service : "", errstr);
  77. exit(1);
  78. }
  79. }
  80. static int sock_listen_mptcp(const char * const listenaddr,
  81. const char * const port)
  82. {
  83. int sock = -1;
  84. struct addrinfo hints = {
  85. .ai_protocol = IPPROTO_MPTCP,
  86. .ai_socktype = SOCK_STREAM,
  87. .ai_flags = AI_PASSIVE | AI_NUMERICHOST
  88. };
  89. hints.ai_family = pf;
  90. struct addrinfo *a, *addr;
  91. int one = 1;
  92. xgetaddrinfo(listenaddr, port, &hints, &addr);
  93. hints.ai_family = pf;
  94. for (a = addr; a; a = a->ai_next) {
  95. sock = socket(a->ai_family, a->ai_socktype, proto_rx);
  96. if (sock < 0)
  97. continue;
  98. if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
  99. sizeof(one)))
  100. perror("setsockopt");
  101. if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
  102. break; /* success */
  103. perror("bind");
  104. close(sock);
  105. sock = -1;
  106. }
  107. freeaddrinfo(addr);
  108. if (sock < 0)
  109. xerror("could not create listen socket");
  110. if (listen(sock, 20))
  111. die_perror("listen");
  112. return sock;
  113. }
  114. static int sock_connect_mptcp(const char * const remoteaddr,
  115. const char * const port, int proto)
  116. {
  117. struct addrinfo hints = {
  118. .ai_protocol = IPPROTO_MPTCP,
  119. .ai_socktype = SOCK_STREAM,
  120. };
  121. struct addrinfo *a, *addr;
  122. int sock = -1;
  123. hints.ai_family = pf;
  124. xgetaddrinfo(remoteaddr, port, &hints, &addr);
  125. for (a = addr; a; a = a->ai_next) {
  126. sock = socket(a->ai_family, a->ai_socktype, proto);
  127. if (sock < 0)
  128. continue;
  129. if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
  130. break; /* success */
  131. die_perror("connect");
  132. }
  133. if (sock < 0)
  134. xerror("could not create connect socket");
  135. freeaddrinfo(addr);
  136. return sock;
  137. }
  138. static int protostr_to_num(const char *s)
  139. {
  140. if (strcasecmp(s, "tcp") == 0)
  141. return IPPROTO_TCP;
  142. if (strcasecmp(s, "mptcp") == 0)
  143. return IPPROTO_MPTCP;
  144. die_usage(1);
  145. return 0;
  146. }
  147. static void parse_opts(int argc, char **argv)
  148. {
  149. int c;
  150. while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
  151. switch (c) {
  152. case 'h':
  153. die_usage(0);
  154. break;
  155. case '6':
  156. pf = AF_INET6;
  157. break;
  158. case 't':
  159. proto_tx = protostr_to_num(optarg);
  160. break;
  161. case 'r':
  162. proto_rx = protostr_to_num(optarg);
  163. break;
  164. default:
  165. die_usage(1);
  166. break;
  167. }
  168. }
  169. }
  170. /* wait up to timeout milliseconds */
  171. static void wait_for_ack(int fd, int timeout, size_t total)
  172. {
  173. int i;
  174. for (i = 0; i < timeout; i++) {
  175. int nsd, ret, queued = -1;
  176. struct timespec req;
  177. ret = ioctl(fd, TIOCOUTQ, &queued);
  178. if (ret < 0)
  179. die_perror("TIOCOUTQ");
  180. ret = ioctl(fd, SIOCOUTQNSD, &nsd);
  181. if (ret < 0)
  182. die_perror("SIOCOUTQNSD");
  183. if ((size_t)queued > total)
  184. xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
  185. assert(nsd <= queued);
  186. if (queued == 0)
  187. return;
  188. /* wait for peer to ack rx of all data */
  189. req.tv_sec = 0;
  190. req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
  191. nanosleep(&req, NULL);
  192. }
  193. xerror("still tx data queued after %u ms\n", timeout);
  194. }
  195. static void connect_one_server(int fd, int unixfd)
  196. {
  197. size_t len, i, total, sent;
  198. char buf[4096], buf2[4096];
  199. ssize_t ret;
  200. len = rand() % (sizeof(buf) - 1);
  201. if (len < 128)
  202. len = 128;
  203. for (i = 0; i < len ; i++) {
  204. buf[i] = rand() % 26;
  205. buf[i] += 'A';
  206. }
  207. buf[i] = '\n';
  208. /* un-block server */
  209. ret = read(unixfd, buf2, 4);
  210. assert(ret == 4);
  211. assert(strncmp(buf2, "xmit", 4) == 0);
  212. ret = write(unixfd, &len, sizeof(len));
  213. assert(ret == (ssize_t)sizeof(len));
  214. ret = write(fd, buf, len);
  215. if (ret < 0)
  216. die_perror("write");
  217. if (ret != (ssize_t)len)
  218. xerror("short write");
  219. ret = read(unixfd, buf2, 4);
  220. assert(strncmp(buf2, "huge", 4) == 0);
  221. total = rand() % (16 * 1024 * 1024);
  222. total += (1 * 1024 * 1024);
  223. sent = total;
  224. ret = write(unixfd, &total, sizeof(total));
  225. assert(ret == (ssize_t)sizeof(total));
  226. wait_for_ack(fd, 5000, len);
  227. while (total > 0) {
  228. if (total > sizeof(buf))
  229. len = sizeof(buf);
  230. else
  231. len = total;
  232. ret = write(fd, buf, len);
  233. if (ret < 0)
  234. die_perror("write");
  235. total -= ret;
  236. /* we don't have to care about buf content, only
  237. * number of total bytes sent
  238. */
  239. }
  240. ret = read(unixfd, buf2, 4);
  241. assert(ret == 4);
  242. assert(strncmp(buf2, "shut", 4) == 0);
  243. wait_for_ack(fd, 5000, sent);
  244. ret = write(fd, buf, 1);
  245. assert(ret == 1);
  246. close(fd);
  247. ret = write(unixfd, "closed", 6);
  248. assert(ret == 6);
  249. close(unixfd);
  250. }
  251. static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
  252. {
  253. struct cmsghdr *cmsg;
  254. for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
  255. if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
  256. memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
  257. return;
  258. }
  259. }
  260. xerror("could not find TCP_CM_INQ cmsg type");
  261. }
  262. static void process_one_client(int fd, int unixfd)
  263. {
  264. unsigned int tcp_inq;
  265. size_t expect_len;
  266. char msg_buf[4096];
  267. char buf[4096];
  268. char tmp[16];
  269. struct iovec iov = {
  270. .iov_base = buf,
  271. .iov_len = 1,
  272. };
  273. struct msghdr msg = {
  274. .msg_iov = &iov,
  275. .msg_iovlen = 1,
  276. .msg_control = msg_buf,
  277. .msg_controllen = sizeof(msg_buf),
  278. };
  279. ssize_t ret, tot;
  280. ret = write(unixfd, "xmit", 4);
  281. assert(ret == 4);
  282. ret = read(unixfd, &expect_len, sizeof(expect_len));
  283. assert(ret == (ssize_t)sizeof(expect_len));
  284. if (expect_len > sizeof(buf))
  285. xerror("expect len %zu exceeds buffer size", expect_len);
  286. for (;;) {
  287. struct timespec req;
  288. unsigned int queued;
  289. ret = ioctl(fd, FIONREAD, &queued);
  290. if (ret < 0)
  291. die_perror("FIONREAD");
  292. if (queued > expect_len)
  293. xerror("FIONREAD returned %u, but only %zu expected\n",
  294. queued, expect_len);
  295. if (queued == expect_len)
  296. break;
  297. req.tv_sec = 0;
  298. req.tv_nsec = 1000 * 1000ul;
  299. nanosleep(&req, NULL);
  300. }
  301. /* read one byte, expect cmsg to return expected - 1 */
  302. ret = recvmsg(fd, &msg, 0);
  303. if (ret < 0)
  304. die_perror("recvmsg");
  305. if (msg.msg_controllen == 0)
  306. xerror("msg_controllen is 0");
  307. get_tcp_inq(&msg, &tcp_inq);
  308. assert((size_t)tcp_inq == (expect_len - 1));
  309. iov.iov_len = sizeof(buf);
  310. ret = recvmsg(fd, &msg, 0);
  311. if (ret < 0)
  312. die_perror("recvmsg");
  313. /* should have gotten exact remainder of all pending data */
  314. assert(ret == (ssize_t)tcp_inq);
  315. /* should be 0, all drained */
  316. get_tcp_inq(&msg, &tcp_inq);
  317. assert(tcp_inq == 0);
  318. /* request a large swath of data. */
  319. ret = write(unixfd, "huge", 4);
  320. assert(ret == 4);
  321. ret = read(unixfd, &expect_len, sizeof(expect_len));
  322. assert(ret == (ssize_t)sizeof(expect_len));
  323. /* peer should send us a few mb of data */
  324. if (expect_len <= sizeof(buf))
  325. xerror("expect len %zu too small\n", expect_len);
  326. tot = 0;
  327. do {
  328. iov.iov_len = sizeof(buf);
  329. ret = recvmsg(fd, &msg, 0);
  330. if (ret < 0)
  331. die_perror("recvmsg");
  332. tot += ret;
  333. get_tcp_inq(&msg, &tcp_inq);
  334. if (tcp_inq > expect_len - tot)
  335. xerror("inq %d, remaining %d total_len %d\n",
  336. tcp_inq, expect_len - tot, (int)expect_len);
  337. assert(tcp_inq <= expect_len - tot);
  338. } while ((size_t)tot < expect_len);
  339. ret = write(unixfd, "shut", 4);
  340. assert(ret == 4);
  341. /* wait for hangup. Should have received one more byte of data. */
  342. ret = read(unixfd, tmp, sizeof(tmp));
  343. assert(ret == 6);
  344. assert(strncmp(tmp, "closed", 6) == 0);
  345. sleep(1);
  346. iov.iov_len = 1;
  347. ret = recvmsg(fd, &msg, 0);
  348. if (ret < 0)
  349. die_perror("recvmsg");
  350. assert(ret == 1);
  351. get_tcp_inq(&msg, &tcp_inq);
  352. /* tcp_inq should be 1 due to received fin. */
  353. assert(tcp_inq == 1);
  354. iov.iov_len = 1;
  355. ret = recvmsg(fd, &msg, 0);
  356. if (ret < 0)
  357. die_perror("recvmsg");
  358. /* expect EOF */
  359. assert(ret == 0);
  360. get_tcp_inq(&msg, &tcp_inq);
  361. assert(tcp_inq == 1);
  362. close(fd);
  363. }
  364. static int xaccept(int s)
  365. {
  366. int fd = accept(s, NULL, 0);
  367. if (fd < 0)
  368. die_perror("accept");
  369. return fd;
  370. }
  371. static int server(int unixfd)
  372. {
  373. int fd = -1, r, on = 1;
  374. switch (pf) {
  375. case AF_INET:
  376. fd = sock_listen_mptcp("127.0.0.1", "15432");
  377. break;
  378. case AF_INET6:
  379. fd = sock_listen_mptcp("::1", "15432");
  380. break;
  381. default:
  382. xerror("Unknown pf %d\n", pf);
  383. break;
  384. }
  385. r = write(unixfd, "conn", 4);
  386. assert(r == 4);
  387. alarm(15);
  388. r = xaccept(fd);
  389. if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
  390. die_perror("setsockopt");
  391. process_one_client(r, unixfd);
  392. close(fd);
  393. return 0;
  394. }
  395. static int client(int unixfd)
  396. {
  397. int fd = -1;
  398. alarm(15);
  399. switch (pf) {
  400. case AF_INET:
  401. fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
  402. break;
  403. case AF_INET6:
  404. fd = sock_connect_mptcp("::1", "15432", proto_tx);
  405. break;
  406. default:
  407. xerror("Unknown pf %d\n", pf);
  408. }
  409. connect_one_server(fd, unixfd);
  410. return 0;
  411. }
  412. static void init_rng(void)
  413. {
  414. unsigned int foo;
  415. if (getrandom(&foo, sizeof(foo), 0) == -1) {
  416. perror("getrandom");
  417. exit(1);
  418. }
  419. srand(foo);
  420. }
  421. static pid_t xfork(void)
  422. {
  423. pid_t p = fork();
  424. if (p < 0)
  425. die_perror("fork");
  426. else if (p == 0)
  427. init_rng();
  428. return p;
  429. }
  430. static int rcheck(int wstatus, const char *what)
  431. {
  432. if (WIFEXITED(wstatus)) {
  433. if (WEXITSTATUS(wstatus) == 0)
  434. return 0;
  435. fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
  436. return WEXITSTATUS(wstatus);
  437. } else if (WIFSIGNALED(wstatus)) {
  438. xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
  439. } else if (WIFSTOPPED(wstatus)) {
  440. xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
  441. }
  442. return 111;
  443. }
  444. int main(int argc, char *argv[])
  445. {
  446. int e1, e2, wstatus;
  447. pid_t s, c, ret;
  448. int unixfds[2];
  449. parse_opts(argc, argv);
  450. e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
  451. if (e1 < 0)
  452. die_perror("pipe");
  453. s = xfork();
  454. if (s == 0) {
  455. close(unixfds[0]);
  456. ret = server(unixfds[1]);
  457. close(unixfds[1]);
  458. return ret;
  459. }
  460. close(unixfds[1]);
  461. /* wait until server bound a socket */
  462. e1 = read(unixfds[0], &e1, 4);
  463. assert(e1 == 4);
  464. c = xfork();
  465. if (c == 0)
  466. return client(unixfds[0]);
  467. close(unixfds[0]);
  468. ret = waitpid(s, &wstatus, 0);
  469. if (ret == -1)
  470. die_perror("waitpid");
  471. e1 = rcheck(wstatus, "server");
  472. ret = waitpid(c, &wstatus, 0);
  473. if (ret == -1)
  474. die_perror("waitpid");
  475. e2 = rcheck(wstatus, "client");
  476. return e1 ? e1 : e2;
  477. }