diff --git a/l2tpns.c b/l2tpns.c index 16748db..658b87c 100644 --- a/l2tpns.c +++ b/l2tpns.c @@ -39,6 +39,15 @@ #include #include #include +#include +#include +#include +#include + +#ifndef PPPIOCBRIDGECHAN +#define PPPIOCBRIDGECHAN _IOW('t', 53, int) +#define PPPIOCUNBRIDGECHAN _IO('t', 54) +#endif #include "md5.h" #include "dhcp6.h" @@ -66,6 +75,7 @@ uint32_t call_serial_number = 0; configt *config = NULL; // all configuration int rtnlfd = -1; // route netlink socket int genlfd = -1; // generic netlink socket +int genl_l2tp_id = -1; // L2TP generic netlink ID int tunfd = -1; // tun interface file handle. (network device) int udpfd[MAX_UDPFD + 1] = INIT_TABUDPFD; // array UDP file handle + 1 for lac udp int udplacfd = -1; // UDP LAC file handle @@ -461,6 +471,373 @@ void random_data(uint8_t *buf, int len) buf[n++] = (rand() >> 4) & 0xff; } +// +// Create tunnel in kernel +static int create_kernel_tunnel(uint32_t tid, uint32_t peer_tid) +{ + struct { + struct nlmsghdr nh; + struct genlmsghdr glh; + char data[64]; + } req; + + if (genl_l2tp_id < 0) + { + errno = ENOSYS; + return -1; + } + + LOG(3, 0, tid, "Creating kernel tunnel from %u to %u\n", tid, peer_tid); + + memset(&req, 0, sizeof(req)); + + req.nh.nlmsg_type = genl_l2tp_id; + req.nh.nlmsg_flags = NLM_F_REQUEST|NLM_F_ACK; + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.glh)); + + req.glh.cmd = L2TP_CMD_TUNNEL_CREATE; + req.glh.version = L2TP_GENL_VERSION; + + uint32_t fd = udpfd[tunnel[tid].indexudp]; + genetlink_addattr(&req.nh, L2TP_ATTR_FD, &fd, sizeof(fd)); + genetlink_addattr(&req.nh, L2TP_ATTR_CONN_ID, &tid, sizeof(tid)); + genetlink_addattr(&req.nh, L2TP_ATTR_PEER_CONN_ID, &peer_tid, sizeof(peer_tid)); + uint8_t version = 2; + genetlink_addattr(&req.nh, L2TP_ATTR_PROTO_VERSION, &version, sizeof(version)); + uint16_t encap = L2TP_ENCAPTYPE_UDP; + genetlink_addattr(&req.nh, L2TP_ATTR_ENCAP_TYPE, &encap, sizeof(encap)); + + assert(req.nh.nlmsg_len < sizeof(req)); + + if (genetlink_send(&req.nh) < 0) + { + LOG(2, 0, tid, "Can't create tunnel %d to %d: %s\n", tid, peer_tid, strerror(errno)); + return -1; + } + + ssize_t size = genetlink_recv(&req, sizeof(req)); + if (size < 0) + { + LOG(1, 0, 0, "Can't receive answer for tunnel creation: %s\n", strerror(errno)); + return -1; + } + + if (netlink_handle_ack((struct nlmsghdr *)&req, 1, 0, NULL) < 0) + return -1; + + return 0; +} + +// +// Delete tunnel in kernel +static int delete_kernel_tunnel(uint32_t tid) +{ + struct { + struct nlmsghdr nh; + struct genlmsghdr glh; + char data[64]; + } req; + + if (genl_l2tp_id < 0) + { + errno = ENOSYS; + return -1; + } + + LOG(3, 0, tid, "Deleting kernel tunnel for %u\n", tid); + + memset(&req, 0, sizeof(req)); + + req.nh.nlmsg_type = genl_l2tp_id; + req.nh.nlmsg_flags = NLM_F_REQUEST|NLM_F_ACK; + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.glh)); + + req.glh.cmd = L2TP_CMD_TUNNEL_DELETE; + req.glh.version = L2TP_GENL_VERSION; + + genetlink_addattr(&req.nh, L2TP_ATTR_CONN_ID, &tid, sizeof(tid)); + + assert(req.nh.nlmsg_len < sizeof(req)); + + if (genetlink_send(&req.nh) < 0) + { + LOG(2, 0, tid, "Can't delete tunnel %d: %s\n", tid, strerror(errno)); + return -1; + } + + ssize_t size = genetlink_recv(&req, sizeof(req)); + if (size < 0) + { + LOG(1, 0, 0, "Can't receive answer for tunnel deletion: %s\n", strerror(errno)); + return -1; + } + if (netlink_handle_ack((struct nlmsghdr *)&req, 1, 0, NULL) < 0) + return -1; + + return 0; +} + +// +// Create session in kernel +static int create_kernel_session(uint32_t tid, uint32_t peer_tid, uint32_t sid, uint32_t peer_sid) +{ + struct { + struct nlmsghdr nh; + struct genlmsghdr glh; + char data[64]; + } req; + + if (genl_l2tp_id < 0) + { + errno = ENOSYS; + return -1; + } + + LOG(3, sid, tid, "Creating kernel session from %u:%u to %u:%u\n", tid, sid, peer_tid, peer_sid); + + memset(&req, 0, sizeof(req)); + + req.nh.nlmsg_type = genl_l2tp_id; + req.nh.nlmsg_flags = NLM_F_REQUEST|NLM_F_ACK; + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.glh)); + + req.glh.cmd = L2TP_CMD_SESSION_CREATE; + req.glh.version = L2TP_GENL_VERSION; + + genetlink_addattr(&req.nh, L2TP_ATTR_CONN_ID, &tid, sizeof(tid)); + genetlink_addattr(&req.nh, L2TP_ATTR_PEER_CONN_ID, &peer_tid, sizeof(peer_tid)); + genetlink_addattr(&req.nh, L2TP_ATTR_SESSION_ID, &sid, sizeof(sid)); + genetlink_addattr(&req.nh, L2TP_ATTR_PEER_SESSION_ID, &peer_sid, sizeof(peer_sid)); + uint16_t pwtype = L2TP_PWTYPE_PPP; + genetlink_addattr(&req.nh, L2TP_ATTR_PW_TYPE, &pwtype, sizeof(pwtype)); + + assert(req.nh.nlmsg_len < sizeof(req)); + + if (genetlink_send(&req.nh) < 0) + { + LOG(2, sid, tid, "Can't create session %d:%d to %d:%d: %s\n", tid, sid, peer_tid, peer_sid, strerror(errno)); + return -1; + } + + ssize_t size = genetlink_recv(&req, sizeof(req)); + if (size < 0) + { + LOG(1, 0, 0, "Can't receive answer for session creation: %s\n", strerror(errno)); + return -1; + } + if (netlink_handle_ack((struct nlmsghdr *)&req, 1, 0, NULL) < 0) + return -1; + + return 0; +} + +// +// Delete session in kernel +static int delete_kernel_session(uint32_t tid, uint32_t sid) +{ + struct { + struct nlmsghdr nh; + struct genlmsghdr glh; + char data[64]; + } req; + + if (genl_l2tp_id < 0) + { + errno = ENOSYS; + return -1; + } + + LOG(3, sid, tid, "Deleting kernel session for %u:%u\n", tid, sid); + + memset(&req, 0, sizeof(req)); + + req.nh.nlmsg_type = genl_l2tp_id; + req.nh.nlmsg_flags = NLM_F_REQUEST|NLM_F_ACK; + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.glh)); + + req.glh.cmd = L2TP_CMD_SESSION_DELETE; + req.glh.version = L2TP_GENL_VERSION; + + genetlink_addattr(&req.nh, L2TP_ATTR_CONN_ID, &tid, sizeof(tid)); + genetlink_addattr(&req.nh, L2TP_ATTR_SESSION_ID, &sid, sizeof(sid)); + + assert(req.nh.nlmsg_len < sizeof(req)); + + if (genetlink_send(&req.nh) < 0) + { + LOG(2, sid, tid, "Can't delete session %d:%d: %s\n", tid, sid, strerror(errno)); + return -1; + } + + ssize_t size = genetlink_recv(&req, sizeof(req)); + if (size < 0) + { + LOG(1, 0, 0, "Can't receive answer for session deletion: %s\n", strerror(errno)); + return -1; + } + if (netlink_handle_ack((struct nlmsghdr *)&req, 1, 0, NULL) < 0) + return -1; + + return 0; +} + +// +// Create the kernel PPPoX socket +static int create_ppp_socket(int udp_fd, uint32_t tid, uint32_t peer_tid, uint32_t sid, uint32_t peer_sid, const struct sockaddr *dst, socklen_t addrlen) +{ + int pppox_fd; + int ret; + + if (genl_l2tp_id < 0) + return -1; + + LOG(3, sid, tid, "Creating PPPoL2TPsocket from %u:%u to %u:%u\n", tid, sid, peer_tid, peer_sid); + + pppox_fd = socket(AF_PPPOX, SOCK_DGRAM, PX_PROTO_OL2TP); + if (pppox_fd < 0) + { + LOG(2, sid, tid, "Can't create PPPoL2TP socket: %s\n", strerror(errno)); + return -1; + } + + struct sockaddr_pppol2tp sax; + memset(&sax, 0, sizeof(sax)); + + sax.sa_family = AF_PPPOX; + sax.sa_protocol = PX_PROTO_OL2TP; + sax.pppol2tp.fd = udp_fd; + memcpy(&sax.pppol2tp.addr, dst, addrlen); + sax.pppol2tp.s_tunnel = tid; + sax.pppol2tp.s_session = sid; + sax.pppol2tp.d_tunnel = peer_tid; + sax.pppol2tp.d_session = peer_sid; + + ret = connect(pppox_fd, (struct sockaddr *)&sax, sizeof(sax)); + if (ret < 0) + { + LOG(2, sid, tid, "Can't connect PPPoL2TP: %s\n", strerror(errno)); + close(pppox_fd); + return -1; + } + + return pppox_fd; +} + +// +// Get the kernel PPP channel +static int get_kernel_ppp_chan(sessionidt s, int pppox_fd) +{ + int ret; + int chindx; + + ret = ioctl(pppox_fd, PPPIOCGCHAN, &chindx); + if (ret < 0) + { + LOG(2, s, session[s].tunnel, "Can't get pppox_fd chan: %s\n", strerror(errno)); + return -1; + } + + return chindx; +} + +// +// Get the kernel PPP channel fd +static int create_kernel_ppp_chan(sessionidt s, int pppox_fd) +{ + int chindx = get_kernel_ppp_chan(s, pppox_fd); + int ret; + + int ppp_chan_fd = open("/dev/ppp", O_RDWR); + + LOG(3, s, session[s].tunnel, "Creating PPP channel\n"); + + ret = fcntl(ppp_chan_fd, F_GETFL, NULL); + if (ret < 0) + { + LOG(2, s, session[s].tunnel, "Can't get ppp chan flags: %s\n", strerror(errno)); + close(ppp_chan_fd); + return -1; + } + ret = fcntl(ppp_chan_fd, F_SETFL, ret | O_NONBLOCK); + if (ret < 0) + { + LOG(2, s, session[s].tunnel, "Can't set ppp chan flags: %s\n", strerror(errno)); + close(ppp_chan_fd); + return -1; + } + + ret = ioctl(ppp_chan_fd, PPPIOCATTCHAN, &chindx); + if (ret < 0) + { + LOG(2, s, session[s].tunnel, "Can't attach channel %d: %s\n", chindx, strerror(errno)); + close(ppp_chan_fd); + return -1; + } + + return ppp_chan_fd; +} + +// +// Create the kernel PPP interface +static int create_kernel_ppp_if(sessionidt s, int ppp_chan_fd, int *ifunit) +{ + int ppp_if_fd = open("/dev/ppp", O_RDWR); + int ret; + + LOG(3, s, session[s].tunnel, "Creating PPP interface\n"); + + ret = fcntl(ppp_if_fd, F_GETFL, NULL); + if (ret < 0) + { + LOG(2, s, session[s].tunnel, "Can't get ppp if flags: %s\n", strerror(errno)); + close(ppp_if_fd); + return -1; + } + ret = fcntl(ppp_if_fd, F_SETFL, ret | O_NONBLOCK); + if (ret < 0) + { + LOG(2, s, session[s].tunnel, "Can't set ppp if flags: %s\n", strerror(errno)); + close(ppp_if_fd); + return -1; + } + + ret = ioctl(ppp_if_fd, PPPIOCNEWUNIT, ifunit); + if (ret < 0) + { + LOG(2, s, session[s].tunnel, "Can't create ppp interface: %s\n", strerror(errno)); + close(ppp_if_fd); + return -1; + } + + ret = ioctl(ppp_chan_fd, PPPIOCCONNECT, ifunit); + if (ret < 0) + { + LOG(2, s, session[s].tunnel, "Can't attach channel to unit %d: %s\n", *ifunit, strerror(errno)); + close(ppp_if_fd); + return -1; + } + + return ppp_if_fd; +} + +// +// Bridge kernel channels to accelerate LAC +static int bridge_kernel_chans(sessionidt s, int pppox_fd, int pppox_fd2) +{ + int ppp_chan_fd = create_kernel_ppp_chan(s, pppox_fd); + int chindx2 = get_kernel_ppp_chan(s, pppox_fd2); + int ret; + + ret = ioctl(ppp_chan_fd, PPPIOCBRIDGECHAN, &chindx2); + close(ppp_chan_fd); + if (ret < 0) + { + LOG(2, s, session[s].tunnel, "Can't set LAC bridge: %s\n", strerror(errno)); + return -1; + } + return 0; +} + // Add a route // // This adds it to the routing table, advertises it @@ -615,6 +992,89 @@ void route6set(sessionidt s, struct in6_addr ip, int prefixlen, int add) return; } +// +// Get L2TP netlink id +static int16_t netlink_get_l2tp_id(void) +{ + struct { + struct nlmsghdr nh; + struct genlmsghdr glh; + char data[32]; + } req; + struct nlattr *ah; + int16_t ret; + + if (system("modprobe l2tp_ppp")) + LOG(3, 0, 0, "Can't modprobe l2tp_ppp: %s\n", strerror(errno)); + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_type = GENL_ID_CTRL; + req.nh.nlmsg_flags = NLM_F_REQUEST|NLM_F_ACK; + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.glh)); + + req.glh.cmd = CTRL_CMD_GETFAMILY; + req.glh.version = 1; + + genetlink_addattr(&req.nh, CTRL_ATTR_FAMILY_NAME, L2TP_GENL_NAME, sizeof(L2TP_GENL_NAME)); + + assert(req.nh.nlmsg_len < sizeof(req)); + + if (genetlink_send(&req.nh) < 0) + { + LOG(2, 0, 0, "Can't send request for l2tp netlink name: %s\n", strerror(errno)); + return -1; + } + + + ssize_t size = genetlink_recv(&req.nh, sizeof(req)); + if (size < 0) + { + LOG(2, 0, 0, "Can't receive answer for l2tp netlink name: %s\n", strerror(errno)); + return -1; + } + if (size < sizeof(req.nh)) + { + LOG(2, 0, 0, "Short answer for l2tp netlink name\n"); + return -1; + } + + if (req.nh.nlmsg_type != GENL_ID_CTRL) + { + LOG(2, 0, 0, "Unexpected answer type %d for l2tp netlink name.\n" + "Does your Linux kernel have the l2tp_netlink module available?\n", req.nh.nlmsg_type); + return -1; + } + if (size < NLMSG_HDRLEN + GENL_HDRLEN) + { + LOG(2, 0, 0, "Short answer for l2tp netlink name\n"); + return -1; + } + + size -= NLMSG_HDRLEN + GENL_HDRLEN; + ret = -1; + char *data = &req.data[0]; + for (ah = (void*) data; (char*) ah < data + size; ah = (void*) ((char *) ah + NLA_ALIGN(ah->nla_len))) + { + if ((ah->nla_type & NLA_TYPE_MASK) == CTRL_ATTR_FAMILY_ID) + { + if (ah->nla_len < NLA_HDRLEN + 2) + LOG(2, 0, 0, "Short netlink family ID for l2tp\n"); + ret = *(uint16_t*) ((char*) ah + NLA_HDRLEN); + break; + } + } + if (ret == -1) + LOG(2, 0, 0, "Did not get netlink family ID for l2tp\n"); + + size = genetlink_recv(&req, sizeof(req)); + if (size < 0) + LOG(2, 0, 0, "Can't receive ack for family ID: %s\n", strerror(errno)); + else + netlink_handle_ack((struct nlmsghdr *)&req, 1, 0, NULL); + + return ret; +} + // // Set up netlink socket static void initnetlink(void) @@ -654,6 +1114,9 @@ static void initnetlink(void) LOG(0, 0, 0, "Can't bind generic netlink socket: %s\n", strerror(errno)); exit(1); } + + genl_l2tp_id = netlink_get_l2tp_id(); + LOG(3, 0, 0, "gen l2tp id is %d\n", genl_l2tp_id); } //