| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- #! /usr/bin/env python3
- # SPDX-License-Identifier: GPL-2.0
- import argparse
- import ctypes
- import errno
- import hashlib
- import os
- import select
- import signal
- import socket
- import subprocess
- import sys
- import atexit
- from pwd import getpwuid
- from os import stat
- # Allow utils module to be imported from different directory
- this_dir = os.path.dirname(os.path.realpath(__file__))
- sys.path.append(os.path.join(this_dir, "../"))
- from lib.py.utils import ip
- libc = ctypes.cdll.LoadLibrary('libc.so.6')
- setns = libc.setns
- net0 = 'net0'
- net1 = 'net1'
- veth0 = 'veth0'
- veth1 = 'veth1'
- # Helper function for creating a socket inside a network namespace.
- # We need this because otherwise RDS will detect that the two TCP
- # sockets are on the same interface and use the loop transport instead
- # of the TCP transport.
- def netns_socket(netns, *args):
- u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
- child = os.fork()
- if child == 0:
- # change network namespace
- with open(f'/var/run/netns/{netns}') as f:
- try:
- ret = setns(f.fileno(), 0)
- except IOError as e:
- print(e.errno)
- print(e)
- # create socket in target namespace
- s = socket.socket(*args)
- # send resulting socket to parent
- socket.send_fds(u0, [], [s.fileno()])
- sys.exit(0)
- # receive socket from child
- _, s, _, _ = socket.recv_fds(u1, 0, 1)
- os.waitpid(child, 0)
- u0.close()
- u1.close()
- return socket.fromfd(s[0], *args)
- def signal_handler(sig, frame):
- print('Test timed out')
- sys.exit(1)
- #Parse out command line arguments. We take an optional
- # timeout parameter and an optional log output folder
- parser = argparse.ArgumentParser(description="init script args",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument("-d", "--logdir", action="store",
- help="directory to store logs", default="/tmp")
- parser.add_argument('--timeout', help="timeout to terminate hung test",
- type=int, default=0)
- parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
- type=int, default=0)
- parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
- type=int, default=0)
- parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
- type=int, default=0)
- args = parser.parse_args()
- logdir=args.logdir
- packet_loss=str(args.loss)+'%'
- packet_corruption=str(args.corruption)+'%'
- packet_duplicate=str(args.duplicate)+'%'
- ip(f"netns add {net0}")
- ip(f"netns add {net1}")
- ip(f"link add type veth")
- addrs = [
- # we technically don't need different port numbers, but this will
- # help identify traffic in the network analyzer
- ('10.0.0.1', 10000),
- ('10.0.0.2', 20000),
- ]
- # move interfaces to separate namespaces so they can no longer be
- # bound directly; this prevents rds from switching over from the tcp
- # transport to the loop transport.
- ip(f"link set {veth0} netns {net0} up")
- ip(f"link set {veth1} netns {net1} up")
- # add addresses
- ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}")
- ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}")
- # add routes
- ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}")
- ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}")
- # sanity check that our two interfaces/addresses are correctly set up
- # and communicating by doing a single ping
- ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}")
- # Start a packet capture on each network
- for net in [net0, net1]:
- tcpdump_pid = os.fork()
- if tcpdump_pid == 0:
- pcap = logdir+'/'+net+'.pcap'
- subprocess.check_call(['touch', pcap])
- user = getpwuid(stat(pcap).st_uid).pw_name
- ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}")
- sys.exit(0)
- # simulate packet loss, duplication and corruption
- for net, iface in [(net0, veth0), (net1, veth1)]:
- ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \
- corrupt {packet_corruption} loss {packet_loss} duplicate \
- {packet_duplicate}")
- # add a timeout
- if args.timeout > 0:
- signal.alarm(args.timeout)
- signal.signal(signal.SIGALRM, signal_handler)
- sockets = [
- netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET),
- netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET),
- ]
- for s, addr in zip(sockets, addrs):
- s.bind(addr)
- s.setblocking(0)
- fileno_to_socket = {
- s.fileno(): s for s in sockets
- }
- addr_to_socket = {
- addr: s for addr, s in zip(addrs, sockets)
- }
- socket_to_addr = {
- s: addr for addr, s in zip(addrs, sockets)
- }
- send_hashes = {}
- recv_hashes = {}
- ep = select.epoll()
- for s in sockets:
- ep.register(s, select.EPOLLRDNORM)
- n = 50000
- nr_send = 0
- nr_recv = 0
- while nr_send < n:
- # Send as much as we can without blocking
- print("sending...", nr_send, nr_recv)
- while nr_send < n:
- send_data = hashlib.sha256(
- f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')
- # pseudo-random send/receive pattern
- sender = sockets[nr_send % 2]
- receiver = sockets[1 - (nr_send % 3) % 2]
- try:
- sender.sendto(send_data, socket_to_addr[receiver])
- send_hashes.setdefault((sender.fileno(), receiver.fileno()),
- hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
- nr_send = nr_send + 1
- except BlockingIOError as e:
- break
- except OSError as e:
- if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
- break
- raise
- # Receive as much as we can without blocking
- print("receiving...", nr_send, nr_recv)
- while nr_recv < nr_send:
- for fileno, eventmask in ep.poll():
- receiver = fileno_to_socket[fileno]
- if eventmask & select.EPOLLRDNORM:
- while True:
- try:
- recv_data, address = receiver.recvfrom(1024)
- sender = addr_to_socket[address]
- recv_hashes.setdefault((sender.fileno(),
- receiver.fileno()), hashlib.sha256()).update(
- f'<{recv_data}>'.encode('utf-8'))
- nr_recv = nr_recv + 1
- except BlockingIOError as e:
- break
- # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
- for net in [net0, net1]:
- ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
- ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")
- print("done", nr_send, nr_recv)
- # the Python socket module doesn't know these
- RDS_INFO_FIRST = 10000
- RDS_INFO_LAST = 10017
- nr_success = 0
- nr_error = 0
- for s in sockets:
- for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
- # Sigh, the Python socket module doesn't allow us to pass
- # buffer lengths greater than 1024 for some reason. RDS
- # wants multiple pages.
- try:
- s.getsockopt(socket.SOL_RDS, optname, 1024)
- nr_success = nr_success + 1
- except OSError as e:
- nr_error = nr_error + 1
- if e.errno == errno.ENOSPC:
- # ignore
- pass
- print(f"getsockopt(): {nr_success}/{nr_error}")
- print("Stopping network packet captures")
- subprocess.check_call(['killall', '-q', 'tcpdump'])
- # We're done sending and receiving stuff, now let's check if what
- # we received is what we sent.
- for (sender, receiver), send_hash in send_hashes.items():
- recv_hash = recv_hashes.get((sender, receiver))
- if recv_hash is None:
- print("FAIL: No data received")
- sys.exit(1)
- if send_hash.hexdigest() != recv_hash.hexdigest():
- print("FAIL: Send/recv mismatch")
- print("hash expected:", send_hash.hexdigest())
- print("hash received:", recv_hash.hexdigest())
- sys.exit(1)
- print(f"{sender}/{receiver}: ok")
- print("Success")
- sys.exit(0)
|