Actual source code: tomographyADMM.c

  1: #include <petsctao.h>
  2: /*
  3: Description:   ADMM tomography reconstruction example .
  4:                0.5*||Ax-b||^2 + lambda*g(x)
  5: Reference:     BRGN Tomography Example
  6: */

  8: static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
  9:                       A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
 10:                       We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
 11:                       g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
 12:                       natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for  \n\
 13:                       either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
 14:                       Then, we augment both f and g, and solve it via ADMM. \n\
 15:                       D is the M*N transform matrix so that D*x is sparse. \n";

 17: typedef struct {
 18:   PetscInt  M, N, K, reg;
 19:   PetscReal lambda, eps, mumin;
 20:   Mat       A, ATA, H, Hx, D, Hz, DTD, HF;
 21:   Vec       c, xlb, xub, x, b, workM, workN, workN2, workN3, xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/
 22: } AppCtx;

 24: /*------------------------------------------------------------*/

 26: PetscErrorCode NullJacobian(Tao tao, Vec X, Mat J, Mat Jpre, void *ptr)
 27: {
 28:   PetscFunctionBegin;
 29:   PetscFunctionReturn(PETSC_SUCCESS);
 30: }

 32: /*------------------------------------------------------------*/

 34: static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao)
 35: {
 36:   PetscReal lambda, mu;
 37:   AppCtx   *user;
 38:   Vec       out, work, y, x;
 39:   Tao       admm_tao, misfit;

 41:   PetscFunctionBegin;
 42:   user = NULL;
 43:   mu   = 0;
 44:   PetscCall(TaoGetADMMParentTao(tao, &admm_tao));
 45:   PetscCall(TaoADMMGetMisfitSubsolver(admm_tao, &misfit));
 46:   PetscCall(TaoADMMGetSpectralPenalty(admm_tao, &mu));
 47:   PetscCall(TaoShellGetContext(tao, &user));

 49:   lambda = user->lambda;
 50:   work   = user->workN;
 51:   PetscCall(TaoGetSolution(tao, &out));
 52:   PetscCall(TaoGetSolution(misfit, &x));
 53:   PetscCall(TaoADMMGetDualVector(admm_tao, &y));

 55:   /* Dx + y/mu */
 56:   PetscCall(MatMult(user->D, x, work));
 57:   PetscCall(VecAXPY(work, 1 / mu, y));

 59:   /* soft thresholding */
 60:   PetscCall(TaoSoftThreshold(work, -lambda / mu, lambda / mu, out));
 61:   PetscFunctionReturn(PETSC_SUCCESS);
 62: }

 64: /*------------------------------------------------------------*/

 66: PetscErrorCode MisfitObjectiveAndGradient(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
 67: {
 68:   AppCtx *user = (AppCtx *)ptr;

 70:   PetscFunctionBegin;
 71:   /* Objective  0.5*||Ax-b||_2^2 */
 72:   PetscCall(MatMult(user->A, X, user->workM));
 73:   PetscCall(VecAXPY(user->workM, -1, user->b));
 74:   PetscCall(VecDot(user->workM, user->workM, f));
 75:   *f *= 0.5;
 76:   /* Gradient. ATAx-ATb */
 77:   PetscCall(MatMult(user->ATA, X, user->workN));
 78:   PetscCall(MatMultTranspose(user->A, user->b, user->workN2));
 79:   PetscCall(VecWAXPY(g, -1., user->workN2, user->workN));
 80:   PetscFunctionReturn(PETSC_SUCCESS);
 81: }

 83: /*------------------------------------------------------------*/

 85: PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
 86: {
 87:   AppCtx *user = (AppCtx *)ptr;

 89:   PetscFunctionBegin;
 90:   /* compute regularizer objective
 91:    * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
 92:   PetscCall(VecCopy(X, user->workN2));
 93:   PetscCall(VecPow(user->workN2, 2.));
 94:   PetscCall(VecShift(user->workN2, user->eps * user->eps));
 95:   PetscCall(VecSqrtAbs(user->workN2));
 96:   PetscCall(VecCopy(user->workN2, user->workN3));
 97:   PetscCall(VecShift(user->workN2, -user->eps));
 98:   PetscCall(VecSum(user->workN2, f_reg));
 99:   *f_reg *= user->lambda;
100:   /* compute regularizer gradient = lambda*x */
101:   PetscCall(VecPointwiseDivide(G_reg, X, user->workN3));
102:   PetscCall(VecScale(G_reg, user->lambda));
103:   PetscFunctionReturn(PETSC_SUCCESS);
104: }

106: /*------------------------------------------------------------*/

108: PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
109: {
110:   AppCtx   *user = (AppCtx *)ptr;
111:   PetscReal temp;

113:   PetscFunctionBegin;
114:   /* compute regularizer objective = lambda*|z|_2^2 */
115:   PetscCall(VecDot(X, X, &temp));
116:   *f_reg = 0.5 * user->lambda * temp;
117:   /* compute regularizer gradient = lambda*z */
118:   PetscCall(VecCopy(X, G_reg));
119:   PetscCall(VecScale(G_reg, user->lambda));
120:   PetscFunctionReturn(PETSC_SUCCESS);
121: }

123: /*------------------------------------------------------------*/

125: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
126: {
127:   PetscFunctionBegin;
128:   PetscFunctionReturn(PETSC_SUCCESS);
129: }

131: /*------------------------------------------------------------*/

133: static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
134: {
135:   AppCtx *user = (AppCtx *)ptr;

137:   PetscFunctionBegin;
138:   PetscCall(MatMult(user->D, x, user->workN));
139:   PetscCall(VecPow(user->workN2, 2.));
140:   PetscCall(VecShift(user->workN2, user->eps * user->eps));
141:   PetscCall(VecSqrtAbs(user->workN2));
142:   PetscCall(VecShift(user->workN2, -user->eps));
143:   PetscCall(VecReciprocal(user->workN2));
144:   PetscCall(VecScale(user->workN2, user->eps * user->eps));
145:   PetscCall(MatDiagonalSet(H, user->workN2, INSERT_VALUES));
146:   PetscFunctionReturn(PETSC_SUCCESS);
147: }

149: /*------------------------------------------------------------*/

151: PetscErrorCode FullObjGrad(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
152: {
153:   AppCtx   *user = (AppCtx *)ptr;
154:   PetscReal f_reg;

156:   PetscFunctionBegin;
157:   /* Objective  0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/
158:   PetscCall(MatMult(user->A, X, user->workM));
159:   PetscCall(VecAXPY(user->workM, -1, user->b));
160:   PetscCall(VecDot(user->workM, user->workM, f));
161:   PetscCall(VecNorm(X, NORM_2, &f_reg));
162:   *f *= 0.5;
163:   *f += user->lambda * f_reg * f_reg;
164:   /* Gradient. ATAx-ATb + 2*lambda*x */
165:   PetscCall(MatMult(user->ATA, X, user->workN));
166:   PetscCall(MatMultTranspose(user->A, user->b, user->workN2));
167:   PetscCall(VecWAXPY(g, -1., user->workN2, user->workN));
168:   PetscCall(VecAXPY(g, 2 * user->lambda, X));
169:   PetscFunctionReturn(PETSC_SUCCESS);
170: }
171: /*------------------------------------------------------------*/

173: static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
174: {
175:   PetscFunctionBegin;
176:   PetscFunctionReturn(PETSC_SUCCESS);
177: }
178: /*------------------------------------------------------------*/

180: PetscErrorCode InitializeUserData(AppCtx *user)
181: {
182:   char        dataFile[] = "tomographyData_A_b_xGT"; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
183:   PetscViewer fd;                                    /* used to load data from file */
184:   PetscInt    k, n;
185:   PetscScalar v;

187:   PetscFunctionBegin;
188:   /* Load the A matrix, b vector, and xGT vector from a binary file. */
189:   PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, dataFile, FILE_MODE_READ, &fd));
190:   PetscCall(MatCreate(PETSC_COMM_WORLD, &user->A));
191:   PetscCall(MatSetType(user->A, MATAIJ));
192:   PetscCall(MatLoad(user->A, fd));
193:   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->b));
194:   PetscCall(VecLoad(user->b, fd));
195:   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->xGT));
196:   PetscCall(VecLoad(user->xGT, fd));
197:   PetscCall(PetscViewerDestroy(&fd));

