Skip to content

Commit a80ab3f

Browse files
committed
Implement bf16 compiler runtime library
1 parent 9ebacb7 commit a80ab3f

File tree

5 files changed

+256
-7
lines changed

5 files changed

+256
-7
lines changed

libc/intrin/extendbfsf2.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
2+
│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │
3+
╞══════════════════════════════════════════════════════════════════════════════╡
4+
│ Copyright 2024 Justine Alexandra Roberts Tunney │
5+
│ │
6+
│ Permission to use, copy, modify, and/or distribute this software for │
7+
│ any purpose with or without fee is hereby granted, provided that the │
8+
│ above copyright notice and this permission notice appear in all copies. │
9+
│ │
10+
│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
11+
│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
12+
│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
13+
│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
14+
│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
15+
│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
16+
│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
17+
│ PERFORMANCE OF THIS SOFTWARE. │
18+
╚─────────────────────────────────────────────────────────────────────────────*/
19+
20+
float __extendbfsf2(__bf16 f) {
21+
union {
22+
__bf16 f;
23+
unsigned short i;
24+
} ub = {f};
25+
26+
// convert brain16 to binary32
27+
unsigned x = (unsigned)ub.i << 16;
28+
29+
// force nan to quiet
30+
if ((x & 0x7fffffff) > 0x7f800000)
31+
x |= 0x00400000;
32+
33+
// pun to float
34+
union {
35+
unsigned i;
36+
float f;
37+
} uf = {x};
38+
return uf.f;
39+
}

libc/intrin/truncdfbf2.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
2+
│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │
3+
╞══════════════════════════════════════════════════════════════════════════════╡
4+
│ Copyright 2024 Justine Alexandra Roberts Tunney │
5+
│ │
6+
│ Permission to use, copy, modify, and/or distribute this software for │
7+
│ any purpose with or without fee is hereby granted, provided that the │
8+
│ above copyright notice and this permission notice appear in all copies. │
9+
│ │
10+
│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
11+
│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
12+
│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
13+
│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
14+
│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
15+
│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
16+
│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
17+
│ PERFORMANCE OF THIS SOFTWARE. │
18+
╚─────────────────────────────────────────────────────────────────────────────*/
19+
20+
__bf16 __truncsfbf2(float);
21+
__bf16 __truncdfbf2(double f) {
22+
// TODO(jart): What else are we supposed to do here?
23+
return __truncsfbf2(f);
24+
}

libc/intrin/truncsfbf2.c

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
2+
│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │
3+
╞══════════════════════════════════════════════════════════════════════════════╡
4+
│ Copyright 2024 Justine Alexandra Roberts Tunney │
5+
│ │
6+
│ Permission to use, copy, modify, and/or distribute this software for │
7+
│ any purpose with or without fee is hereby granted, provided that the │
8+
│ above copyright notice and this permission notice appear in all copies. │
9+
│ │
10+
│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
11+
│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
12+
│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
13+
│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
14+
│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
15+
│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
16+
│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
17+
│ PERFORMANCE OF THIS SOFTWARE. │
18+
╚─────────────────────────────────────────────────────────────────────────────*/
19+
20+
__bf16 __truncsfbf2(float f) {
21+
union {
22+
float f;
23+
unsigned i;
24+
} uf = {f};
25+
unsigned x = uf.i;
26+
27+
if ((x & 0x7fffffff) > 0x7f800000)
28+
// force nan to quiet
29+
x = (x | 0x00400000) >> 16;
30+
else
31+
// convert binary32 to brain16 with nearest rounding
32+
x = (x + (0x7fff + ((x >> 16) & 1))) >> 16;
33+
34+
// pun to bf16
35+
union {
36+
unsigned short i;
37+
__bf16 f;
38+
} ub = {x};
39+
return ub.f;
40+
}

test/libc/tinymath/fdot_test.cc

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "libc/stdio/stdio.h"
1111
#include "libc/testlib/benchmark.h"
1212
#include "libc/x/xasprintf.h"
13+
#include "third_party/aarch64/arm_neon.internal.h"
14+
#include "third_party/intel/immintrin.internal.h"
1315

1416
#define EXPENSIVE_TESTS 0
1517

@@ -18,12 +20,11 @@
1820
#define FASTMATH __attribute__((__optimize__("-O3,-ffast-math")))
1921
#define PORTABLE __target_clones("avx512f,avx")
2022

