An efficient macro for Stump – two terminal nodes tree

This post was kindly contributed by SAS Programming for Data Mining Applications - go there to comment and to read the full post.


In this post, I post an improved SAS macro of the single partition split algorithm in Chapter 2 of “Pharmaceutical Statistics Using SAS: A Practical Guide” by Alex Dmitrienko, Christy Chuang-Stein, Ralph B. D’Agostino.
The single partition split algorithm is a simplified version of Stumps, and is a weak classifier, usually used to form the base weak learner for boosting algorithm. This specific classifier seeks to separate the space into 2 subspaces for independent variables where each subspace has increasingly higher purity of the response classes, say 1 and 0. In the examples, Gini Index is used to measure purity/impurity.
The SAS example macro %SPLIT in the book (found @ Here) is for illustration purpose, and is so inefficient that it practically can’t be used by industrial standard.
I modified this macro and made it usable in real business applications where millions of observations and hundreds of variables are more than common.

Improved Code:



%macro Dsplit(dsn, p);
/************************************************/
/* dsn: Name of input SAS data sets. All        */
/*        independent variables should be named */
/*        as X1, X2,....,Xp and be continous    */
/*        numeric variables                     */
/*   p: Number of independent variables         */
/************************************************/
options nonotes;
%do i=1 %to &p;
proc sort data=&dsn.(keep=y w x&i)  out=work.sort; by x&i;
run;

proc means data=work.sort noprint;
     var y;
     weight w;
     output out=_ysum  n(y)=ntotal  sum(y)=ysum;
run;
data _null_;
     set _ysum;
     call symput('ntotal', ntotal);
     call symput('ysum', ysum); 
run;
data y_pred;
     set work.sort  end=eof;
     array _p[&ntotal, 2] _temporary_;
     array _g[&ntotal, 2] _temporary_;
     array _x[&ntotal]    _temporary_;
     retain _y1  _oldx 0;
     if _n_=1 then _oldx=x&i;
     _x[_n_]=x&i;
     if ^eof then do;
        _y1+y; 
       _p[_n_, 1]=_y1/_n_; _p[_n_, 2]=(&ysum-_y1)/(&ntotal-_n_);
       ppn1=_n_/&ntotal;  ppn2=1-ppn1;
       _g[_n_, 1]=2*(ppn1*(1-_p[_n_, 1])*_p[_n_, 1]+ppn2*(1-_p[_n_, 2])*_p[_n_, 2]);
       if _n_>1 then _g[_n_-1, 2]=(_oldx+x&i)/2;
       _oldx=x&i;
     end;
     else do;
       _g[_n_-1, 2]=(_oldx+x&i)/2;
       ginimin=2;
       do i=1 to &ntotal-1;
          gini=_g[i, 1]; x&i=_g[i, 2];
          keep gini x&i;
          output;
     if gini lt ginimin then do;
        ginimin=gini; xmin=x&i;
            p1_LH=_P[i, 1]; p0_LH=1-p1_LH;
              p1_RH=_P[i, 2]; p0_RH=1-p1_RH;
        c_L=(p1_LH>0.5); c_R=(p1_RH>0.5);
    end;
       end;
     end;  
     do i=1 to &ntotal;
        if _x[i]<=xmin then y_pred=c_L;
        if _x[i]>xmin  then y_pred=c_R;
        keep y_pred y w; output y_pred;
     end;
     call symput('ginimin', ginimin);
     call symput('xmin', xmin);
     end;
run;
data _giniout&i;
     length varname $ 8;
     varname="x&i";
     cutoff=&xmin;
     gini=&ginimin;
run;
%end;
data outsplit;
     set %do i=1 %to &p;
            _giniout&i
         %end;;
run;
proc datasets library=work nolist;
     delete _giniout:;
quit;
option notes;    
%mend;

/* weak classifiers for Boost Algorithm */
%macro gini(_y0wsum, _y1wsum, i, nobs);
data _giniout&i.(keep=varname  mingini cut_val   p0_LH  p1_LH  c_L  p0_RH  p1_RH  c_R);     
     length varname $ 8;
     set sorted  end=eof;
  retain  _y0w  _y1w  _w  _ginik  0;
  retain  p0_LH  p1_LH  p0_RH  p1_RH  c_L  c_R  0; 
  array _mingini{4}  _temporary_; 
  if _n_=1 then do;
     _y0w = (y^=1)*w;  _y1w = (y=1)*w;   _w = w;    
  _mingini[1] = 2;
        _mingini[2] = 1; 
        _mingini[3] = x&i; 
        _mingini[4] = x&i;        
  end;
  else do;
     _y0w + (y^=1)*w; _y1w + (y=1)*w; _w + w;
  end;

  if ^eof then do;       
        p0_L = _y0w/_w;  p0_R = (&_y0wsum - _y0w)/(1-_w);
        p1_L = _y1w/_w;  p1_R = (&_y1wsum - _y1w)/(1-_w);
        _ginik= p1_L*p0_L*_w + p1_R*p0_R*(1-_w);
  end;

  if _ginik<_mingini[1] then do;     
  _mingini[1]=_ginik;   _mingini[2]=_n_; _mingini[3]=x&i;
  p0_LH=p0_L;  p1_LH=p1_L;  p0_RH=p0_R;  p1_RH=p1_R;
  c_L = (p1_LH > 0.5); c_R = (p1_RH > 0.5);  
  end; 
  if _n_=(_mingini[2]+1) then _mingini[4]=x&i;

  if eof then do;   
     cut_val=(_mingini[3]+_mingini[4])/2;
  mingini=_mingini[ 1]; 
        varname="x&i";   
  output  ; 
  end;   

