function [h,stats] = regressPlot(X,Y,varargin);
% [h,stats] = regressPlot(X,Y,[options]);
%
% Plots a bunch of data points in X and Y against each other,
% fits a regression line, and prints out R^2 in the corner of
% the plot.
%
% Options:
% 
% 'alpha': change the confidence parameter for the regression
% fitting. It defaults to 0.05.
%
% 'nolabel': don't add text labels for R^2 and p to the graph.
%
% 'title': add the stats label as a title, instead of text to the side.
%
% 'font': set font name for reporting R^2 and p(default: 'Arial').
%
% 'fontsz': set font size for reporting R^2 and p (default: 12-pt.).
%
% 'symbol': set symbol for data points (default: '+').Can also enter it
% directly if it is one of the following: +, o, ., x, *, $.
%
% 'onecolor': if matrices are passed in, make into linear vectors before
% plotting, so symbols are all the same color. Not doing this is used to 
% e.g. color-code data from different subjects, to eyeball
% inter-subject-driven regressions. NOTE: will do this automatically if
% either X or Y has NaNs in it.
%
% 'legend',{leglabels}: add a legend with the specified labels
% (cell-of-strings).
%
% 'symbols',[{values}] or 'multisymbol',[{values}]: if X is a matrix of columns, 
% instead of coloring each column, use different symbols (useful if preparing 
% black&white plots). If a cell is passed as the next argument, this will
% be interpreted as specifying the symbols for each column (symbol shape
% and color strings -- see HELP PLOT for options).
%
% 'confidencebounds': plot 75% confidence bounds parallel to the regression
% line.
%
% 04/03 ras
% 10/03 ras: updated to deal w/ NaNs, added several optional arguments.
% 12/03 ras: intelligently places label depending on regression line slope,
% added multisymbol option, confidencebounds option. 
alpha = 0.05;
labelOn = 1;
font = 'Arial';
fontsz = 12;
symbol = 'o';
multisymbol = 0;
confbounds = 0;
leg = {};  % if empty, no legend
colors = {'b'};

%%%%% parse the option flags
if ~isempty(varargin)
    for i = 1:length(varargin)
        if ischar(varargin{i})
            switch lower(varargin{i})
                case 'alpha',
                    alpha = varargin{i+1};
                case {'nolabel','nolabels'},
                    labelOn = 0;
                case {'title','ttl','ttllabel'},
                    labelOn = 2; fontsz = 9;
                case {'labelfont','font'}
                    font = varargin{i+1};
                case {'labelfontsz','fontsz','fontsize'}
                    fontsz = varargin{i+1};
                case {'o','+','x','*','.','$'}
                    symbol = varargin{i};
                case 'legend',
                    leg = varargin{i+1};
                case 'symbol',
                    symbol = varargin{i+1};
                case 'onecolor',
                    X = X(1:size(X,1)*size(X,2)*size(X,3));
                    Y = Y(1:size(X,1)*size(X,2)*size(X,3));
                case {'color','colors'},
                    colors = {varargin{i+1}};
                case 'extend',
                    bounds = varargin{i+1};
                case {'multisymbol','symbols'}
                    multisymbol = 1;
                    if i < length(varargin) & iscell(varargin{i+1})
                        symbols = varargin{i+1};
                    else
                       symbols = {'o','s','x','*','+','d','<','>','^',...
                                  'p','h'};
                      for j = 1:length(symbols)
                          symbols{j} = ['r' symbols{j}];
                      end
                   end
                case {'confidencebounds','confbounds'}
                    confbounds = 1;                    
                otherwise,
%                     fprintf('ras_regressPlot unrecognized flag.\n');
            end
        end
    end
end

% remove any NaNs in the inputs
if any(any(isnan(X))) | any(any(isnan(Y)))
	ind = find(~isnan(X) & ~isnan(Y));
	X = X(ind);
	Y = Y(ind);
end

%%%%% plot the points
hold on
plot(X,Y,symbol);
if multisymbol
   while length(symbols) < size(X,2)
        symbols = [symbols symbols];
   end
   cla;
   for j = 1:size(X,2)
       plot(X(:,j),Y(:,j),symbols{j});
   end
end

%%%%%% set symbol colors
setLineColors(colors);

%%%%% add legend if specified
if ~isempty(leg)
    legend(leg,-1);
end

%%%%% do regression analysis, add fitted line
Y = reshape(Y,[1 size(Y,1)*size(Y,2)]);
X = reshape(X,[1 size(X,1)*size(X,2)]);

% calculate a fitted line to the data
pts = linspace(-3*abs(min(X)),3*max(X),length(X));
[P S] = polyfit(X,Y,1); % this only does a first-order fit
[LN,delta] = polyval(P,pts,S);

if confbounds
	hold on, errorbar(pts,LN,delta,'m:');
	plot(pts,LN+delta,'m:');
	plot(pts,LN-delta,'m:');
end

hold on, plot(pts,LN,'k','linewidth',2);

% the main regression
[B,BINT,R,RINT,stats] = regress(Y',[X' ones(length(X),1)],alpha);
LN = B(1)*pts + B(2);

% set axes, text label
AX = axis;
AX(1) = 0.9*min(X); AX(2) = 1.1*max(X);
AX(3) = 0.9*min(Y); AX(4) = 1.1*max(Y);
axis(AX);
axis square
if labelOn > 0
    xloc = AX(1)+1*(AX(2)-AX(1));
    if B(1) > 0
        yloc = AX(3)+0.2*(AX(4)-AX(3));
    else
        yloc = AX(3)+0.9*(AX(4)-AX(3));
    end    
    
    if labelOn==1 % label on side
        msg = sprintf('R^2: %1.2f \n%s',stats(1),pvalText(stats(3),1));

        text(xloc,yloc,msg,'Color','k','HorizontalAlignment','left',...
             'FontSize',fontsz,'FontName',font);    
    else
        msg = sprintf('R^2: %1.2f, %s',stats(1),pvalText(stats(3),1));

        title(msg,'FontSize',fontsz,'FontName',font);
	end
end

% add some text in the stdout displaying the results nicely
if stats(3) < 10^(-2)
    bound = num2str(floor(log10(stats(3)))); 
    ptxt = sprintf('p < 10e%s',bound);
else
    ptxt = sprintf('p = %1.2f',stats(3));
end
fprintf('Regression Results: \t')
fprintf('Y = %3.2f*X + %3.2f, R^2 = %3.2f, F = %3.2f, %s\n',...
            P(1),P(2),stats(1),stats(2),ptxt);

% set the output arguments nicely        
h = gca;

tmp = stats; clear stats;
stats.R2 = tmp(1);
stats.F = tmp(2);
stats.p = tmp(3);
stats.betas = B;
stats.betaConfIntervals = BINT;
stats.residuals = R;
stats.resConfIntervals = RINT;
stats.lineCoefficients = P;
stats.linefitStruct = S;
stats.alpha = alpha;

return