aboutsummaryrefslogtreecommitdiff
path: root/xdp_checksum.h
blob: 1d355659ed4cbae83484460c33b62d3a25a4e197 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#ifndef XDP_CHECKSUM_H
#define XDP_CHECKSUM_H 1

#include <linux/ip.h>
#include <linux/udp.h>
#include <stdlib.h>

static __always_inline __u16 ip_checksum(struct iphdr *iph) {
  __u32 sum = 0;
  __u16 *ptr = (__u16 *)iph;

// IP header is iph->ihl * 4 bytes in size --> iph->ihl * 2 16 bit works
#pragma unroll
  for (int i = 0; i < 10; i++) { // max 20 bytes (5 * 4)
    if (i >= iph->ihl * 2)
      break;
    if (i == 5)
      continue; // checksum field at offset 10–11, skip for calc
    sum += (__u32)ptr[i];
  }

  while (sum >> 16)
    sum = (sum & 0xFFFF) + (sum >> 16);

  return ~sum;
}

static __always_inline __u16 csum_fold_helper(__u32 sum) {
  while (sum >> 16)
    sum = (sum & 0xffff) + (sum >> 16);
  return ~sum;
}

static __always_inline __u16 csum_add(__u32 sum, __u32 value) {
  sum += value;
  return sum;
}

static __always_inline __u32 csum16_add(__u32 sum, __u16 val) {
  sum += val;
  if (sum > 0xffff)
    sum -= 0xffff;
  return sum;
}

static __always_inline __u16 udp_checksum(struct iphdr *iph,
                                          struct udphdr *udph, void *data_end) {
  __u32 sum = 0;

  // IP header source/dest address + layer4 protocol
  sum = csum16_add(sum, (__u16)(iph->saddr >> 16));
  sum = csum16_add(sum, (__u16)(iph->saddr & 0xffff));
  sum = csum16_add(sum, (__u16)(iph->daddr >> 16));
  sum = csum16_add(sum, (__u16)(iph->daddr & 0xffff));
  sum = csum16_add(sum, __constant_htons(IPPROTO_UDP));
  // UDP header length
  sum = csum16_add(sum, udph->len);
  // UDP header checksum
  __u16 *ptr = (__u16 *)udph;
#pragma unroll
  for (size_t i = 0; i < sizeof(struct udphdr) / 2; i++) {
    if ((void *)(ptr + 1) > data_end)
      break;
    sum = csum16_add(sum, *ptr);
    ptr++;
  }

  // UDP payload checksum
  void *payload = (void *)udph + sizeof(struct udphdr);
  int payload_len = __constant_ntohs(udph->len) - sizeof(struct udphdr);
  ptr = (__u16 *)payload;

  for (int i = 0; i < 256; i++) { // max 512 bytes
    if (payload_len <= 0)
      break;
    if ((void *)(ptr + 1) > data_end)
      break;
    sum = csum16_add(sum, *ptr);
    ptr++;
    payload_len -= 2;
  }

  // if payload length off: pad wiuth null
  if (payload_len == 1) {
    __u8 *last = (__u8 *)ptr;
    if ((void *)(last + 1) <= data_end) {
      sum = csum16_add(sum, (*last) << 8);
    }
  }

  return csum_fold_helper(sum);
}

#endif