#ifndef XDP_CHECKSUM_H #define XDP_CHECKSUM_H 1 #include #include #include static __always_inline __u16 ip_checksum(struct iphdr *iph) { __u32 sum = 0; __u16 *ptr = (__u16 *)iph; // IP header is iph->ihl * 4 bytes in size --> iph->ihl * 2 16 bit works #pragma unroll for (int i = 0; i < 10; i++) { // max 20 bytes (5 * 4) if (i >= iph->ihl * 2) break; if (i == 5) continue; // checksum field at offset 10–11, skip for calc sum += (__u32)ptr[i]; } while (sum >> 16) sum = (sum & 0xFFFF) + (sum >> 16); return ~sum; } static __always_inline __u16 csum_fold_helper(__u32 sum) { while (sum >> 16) sum = (sum & 0xffff) + (sum >> 16); return ~sum; } static __always_inline __u16 csum_add(__u32 sum, __u32 value) { sum += value; return sum; } static __always_inline __u32 csum16_add(__u32 sum, __u16 val) { sum += val; if (sum > 0xffff) sum -= 0xffff; return sum; } static __always_inline __u16 udp_checksum(struct iphdr *iph, struct udphdr *udph, void *data_end) { __u32 sum = 0; // IP header source/dest address + layer4 protocol sum = csum16_add(sum, (__u16)(iph->saddr >> 16)); sum = csum16_add(sum, (__u16)(iph->saddr & 0xffff)); sum = csum16_add(sum, (__u16)(iph->daddr >> 16)); sum = csum16_add(sum, (__u16)(iph->daddr & 0xffff)); sum = csum16_add(sum, __constant_htons(IPPROTO_UDP)); // UDP header length sum = csum16_add(sum, udph->len); // UDP header checksum __u16 *ptr = (__u16 *)udph; #pragma unroll for (size_t i = 0; i < sizeof(struct udphdr) / 2; i++) { if ((void *)(ptr + 1) > data_end) break; sum = csum16_add(sum, *ptr); ptr++; } // UDP payload checksum void *payload = (void *)udph + sizeof(struct udphdr); int payload_len = __constant_ntohs(udph->len) - sizeof(struct udphdr); ptr = (__u16 *)payload; for (int i = 0; i < 256; i++) { // max 512 bytes if (payload_len <= 0) break; if ((void *)(ptr + 1) > data_end) break; sum = csum16_add(sum, *ptr); ptr++; payload_len -= 2; } // if payload length off: pad wiuth null if (payload_len == 1) { __u8 *last = (__u8 *)ptr; if ((void *)(last + 1) <= data_end) { sum = csum16_add(sum, (*last) << 8); } } return csum_fold_helper(sum); } #endif