21-
static unsigned long long lcg = 1;
22-
2323
int rand32(void) {
2424
/* Knuth, D.E., "The Art of Computer Programming," Vol 2,
2525
Seminumerical Algorithms, Third Edition, Addison-Wesley, 1998,
2626
p. 106 (line 26) & p. 108 */
27+
static unsigned long long lcg = 1;
2728
lcg *= 6364136223846793005;
2829
lcg += 1442695040888963407;
2930
return lcg >> 32;
@@ -122,6 +123,34 @@ float fdotf_recursive(const float *A, const float *B, size_t n) {
122123
}
123124
}
124125

126+
optimizespeed float fdotf_intrin(const float *A, const float *B, size_t n) {
127+
size_t i = 0;
128+
#ifdef __AVX512F__
129+
__m512 vec[CHUNK] = {};
130+
for (; i + CHUNK * 16 <= n; i += CHUNK * 16)
131+
for (int j = 0; j < CHUNK; ++j)
132+
vec[j] = _mm512_fmadd_ps(_mm512_loadu_ps(A + i + j * 16),
133+
_mm512_loadu_ps(B + i + j * 16), vec[j]);
134+
float res = 0;
135+
for (int j = 0; j < CHUNK; ++j)
136+
res += _mm512_reduce_add_ps(vec[j]);
137+
#elif defined(__aarch64__)
138+
float32x4_t vec[CHUNK] = {};
139+
for (; i + CHUNK * 4 <= n; i += CHUNK * 4)
140+
for (int j = 0; j < CHUNK; ++j)
141+
vec[j] =
142+
vfmaq_f32(vec[j], vld1q_f32(A + i + j * 4), vld1q_f32(B + i + j * 4));
143+
float res = 0;
144+
for (int j = 0; j < CHUNK; ++j)
145+
res += vaddvq_f32(vec[j]);
146+
#else
147+
float res = 0;
148+
#endif
149+
for (; i < n; ++i)
150+
res += A[i] * B[i];
151+
return res;
152+
}
153+
125154
FASTMATH float fdotf_ruler(const float *A, const float *B, size_t n) {
126155
int rule, step = 2;
127156
size_t chunk, sp = 0;
@@ -179,6 +208,8 @@ void test_fdotf_ruler(void) {
179208
}
180209

181210
PORTABLE float fdotf_hefty(const float *A, const float *B, size_t n) {
211+
if (1)
212+
return 0;
182213
unsigned i, par, len = 0;
183214
float sum, res[n / CHUNK + 1];
184215
for (res[0] = i = 0; i + CHUNK <= n; i += CHUNK)
@@ -244,7 +275,7 @@ int main() {
244275
#if EXPENSIVE_TESTS
245276
size_t n = 512 * 1024;
246277
#else
247-
size_t n = 1024;
278+
size_t n = 4096;
248279
#endif
249280

250281
float *A = new float[n];
@@ -253,22 +284,24 @@ int main() {
253284
A[i] = numba();
254285
B[i] = numba();
255286
}
256-
float kahan, naive, dubble, recursive, hefty, ruler;
287+
float kahan, naive, dubble, recursive, ruler, intrin;
257288
test_fdotf_naive();
258-
test_fdotf_hefty();
289+
// test_fdotf_hefty();
259290
test_fdotf_ruler();
260291
BENCHMARK(20, 1, (kahan = barrier(fdotf_kahan(A, B, n))));
261292
BENCHMARK(20, 1, (dubble = barrier(fdotf_dubble(A, B, n))));
262293
BENCHMARK(20, 1, (naive = barrier(fdotf_naive(A, B, n))));
263294
BENCHMARK(20, 1, (recursive = barrier(fdotf_recursive(A, B, n))));
295+
BENCHMARK(20, 1, (intrin = barrier(fdotf_intrin(A, B, n))));
264296
BENCHMARK(20, 1, (ruler = barrier(fdotf_ruler(A, B, n))));
265-
BENCHMARK(20, 1, (hefty = barrier(fdotf_hefty(A, B, n))));
297+
// BENCHMARK(20, 1, (hefty = barrier(fdotf_hefty(A, B, n))));
266298
printf("dubble = %f (%g)\n", dubble, fabs(dubble - dubble));
267299
printf("kahan = %f (%g)\n", kahan, fabs(kahan - dubble));
268300
printf("naive = %f (%g)\n", naive, fabs(naive - dubble));
269301
printf("recursive = %f (%g)\n", recursive, fabs(recursive - dubble));
302+
printf("intrin = %f (%g)\n", intrin, fabs(intrin - dubble));
270303
printf("ruler = %f (%g)\n", ruler, fabs(ruler - dubble));
271-
printf("hefty = %f (%g)\n", hefty, fabs(hefty - dubble));
304+
// printf("hefty = %f (%g)\n", hefty, fabs(hefty - dubble));
272305
delete[] B;
273306
delete[] A;
274307

test/math/bf16_test.c

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
2+
│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │
3+
╞══════════════════════════════════════════════════════════════════════════════╡
4+
│ Copyright 2024 Justine Alexandra Roberts Tunney │
5+
│ │
6+
│ Permission to use, copy, modify, and/or distribute this software for │
7+
│ any purpose with or without fee is hereby granted, provided that the │
8+
│ above copyright notice and this permission notice appear in all copies. │
9+
│ │
10+
│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
11+
│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
12+
│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
13+
│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
14+
│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
15+
│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
16+
│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
17+
│ PERFORMANCE OF THIS SOFTWARE. │
18+
╚─────────────────────────────────────────────────────────────────────────────*/
19+
#include "libc/math.h"
20+
21+
#define CHECK(x) \
22+
if (!(x)) \
23+
return __LINE__
24+
#define FALSE(x) \
25+
{ \
26+
volatile bool x_ = x; \
27+
if (x_) \
28+
return __LINE__; \
29+
}
30+
#define TRUE(x) \
31+
{ \
32+
volatile bool x_ = x; \
33+
if (!x_) \
34+
return __LINE__; \
35+
}
36+
37+
__bf16 identity(__bf16 x) {
38+
return x;
39+
}
40+
__bf16 (*half)(__bf16) = identity;
41+
42+
unsigned toint(float f) {
43+
union {
44+
float f;
45+
unsigned i;
46+
} u = {f};
47+
return u.i;
48+
}
49+
50+
int main() {
51+
volatile float f;
52+
volatile double d;
53+
volatile __bf16 pi = 3.141;
54+
55+
// half → float → half
56+
f = pi;
57+
pi = f;
58+
59+
// half → float
60+
float __extendbfsf2(__bf16);
61+
CHECK(0.f == __extendbfsf2(0));
62+
CHECK(3.140625f == __extendbfsf2(pi));
63+
CHECK(3.140625f == pi);
64+
65+
// half → double → half
66+
d = pi;
67+
pi = d;
68+
69+
// float → half
70+
__bf16 __truncsfbf2(float);
71+
CHECK(0 == (float)__truncsfbf2(0));
72+
CHECK(pi == (float)__truncsfbf2(3.141f));
73+
CHECK(3.140625f == (float)__truncsfbf2(3.141f));
74+
75+
// double → half
76+
__bf16 __truncdfbf2(double);
77+
CHECK(0 == (double)__truncdfbf2(0));
78+
CHECK(3.140625 == (double)__truncdfbf2(3.141));
79+
80+
// specials
81+
volatile __bf16 nan = NAN;
82+
volatile __bf16 positive_infinity = +INFINITY;
83+
volatile __bf16 negative_infinity = -INFINITY;
84+
CHECK(isnan(nan));
85+
CHECK(!isinf(pi));
86+
CHECK(!isnan(pi));
87+
CHECK(isinf(positive_infinity));
88+
CHECK(isinf(negative_infinity));
89+
CHECK(!isnan(positive_infinity));
90+
CHECK(!isnan(negative_infinity));
91+
CHECK(!signbit(pi));
92+
CHECK(signbit(half(-pi)));
93+
CHECK(!signbit(half(+0.)));
94+
CHECK(signbit(half(-0.)));
95+
96+
// arithmetic
97+
CHECK(half(-3) == -half(3));
98+
CHECK(half(9) == half(3) * half(3));
99+
CHECK(half(0) == half(pi) - half(pi));
100+
CHECK(half(6.28125) == half(pi) + half(pi));
101+
102+
// comparisons
103+
CHECK(half(3) > half(2));
104+
CHECK(half(3) < half(4));
105+
CHECK(half(3) <= half(3));
106+
CHECK(half(3) >= half(3));
107+
TRUE(half(NAN) != half(NAN));
108+
FALSE(half(NAN) == half(NAN));
109+
TRUE(half(3) != half(NAN));
110+
FALSE(half(3) == half(NAN));
111+
TRUE(half(NAN) != half(3));
112+
FALSE(half(NAN) == half(3));
113+
}

0 commit comments

Comments
 (0)