diff options
author | Vladimir Gavrilov <105977161+0xA50C1A1@users.noreply.github.com> | 2024-05-10 23:43:59 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-10 22:43:59 +0200 |
commit | a813121e0a7021cdbfd64630960b330a23b1a4d2 (patch) | |
tree | 7021b987d621b6940ebccc6738967654c399d477 | |
parent | 4b4b358562a80b546d10f779dba8c56c5d0c6502 (diff) |
`ndpi_strnstr()` optimization (#2433)
-rw-r--r-- | example/ndpiReader.c | 38 | ||||
-rw-r--r-- | src/lib/ndpi_main.c | 47 | ||||
-rw-r--r-- | tests/performance/Makefile.in | 10 | ||||
-rw-r--r-- | tests/performance/strnstr.cpp | 155 |
4 files changed, 233 insertions, 17 deletions
diff --git a/example/ndpiReader.c b/example/ndpiReader.c index 60d13f919..7a326f71b 100644 --- a/example/ndpiReader.c +++ b/example/ndpiReader.c @@ -5658,6 +5658,43 @@ void strlcpyUnitTest() { /* *********************************************** */ +void strnstrUnitTest(void) { + /* Test 1: null string */ + assert(ndpi_strnstr(NULL, "find", 10) == NULL); + assert(ndpi_strnstr("string", NULL, 10) == NULL); + + /* Test 2: empty substring */ + assert(strcmp(ndpi_strnstr("string", "", 6), "string") == 0); + + /* Test 3: single character substring */ + assert(strcmp(ndpi_strnstr("string", "r", 6), "ring") == 0); + assert(ndpi_strnstr("string", "x", 6) == NULL); + + /* Test 4: multiple character substring */ + assert(strcmp(ndpi_strnstr("string", "ing", 6), "ing") == 0); + assert(ndpi_strnstr("string", "xyz", 6) == NULL); + + /* Test 5: substring equal to the beginning of the string */ + assert(strcmp(ndpi_strnstr("string", "str", 3), "string") == 0); + + /* Test 6: substring at the end of the string */ + assert(strcmp(ndpi_strnstr("string", "ing", 6), "ing") == 0); + + /* Test 7: substring in the middle of the string */ + assert(strcmp(ndpi_strnstr("hello world", "lo wo", 11), "lo world") == 0); + + /* Test 8: repeated characters in the string */ + assert(strcmp(ndpi_strnstr("aaaaaa", "aaa", 6), "aaaaaa") == 0); + + /* Test 9: empty string and slen 0 */ + assert(ndpi_strnstr("", "find", 0) == NULL); + + /* Test 10: substring equal to the string */ + assert(strcmp(ndpi_strnstr("string", "string", 6), "string") == 0); +} + +/* *********************************************** */ + void filterUnitTest() { ndpi_filter* f = ndpi_filter_alloc(); u_int32_t v, i; @@ -6024,6 +6061,7 @@ int main(int argc, char **argv) { compressedBitmapUnitTest(); strtonumUnitTest(); strlcpyUnitTest(); + strnstrUnitTest(); #endif } diff --git a/src/lib/ndpi_main.c b/src/lib/ndpi_main.c index ada8129b2..2752aa55f 100644 --- a/src/lib/ndpi_main.c +++ b/src/lib/ndpi_main.c @@ -9676,25 +9676,42 @@ void ndpi_dump_risks_score(FILE *risk_out) { * first slen characters of s. */ char *ndpi_strnstr(const char *s, const char *find, size_t slen) { - char c; - size_t len; + if (s == NULL || find == NULL || slen == 0) { + return NULL; + } + + char c = *find; + + if (c == '\0') { + return (char *)s; + } + + if (*(find + 1) == '\0') { + return (char *)memchr(s, c, slen); + } + + size_t find_len = strnlen(find, slen); - if((c = *find++) != '\0') { - len = strnlen(find, slen); - do { - char sc; + if (find_len > slen) { + return NULL; + } + + const char *end = s + slen - find_len; + + while (s <= end) { + if (memcmp(s, find, find_len) == 0) { + return (char *)s; + } - do { - if(slen-- < 1 || (sc = *s++) == '\0') - return(NULL); - } while(sc != c); - if(len > slen) - return(NULL); - } while(strncmp(s, find, len) != 0); - s--; + size_t remaining_length = end - s; + s = (char *)memchr(s + 1, c, remaining_length); + + if (s == NULL || s > end) { + return NULL; + } } - return((char *) s); + return NULL; } /* ****************************************************** */ diff --git a/tests/performance/Makefile.in b/tests/performance/Makefile.in index 54c5e9fd9..d342f8483 100644 --- a/tests/performance/Makefile.in +++ b/tests/performance/Makefile.in @@ -1,8 +1,8 @@ INC=-I ../../src/include/ -I ../../src/lib/third_party/include/ LIB=../../src/lib/libndpi.a @ADDITIONAL_LIBS@ @LIBS@ -TOOLS=substringsearch patriciasearch gcrypt-int gcrypt-gnu -TESTS=substring_test patricia_test +TOOLS=substringsearch patriciasearch gcrypt-int gcrypt-gnu strnstr +TESTS=substring_test patricia_test strnstr_test all: $(TESTS) @@ -18,9 +18,15 @@ gcrypt-gnu: gcrypt.c Makefile substringsearch: substringsearch.c Makefile $(CC) $(INC) @CFLAGS@ substringsearch.c -o substringsearch $(LIB) +strnstr: strnstr.cpp Makefile + $(CXX) $(INC) @CFLAGS@ strnstr.cpp -o strnstr + substring_test: substringsearch top-1m.csv ./substringsearch +strnstr_test: strnstr + ./strnstr + # patriciasearch: patriciasearch.c Makefile diff --git a/tests/performance/strnstr.cpp b/tests/performance/strnstr.cpp new file mode 100644 index 000000000..84922150a --- /dev/null +++ b/tests/performance/strnstr.cpp @@ -0,0 +1,155 @@ +#include <algorithm> +#include <chrono> +#include <cmath> +#include <cstring> +#include <functional> +#include <iomanip> +#include <iostream> +#include <map> +#include <random> +#include <string> +#include <vector> + +char *ndpi_strnstr(const char *s, const char *find, size_t slen) { + char c; + size_t len; + + if ((c = *find++) != '\0') { + len = strnlen(find, slen); + do { + char sc; + + do { + if (slen-- < 1 || (sc = *s++) == '\0') return (NULL); + } while (sc != c); + if (len > slen) return (NULL); + } while (strncmp(s, find, len) != 0); + s--; + } + + return ((char *)s); +} + +char *ndpi_strnstr_opt(const char *s, const char *find, size_t slen) { + if (s == NULL || find == NULL || slen == 0) { + return NULL; + } + + char c = *find; + + if (c == '\0') { + return (char *)s; + } + + if (*(find + 1) == '\0') { + return (char *)memchr(s, c, slen); + } + + size_t find_len = strnlen(find, slen); + + if (find_len > slen) { + return NULL; + } + + const char *end = s + slen - find_len; + + while (s <= end) { + if (memcmp(s, find, find_len) == 0) { + return (char *)s; + } + + size_t remaining_length = end - s; + s = (char *)memchr(s + 1, c, remaining_length); + + if (s == NULL || s > end) { + return NULL; + } + } + + return NULL; +} + +std::string random_string(size_t length, std::mt19937 &gen) { + std::uniform_int_distribution<> dis(0, 255); + std::string str(length, 0); + for (size_t i = 0; i < length; i++) { + str[i] = static_cast<char>(dis(gen)); + } + return str; +} + +double measure_time(const std::function<char *(const char *, const char *, + size_t)> &strnstr_impl, + const std::string &haystack, const std::string &needle, + std::mt19937 &gen) { + auto start = std::chrono::high_resolution_clock::now(); + // Call the function to prevent optimization + volatile auto result = + strnstr_impl(haystack.c_str(), needle.c_str(), haystack.size()); + auto end = std::chrono::high_resolution_clock::now(); + + return std::chrono::duration_cast<std::chrono::nanoseconds>(end - start) + .count(); +} + +int main() { + std::ios_base::sync_with_stdio(false); + std::mt19937 gen(std::random_device{}()); + + const std::vector<size_t> haystack_lengths = { + 128, 256, 368, 448, 512, 640, 704, 768, 832, 896, + 960, 1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472}; + const std::vector<size_t> needle_lengths = {5, 10, 15, 20, 25, 30, + 35, 40, 45, 50, 55, 60}; + + const std::vector<std::pair< + std::string, std::function<char *(const char *, const char *, size_t)>>> + strnstr_impls = { + {"ndpi_strnstr", ndpi_strnstr}, {"ndpi_strnstr_opt", ndpi_strnstr_opt} + // Add other implementations for comparison here + }; + + for (size_t haystack_len : haystack_lengths) { + for (size_t needle_len : needle_lengths) { + std::cout << "\nTest case - Haystack length: " << haystack_len + << ", Needle length: " << needle_len << "\n"; + + std::string haystack = random_string(haystack_len, gen); + std::string needle = random_string(needle_len, gen); + + std::map<std::string, double> times; + + for (const auto &impl : strnstr_impls) { + double time_sum = 0.0; + for (int i = 0; i < 100000; i++) { + time_sum += measure_time(impl.second, haystack, needle, gen); + } + double average_time = + time_sum / 100000.0; // Average time in nanoseconds + + times[impl.first] = average_time; + std::cout << "Average time for " << impl.first << ": " << average_time + << " ns\n"; + } + + // Compare execution times between implementations + std::string fastest_impl; + double fastest_time = std::numeric_limits<double>::max(); + for (const auto &impl_time : times) { + if (impl_time.second < fastest_time) { + fastest_impl = impl_time.first; + fastest_time = impl_time.second; + } + } + + for (const auto &impl_time : times) { + if (impl_time.first != fastest_impl) { + std::cout << fastest_impl << " is " << impl_time.second / fastest_time + << " times faster than " << impl_time.first << "\n"; + } + } + } + } + + return 0; +} |