function [x] = symtrisolve(TD,TS,y)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SYMTRISOLVE Symmetric Tridiagonal Solver.
% x = symtrisolve(TD, TS, y) attempts to solve the system of linear
% equations Tx = y, where T is symmetric tridiagonal with diagonal
% nx1 vector TD and subdiagonal and superdiagonal (n-1)x1 vector TS.
%
% Author: Roummel F. Marcia
% Date: July 10, 2014
%
% Based on the paper, "A simplified pivoting strategy for
% symmetric tridiagonal matrices", by James R. Bunch and
% Roummel F. Marcia, Numer. Linear Algebra Appl., vol. 13,
% p. 865-867 (2006).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
n = length(y);
x = zeros(n,1);
% L1 and L2 are the vectors corresponding to the 1st and 2nd subdiagonal
% of the lower triangular factor L, respectively.
% B1 is a vector of diagonal entries in the block diagonal matrix B,
% and B2 is a vector of sub-diagonal entries in B.
L1 = zeros(n-1,1);
L2 = zeros(n-2,1);
B1 = zeros(n,1);
B2 = zeros(n-1,1);
% Solve for n <= 2
if n < 1
fprintf('n must be a positive integer\n')
return
elseif n == 1
if abs(TD(1)) < eps
fprintf('T is a scalar and near 0\n')
return
else
x = y/T(1,1);
return
end
elseif n == 2
Delta = TD(1)*TD(2) - TS(1)*TS(1);
if abs(Delta) < eps
fprintf('T is (near) singular\n')
return
else
denom = TS(1)*( (TD(1)/TS(1)*TD(2)/TS(1)) - 1.0 );
x(1) = ( TD(2)/TS(1)*y(1)/denom) - (y(2)/denom);
x(2) = (-y(1)/denom) + (TD(1)/TS(1)*y(2)/denom);
return
end
end
% Set constant alpha for pivot size criterion
alpha = (sqrt(5)-1)/2;
% Initialize
j = 1;
alpha1 = TD(j);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Main factorization: Compute unit lower triangular factor L and
% block diagonal matrix B with 1x1 and 2x2 blocks
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
while j <= n-2
beta2 = TS(j);
beta3 = TS(j+1);
alpha2 = TD(j+1);
Delta = alpha1*alpha2 - beta2*beta2;
% decide which pivot size to use
if (abs(alpha1*alpha2) >= alpha*beta2*beta2) || ...
(abs(Delta) <= alpha*abs(alpha1*beta3)) || ...
(abs(beta2*Delta) <= alpha*abs(alpha1*alpha1*beta3))
% use a 1x1 pivot
if abs(alpha1) < eps
L1(j) = 0.0;
B1(j) = 0.0;
alpha1 = TD(j+1);
else
L1(j) = beta2/alpha1;
B1(j) = alpha1;
alpha1 = TD(j+1) - (beta2*beta2/alpha1);
end
j = j+1;
else
% use a 2x2 pivot
L2(j) = -beta2*beta3/Delta;
L1(j+1) = alpha1*beta3/Delta;
B1(j) = alpha1;
B2(j) = beta2;
B1(j+1) = alpha2;
alpha1 = TD(j+2) - (alpha1*beta3*beta3/Delta);
j = j+2;
end
end
% Be careful near end of factorization.
if j == n
B1(j) = alpha1;
j = j+1;
else
% j = n-1;
beta2 = TS(j);
alpha2 = TD(j+1);
if abs(alpha1*alpha2) >= alpha*beta2*beta2
% use a 1x1 pivot
if abs(alpha1) < eps
L1(j) = 0.0;
B1(j) = 0.0;
B2(j) = TD(j+1);
else
L1(j) = beta2/alpha1;
B1(j) = alpha1;
B1(j+1) = TD(j+1) - (beta2*beta2/alpha1);
end
j = j+1;
else
B1(j) = alpha1;
B2(j) = beta2;
B1(j+1) = alpha2;
j = j+2;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% End main factorization
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Solve Tx = y by computing LBLTx = y
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% First solve Lw = y
w = zeros(n,1);
w(1) = y(1);
w(2) = y(2) - L1(1)*w(1);
for k = 3:n
w(k) = y(k) - L1(k-1)*w(k-1) - L2(k-2)*w(k-2);
end
% Next, solve Bz = w
z = zeros(n,1);
k = 1;
while k <= n
if k == n
if abs(B1(k)) > eps
z(k) = w(k)/B1(k);
else
fprintf('B is singular\n');
break
end
k = k+1;
else
if abs(B2(k)) < eps
if abs(B1(k)) > eps
z(k) = w(k)/B1(k);
else
fprintf('B is singular\n');
break
end
k = k+1;
else
Delta = B1(k)*B1(k+1) - B2(k)*B2(k);
if abs(Delta) > eps
denom = B2(k)*( (B1(k)/B2(k)*B1(k+1)/B2(k)) - 1.0 );
z(k) = ( B1(k+1)/B2(k)*w(k) - w(k+1) )/denom;
z(k+1) = (-w(k) + B1(k)/B2(k)*w(k+1) )/denom;
else
fprintf('B is singular\n');
break
end
k = k+2;
end
end
end
% Finally, solve LTx = z
x = zeros(n,1);
x(n) = z(n);
x(n-1) = z(n-1) - L1(n-1)*x(n);
for k = n-2:-1:1
x(k) = z(k) - L1(k)*x(k+1) - L2(k)*x(k+2);
end
% The vector x is the solution to Tx = y.
return