199:   PetscCall(MatGetSize(user->A, &user->M, &user->N));

201:   PetscCall(MatCreate(PETSC_COMM_WORLD, &user->D));
202:   PetscCall(MatSetSizes(user->D, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N));
203:   PetscCall(MatSetFromOptions(user->D));
204:   PetscCall(MatSetUp(user->D));
205:   for (k = 0; k < user->N; k++) {
206:     v = 1.0;
207:     n = k + 1;
208:     if (k < user->N - 1) PetscCall(MatSetValues(user->D, 1, &k, 1, &n, &v, INSERT_VALUES));
209:     v = -1.0;
210:     PetscCall(MatSetValues(user->D, 1, &k, 1, &k, &v, INSERT_VALUES));
211:   }
212:   PetscCall(MatAssemblyBegin(user->D, MAT_FINAL_ASSEMBLY));
213:   PetscCall(MatAssemblyEnd(user->D, MAT_FINAL_ASSEMBLY));

215:   PetscCall(MatTransposeMatMult(user->D, user->D, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &user->DTD));

217:   PetscCall(MatCreate(PETSC_COMM_WORLD, &user->Hz));
218:   PetscCall(MatSetSizes(user->Hz, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N));
219:   PetscCall(MatSetFromOptions(user->Hz));
220:   PetscCall(MatSetUp(user->Hz));
221:   PetscCall(MatAssemblyBegin(user->Hz, MAT_FINAL_ASSEMBLY));
222:   PetscCall(MatAssemblyEnd(user->Hz, MAT_FINAL_ASSEMBLY));

