Canoe
Comprehensive Atmosphere N' Ocean Engine
tridiag.c
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 
4 #include "linalg.h"
5 
6 #define WITHOUT_PIVOTING 0
7 #define WITH_PIVOTING 1
8 
9 /*======================= tridiag() =========================================*/
10 
11 /*
12  * Solve a tridiagnonal matrix system.
13  * Adapted from Numerical Recipes in C, 2nd ed., pp. 51-54.
14  * If pivot_type = WITH_PIVOTING, use band_decomp() and band_back_sub().
15  * Assumes zero-based indexing.
16  */
17 
18 #undef AA
19 #define AA(i, j) aa[(m1 + m2 + 1) * (i) + (j)]
20 #undef AAORIG
21 #define AAORIG(i, j) aaorig[(m1 + m2 + 1) * (i) + (j)]
22 #undef AAL
23 #define AAL(i, j) aal[m1 * (i) + (j)]
24 
25 void tridiag(int n, double *a, double *b, double *c, double *r, double *u,
26  int pivot_type) {
27  int j, m1, m2, mm, *index;
28  double bet, *gam, *aa, *aaorig, *aal, d;
29  /*
30  * The following are part of DEBUG_MILESTONE(.) statements:
31  */
32  int idbms = 0;
33  static char dbmsname[] = "tridiag";
34 
35  /*
36  * Check validity of n.
37  */
38  if (n <= 0) {
39  fprintf(stderr, "**error:%s, n = %d\n", dbmsname, n);
40  exit(1);
41  }
42 
43  if (pivot_type == WITHOUT_PIVOTING) {
44  /* Allocate memory. */
45  gam = (double *)malloc(n * sizeof(double));
46 
47  if (b[0] == 0.0) {
48  fprintf(stderr,
49  "**error:%s,b[0] = 0\n"
50  "Rewrite equations as a set of order n-1, with u[1]\n"
51  "trivially eliminated\n",
52  dbmsname);
53  exit(1);
54  }
55  bet = b[0];
56  u[0] = r[0] / bet;
57  for (j = 1; j < n; j++) {
58  gam[j] = c[j - 1] / bet;
59  bet = b[j] - a[j] * gam[j];
60  if (bet == 0.) {
61  /*
62  * Encountered a zero pivot.
63  * Try again using pivot_type = WITH_PIVOTING.
64  */
65  fprintf(stderr, "Warning: tridiag(): retrying with pivoting.\n");
66  /* Free allocated memory. */
67  free(gam);
68  tridiag(n, a, b, c, r, u, WITH_PIVOTING);
69  return;
70  }
71  u[j] = (r[j] - a[j] * u[j - 1]) / bet;
72  }
73 
74  /* Backsubstitution: */
75  for (j = n - 2; j >= 0; j--) {
76  u[j] -= gam[j + 1] * u[j + 1];
77  }
78  /* Free allocated memory. */
79  free(gam);
80  return;
81  } else if (pivot_type == WITH_PIVOTING) {
82  /*
83  * Use band_decomp() and band_back_sub().
84  */
85  m1 = 1;
86  m2 = 1;
87  mm = m1 + m2 + 1;
88  /*
89  * Allocate memory.
90  */
91  aa = (double *)malloc(n * (m1 + m2 + 1) * sizeof(double));
92  aaorig = (double *)malloc(n * (m1 + m2 + 1) * sizeof(double));
93  aal = (double *)malloc(n * m1 * sizeof(double));
94  index = (int *)malloc(n * sizeof(int));
95 
96  /*
97  * Load matrix AA and keep copy AAORIG.
98  */
99  for (j = 0; j < n; j++) {
100  AA(j, m1 + 1) = AAORIG(j, m1 + 1) = c[j];
101  AA(j, m1) = AAORIG(j, m1) = b[j];
102  AA(j, m1 - 1) = AAORIG(j, m1 - 1) = a[j];
103  }
104 
105  band_decomp(n, m1, m2, aa, aal, index, &d);
106 
107  /*
108  * Since tridiag() does not overwrite the input rhs vector, r,
109  * with the answer, u, but band_back_sub() does, copy r into u
110  * before calling band_back_sub().
111  */
112  for (j = 0; j < n; j++) {
113  u[j] = r[j];
114  }
115 
116  band_back_sub(n, m1, m2, aa, aal, index, u);
117 
118  /*
119  * Reduce roundoff errors with call to band_improve().
120  */
121  band_improve(n, m1, m2, aaorig, aa, aal, index, r, u);
122 
123  /*
124  * Free allocated memory.
125  */
126  free(aa);
127  free(aaorig);
128  free(aal);
129  free(index);
130 
131  return;
132  } else {
133  fprintf(stderr, "**error:%s, unrecognized pivot_type=%d\n", dbmsname,
134  pivot_type);
135  exit(1);
136  }
137 }
138 
139 /*======================= end of tridiag() ===================================*/
void band_back_sub(int n, int m1, int m2, double *a, double *al, int *index, double *b)
Definition: band_back_sub.c:22
void band_decomp(int n, int m1, int m2, double *a, double *al, int *index, double *d)
Definition: band_decomp.c:28
void band_improve(int n, int m1, int m2, double *aorig, double *a, double *al, int *index, double *b, double *x)
Definition: band_improve.c:36
#define AAORIG(i, j)
Definition: tridiag.c:21
#define WITH_PIVOTING
Definition: tridiag.c:7
void tridiag(int n, double *a, double *b, double *c, double *r, double *u, int pivot_type)
Definition: tridiag.c:25
#define WITHOUT_PIVOTING
Definition: tridiag.c:6
#define AA(i, j)
Definition: tridiag.c:19