test.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. #! /usr/bin/env python3
  2. # SPDX-License-Identifier: GPL-2.0
  3. import argparse
  4. import ctypes
  5. import errno
  6. import hashlib
  7. import os
  8. import select
  9. import signal
  10. import socket
  11. import subprocess
  12. import sys
  13. import atexit
  14. from pwd import getpwuid
  15. from os import stat
  16. # Allow utils module to be imported from different directory
  17. this_dir = os.path.dirname(os.path.realpath(__file__))
  18. sys.path.append(os.path.join(this_dir, "../"))
  19. from lib.py.utils import ip
  20. libc = ctypes.cdll.LoadLibrary('libc.so.6')
  21. setns = libc.setns
  22. net0 = 'net0'
  23. net1 = 'net1'
  24. veth0 = 'veth0'
  25. veth1 = 'veth1'
  26. # Helper function for creating a socket inside a network namespace.
  27. # We need this because otherwise RDS will detect that the two TCP
  28. # sockets are on the same interface and use the loop transport instead
  29. # of the TCP transport.
  30. def netns_socket(netns, *args):
  31. u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
  32. child = os.fork()
  33. if child == 0:
  34. # change network namespace
  35. with open(f'/var/run/netns/{netns}') as f:
  36. try:
  37. ret = setns(f.fileno(), 0)
  38. except IOError as e:
  39. print(e.errno)
  40. print(e)
  41. # create socket in target namespace
  42. s = socket.socket(*args)
  43. # send resulting socket to parent
  44. socket.send_fds(u0, [], [s.fileno()])
  45. sys.exit(0)
  46. # receive socket from child
  47. _, s, _, _ = socket.recv_fds(u1, 0, 1)
  48. os.waitpid(child, 0)
  49. u0.close()
  50. u1.close()
  51. return socket.fromfd(s[0], *args)
  52. def signal_handler(sig, frame):
  53. print('Test timed out')
  54. sys.exit(1)
  55. #Parse out command line arguments. We take an optional
  56. # timeout parameter and an optional log output folder
  57. parser = argparse.ArgumentParser(description="init script args",
  58. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  59. parser.add_argument("-d", "--logdir", action="store",
  60. help="directory to store logs", default="/tmp")
  61. parser.add_argument('--timeout', help="timeout to terminate hung test",
  62. type=int, default=0)
  63. parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
  64. type=int, default=0)
  65. parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
  66. type=int, default=0)
  67. parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
  68. type=int, default=0)
  69. args = parser.parse_args()
  70. logdir=args.logdir
  71. packet_loss=str(args.loss)+'%'
  72. packet_corruption=str(args.corruption)+'%'
  73. packet_duplicate=str(args.duplicate)+'%'
  74. ip(f"netns add {net0}")
  75. ip(f"netns add {net1}")
  76. ip(f"link add type veth")
  77. addrs = [
  78. # we technically don't need different port numbers, but this will
  79. # help identify traffic in the network analyzer
  80. ('10.0.0.1', 10000),
  81. ('10.0.0.2', 20000),
  82. ]
  83. # move interfaces to separate namespaces so they can no longer be
  84. # bound directly; this prevents rds from switching over from the tcp
  85. # transport to the loop transport.
  86. ip(f"link set {veth0} netns {net0} up")
  87. ip(f"link set {veth1} netns {net1} up")
  88. # add addresses
  89. ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}")
  90. ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}")
  91. # add routes
  92. ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}")
  93. ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}")
  94. # sanity check that our two interfaces/addresses are correctly set up
  95. # and communicating by doing a single ping
  96. ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}")
  97. # Start a packet capture on each network
  98. for net in [net0, net1]:
  99. tcpdump_pid = os.fork()
  100. if tcpdump_pid == 0:
  101. pcap = logdir+'/'+net+'.pcap'
  102. subprocess.check_call(['touch', pcap])
  103. user = getpwuid(stat(pcap).st_uid).pw_name
  104. ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}")
  105. sys.exit(0)
  106. # simulate packet loss, duplication and corruption
  107. for net, iface in [(net0, veth0), (net1, veth1)]:
  108. ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \
  109. corrupt {packet_corruption} loss {packet_loss} duplicate \
  110. {packet_duplicate}")
  111. # add a timeout
  112. if args.timeout > 0:
  113. signal.alarm(args.timeout)
  114. signal.signal(signal.SIGALRM, signal_handler)
  115. sockets = [
  116. netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET),
  117. netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET),
  118. ]
  119. for s, addr in zip(sockets, addrs):
  120. s.bind(addr)
  121. s.setblocking(0)
  122. fileno_to_socket = {
  123. s.fileno(): s for s in sockets
  124. }
  125. addr_to_socket = {
  126. addr: s for addr, s in zip(addrs, sockets)
  127. }
  128. socket_to_addr = {
  129. s: addr for addr, s in zip(addrs, sockets)
  130. }
  131. send_hashes = {}
  132. recv_hashes = {}
  133. ep = select.epoll()
  134. for s in sockets:
  135. ep.register(s, select.EPOLLRDNORM)
  136. n = 50000
  137. nr_send = 0
  138. nr_recv = 0
  139. while nr_send < n:
  140. # Send as much as we can without blocking
  141. print("sending...", nr_send, nr_recv)
  142. while nr_send < n:
  143. send_data = hashlib.sha256(
  144. f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')
  145. # pseudo-random send/receive pattern
  146. sender = sockets[nr_send % 2]
  147. receiver = sockets[1 - (nr_send % 3) % 2]
  148. try:
  149. sender.sendto(send_data, socket_to_addr[receiver])
  150. send_hashes.setdefault((sender.fileno(), receiver.fileno()),
  151. hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
  152. nr_send = nr_send + 1
  153. except BlockingIOError as e:
  154. break
  155. except OSError as e:
  156. if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
  157. break
  158. raise
  159. # Receive as much as we can without blocking
  160. print("receiving...", nr_send, nr_recv)
  161. while nr_recv < nr_send:
  162. for fileno, eventmask in ep.poll():
  163. receiver = fileno_to_socket[fileno]
  164. if eventmask & select.EPOLLRDNORM:
  165. while True:
  166. try:
  167. recv_data, address = receiver.recvfrom(1024)
  168. sender = addr_to_socket[address]
  169. recv_hashes.setdefault((sender.fileno(),
  170. receiver.fileno()), hashlib.sha256()).update(
  171. f'<{recv_data}>'.encode('utf-8'))
  172. nr_recv = nr_recv + 1
  173. except BlockingIOError as e:
  174. break
  175. # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
  176. for net in [net0, net1]:
  177. ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
  178. ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
  179. print("done", nr_send, nr_recv)
  180. # the Python socket module doesn't know these
  181. RDS_INFO_FIRST = 10000
  182. RDS_INFO_LAST = 10017
  183. nr_success = 0
  184. nr_error = 0
  185. for s in sockets:
  186. for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
  187. # Sigh, the Python socket module doesn't allow us to pass
  188. # buffer lengths greater than 1024 for some reason. RDS
  189. # wants multiple pages.
  190. try:
  191. s.getsockopt(socket.SOL_RDS, optname, 1024)
  192. nr_success = nr_success + 1
  193. except OSError as e:
  194. nr_error = nr_error + 1
  195. if e.errno == errno.ENOSPC:
  196. # ignore
  197. pass
  198. print(f"getsockopt(): {nr_success}/{nr_error}")
  199. print("Stopping network packet captures")
  200. subprocess.check_call(['killall', '-q', 'tcpdump'])
  201. # We're done sending and receiving stuff, now let's check if what
  202. # we received is what we sent.
  203. for (sender, receiver), send_hash in send_hashes.items():
  204. recv_hash = recv_hashes.get((sender, receiver))
  205. if recv_hash is None:
  206. print("FAIL: No data received")
  207. sys.exit(1)
  208. if send_hash.hexdigest() != recv_hash.hexdigest():
  209. print("FAIL: Send/recv mismatch")
  210. print("hash expected:", send_hash.hexdigest())
  211. print("hash received:", recv_hash.hexdigest())
  212. sys.exit(1)
  213. print(f"{sender}/{receiver}: ok")
  214. print("Success")
  215. sys.exit(0)