224:   PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->x)));
225:   PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->workM)));
226:   PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->workN)));
227:   PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->workN2)));
228:   PetscCall(VecSetSizes(user->x, PETSC_DECIDE, user->N));
229:   PetscCall(VecSetSizes(user->workM, PETSC_DECIDE, user->M));
230:   PetscCall(VecSetSizes(user->workN, PETSC_DECIDE, user->N));
231:   PetscCall(VecSetSizes(user->workN2, PETSC_DECIDE, user->N));
232:   PetscCall(VecSetFromOptions(user->x));
233:   PetscCall(VecSetFromOptions(user->workM));
234:   PetscCall(VecSetFromOptions(user->workN));
235:   PetscCall(VecSetFromOptions(user->workN2));

237:   PetscCall(VecDuplicate(user->workN, &(user->workN3)));
238:   PetscCall(VecDuplicate(user->x, &(user->xlb)));
239:   PetscCall(VecDuplicate(user->x, &(user->xub)));
240:   PetscCall(VecDuplicate(user->x, &(user->c)));
241:   PetscCall(VecSet(user->xlb, 0.0));
242:   PetscCall(VecSet(user->c, 0.0));
243:   PetscCall(VecSet(user->xub, PETSC_INFINITY));

245:   PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA)));
246:   PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx)));
247:   PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF)));

249:   PetscCall(MatAssemblyBegin(user->ATA, MAT_FINAL_ASSEMBLY));
250:   PetscCall(MatAssemblyEnd(user->ATA, MAT_FINAL_ASSEMBLY));
251:   PetscCall(MatAssemblyBegin(user->Hx, MAT_FINAL_ASSEMBLY));
252:   PetscCall(MatAssemblyEnd(user->Hx, MAT_FINAL_ASSEMBLY));
253:   PetscCall(MatAssemblyBegin(user->HF, MAT_FINAL_ASSEMBLY));
254:   PetscCall(MatAssemblyEnd(user->HF, MAT_FINAL_ASSEMBLY));

256:   user->lambda = 1.e-8;
257:   user->eps    = 1.e-3;
258:   user->reg    = 2;
259:   user->mumin  = 5.e-6;

261:   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");
262:   PetscCall(PetscOptionsInt("-reg", "Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL));
263:   PetscCall(PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL));
264:   PetscCall(PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL));
265:   PetscCall(PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL));
266:   PetscOptionsEnd();
267:   PetscFunctionReturn(PETSC_SUCCESS);
268: }

270: /*------------------------------------------------------------*/

272: PetscErrorCode DestroyContext(AppCtx *user)
273: {
274:   PetscFunctionBegin;
275:   PetscCall(MatDestroy(&user->A));
276:   PetscCall(MatDestroy(&user->ATA));
277:   PetscCall(MatDestroy(&user->Hx));
278:   PetscCall(MatDestroy(&user->Hz));
279:   PetscCall(MatDestroy(&user->HF));
280:   PetscCall(MatDestroy(&user->D));
281:   PetscCall(MatDestroy(&user->DTD));
282:   PetscCall(VecDestroy(&user->xGT));
283:   PetscCall(VecDestroy(&user->xlb));
284:   PetscCall(VecDestroy(&user->xub));
285:   PetscCall(VecDestroy(&user->b));
286:   PetscCall(VecDestroy(&user->x));
287:   PetscCall(VecDestroy(&user->c));
288:   PetscCall(VecDestroy(&user->workN3));
289:   PetscCall(VecDestroy(&user->workN2));
290:   PetscCall(VecDestroy(&user->workN));
291:   PetscCall(VecDestroy(&user->workM));
292:   PetscFunctionReturn(PETSC_SUCCESS);
293: }

295: /*------------------------------------------------------------*/

