This post was kindly contributed by SAS Programming for Data Mining Applications - go there to comment and to read the full post. |
Test the Stochastic Gradient Decending Logistic Regression in SAS. The logic and code follows the code piece of Ravi Varadhan, Ph.D from this discussion of R Help. The blog SAS Die Hard also has a post about SGD Logistic Regression in SAS.
filename foo url "http://www.biostat.jhsph.edu/~ririzarr/Teaching/754/lbw.dat" ;
data temp;
infile foo length=len;
input low age lwt race smoke ptl ht ui ftv bwt;
put low age lwt race smoke ptl ht ui ftv bwt;
if _n_>1;
run;
proc standard data=temp out=temp mean=0 std=1;
var age lwt smoke ht ui;
run;
proc contents data=temp out=vars(keep=varnum name type) noprint; run;
proc sql noprint;
select name into :covars separated by " "
from vars
where substr(name, 1, 1)="x"
;
select cats("b_", name) into :covars2 separated by " "
from vars
where substr(name, 1, 1)="x"
;
select count(*)+1 into :nparms
from vars
where substr(name, 1, 1)="x"
;
quit;
%put &covars2;
sasfile _xbeta close;
%lr_sgd(temp, beta, z, &covars, alpha=0.008, decay=0.8, criterion=0.00001, maxiter=1000);
options fullstimer;
proc logistic data=temp outest=_beta desc noprint;
model low = age lwt smoke ht ui;
run;
The macro %LR_SGD.
/*
SAS macro:
Logistic Regression using Stochastic Gradient Descent.
Name:
%ls_sgd();
Copyright (c) 2009, Liang Xie (Contact me @ xie1978 at gmail dot com)
The SAS macro is a demonstration of an implementation of logistic
regression modelstrained by Stochastic Gradient Decent (SGD).This
program reads a training set specified as &dsn_in, trains a logistic
regression model, and outputs the estimated coefficients to &outest.
Example usage:
%let inputdata=train_data;
%let beta=coefficient;
%let response=Event;
%lr_sgd(&inputdata, &beta, &response, &covars,
alpha=0.008, decay=0.8,
criterion=0.00001, maxiter=1000);
The following topics are not covered for simplicity:
- bias term
- regularization
- multiclass logistic regression (maximum entropy model)
- calibration of learning rate
Distributed under GNU Affero General Public License version 3. This
program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, only version 3 of the
License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
*/
%macro logistic(dsn_in, outest, response, alpha=0.0005);
proc score data=&dsn_in score=&outest type=parms out=score(keep=score);
var intercept &covars;
run;
data _xtemp/view=_xtemp;
merge &dsn score ;
_w=&response - 1/(1+exp(-score));
/*
array x{*} intercept &covars;
_w=&response - 1/(1+exp(-score));
do i=1 to dim(x); x[i]=x[i]*_w; end;
*/
run;
data _x&outest;
array x{*} intercept &covars;
array _x{*} b_intercept &covars2;
retain b_intercept &covars2;
retain logneg logpos 0;
modify _x&outest;
do i=1 to dim(x); x[i]=_x[i]; end;
do until (eof);
set _xtemp end=eof;
do i=1 to dim(x);
_x[i]=_x[i]+&alpha*x[i]*_w;
end;
end;
replace;
run;
%mend;
%macro compare(dsn1, dsn2);
data _null_;
merge &dsn1 &dsn2;
array _x1{*} intercept &covars;
array _x2{*} b_intercept &covars2;
retain maxdiff 0;
do i=1 to dim(_x1);
maxdiff=max(maxdiff, abs(_x1[i]-_x2[i]));
*put _x1[*]=;
*put _x2[*]=;
end;
call symput('maxdiff', maxdiff);
run;
%mend;
%macro lr_sgd(dsn, outest, response, covars,
alpha=0.0005, decay=0.9,
criterion=0.00001, maxiter=1000);
options nosource nonotes;
options nomlogic nomprint;
%local i t0 t1 dt maxdiff status stopiter;
%let t00=%sysfunc(datetime());
data &dsn;
set &dsn;
intercept=1; _w=1;
run;
data &outest;
retain _TYPE_ "PARMS" _NAME_ "SCORE";
array x{*} intercept &covars;
do i=1 to dim(x);
x[i]=0;
end;
drop i;
output;
run;
data _x&outest;
retain _TYPE_ "PARMS" _NAME_ "SCORE";
array bx{*} b_intercept &covars2;
array x{*} intercept &covars;
set &outest;
do j=1 to dim(x); bx[j]=x[j]; end;
keep b_intercept &covars2 _TYPE_ _NAME_;
drop j;
run;
sasfile _x&outest load;
%let stopiter=&maxiter;
%let status=Not Converged.;
%do i=1 %to &maxiter;
%let t0=%sysfunc(datetime());
%logistic(&dsn, &outest, &response, alpha=&alpha);
%compare(&outest, _x&outest);
data &outest;
retain _TYPE_ "PARMS" _NAME_ "SCORE";
array bx{*} b_intercept &covars2;
array x{*} intercept &covars;
set _x&outest;
do j=1 to dim(x); x[j]=bx[j]; end;
keep intercept &covars _TYPE_ _NAME_;
drop j;
run;
%let alpha=%sysevalf(&alpha * &decay);
%let alpha=%sysfunc(max(0.00005, &alpha));
%let t1=%sysfunc(datetime());
%let dt=%sysfunc(round(&t1-&t0, 0.001));
%put Iteration &i, time used &dt, converge criteria is &maxdiff;
%if %sysevalf(&maxdiff<&criterion) %then %do;
%let stopiter=&i;
%let i=%eval(&maxiter+1);
%let status=Converged.;
%end;
%end;
sasfile _x&outest close;
%let t11=%sysfunc(datetime());
%let dt=%sysfunc(round(&t11-&t00, 0.01));
%put Total Time is &dt sec.;
%put Total Iteration is &stopiter, convergence status is &status;
%put At Final Iteration, max difference is &maxdiff;
options mlogic mprint notes source;
%mend;
This post was kindly contributed by SAS Programming for Data Mining Applications - go there to comment and to read the full post. |