Actual source code: baijfact81.c
2: /*
3: Factorization code for BAIJ format.
4: */
5: #include <../src/mat/impls/baij/seq/baij.h>
6: #include <petsc/private/kernels/blockinvert.h>
7: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
8: #include <immintrin.h>
9: #endif
10: /*
11: Version for when blocks are 9 by 9
12: */
13: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
14: PetscErrorCode MatLUFactorNumeric_SeqBAIJ_9_NaturalOrdering(Mat B, Mat A, const MatFactorInfo *info)
15: {
16: Mat C = B;
17: Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
18: PetscInt i, j, k, nz, nzL, row;
19: const PetscInt n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
20: const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
21: MatScalar *rtmp, *pc, *mwork, *v, *pv, *aa = a->a;
22: PetscInt flg;
23: PetscReal shift = info->shiftamount;
24: PetscBool allowzeropivot, zeropivotdetected;
26: PetscFunctionBegin;
27: allowzeropivot = PetscNot(A->erroriffailure);
29: /* generate work space needed by the factorization */
30: PetscCall(PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork));
31: PetscCall(PetscArrayzero(rtmp, bs2 * n));
33: for (i = 0; i < n; i++) {
34: /* zero rtmp */
35: /* L part */
36: nz = bi[i + 1] - bi[i];
37: bjtmp = bj + bi[i];
38: for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
40: /* U part */
41: nz = bdiag[i] - bdiag[i + 1];
42: bjtmp = bj + bdiag[i + 1] + 1;
43: for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
45: /* load in initial (unfactored row) */
46: nz = ai[i + 1] - ai[i];
47: ajtmp = aj + ai[i];
48: v = aa + bs2 * ai[i];
49: for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(rtmp + bs2 * ajtmp[j], v + bs2 * j, bs2));
51: /* elimination */
52: bjtmp = bj + bi[i];
53: nzL = bi[i + 1] - bi[i];
54: for (k = 0; k < nzL; k++) {
55: row = bjtmp[k];
56: pc = rtmp + bs2 * row;
57: for (flg = 0, j = 0; j < bs2; j++) {
58: if (pc[j] != 0.0) {
59: flg = 1;
60: break;
61: }
62: }
63: if (flg) {
64: pv = b->a + bs2 * bdiag[row];
65: /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
66: PetscCall(PetscKernel_A_gets_A_times_B_9(pc, pv, mwork));
68: pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
69: pv = b->a + bs2 * (bdiag[row + 1] + 1);
70: nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries inU(row,:), excluding diag */
71: for (j = 0; j < nz; j++) {
72: /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
73: /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
74: v = rtmp + bs2 * pj[j];
75: PetscCall(PetscKernel_A_gets_A_minus_B_times_C_9(v, pc, pv + 81 * j));
76: /* pv incremented in PetscKernel_A_gets_A_minus_B_times_C_9 */
77: }
78: PetscCall(PetscLogFlops(1458 * nz + 1377)); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
79: }
80: }
82: /* finished row so stick it into b->a */
83: /* L part */
84: pv = b->a + bs2 * bi[i];
85: pj = b->j + bi[i];
86: nz = bi[i + 1] - bi[i];
87: for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
89: /* Mark diagonal and invert diagonal for simpler triangular solves */
90: pv = b->a + bs2 * bdiag[i];
91: pj = b->j + bdiag[i];
92: PetscCall(PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2));
93: PetscCall(PetscKernel_A_gets_inverse_A_9(pv, shift, allowzeropivot, &zeropivotdetected));
94: if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
96: /* U part */
97: pv = b->a + bs2 * (bdiag[i + 1] + 1);
98: pj = b->j + bdiag[i + 1] + 1;
99: nz = bdiag[i] - bdiag[i + 1] - 1;
100: for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
101: }
102: PetscCall(PetscFree2(rtmp, mwork));
104: C->ops->solve = MatSolve_SeqBAIJ_9_NaturalOrdering;
105: C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_N;
106: C->assembled = PETSC_TRUE;
108: PetscCall(PetscLogFlops(1.333333333333 * 9 * 9 * 9 * n)); /* from inverting diagonal blocks */
109: PetscFunctionReturn(PETSC_SUCCESS);
110: }
112: PetscErrorCode MatSolve_SeqBAIJ_9_NaturalOrdering(Mat A, Vec bb, Vec xx)
113: {
114: Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data;
115: const PetscInt *ai = a->i, *aj = a->j, *adiag = a->diag, *vi;
116: PetscInt i, k, n = a->mbs;
117: PetscInt nz, bs = A->rmap->bs, bs2 = a->bs2;
118: const MatScalar *aa = a->a, *v;
119: PetscScalar *x, *s, *t, *ls;
120: const PetscScalar *b;
121: __m256d a0, a1, a2, a3, a4, a5, w0, w1, w2, w3, s0, s1, s2, v0, v1, v2, v3;
123: PetscFunctionBegin;
124: PetscCall(VecGetArrayRead(bb, &b));
125: PetscCall(VecGetArray(xx, &x));
126: t = a->solve_work;
128: /* forward solve the lower triangular */
129: PetscCall(PetscArraycpy(t, b, bs)); /* copy 1st block of b to t */
131: for (i = 1; i < n; i++) {
132: v = aa + bs2 * ai[i];
133: vi = aj + ai[i];
134: nz = ai[i + 1] - ai[i];
135: s = t + bs * i;
136: PetscCall(PetscArraycpy(s, b + bs * i, bs)); /* copy i_th block of b to t */
138: __m256d s0, s1, s2;
139: s0 = _mm256_loadu_pd(s + 0);
140: s1 = _mm256_loadu_pd(s + 4);
141: s2 = _mm256_maskload_pd(s + 8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
143: for (k = 0; k < nz; k++) {
144: w0 = _mm256_set1_pd((t + bs * vi[k])[0]);
145: a0 = _mm256_loadu_pd(&v[0]);
146: s0 = _mm256_fnmadd_pd(a0, w0, s0);
147: a1 = _mm256_loadu_pd(&v[4]);
148: s1 = _mm256_fnmadd_pd(a1, w0, s1);
149: a2 = _mm256_loadu_pd(&v[8]);
150: s2 = _mm256_fnmadd_pd(a2, w0, s2);
152: w1 = _mm256_set1_pd((t + bs * vi[k])[1]);
153: a3 = _mm256_loadu_pd(&v[9]);
154: s0 = _mm256_fnmadd_pd(a3, w1, s0);
155: a4 = _mm256_loadu_pd(&v[13]);
156: s1 = _mm256_fnmadd_pd(a4, w1, s1);
157: a5 = _mm256_loadu_pd(&v[17]);
158: s2 = _mm256_fnmadd_pd(a5, w1, s2);
160: w2 = _mm256_set1_pd((t + bs * vi[k])[2]);
161: a0 = _mm256_loadu_pd(&v[18]);
162: s0 = _mm256_fnmadd_pd(a0, w2, s0);
163: a1 = _mm256_loadu_pd(&v[22]);
164: s1 = _mm256_fnmadd_pd(a1, w2, s1);
165: a2 = _mm256_loadu_pd(&v[26]);
166: s2 = _mm256_fnmadd_pd(a2, w2, s2);
168: w3 = _mm256_set1_pd((t + bs * vi[k])[3]);
169: a3 = _mm256_loadu_pd(&v[27]);
170: s0 = _mm256_fnmadd_pd(a3, w3, s0);
171: a4 = _mm256_loadu_pd(&v[31]);
172: s1 = _mm256_fnmadd_pd(a4, w3, s1);
173: a5 = _mm256_loadu_pd(&v[35]);
174: s2 = _mm256_fnmadd_pd(a5, w3, s2);
176: w0 = _mm256_set1_pd((t + bs * vi[k])[4]);
177: a0 = _mm256_loadu_pd(&v[36]);
178: s0 = _mm256_fnmadd_pd(a0, w0, s0);
179: a1 = _mm256_loadu_pd(&v[40]);
180: s1 = _mm256_fnmadd_pd(a1, w0, s1);
181: a2 = _mm256_loadu_pd(&v[44]);
182: s2 = _mm256_fnmadd_pd(a2, w0, s2);
184: w1 = _mm256_set1_pd((t + bs * vi[k])[5]);
185: a3 = _mm256_loadu_pd(&v[45]);
186: s0 = _mm256_fnmadd_pd(a3, w1, s0);
187: a4 = _mm256_loadu_pd(&v[49]);
188: s1 = _mm256_fnmadd_pd(a4, w1, s1);
189: a5 = _mm256_loadu_pd(&v[53]);
190: s2 = _mm256_fnmadd_pd(a5, w1, s2);
192: w2 = _mm256_set1_pd((t + bs * vi[k])[6]);
193: a0 = _mm256_loadu_pd(&v[54]);
194: s0 = _mm256_fnmadd_pd(a0, w2, s0);
195: a1 = _mm256_loadu_pd(&v[58]);
196: s1 = _mm256_fnmadd_pd(a1, w2, s1);
197: a2 = _mm256_loadu_pd(&v[62]);
198: s2 = _mm256_fnmadd_pd(a2, w2, s2);
200: w3 = _mm256_set1_pd((t + bs * vi[k])[7]);
201: a3 = _mm256_loadu_pd(&v[63]);
202: s0 = _mm256_fnmadd_pd(a3, w3, s0);
203: a4 = _mm256_loadu_pd(&v[67]);
204: s1 = _mm256_fnmadd_pd(a4, w3, s1);
205: a5 = _mm256_loadu_pd(&v[71]);
206: s2 = _mm256_fnmadd_pd(a5, w3, s2);
208: w0 = _mm256_set1_pd((t + bs * vi[k])[8]);
209: a0 = _mm256_loadu_pd(&v[72]);
210: s0 = _mm256_fnmadd_pd(a0, w0, s0);
211: a1 = _mm256_loadu_pd(&v[76]);
212: s1 = _mm256_fnmadd_pd(a1, w0, s1);
213: a2 = _mm256_maskload_pd(v + 80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
214: s2 = _mm256_fnmadd_pd(a2, w0, s2);
215: v += bs2;
216: }
217: _mm256_storeu_pd(&s[0], s0);
218: _mm256_storeu_pd(&s[4], s1);
219: _mm256_maskstore_pd(&s[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), s2);
220: }
222: /* backward solve the upper triangular */
223: ls = a->solve_work + A->cmap->n;
224: for (i = n - 1; i >= 0; i--) {
225: v = aa + bs2 * (adiag[i + 1] + 1);
226: vi = aj + adiag[i + 1] + 1;
227: nz = adiag[i] - adiag[i + 1] - 1;
228: PetscCall(PetscArraycpy(ls, t + i * bs, bs));
230: s0 = _mm256_loadu_pd(ls + 0);
231: s1 = _mm256_loadu_pd(ls + 4);
232: s2 = _mm256_maskload_pd(ls + 8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
234: for (k = 0; k < nz; k++) {
235: w0 = _mm256_set1_pd((t + bs * vi[k])[0]);
236: a0 = _mm256_loadu_pd(&v[0]);
237: s0 = _mm256_fnmadd_pd(a0, w0, s0);
238: a1 = _mm256_loadu_pd(&v[4]);
239: s1 = _mm256_fnmadd_pd(a1, w0, s1);
240: a2 = _mm256_loadu_pd(&v[8]);
241: s2 = _mm256_fnmadd_pd(a2, w0, s2);
243: /* v += 9; */
244: w1 = _mm256_set1_pd((t + bs * vi[k])[1]);
245: a3 = _mm256_loadu_pd(&v[9]);
246: s0 = _mm256_fnmadd_pd(a3, w1, s0);
247: a4 = _mm256_loadu_pd(&v[13]);
248: s1 = _mm256_fnmadd_pd(a4, w1, s1);
249: a5 = _mm256_loadu_pd(&v[17]);
250: s2 = _mm256_fnmadd_pd(a5, w1, s2);
252: /* v += 9; */
253: w2 = _mm256_set1_pd((t + bs * vi[k])[2]);
254: a0 = _mm256_loadu_pd(&v[18]);
255: s0 = _mm256_fnmadd_pd(a0, w2, s0);
256: a1 = _mm256_loadu_pd(&v[22]);
257: s1 = _mm256_fnmadd_pd(a1, w2, s1);
258: a2 = _mm256_loadu_pd(&v[26]);
259: s2 = _mm256_fnmadd_pd(a2, w2, s2);
261: /* v += 9; */
262: w3 = _mm256_set1_pd((t + bs * vi[k])[3]);
263: a3 = _mm256_loadu_pd(&v[27]);
264: s0 = _mm256_fnmadd_pd(a3, w3, s0);
265: a4 = _mm256_loadu_pd(&v[31]);
266: s1 = _mm256_fnmadd_pd(a4, w3, s1);
267: a5 = _mm256_loadu_pd(&v[35]);
268: s2 = _mm256_fnmadd_pd(a5, w3, s2);
270: /* v += 9; */
271: w0 = _mm256_set1_pd((t + bs * vi[k])[4]);
272: a0 = _mm256_loadu_pd(&v[36]);
273: s0 = _mm256_fnmadd_pd(a0, w0, s0);
274: a1 = _mm256_loadu_pd(&v[40]);
275: s1 = _mm256_fnmadd_pd(a1, w0, s1);
276: a2 = _mm256_loadu_pd(&v[44]);
277: s2 = _mm256_fnmadd_pd(a2, w0, s2);
279: /* v += 9; */
280: w1 = _mm256_set1_pd((t + bs * vi[k])[5]);
281: a3 = _mm256_loadu_pd(&v[45]);
282: s0 = _mm256_fnmadd_pd(a3, w1, s0);
283: a4 = _mm256_loadu_pd(&v[49]);
284: s1 = _mm256_fnmadd_pd(a4, w1, s1);
285: a5 = _mm256_loadu_pd(&v[53]);
286: s2 = _mm256_fnmadd_pd(a5, w1, s2);
288: /* v += 9; */
289: w2 = _mm256_set1_pd((t + bs * vi[k])[6]);
290: a0 = _mm256_loadu_pd(&v[54]);
291: s0 = _mm256_fnmadd_pd(a0, w2, s0);
292: a1 = _mm256_loadu_pd(&v[58]);
293: s1 = _mm256_fnmadd_pd(a1, w2, s1);
294: a2 = _mm256_loadu_pd(&v[62]);
295: s2 = _mm256_fnmadd_pd(a2, w2, s2);
297: /* v += 9; */
298: w3 = _mm256_set1_pd((t + bs * vi[k])[7]);
299: a3 = _mm256_loadu_pd(&v[63]);
300: s0 = _mm256_fnmadd_pd(a3, w3, s0);
301: a4 = _mm256_loadu_pd(&v[67]);
302: s1 = _mm256_fnmadd_pd(a4, w3, s1);
303: a5 = _mm256_loadu_pd(&v[71]);
304: s2 = _mm256_fnmadd_pd(a5, w3, s2);
306: /* v += 9; */
307: w0 = _mm256_set1_pd((t + bs * vi[k])[8]);
308: a0 = _mm256_loadu_pd(&v[72]);
309: s0 = _mm256_fnmadd_pd(a0, w0, s0);
310: a1 = _mm256_loadu_pd(&v[76]);
311: s1 = _mm256_fnmadd_pd(a1, w0, s1);
312: a2 = _mm256_maskload_pd(v + 80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
313: s2 = _mm256_fnmadd_pd(a2, w0, s2);
314: v += bs2;
315: }
317: _mm256_storeu_pd(&ls[0], s0);
318: _mm256_storeu_pd(&ls[4], s1);
319: _mm256_maskstore_pd(&ls[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), s2);
321: w0 = _mm256_setzero_pd();
322: w1 = _mm256_setzero_pd();
323: w2 = _mm256_setzero_pd();
325: /* first row */
326: v0 = _mm256_set1_pd(ls[0]);
327: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[0]);
328: w0 = _mm256_fmadd_pd(a0, v0, w0);
329: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[4]);
330: w1 = _mm256_fmadd_pd(a1, v0, w1);
331: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[8]);
332: w2 = _mm256_fmadd_pd(a2, v0, w2);
334: /* second row */
335: v1 = _mm256_set1_pd(ls[1]);
336: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[9]);
337: w0 = _mm256_fmadd_pd(a3, v1, w0);
338: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[13]);
339: w1 = _mm256_fmadd_pd(a4, v1, w1);
340: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[17]);
341: w2 = _mm256_fmadd_pd(a5, v1, w2);
343: /* third row */
344: v2 = _mm256_set1_pd(ls[2]);
345: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[18]);
346: w0 = _mm256_fmadd_pd(a0, v2, w0);
347: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[22]);
348: w1 = _mm256_fmadd_pd(a1, v2, w1);
349: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[26]);
350: w2 = _mm256_fmadd_pd(a2, v2, w2);
352: /* fourth row */
353: v3 = _mm256_set1_pd(ls[3]);
354: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[27]);
355: w0 = _mm256_fmadd_pd(a3, v3, w0);
356: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[31]);
357: w1 = _mm256_fmadd_pd(a4, v3, w1);
358: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[35]);
359: w2 = _mm256_fmadd_pd(a5, v3, w2);
361: /* fifth row */
362: v0 = _mm256_set1_pd(ls[4]);
363: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[36]);
364: w0 = _mm256_fmadd_pd(a0, v0, w0);
365: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[40]);
366: w1 = _mm256_fmadd_pd(a1, v0, w1);
367: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[44]);
368: w2 = _mm256_fmadd_pd(a2, v0, w2);
370: /* sixth row */
371: v1 = _mm256_set1_pd(ls[5]);
372: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[45]);
373: w0 = _mm256_fmadd_pd(a3, v1, w0);
374: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[49]);
375: w1 = _mm256_fmadd_pd(a4, v1, w1);
376: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[53]);
377: w2 = _mm256_fmadd_pd(a5, v1, w2);
379: /* seventh row */
380: v2 = _mm256_set1_pd(ls[6]);
381: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[54]);
382: w0 = _mm256_fmadd_pd(a0, v2, w0);
383: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[58]);
384: w1 = _mm256_fmadd_pd(a1, v2, w1);
385: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[62]);
386: w2 = _mm256_fmadd_pd(a2, v2, w2);
388: /* eighth row */
389: v3 = _mm256_set1_pd(ls[7]);
390: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[63]);
391: w0 = _mm256_fmadd_pd(a3, v3, w0);
392: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[67]);
393: w1 = _mm256_fmadd_pd(a4, v3, w1);
394: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[71]);
395: w2 = _mm256_fmadd_pd(a5, v3, w2);
397: /* ninth row */
398: v0 = _mm256_set1_pd(ls[8]);
399: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[72]);
400: w0 = _mm256_fmadd_pd(a3, v0, w0);
401: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[76]);
402: w1 = _mm256_fmadd_pd(a4, v0, w1);
403: a2 = _mm256_maskload_pd((&(aa + bs2 * adiag[i])[80]), _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
404: w2 = _mm256_fmadd_pd(a2, v0, w2);
406: _mm256_storeu_pd(&(t + i * bs)[0], w0);
407: _mm256_storeu_pd(&(t + i * bs)[4], w1);
408: _mm256_maskstore_pd(&(t + i * bs)[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), w2);
410: PetscCall(PetscArraycpy(x + i * bs, t + i * bs, bs));
411: }
413: PetscCall(VecRestoreArrayRead(bb, &b));
414: PetscCall(VecRestoreArray(xx, &x));
415: PetscCall(PetscLogFlops(2.0 * (a->bs2) * (a->nz) - A->rmap->bs * A->cmap->n));
416: PetscFunctionReturn(PETSC_SUCCESS);
417: }
418: #endif