1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
|
#ifndef XDP_CHECKSUM_H
#define XDP_CHECKSUM_H 1
#include <linux/ip.h>
#include <linux/udp.h>
#include <stdlib.h>
static __always_inline __u16 ip_checksum(struct iphdr *iph) {
__u32 sum = 0;
__u16 *ptr = (__u16 *)iph;
// IP header is iph->ihl * 4 bytes in size --> iph->ihl * 2 16 bit works
#pragma unroll
for (int i = 0; i < 10; i++) { // max 20 bytes (5 * 4)
if (i >= iph->ihl * 2)
break;
if (i == 5)
continue; // checksum field at offset 10–11, skip for calc
sum += (__u32)ptr[i];
}
while (sum >> 16)
sum = (sum & 0xFFFF) + (sum >> 16);
return ~sum;
}
static __always_inline __u16 csum_fold_helper(__u32 sum) {
while (sum >> 16)
sum = (sum & 0xffff) + (sum >> 16);
return ~sum;
}
static __always_inline __u16 csum_add(__u32 sum, __u32 value) {
sum += value;
return sum;
}
static __always_inline __u32 csum16_add(__u32 sum, __u16 val) {
sum += val;
if (sum > 0xffff)
sum -= 0xffff;
return sum;
}
static __always_inline __u16 udp_checksum(struct iphdr *iph,
struct udphdr *udph, void *data_end) {
__u32 sum = 0;
// IP header source/dest address + layer4 protocol
sum = csum16_add(sum, (__u16)(iph->saddr >> 16));
sum = csum16_add(sum, (__u16)(iph->saddr & 0xffff));
sum = csum16_add(sum, (__u16)(iph->daddr >> 16));
sum = csum16_add(sum, (__u16)(iph->daddr & 0xffff));
sum = csum16_add(sum, __constant_htons(IPPROTO_UDP));
// UDP header length
sum = csum16_add(sum, udph->len);
// UDP header checksum
__u16 *ptr = (__u16 *)udph;
#pragma unroll
for (size_t i = 0; i < sizeof(struct udphdr) / 2; i++) {
if ((void *)(ptr + 1) > data_end)
break;
sum = csum16_add(sum, *ptr);
ptr++;
}
// UDP payload checksum
void *payload = (void *)udph + sizeof(struct udphdr);
int payload_len = __constant_ntohs(udph->len) - sizeof(struct udphdr);
ptr = (__u16 *)payload;
for (int i = 0; i < 256; i++) { // max 512 bytes
if (payload_len <= 0)
break;
if ((void *)(ptr + 1) > data_end)
break;
sum = csum16_add(sum, *ptr);
ptr++;
payload_len -= 2;
}
// if payload length off: pad wiuth null
if (payload_len == 1) {
__u8 *last = (__u8 *)ptr;
if ((void *)(last + 1) <= data_end) {
sum = csum16_add(sum, (*last) << 8);
}
}
return csum_fold_helper(sum);
}
#endif
|