x86マシンコード(128ビットSSE1&AVXを使用したSIMD 4x float)94バイト
x86マシンコード(256ビットAVXを使用したSIMD 4xダブル)123バイト
float
問題のテストケースに合格しますが、それを実現するのに十分なループ終了しきい値を使用すると、ランダム入力の無限ループに陥りやすくなります。
SSE1パック単精度命令の長さは3バイトですが、SSE2および単純なAVX命令の長さは4バイトです。(スカラー-シングル命令のようなsqrtss
ものも4バイトの長さsqrtps
です。そのため、low要素だけを気にしているのに使用します。最新のハードウェアのsqrtssよりも遅くありません)。非破壊的な宛先にAVXを使用して、movaps + opに対して2バイトを節約しました。
ダブルバージョンでは、movlhps
64ビットチャンクをコピーするためにカップルを行うことができます(多くの場合、水平方向の合計の低い要素のみを考慮するため)。256ビットSIMDベクトルの水平方向の合計はまた、余分に必要とvextractf128
対、高の半分を取得するために、ゆっくりではあるが、小さな2倍のhaddps
フロートのための戦略。のdouble
バージョンには、2バイトの4バイトではなく、2バイトの8バイト定数も必要です。全体的には、float
バージョンのサイズの4/3近くになります。
mean(a,b) = mean(a,a,b,b)
これらの4つの手段すべてについて、入力を4つの要素まで単純に複製でき、length = 2を実装する必要はありません。したがって、幾何平均を4th-root = sqrt(sqrt)などとしてハードコーディングできます。そして、必要なFP定数は1つだけです4.0
。
4つすべての単一のSIMDベクトルがあります[a_i, b_i, c_i, d_i]
。それから、4つの平均を別々のレジスタのスカラーとして計算し、次の反復のためにそれらを一緒にシャッフルします。 (SIMDベクトルの水平方向の操作は不便ですが、バランスをとるのに十分な場合には4つの要素すべてに対して同じことを行う必要があります。
のループ終了条件}while(quadratic - harmonic > 4e-5)
(またはのより小さい定数double
)は、@ RobinRyderのR回答、およびKevin CruijssenのJava回答に基づいています:二次平均は常に最大の大きさであり、調和平均は常に最小です(丸め誤差を無視)。したがって、これら2つの間のデルタをチェックして、収束を検出できます。算術平均をスカラー結果として返します。通常はこの2つの間にあり、おそらく丸め誤差の影響を受けにくいものです。
浮動バージョン:float meanmean_float_avx(__m128);
argと同様に呼び出し可能で、xmm0の値を返します。(つまり、x86-64 System V、またはWindows x64 vectorcallで、x64 fastcallではありません。)または、return-typeを宣言して、__m128
テスト用に2次平均と調和平均を取得できるようにします。
これfloat
にxmm0とxmm1の2つの個別の引数を使用すると、1バイト余分にコストがかかります。一緒にシャッフルして2つの入力を複製するshufps
には、(ちょうどの代わりにunpcklps xmm0,xmm0
)imm8 が必要です。
40 address align 32
41 code bytes global meanmean_float_avx
42 meanmean_float_avx:
43 00000000 B9[52000000] mov ecx, .arith_mean ; allows 2-byte call reg, and a base for loading constants
44 00000005 C4E2791861FC vbroadcastss xmm4, [rcx-4] ; float 4.0
45
46 ;; mean(a,b) = mean(a,b,a,b) for all 4 types of mean
47 ;; so we only ever have to do the length=4 case
48 0000000B 0F14C0 unpcklps xmm0,xmm0 ; [b,a] => [b,b,a,a]
49
50 ; do{ ... } while(quadratic - harmonic > threshold);
51 .loop:
52 ;;; XMM3 = geometric mean: not based on addition. (Transform to log would be hard. AVX512ER has exp with 23-bit accuracy, but not log. vgetexp = floor(lofg2(x)), so that's no good.)
53 ;; sqrt once *first*, making magnitudes closer to 1.0 to reduce rounding error. Numbers are all positive so this is safe.
54 ;; both sqrts first was better behaved, I think.
55 0000000E 0F51D8 sqrtps xmm3, xmm0 ; xmm3 = 4th root(x)
56 00000011 F30F16EB movshdup xmm5, xmm3 ; bring odd elements down to even
57 00000015 0F59EB mulps xmm5, xmm3
58 00000018 0F12DD movhlps xmm3, xmm5 ; high half -> low
59 0000001B 0F59DD mulps xmm3, xmm5 ; xmm3[0] = hproduct(sqrt(xmm))
60 ; sqrtps xmm3, xmm3 ; sqrt(hprod(sqrt)) = 4th root(hprod)
61 ; common final step done after interleaving with quadratic mean
62
63 ;;; XMM2 = quadratic mean = max of the means
64 0000001E C5F859E8 vmulps xmm5, xmm0,xmm0
65 00000022 FFD1 call rcx ; arith mean of squares
66 00000024 0F14EB unpcklps xmm5, xmm3 ; [quad^2, geo^2, ?, ?]
67 00000027 0F51D5 sqrtps xmm2, xmm5 ; [quad, geo, ?, ?]
68
69 ;;; XMM1 = harmonic mean = min of the means
70 0000002A C5D85EE8 vdivps xmm5, xmm4, xmm0 ; 4/x
71 0000002E FFD1 call rcx ; arithmetic mean (under inversion)
72 00000030 C5D85ECD vdivps xmm1, xmm4, xmm5 ; 4/. (the factor of 4 cancels out)
73
74 ;;; XMM5 = arithmetic mean
75 00000034 0F28E8 movaps xmm5, xmm0
76 00000037 FFD1 call rcx
77
78 00000039 0F14E9 unpcklps xmm5, xmm1 ; [arith, harm, ?,?]
79 0000003C C5D014C2 vunpcklps xmm0, xmm5,xmm2 ; x = [arith, harm, quad, geo]
80
81 00000040 0F5CD1 subps xmm2, xmm1 ; largest - smallest mean: guaranteed non-negative
82 00000043 0F2E51F8 ucomiss xmm2, [rcx-8] ; quad-harm > convergence_threshold
83 00000047 73C5 jae .loop
84
85 ; return with the arithmetic mean in the low element of xmm0 = scalar return value
86 00000049 C3 ret
87
88 ;;; "constant pool" between the main function and the helper, like ARM literal pools
89 0000004A ACC52738 .fpconst_threshold: dd 4e-5 ; 4.3e-5 is the highest we can go and still pass the main test cases
90 0000004E 00008040 .fpconst_4: dd 4.0
91 .arith_mean: ; returns XMM5 = hsum(xmm5)/4.
92 00000052 C5D37CED vhaddps xmm5, xmm5 ; slow but small
93 00000056 C5D37CED vhaddps xmm5, xmm5
94 0000005A 0F5EEC divps xmm5, xmm4 ; divide before/after summing doesn't matter mathematically or numerically; divisor is a power of 2
95 0000005D C3 ret
96 0000005E 5E000000 .size: dd $ - meanmean_float_avx
0x5e = 94 bytes
(で作成されたNASMリストnasm -felf64 mean-mean.asm -l/dev/stdout | cut -b -34,$((34+6))-
。リスト部分を削除し、でソースを回復します。cut -b 34- > mean-mean.asm
)
SIMDの水平方向の合計と4による除算(算術平均)は、別の関数で実装されcall
ます(アドレスのコストを償却する関数ポインターを使用)。4/x
前/後、またはx^2
前とsqrtの後、私たちは調和平均と二次の平均値を取得します。(div
正確に表現可能なを乗算する代わりに、これらの命令を記述するのは苦痛0.25
でした。)
幾何平均は、乗算および連鎖sqrtを使用して個別に実装されます。または、最初に1つのsqrtを使用して、指数の大きさを減らし、数値の精度を上げることができます。ログはfloor(log2(x))
AVX512経由でのみ利用可能ですvgetexpps/pd
。ExpはAVX512ER(Xeon Phiのみ)を介して利用できますが、精度は2 ^ -23しかありません。
128ビットAVX命令とレガシーSSEを混在させることは、パフォーマンスの問題ではありません。256ビットAVXとレガシーSSEを混在させることはHaswellで可能ですが、Skylakeでは、SSE命令の潜在的な誤った依存関係を潜在的に作成するだけです。私のdouble
バージョンでは、不必要なループ搬送depチェーン、およびdiv / sqrtレイテンシ/スループットのボトルネックを回避できると思います。
ダブルバージョン:
108 global meanmean_double_avx
109 meanmean_double_avx:
110 00000080 B9[E8000000] mov ecx, .arith_mean
111 00000085 C4E27D1961F8 vbroadcastsd ymm4, [rcx-8] ; float 4.0
112
113 ;; mean(a,b) = mean(a,b,a,b) for all 4 types of mean
114 ;; so we only ever have to do the length=4 case
115 0000008B C4E37D18C001 vinsertf128 ymm0, ymm0, xmm0, 1 ; [b,a] => [b,a,b,a]
116
117 .loop:
118 ;;; XMM3 = geometric mean: not based on addition.
119 00000091 C5FD51D8 vsqrtpd ymm3, ymm0 ; sqrt first to get magnitude closer to 1.0 for better(?) numerical precision
120 00000095 C4E37D19DD01 vextractf128 xmm5, ymm3, 1 ; extract high lane
121 0000009B C5D159EB vmulpd xmm5, xmm3
122 0000009F 0F12DD movhlps xmm3, xmm5 ; extract high half
123 000000A2 F20F59DD mulsd xmm3, xmm5 ; xmm3 = hproduct(sqrt(xmm0))
124 ; sqrtsd xmm3, xmm3 ; xmm3 = 4th root = geomean(xmm0) ;deferred until quadratic
125
126 ;;; XMM2 = quadratic mean = max of the means
127 000000A6 C5FD59E8 vmulpd ymm5, ymm0,ymm0
128 000000AA FFD1 call rcx ; arith mean of squares
129 000000AC 0F16EB movlhps xmm5, xmm3 ; [quad^2, geo^2]
130 000000AF 660F51D5 sqrtpd xmm2, xmm5 ; [quad , geo]
131
132 ;;; XMM1 = harmonic mean = min of the means
133 000000B3 C5DD5EE8 vdivpd ymm5, ymm4, ymm0 ; 4/x
134 000000B7 FFD1 call rcx ; arithmetic mean under inversion
135 000000B9 C5DB5ECD vdivsd xmm1, xmm4, xmm5 ; 4/. (the factor of 4 cancels out)
136
137 ;;; XMM5 = arithmetic mean
138 000000BD C5FC28E8 vmovaps ymm5, ymm0
139 000000C1 FFD1 call rcx
140
141 000000C3 0F16E9 movlhps xmm5, xmm1 ; [arith, harm]
142 000000C6 C4E35518C201 vinsertf128 ymm0, ymm5, xmm2, 1 ; x = [arith, harm, quad, geo]
143
144 000000CC C5EB5CD1 vsubsd xmm2, xmm1 ; largest - smallest mean: guaranteed non-negative
145 000000D0 660F2E51F0 ucomisd xmm2, [rcx-16] ; quad - harm > threshold
146 000000D5 77BA ja .loop
147
148 ; vzeroupper ; not needed for correctness, only performance
149 ; return with the arithmetic mean in the low element of xmm0 = scalar return value
150 000000D7 C3 ret
151
152 ; "literal pool" between the function
153 000000D8 95D626E80B2E113E .fpconst_threshold: dq 1e-9
154 000000E0 0000000000001040 .fpconst_4: dq 4.0 ; TODO: golf these zeros? vpbroadcastb and convert?
155 .arith_mean: ; returns YMM5 = hsum(ymm5)/4.
156 000000E8 C4E37D19EF01 vextractf128 xmm7, ymm5, 1
157 000000EE C5D158EF vaddpd xmm5, xmm7
158 000000F2 C5D17CED vhaddpd xmm5, xmm5 ; slow but small
159 000000F6 C5D35EEC vdivsd xmm5, xmm4 ; only low element matters
160 000000FA C3 ret
161 000000FB 7B000000 .size: dd $ - meanmean_double_avx
0x7b = 123 bytes
Cテストハーネス
#include <immintrin.h>
#include <stdio.h>
#include <math.h>
static const struct ab_avg {
double a,b;
double mean;
} testcases[] = {
{1, 1, 1},
{1, 2, 1.45568889},
{100, 200, 145.568889},
{2.71, 3.14, 2.92103713},
{0.57, 1.78, 1.0848205},
{1.61, 2.41, 1.98965438},
{0.01, 100, 6.7483058},
};
// see asm comments for order of arith, harm, quad, geo
__m128 meanmean_float_avx(__m128); // or float ...
__m256d meanmean_double_avx(__m128d); // or double ...
int main(void) {
int len = sizeof(testcases) / sizeof(testcases[0]);
for(int i=0 ; i<len ; i++) {
const struct ab_avg *p = &testcases[i];
#if 1
__m128 arg = _mm_set_ps(0,0, p->b, p->a);
double res = meanmean_float_avx(arg)[0];
#else
__m128d arg = _mm_loadu_pd(&p->a);
double res = meanmean_double_avx(arg)[0];
#endif
double allowed_diff = (p->b - p->a) / 100000.0;
double delta = fabs(p->mean - res);
if (delta > 1e-3 || delta > allowed_diff) {
printf("%f %f => %.9f but we got %.9f. delta = %g allowed=%g\n",
p->a, p->b, p->mean, res, p->mean - res, allowed_diff);
}
}
while(1) {
double a = drand48(), b = drand48(); // range= [0..1)
if (a>b) {
double tmp=a;
a=b;
b=tmp; // sorted
}
// a *= 0.00000001;
// b *= 123156;
// a += 1<<11; b += (1<<12)+1; // float version gets stuck inflooping on 2048.04, 4097.18 at fpthreshold = 4e-5
// a *= 1<<11 ; b *= 1<<11; // scaling to large magnitude makes sum of squares loses more precision
//a += 1<<11; b+= 1<<11; // adding to large magnitude is hard for everything, catastrophic cancellation
#if 1
printf("testing float %g, %g\n", a, b);
__m128 arg = _mm_set_ps(0,0, b, a);
__m128 res = meanmean_float_avx(arg);
double quad = res[2], harm = res[1]; // same order as double... for now
#else
printf("testing double %g, %g\n", a, b);
__m128d arg = _mm_set_pd(b, a);
__m256d res = meanmean_double_avx(arg);
double quad = res[2], harm = res[1];
#endif
double delta = fabs(quad - harm);
double allowed_diff = (b - a) / 100000.0; // calculated in double even for the float case.
// TODO: use the double res as a reference for float res
// instead of just checking quadratic vs. harmonic mean
if (delta > 1e-3 || delta > allowed_diff) {
printf("%g %g we got q=%g, h=%g, a=%g. delta = %g, allowed=%g\n",
a, b, quad, harm, res[0], quad-harm, allowed_diff);
}
}
}
ビルド:
nasm -felf64 mean-mean.asm &&
gcc -no-pie -fno-pie -g -O2 -march=native mean-mean.c mean-mean.o
明らかに、AVXをサポートするCPU、またはIntel SDEのようなエミュレーターが必要です。ネイティブAVXサポートのないホストでコンパイルするには、-march=sandybridge
またはを使用します-mavx
実行:ハードコーディングされたテストケースを渡しますが、フロートバージョンの場合、ランダムテストケース(b-a)/10000
は質問で設定されたしきい値を満たさないことがよくあります。
$ ./a.out
(note: empty output before the first "testing float" means clean pass on the constant test cases)
testing float 3.90799e-14, 0.000985395
3.90799e-14 0.000985395 we got q=3.20062e-10, h=3.58723e-05, a=2.50934e-05. delta = -3.5872e-05, allowed=9.85395e-09
testing float 0.041631, 0.176643
testing float 0.0913306, 0.364602
testing float 0.0922976, 0.487217
testing float 0.454433, 0.52675
0.454433 0.52675 we got q=0.48992, h=0.489927, a=0.489925. delta = -6.79493e-06, allowed=7.23169e-07
testing float 0.233178, 0.831292
testing float 0.56806, 0.931731
testing float 0.0508319, 0.556094
testing float 0.0189148, 0.767051
0.0189148 0.767051 we got q=0.210471, h=0.210484, a=0.21048. delta = -1.37389e-05, allowed=7.48136e-06
testing float 0.25236, 0.298197
0.25236 0.298197 we got q=0.274796, h=0.274803, a=0.274801. delta = -6.19888e-06, allowed=4.58374e-07
testing float 0.531557, 0.875981
testing float 0.515431, 0.920261
testing float 0.18842, 0.810429
testing float 0.570614, 0.886314
testing float 0.0767746, 0.815274
testing float 0.118352, 0.984891
0.118352 0.984891 we got q=0.427845, h=0.427872, a=0.427863. delta = -2.66135e-05, allowed=8.66539e-06
testing float 0.784484, 0.893906
0.784484 0.893906 we got q=0.838297, h=0.838304, a=0.838302. delta = -7.09295e-06, allowed=1.09422e-06
FPエラーは十分であるため、一部の入力ではクワッドハルが0未満になります。
またはa += 1<<11; b += (1<<12)+1;
コメントなしで:
testing float 2048, 4097
testing float 2048.04, 4097.18
^C (stuck in an infinite loop).
これらの問題はいずれもで発生しませんdouble
。printf
各テスト の前にコメントアウトして、出力が空であることを確認します(if(delta too high)
ブロックからは何もありません)。
TODO:単にクワッドハームとの収束方法を調べるのではなくdouble
、float
バージョンのリファレンスとしてバージョンを使用します。