aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVladimir Gavrilov <105977161+0xA50C1A1@users.noreply.github.com>2024-05-10 23:43:59 +0300
committerGitHub <noreply@github.com>2024-05-10 22:43:59 +0200
commita813121e0a7021cdbfd64630960b330a23b1a4d2 (patch)
tree7021b987d621b6940ebccc6738967654c399d477
parent4b4b358562a80b546d10f779dba8c56c5d0c6502 (diff)
`ndpi_strnstr()` optimization (#2433)
-rw-r--r--example/ndpiReader.c38
-rw-r--r--src/lib/ndpi_main.c47
-rw-r--r--tests/performance/Makefile.in10
-rw-r--r--tests/performance/strnstr.cpp155
4 files changed, 233 insertions, 17 deletions
diff --git a/example/ndpiReader.c b/example/ndpiReader.c
index 60d13f919..7a326f71b 100644
--- a/example/ndpiReader.c
+++ b/example/ndpiReader.c
@@ -5658,6 +5658,43 @@ void strlcpyUnitTest() {
/* *********************************************** */
+void strnstrUnitTest(void) {
+ /* Test 1: null string */
+ assert(ndpi_strnstr(NULL, "find", 10) == NULL);
+ assert(ndpi_strnstr("string", NULL, 10) == NULL);
+
+ /* Test 2: empty substring */
+ assert(strcmp(ndpi_strnstr("string", "", 6), "string") == 0);
+
+ /* Test 3: single character substring */
+ assert(strcmp(ndpi_strnstr("string", "r", 6), "ring") == 0);
+ assert(ndpi_strnstr("string", "x", 6) == NULL);
+
+ /* Test 4: multiple character substring */
+ assert(strcmp(ndpi_strnstr("string", "ing", 6), "ing") == 0);
+ assert(ndpi_strnstr("string", "xyz", 6) == NULL);
+
+ /* Test 5: substring equal to the beginning of the string */
+ assert(strcmp(ndpi_strnstr("string", "str", 3), "string") == 0);
+
+ /* Test 6: substring at the end of the string */
+ assert(strcmp(ndpi_strnstr("string", "ing", 6), "ing") == 0);
+
+ /* Test 7: substring in the middle of the string */
+ assert(strcmp(ndpi_strnstr("hello world", "lo wo", 11), "lo world") == 0);
+
+ /* Test 8: repeated characters in the string */
+ assert(strcmp(ndpi_strnstr("aaaaaa", "aaa", 6), "aaaaaa") == 0);
+
+ /* Test 9: empty string and slen 0 */
+ assert(ndpi_strnstr("", "find", 0) == NULL);
+
+ /* Test 10: substring equal to the string */
+ assert(strcmp(ndpi_strnstr("string", "string", 6), "string") == 0);
+}
+
+/* *********************************************** */
+
void filterUnitTest() {
ndpi_filter* f = ndpi_filter_alloc();
u_int32_t v, i;
@@ -6024,6 +6061,7 @@ int main(int argc, char **argv) {
compressedBitmapUnitTest();
strtonumUnitTest();
strlcpyUnitTest();
+ strnstrUnitTest();
#endif
}
diff --git a/src/lib/ndpi_main.c b/src/lib/ndpi_main.c
index ada8129b2..2752aa55f 100644
--- a/src/lib/ndpi_main.c
+++ b/src/lib/ndpi_main.c
@@ -9676,25 +9676,42 @@ void ndpi_dump_risks_score(FILE *risk_out) {
* first slen characters of s.
*/
char *ndpi_strnstr(const char *s, const char *find, size_t slen) {
- char c;
- size_t len;
+ if (s == NULL || find == NULL || slen == 0) {
+ return NULL;
+ }
+
+ char c = *find;
+
+ if (c == '\0') {
+ return (char *)s;
+ }
+
+ if (*(find + 1) == '\0') {
+ return (char *)memchr(s, c, slen);
+ }
+
+ size_t find_len = strnlen(find, slen);
- if((c = *find++) != '\0') {
- len = strnlen(find, slen);
- do {
- char sc;
+ if (find_len > slen) {
+ return NULL;
+ }
+
+ const char *end = s + slen - find_len;
+
+ while (s <= end) {
+ if (memcmp(s, find, find_len) == 0) {
+ return (char *)s;
+ }
- do {
- if(slen-- < 1 || (sc = *s++) == '\0')
- return(NULL);
- } while(sc != c);
- if(len > slen)
- return(NULL);
- } while(strncmp(s, find, len) != 0);
- s--;
+ size_t remaining_length = end - s;
+ s = (char *)memchr(s + 1, c, remaining_length);
+
+ if (s == NULL || s > end) {
+ return NULL;
+ }
}
- return((char *) s);
+ return NULL;
}
/* ****************************************************** */
diff --git a/tests/performance/Makefile.in b/tests/performance/Makefile.in
index 54c5e9fd9..d342f8483 100644
--- a/tests/performance/Makefile.in
+++ b/tests/performance/Makefile.in
@@ -1,8 +1,8 @@
INC=-I ../../src/include/ -I ../../src/lib/third_party/include/
LIB=../../src/lib/libndpi.a @ADDITIONAL_LIBS@ @LIBS@
-TOOLS=substringsearch patriciasearch gcrypt-int gcrypt-gnu
-TESTS=substring_test patricia_test
+TOOLS=substringsearch patriciasearch gcrypt-int gcrypt-gnu strnstr
+TESTS=substring_test patricia_test strnstr_test
all: $(TESTS)
@@ -18,9 +18,15 @@ gcrypt-gnu: gcrypt.c Makefile
substringsearch: substringsearch.c Makefile
$(CC) $(INC) @CFLAGS@ substringsearch.c -o substringsearch $(LIB)
+strnstr: strnstr.cpp Makefile
+ $(CXX) $(INC) @CFLAGS@ strnstr.cpp -o strnstr
+
substring_test: substringsearch top-1m.csv
./substringsearch
+strnstr_test: strnstr
+ ./strnstr
+
#
patriciasearch: patriciasearch.c Makefile
diff --git a/tests/performance/strnstr.cpp b/tests/performance/strnstr.cpp
new file mode 100644
index 000000000..84922150a
--- /dev/null
+++ b/tests/performance/strnstr.cpp
@@ -0,0 +1,155 @@
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <cstring>
+#include <functional>
+#include <iomanip>
+#include <iostream>
+#include <map>
+#include <random>
+#include <string>
+#include <vector>
+
+char *ndpi_strnstr(const char *s, const char *find, size_t slen) {
+ char c;
+ size_t len;
+
+ if ((c = *find++) != '\0') {
+ len = strnlen(find, slen);
+ do {
+ char sc;
+
+ do {
+ if (slen-- < 1 || (sc = *s++) == '\0') return (NULL);
+ } while (sc != c);
+ if (len > slen) return (NULL);
+ } while (strncmp(s, find, len) != 0);
+ s--;
+ }
+
+ return ((char *)s);
+}
+
+char *ndpi_strnstr_opt(const char *s, const char *find, size_t slen) {
+ if (s == NULL || find == NULL || slen == 0) {
+ return NULL;
+ }
+
+ char c = *find;
+
+ if (c == '\0') {
+ return (char *)s;
+ }
+
+ if (*(find + 1) == '\0') {
+ return (char *)memchr(s, c, slen);
+ }
+
+ size_t find_len = strnlen(find, slen);
+
+ if (find_len > slen) {
+ return NULL;
+ }
+
+ const char *end = s + slen - find_len;
+
+ while (s <= end) {
+ if (memcmp(s, find, find_len) == 0) {
+ return (char *)s;
+ }
+
+ size_t remaining_length = end - s;
+ s = (char *)memchr(s + 1, c, remaining_length);
+
+ if (s == NULL || s > end) {
+ return NULL;
+ }
+ }
+
+ return NULL;
+}
+
+std::string random_string(size_t length, std::mt19937 &gen) {
+ std::uniform_int_distribution<> dis(0, 255);
+ std::string str(length, 0);
+ for (size_t i = 0; i < length; i++) {
+ str[i] = static_cast<char>(dis(gen));
+ }
+ return str;
+}
+
+double measure_time(const std::function<char *(const char *, const char *,
+ size_t)> &strnstr_impl,
+ const std::string &haystack, const std::string &needle,
+ std::mt19937 &gen) {
+ auto start = std::chrono::high_resolution_clock::now();
+ // Call the function to prevent optimization
+ volatile auto result =
+ strnstr_impl(haystack.c_str(), needle.c_str(), haystack.size());
+ auto end = std::chrono::high_resolution_clock::now();
+
+ return std::chrono::duration_cast<std::chrono::nanoseconds>(end - start)
+ .count();
+}
+
+int main() {
+ std::ios_base::sync_with_stdio(false);
+ std::mt19937 gen(std::random_device{}());
+
+ const std::vector<size_t> haystack_lengths = {
+ 128, 256, 368, 448, 512, 640, 704, 768, 832, 896,
+ 960, 1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472};
+ const std::vector<size_t> needle_lengths = {5, 10, 15, 20, 25, 30,
+ 35, 40, 45, 50, 55, 60};
+
+ const std::vector<std::pair<
+ std::string, std::function<char *(const char *, const char *, size_t)>>>
+ strnstr_impls = {
+ {"ndpi_strnstr", ndpi_strnstr}, {"ndpi_strnstr_opt", ndpi_strnstr_opt}
+ // Add other implementations for comparison here
+ };
+
+ for (size_t haystack_len : haystack_lengths) {
+ for (size_t needle_len : needle_lengths) {
+ std::cout << "\nTest case - Haystack length: " << haystack_len
+ << ", Needle length: " << needle_len << "\n";
+
+ std::string haystack = random_string(haystack_len, gen);
+ std::string needle = random_string(needle_len, gen);
+
+ std::map<std::string, double> times;
+
+ for (const auto &impl : strnstr_impls) {
+ double time_sum = 0.0;
+ for (int i = 0; i < 100000; i++) {
+ time_sum += measure_time(impl.second, haystack, needle, gen);
+ }
+ double average_time =
+ time_sum / 100000.0; // Average time in nanoseconds
+
+ times[impl.first] = average_time;
+ std::cout << "Average time for " << impl.first << ": " << average_time
+ << " ns\n";
+ }
+
+ // Compare execution times between implementations
+ std::string fastest_impl;
+ double fastest_time = std::numeric_limits<double>::max();
+ for (const auto &impl_time : times) {
+ if (impl_time.second < fastest_time) {
+ fastest_impl = impl_time.first;
+ fastest_time = impl_time.second;
+ }
+ }
+
+ for (const auto &impl_time : times) {
+ if (impl_time.first != fastest_impl) {
+ std::cout << fastest_impl << " is " << impl_time.second / fastest_time
+ << " times faster than " << impl_time.first << "\n";
+ }
+ }
+ }
+ }
+
+ return 0;
+}