297: int main(int argc, char **argv)
298: {
299:   Tao         tao, misfit, reg;
300:   PetscReal   v1, v2;
301:   AppCtx     *user;
302:   PetscViewer fd;
303:   char        resultFile[] = "tomographyResult_x";

305:   PetscFunctionBeginUser;
306:   PetscCall(PetscInitialize(&argc, &argv, (char *)0, help));
307:   PetscCall(PetscNew(&user));
308:   PetscCall(InitializeUserData(user));

310:   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
311:   PetscCall(TaoSetType(tao, TAOADMM));
312:   PetscCall(TaoSetSolution(tao, user->x));
313:   /* f(x) + g(x) for parent tao */
314:   PetscCall(TaoADMMSetSpectralPenalty(tao, 1.));
315:   PetscCall(TaoSetObjectiveAndGradient(tao, NULL, FullObjGrad, (void *)user));
316:   PetscCall(MatShift(user->HF, user->lambda));
317:   PetscCall(TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void *)user));

319:   /* f(x) for misfit tao */
320:   PetscCall(TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void *)user));
321:   PetscCall(TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void *)user));
322:   PetscCall(TaoADMMSetMisfitHessianChangeStatus(tao, PETSC_FALSE));
323:   PetscCall(TaoADMMSetMisfitConstraintJacobian(tao, user->D, user->D, NullJacobian, (void *)user));

325:   /* g(x) for regularizer tao */
326:   if (user->reg == 1) {
327:     PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void *)user));
328:     PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void *)user));
329:     PetscCall(TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE));
330:   } else if (user->reg == 2) {
331:     PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void *)user));
332:     PetscCall(MatShift(user->Hz, 1));
333:     PetscCall(MatScale(user->Hz, user->lambda));
334:     PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void *)user));
335:     PetscCall(TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE));
336:   } else PetscCheck(user->reg == 3, PETSC_COMM_WORLD, PETSC_ERR_ARG_UNKNOWN_TYPE, "Incorrect Reg type"); /* TaoShell case */

338:   /* Set type for the misfit solver */
339:   PetscCall(TaoADMMGetMisfitSubsolver(tao, &misfit));
340:   PetscCall(TaoADMMGetRegularizationSubsolver(tao, &reg));
341:   PetscCall(TaoSetType(misfit, TAONLS));
342:   if (user->reg == 3) {
343:     PetscCall(TaoSetType(reg, TAOSHELL));
344:     PetscCall(TaoShellSetContext(reg, (void *)user));
345:     PetscCall(TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold));
346:   } else {
347:     PetscCall(TaoSetType(reg, TAONLS));
348:   }
349:   PetscCall(TaoSetVariableBounds(misfit, user->xlb, user->xub));

351:   /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
352:   PetscCall(TaoADMMSetRegularizerCoefficient(tao, user->lambda));
353:   PetscCall(TaoADMMSetRegularizerConstraintJacobian(tao, NULL, NULL, NullJacobian, (void *)user));
354:   PetscCall(TaoADMMSetMinimumSpectralPenalty(tao, user->mumin));

356:   PetscCall(TaoADMMSetConstraintVectorRHS(tao, user->c));
357:   PetscCall(TaoSetFromOptions(tao));
358:   PetscCall(TaoSolve(tao));

360:   /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */
361:   PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, resultFile, FILE_MODE_WRITE, &fd));
362:   PetscCall(VecView(user->x, fd));
363:   PetscCall(PetscViewerDestroy(&fd));

365:   /* compute the error */
366:   PetscCall(VecAXPY(user->x, -1, user->xGT));
367:   PetscCall(VecNorm(user->x, NORM_2, &v1));
368:   PetscCall(VecNorm(user->xGT, NORM_2, &v2));
369:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1 / v2)));

371:   /* Free TAO data structures */
372:   PetscCall(TaoDestroy(&tao));
373:   PetscCall(DestroyContext(user));
374:   PetscCall(PetscFree(user));
375:   PetscCall(PetscFinalize());
376:   return 0;
377: }

379: /*TEST

381:    build:
382:       requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)

384:    test:
385:       suffix: 1
386:       localrunfiles: tomographyData_A_b_xGT
387:       args:  -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc

389:    test:
390:       suffix: 2
391:       localrunfiles: tomographyData_A_b_xGT
392:       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8  -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor

394:    test:
395:       suffix: 3
396:       localrunfiles: tomographyData_A_b_xGT
397:       args:  -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor

399:    test:
400:       suffix: 4
401:       localrunfiles: tomographyData_A_b_xGT
402:       args:  -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc

404:    test:
405:       suffix: 5
406:       localrunfiles: tomographyData_A_b_xGT
407:       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc

409:    test:
410:       suffix: 6
411:       localrunfiles: tomographyData_A_b_xGT
412:       args:  -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc

414: TEST*/