run;
%mend;

%macro stump_gini(dsn, p, outdsn);
/***************************************************/
/*    dsn: Name of input SAS data sets. All        */
/*          independent variables should be named  */
/*          as X1, X2,....,Xp and be continous     */
/*          numeric variables                      */
/*      p: Number of independent variables         */
/* outdsn: Name of output SAS data sets. Used for  */
/*          Subsequent scoring. Not to named as    */
/*          _giniout.....                          */
/***************************************************/
%local i  p  ;

%do i=1 %to &p;
    proc sort data=&dsn.(keep=x&i  y  w)  out=sorted  sortsize=max;
      by x&i;
 run;
 data sortedv/view=sortedv;
      set sorted;
   y1=(y=1); y0=(y^=1);
 run;
 proc means data=sortedv(keep=y0 y1 w)   noprint;
      var y0  y1;
   weight w;
   output out=_ywsum(keep=_y0wsum  _y1wsum  _FREQ_)  
                sum(y0)=_y0wsum  sum(y1)=_y1wsum;
 run;
    data _null_;
      set _ywsum;
   call execute('%gini('|| compress(_y0wsum) || ','
                        || compress(_y1wsum) || ','
                              || compress(&i)      || ','
                              || compress(_FREQ_)  || ')'
                      );
 run;
%end;
data &outdsn;
     set %do i=1 %to &p;
         _giniout&i
   %end;;
run;
proc sort data=&outdsn; by mingini; run;
proc datasets library=work nolist;
     delete %do i=1 %to &p;
            _giniout&i
   %end;;
run;quit;
%mend;


%macro css(_ywsum, i, nobs);
data _regout&i.(keep=varname  mincss cut_val  ypred_L  ypred_R);     
     length varname $ 8;
     set sorted  end=eof;
  retain _yw  _w  0;
  retain  ypred_L  ypred_R 0;
  array _mincss{4}  _temporary_; 
  if _n_=1 then do;
     _yw = y*w;  _w = w;  
  _mincss[1] = constant('BIG'); 
        _mincss[2] = 1; 
        _mincss[3] = x&i; 
        _mincss[4] = x&i;
        ypred_L = _yw/_w;  ypred_R = (&_ywsum-_yw)/(1-_w);  
  end;
  else do;
     _yw + y*w; _w + w;
  end;
  if ^eof then do;     
  cssk = 1 - _yw/_w*_yw - (&_ywsum-_yw)/(1-_w)*(&_ywsum-_yw);   
  end;
  else do;
     cssk = 1 -_yw**2;
  end;
  if cssk<_mincss[1] then do;     
  _mincss[1]=cssk;   _mincss[2]=_n_; _mincss[3]=x&i;
  ypred_L=_yw/_w;  ypred_R=(&_ywsum-_yw)/(1-_w);
  end; 
  if _n_=(_mincss[2]+1) then _mincss[4]=x&i;

  if eof then do;   
     cut_val=(_mincss[3]+_mincss[4])/2;
  mincss=_mincss[ 1]; 
        varname="x&i";   
  output  ;
  end;   
     
run;
%mend;

%macro stump_css(dsn, p, outdsn);
/***************************************************/
/*    dsn: Name of input SAS data sets. All        */
/*          independent variables should be named  */
/*          as X1, X2,....,Xp and be continous     */
/*          numeric variables                      */
/*      p: Number of independent variables         */
/* outdsn: Name of output SAS data sets. Used for  */
/*          Subsequent scoring. Not to named as    */
/*          _giniout.....                          */
/***************************************************/
%local i  p  ;
options nosource;
%do i=1 %to &p;
    proc sort data=&dsn.(keep=x&i  y  w)  out=sorted  sortsize=max;
      by x&i;
 run;
 proc means data=sorted(keep=y w)   noprint;
      var y;
   weight w;
   output out=_ywsum(keep=_ywsum  _FREQ_)  sum(y)=_ywsum;
 run;
    data _null_;
      set _ywsum;
   call execute('%css(' || compress(_ywsum) || ','
                              || compress(&i)     || ','
                              || compress(_FREQ_) || ')'
                      );
 run;
%end;
data &outdsn;
     set %do i=1 %to &p;
         _regout&i
   %end;;
run;
proc sort data=&outdsn; by mincss; run;
options source;
%mend;

Comparing the time used and results.

The example data used is the AUC Small training data, 200 sets of predictors for 15,000 ratings, from AusDM2009 competition, and can be found @ Here .

Using example code from the book, it takes 1563 seconds on a regular Windows desktop (Core2Duo E6750 2.67GHz, 4GB Memory, 7K2 rpm HDD with 8MB cache) to process 5 numerical continous variables, whereas with improved macro, it only takes 1 second to process the same amount of data. This improvement is criticle since weak classifiers like this one won’t be used alone, but as the base for more time-consuming Boosting algorithms. With original macro, it is practically not usable for any boosting algorithms on data sets with hundreds of or more observations.

Original Macro:

Improved Macro:
Comparing the results from original macro and the improved macro, they both select X3 with the same partition cut off point and the same Gini Index.

Reference:
Pharmaceutical Statistics Using SAS: A Practical Guide by Alex Dmitrienko, Christy Chuang-Stein, Ralph B. D’Agostino, SAS Publishing 2007

Pharmaceutical Statistics Using SAS: A Practical Guide (SAS Press)

This post was kindly contributed by SAS Programming for Data Mining Applications - go there to comment and to read the full post.