Skip to content

Commit c22b413

Browse files
committed
Make strcasestr() faster
1 parent 22094ae commit c22b413

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

libc/str/strcasestr.c

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,16 @@
1717
│ PERFORMANCE OF THIS SOFTWARE. │
1818
╚─────────────────────────────────────────────────────────────────────────────*/
1919
#include "libc/str/str.h"
20+
#include "libc/ctype.h"
2021
#include "libc/mem/alloca.h"
2122
#include "libc/runtime/stack.h"
2223
#include "libc/str/tab.h"
24+
#include "third_party/aarch64/arm_neon.internal.h"
25+
#include "third_party/intel/immintrin.internal.h"
26+
27+
static int ToUpper(int c) {
28+
return 'a' <= c && c <= 'z' ? c - ('a' - 'A') : c;
29+
}
2330

2431
static void computeLPS(const char *pattern, long M, long *lps) {
2532
long len = 0;
@@ -84,5 +91,104 @@ static char *kmp(const char *s, size_t n, const char *ss, size_t m) {
8491
* @see strstr()
8592
*/
8693
char *strcasestr(const char *haystack, const char *needle) {
94+
if (haystack == needle || !*needle)
95+
return (char *)haystack;
96+
#if defined(__x86_64__) && !defined(__chibicc__)
97+
size_t i;
98+
unsigned k, m;
99+
const __m128i *p;
100+
long progress = 0;
101+
__m128i v, nl, nu, z = _mm_setzero_si128();
102+
const char *hay = haystack;
103+
char first_lower = kToLower[*needle & 255];
104+
char first_upper = ToUpper(*needle);
105+
nl = _mm_set1_epi8(first_lower);
106+
nu = _mm_set1_epi8(first_upper);
107+
for (;;) {
108+
k = (uintptr_t)hay & 15;
109+
p = (const __m128i *)((uintptr_t)hay & -16);
110+
v = _mm_load_si128(p);
111+
m = _mm_movemask_epi8(_mm_or_si128(
112+
_mm_or_si128(_mm_cmpeq_epi8(v, z), // Check for null terminator
113+
_mm_cmpeq_epi8(v, nl)), // Check lowercase
114+
_mm_cmpeq_epi8(v, nu))); // Check uppercase
115+
m >>= k;
116+
m <<= k;
117+
while (!m) {
118+
progress += 16;
119+
v = _mm_load_si128(++p);
120+
m = _mm_movemask_epi8(_mm_or_si128(
121+
_mm_or_si128(_mm_cmpeq_epi8(v, z), _mm_cmpeq_epi8(v, nl)),
122+
_mm_cmpeq_epi8(v, nu)));
123+
}
124+
int offset = __builtin_ctzl(m);
125+
progress += offset;
126+
hay = (const char *)p + offset;
127+
for (i = 0;; ++i) {
128+
if (--progress <= -512)
129+
goto OfferPathologicalAssurances;
130+
if (!needle[i])
131+
return (char *)hay;
132+
if (!hay[i])
133+
break;
134+
if (kToLower[needle[i] & 255] != kToLower[hay[i] & 255])
135+
break;
136+
}
137+
if (!*hay++)
138+
break;
139+
}
140+
return 0;
141+
#elif defined(__aarch64__) && defined(__ARM_NEON)
142+
size_t i;
143+
const char *hay = haystack;
144+
uint8_t first_lower = kToLower[*needle & 255];
145+
uint8_t first_upper = ToUpper(*needle);
146+
uint8x16_t nl = vdupq_n_u8(first_lower);
147+
uint8x16_t nu = vdupq_n_u8(first_upper);
148+
uint8x16_t z = vdupq_n_u8(0);
149+
long progress = 0;
150+
for (;;) {
151+
int k = (uintptr_t)hay & 15;
152+
hay = (const char *)((uintptr_t)hay & -16);
153+
uint8x16_t v = vld1q_u8((const uint8_t *)hay);
154+
uint8x16_t cmp_lower = vceqq_u8(v, nl);
155+
uint8x16_t cmp_upper = vceqq_u8(v, nu);
156+
uint8x16_t cmp_null = vceqq_u8(v, z);
157+
uint8x16_t cmp = vorrq_u8(vorrq_u8(cmp_lower, cmp_upper), cmp_null);
158+
uint8x8_t mask = vshrn_n_u16(vreinterpretq_u16_u8(cmp), 4);
159+
uint64_t m;
160+
vst1_u8((uint8_t *)&m, mask);
161+
m >>= k * 4;
162+
m <<= k * 4;
163+
while (!m) {
164+
hay += 16;
165+
progress += 16;
166+
v = vld1q_u8((const uint8_t *)hay);
167+
cmp_lower = vceqq_u8(v, nl);
168+
cmp_upper = vceqq_u8(v, nu);
169+
cmp_null = vceqq_u8(v, z);
170+
cmp = vorrq_u8(vorrq_u8(cmp_lower, cmp_upper), cmp_null);
171+
mask = vshrn_n_u16(vreinterpretq_u16_u8(cmp), 4);
172+
vst1_u8((uint8_t *)&m, mask);
173+
}
174+
int offset = __builtin_ctzll(m) >> 2;
175+
progress += offset;
176+
hay += offset;
177+
for (i = 0;; ++i) {
178+
if (--progress <= -512)
179+
goto OfferPathologicalAssurances;
180+
if (!needle[i])
181+
return (char *)hay;
182+
if (!hay[i])
183+
break;
184+
if (kToLower[needle[i] & 255] != kToLower[hay[i] & 255])
185+
break;
186+
}
187+
if (!*hay++)
188+
break;
189+
}
190+
return 0;
191+
#endif
192+
OfferPathologicalAssurances:
87193
return kmp(haystack, strlen(haystack), needle, strlen(needle));
88194
}

test/libc/str/strcasestr_test.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,20 @@
1717
│ PERFORMANCE OF THIS SOFTWARE. │
1818
╚─────────────────────────────────────────────────────────────────────────────*/
1919
#include "libc/str/str.h"
20+
#include "libc/assert.h"
21+
#include "libc/calls/calls.h"
2022
#include "libc/dce.h"
23+
#include "libc/intrin/safemacros.h"
2124
#include "libc/mem/alg.h"
2225
#include "libc/mem/gc.h"
2326
#include "libc/mem/mem.h"
2427
#include "libc/nexgen32e/x86feature.h"
28+
#include "libc/runtime/runtime.h"
29+
#include "libc/runtime/sysconf.h"
30+
#include "libc/stdio/rand.h"
2531
#include "libc/str/tab.h"
32+
#include "libc/sysv/consts/map.h"
33+
#include "libc/sysv/consts/prot.h"
2634
#include "libc/testlib/ezbench.h"
2735
#include "libc/testlib/hyperion.h"
2836
#include "libc/testlib/testlib.h"
@@ -54,6 +62,25 @@ TEST(strcasestr, tester) {
5462
ASSERT_STREQ(haystack, strcasestr(haystack, "win"));
5563
}
5664

65+
TEST(strcasestr, safety) {
66+
int pagesz = sysconf(_SC_PAGESIZE);
67+
char *map = (char *)mmap(0, pagesz * 2, PROT_READ | PROT_WRITE,
68+
MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
69+
npassert(map != MAP_FAILED);
70+
npassert(!mprotect(map + pagesz, pagesz, PROT_NONE));
71+
for (int haylen = 1; haylen < 128; ++haylen) {
72+
char *hay = map + pagesz - (haylen + 1);
73+
for (int i = 0; i < haylen; ++i)
74+
hay[i] = max(rand() & 255, 1);
75+
hay[haylen] = 0;
76+
for (int neelen = 1; neelen < haylen; ++neelen) {
77+
char *nee = hay + (haylen + 1) - (neelen + 1);
78+
ASSERT_EQ(strcasestr_naive(hay, nee), strcasestr(hay, nee));
79+
}
80+
}
81+
munmap(map, pagesz * 2);
82+
}
83+
5784
TEST(strcasestr, test_emptyString_isFoundAtBeginning) {
5885
MAKESTRING(haystack, "abc123def");
5986
ASSERT_STREQ(&haystack[0], strcasestr(haystack, gc(strdup(""))));

0 commit comments

Comments
 (0)