diff options
Diffstat (limited to 'xdp_checksum.h')
-rw-r--r-- | xdp_checksum.h | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/xdp_checksum.h b/xdp_checksum.h new file mode 100644 index 0000000..1d35565 --- /dev/null +++ b/xdp_checksum.h @@ -0,0 +1,94 @@ +#ifndef XDP_CHECKSUM_H +#define XDP_CHECKSUM_H 1 + +#include <linux/ip.h> +#include <linux/udp.h> +#include <stdlib.h> + +static __always_inline __u16 ip_checksum(struct iphdr *iph) { + __u32 sum = 0; + __u16 *ptr = (__u16 *)iph; + +// IP header is iph->ihl * 4 bytes in size --> iph->ihl * 2 16 bit works +#pragma unroll + for (int i = 0; i < 10; i++) { // max 20 bytes (5 * 4) + if (i >= iph->ihl * 2) + break; + if (i == 5) + continue; // checksum field at offset 10–11, skip for calc + sum += (__u32)ptr[i]; + } + + while (sum >> 16) + sum = (sum & 0xFFFF) + (sum >> 16); + + return ~sum; +} + +static __always_inline __u16 csum_fold_helper(__u32 sum) { + while (sum >> 16) + sum = (sum & 0xffff) + (sum >> 16); + return ~sum; +} + +static __always_inline __u16 csum_add(__u32 sum, __u32 value) { + sum += value; + return sum; +} + +static __always_inline __u32 csum16_add(__u32 sum, __u16 val) { + sum += val; + if (sum > 0xffff) + sum -= 0xffff; + return sum; +} + +static __always_inline __u16 udp_checksum(struct iphdr *iph, + struct udphdr *udph, void *data_end) { + __u32 sum = 0; + + // IP header source/dest address + layer4 protocol + sum = csum16_add(sum, (__u16)(iph->saddr >> 16)); + sum = csum16_add(sum, (__u16)(iph->saddr & 0xffff)); + sum = csum16_add(sum, (__u16)(iph->daddr >> 16)); + sum = csum16_add(sum, (__u16)(iph->daddr & 0xffff)); + sum = csum16_add(sum, __constant_htons(IPPROTO_UDP)); + // UDP header length + sum = csum16_add(sum, udph->len); + // UDP header checksum + __u16 *ptr = (__u16 *)udph; +#pragma unroll + for (size_t i = 0; i < sizeof(struct udphdr) / 2; i++) { + if ((void *)(ptr + 1) > data_end) + break; + sum = csum16_add(sum, *ptr); + ptr++; + } + + // UDP payload checksum + void *payload = (void *)udph + sizeof(struct udphdr); + int payload_len = __constant_ntohs(udph->len) - sizeof(struct udphdr); + ptr = (__u16 *)payload; + + for (int i = 0; i < 256; i++) { // max 512 bytes + if (payload_len <= 0) + break; + if ((void *)(ptr + 1) > data_end) + break; + sum = csum16_add(sum, *ptr); + ptr++; + payload_len -= 2; + } + + // if payload length off: pad wiuth null + if (payload_len == 1) { + __u8 *last = (__u8 *)ptr; + if ((void *)(last + 1) <= data_end) { + sum = csum16_add(sum, (*last) << 8); + } + } + + return csum_fold_helper(sum); +} + +#endif |