View Javadoc

1   package org.mortbay.servlet;
2   
3   import java.io.IOException;
4   import java.util.HashSet;
5   import java.util.Queue;
6   import java.util.StringTokenizer;
7   import java.util.concurrent.ConcurrentHashMap;
8   import java.util.concurrent.Semaphore;
9   import java.util.concurrent.TimeUnit;
10  
11  import javax.servlet.Filter;
12  import javax.servlet.FilterChain;
13  import javax.servlet.FilterConfig;
14  import javax.servlet.ServletContext;
15  import javax.servlet.ServletException;
16  import javax.servlet.ServletRequest;
17  import javax.servlet.ServletResponse;
18  import javax.servlet.http.HttpServletRequest;
19  import javax.servlet.http.HttpServletResponse;
20  import javax.servlet.http.HttpSession;
21  import javax.servlet.http.HttpSessionBindingEvent;
22  import javax.servlet.http.HttpSessionBindingListener;
23  
24  import org.mortbay.log.Log;
25  import org.mortbay.thread.Timeout;
26  import org.mortbay.util.ArrayQueue;
27  import org.mortbay.util.ajax.Continuation;
28  import org.mortbay.util.ajax.ContinuationSupport;
29  
30  /**
31   * Denial of Service filter
32   * 
33   * <p>
34   * This filter is based on the {@link QoSFilter}. it is useful for limiting
35   * exposure to abuse from request flooding, whether malicious, or as a result of
36   * a misconfigured client.
37   * <p>
38   * The filter keeps track of the number of requests from a connection per
39   * second. If a limit is exceeded, the request is either rejected, delayed, or
40   * throttled.
41   * <p>
42   * When a request is throttled, it is placed in a priority queue. Priority is
43   * given first to authenticated users and users with an HttpSession, then
44   * connections which can be identified by their IP addresses. Connections with
45   * no way to identify them are given lowest priority.
46   * <p>
47   * The {@link #extractUserId(ServletRequest request)} function should be
48   * implemented, in order to uniquely identify authenticated users.
49   * <p>
50   * The following init parameters control the behavior of the filter:
51   * 
52   * maxRequestsPerSec    the maximum number of requests from a connection per
53   *                      second. Requests in excess of this are first delayed, 
54   *                      then throttled.
55   * 
56   * delayMs              is the delay given to all requests over the rate limit, 
57   *                      before they are considered at all. -1 means just reject request, 
58   *                      0 means no delay, otherwise it is the delay.
59   * 
60   * maxWaitMs            how long to blocking wait for the throttle semaphore.
61   * 
62   * throttledRequests    is the number of requests over the rate limit able to be
63   *                      considered at once.
64   * 
65   * throttleMs           how long to async wait for semaphore.
66   * 
67   * maxRequestMs         how long to allow this request to run.
68   * 
69   * maxIdleTrackerMs     how long to keep track of request rates for a connection, 
70   *                      before deciding that the user has gone away, and discarding it
71   * 
72   * insertHeaders        if true , insert the DoSFilter headers into the response. Defaults to true.
73   * 
74   * trackSessions        if true, usage rate is tracked by session if a session exists. Defaults to true.
75   * 
76   * remotePort           if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.
77   * 
78   * ipWhitelist          a comma-separated list of IP addresses that will not be rate limited
79   */
80  
81  public class DoSFilter implements Filter
82  {
83      final static String __TRACKER = "DoSFilter.Tracker";
84      final static String __THROTTLED = "DoSFilter.Throttled";
85  
86      final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
87      final static int __DEFAULT_DELAY_MS = 100;
88      final static int __DEFAULT_THROTTLE = 5;
89      final static int __DEFAULT_WAIT_MS=50;
90      final static long __DEFAULT_THROTTLE_MS = 30000L;
91      final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
92      final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
93  
94      final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
95      final static String DELAY_MS_INIT_PARAM = "delayMs";
96      final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
97      final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
98      final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
99      final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
100     final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
101     final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
102     final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
103     final static String REMOTE_PORT_INIT_PARAM="remotePort";
104     final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
105 
106     final static int USER_AUTH = 2;
107     final static int USER_SESSION = 2;
108     final static int USER_IP = 1;
109     final static int USER_UNKNOWN = 0;
110 
111     ServletContext _context;
112 
113     protected long _delayMs;
114     protected long _throttleMs;
115     protected long _waitMs;
116     protected long _maxRequestMs;
117     protected long _maxIdleTrackerMs;
118     protected boolean _insertHeaders;
119     protected boolean _trackSessions;
120     protected boolean _remotePort;
121     protected Semaphore _passes;
122     protected Queue<Continuation>[] _queue;
123 
124     protected int _maxRequestsPerSec;
125     protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
126     private HashSet<String> _whitelist; 
127     
128     private final Timeout _requestTimeoutQ = new Timeout();
129     private final Timeout _trackerTimeoutQ = new Timeout();
130 
131     private Thread _timerThread;
132     private volatile boolean _running;
133 
134     public void init(FilterConfig filterConfig)
135     {
136         _context = filterConfig.getServletContext();
137 
138         _queue = new Queue[getMaxPriority() + 1];
139         for (int p = 0; p < _queue.length; p++)
140             _queue[p] = new ArrayQueue<Continuation>();
141 
142         int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
143         if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
144             baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
145         _maxRequestsPerSec = baseRateLimit;
146 
147         long delay = __DEFAULT_DELAY_MS;
148         if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
149             delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
150         _delayMs = delay;
151 
152         int passes = __DEFAULT_THROTTLE;
153         if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
154             passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
155         _passes = new Semaphore(passes,true);
156 
157         long wait = __DEFAULT_WAIT_MS;
158         if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
159             wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
160         _waitMs = wait;
161 
162         long suspend = __DEFAULT_THROTTLE_MS;
163         if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
164             suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
165         _throttleMs = suspend;
166 
167         long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
168         if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
169             maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
170         _maxRequestMs = maxRequestMs;
171 
172         long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
173         if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
174             maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
175         _maxIdleTrackerMs = maxIdleTrackerMs;
176         
177         String whitelistString = "";
178         if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
179             whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
180         
181         // empty 
182         if (whitelistString.length() == 0 )
183             _whitelist = new HashSet<String>();
184         else
185         {
186             StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
187             _whitelist = new HashSet<String>(tokenizer.countTokens());
188             while (tokenizer.hasMoreTokens())
189                 _whitelist.add(tokenizer.nextToken().trim());
190             
191             Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
192         }
193 
194         String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
195         _insertHeaders = tmp==null || Boolean.parseBoolean(tmp); 
196         
197         tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
198         _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
199         
200         tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
201         _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
202 
203         _requestTimeoutQ.setNow();
204         _requestTimeoutQ.setDuration(_maxRequestMs);
205         
206         _trackerTimeoutQ.setNow();
207         _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
208         
209         _running=true;
210         _timerThread = (new Thread()
211         {
212             public void run()
213             {
214                 try
215                 {
216                     while (_running)
217                     {
218                         synchronized (_requestTimeoutQ)
219                         {
220                             _requestTimeoutQ.setNow();
221                             _requestTimeoutQ.tick();
222 
223                             _trackerTimeoutQ.setNow(_requestTimeoutQ.getNow());
224                             _trackerTimeoutQ.tick();
225                         }
226                         try
227                         {
228                             Thread.sleep(100);
229                         }
230                         catch (InterruptedException e)
231                         {
232                             Log.ignore(e);
233                         }
234                     }
235                 }
236                 finally
237                 {
238                     Log.info("DoSFilter timer exited");
239                 }
240             }
241         });
242         _timerThread.start();
243     }
244     
245 
246     public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
247     {
248         final HttpServletRequest srequest = (HttpServletRequest)request;
249         final HttpServletResponse sresponse = (HttpServletResponse)response;
250         
251         final long now=_requestTimeoutQ.getNow();
252         
253         // Look for the rate tracker for this request
254         RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
255             
256         if (tracker==null)
257         {
258             // This is the first time we have seen this request.
259             
260             // get a rate tracker associated with this request, and record one hit
261             tracker = getRateTracker(request);
262             
263             // Calculate the rate and check it is over the allowed limit
264             final boolean overRateLimit = tracker.isRateExceeded(now);
265 
266             // pass it through if  we are not currently over the rate limit
267             if (!overRateLimit)
268             {
269                 doFilterChain(filterchain,srequest,sresponse);
270                 return;
271             }   
272             
273             // We are over the limit.
274             Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
275             
276             // So either reject it, delay it or throttle it
277             switch((int)_delayMs)
278             {
279                 case -1: 
280                 {
281                     // Reject this request
282                     if (_insertHeaders)
283                         ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
284                     ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
285                     return;
286                 }
287                 case 0:
288                 {
289                     // fall through to throttle code
290                     request.setAttribute(__TRACKER,tracker);
291                     break;
292                 }
293                 default:
294                 {
295                     // insert a delay before throttling the request
296                     if (_insertHeaders)
297                         ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
298                     Continuation continuation = ContinuationSupport.getContinuation((HttpServletRequest)request,this);
299                     request.setAttribute(__TRACKER,tracker);
300                     continuation.suspend(_delayMs);
301                     // can fall through if this was a waiting continuation
302                 }
303             }
304         }
305 
306         // Throttle the request
307         boolean accepted = false;
308         try
309         {
310             // check if we can afford to accept another request at this time
311             accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
312 
313             if (!accepted)
314             {
315                 // we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
316 
317                 final Continuation continuation = ContinuationSupport.getContinuation((HttpServletRequest)request,this);
318                 
319                 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
320                 if (throttled!=Boolean.TRUE && _throttleMs>0)
321                 {
322                     int priority = getPriority(request,tracker);
323                     request.setAttribute(__THROTTLED,Boolean.TRUE);
324                     if (_insertHeaders)
325                         ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
326                     synchronized (this)
327                     {
328                         _queue[priority].add(continuation);
329                         continuation.reset();
330                         if(continuation.suspend(_throttleMs))
331                         {
332                                 // handle waiting continuation strangeness
333                             // continuation was waiting and was resumed.
334                             _passes.acquire();
335                             accepted = true;
336                         }
337                         // can fall through if this was a waiting continuation
338                     }
339                 }
340                 // else were we resumed?
341                 else if (continuation.isResumed())
342                 {
343                     // we were resumed and somebody stole our pass, so we wait for the next one.
344                     _passes.acquire();
345                     accepted = true;
346                 }
347             }
348             
349             // if we were accepted (either immediately or after throttle)
350             if (accepted)       
351                 // call the chain
352                 doFilterChain(filterchain,srequest,sresponse);
353             else                
354             {
355                 // fail the request
356                 if (_insertHeaders)
357                     ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
358                 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
359             }
360         }
361         catch (InterruptedException e)
362         {
363             _context.log("DoS",e);
364             ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
365         }
366         finally
367         {
368             if (accepted)
369             {
370                 // wake up the next highest priority request.
371                 synchronized (_queue)
372                 {
373                     for (int p = _queue.length; p-- > 0;)
374                     {
375                         Continuation continuation = _queue[p].poll();
376 
377                         if (continuation != null)
378                         {
379                             continuation.resume();
380                             break;
381                         }
382                     }
383                 }
384                 _passes.release();
385             }
386         }
387     }
388 
389     /**
390      * @param chain
391      * @param request
392      * @param response
393      * @throws IOException
394      * @throws ServletException
395      */
396     protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) 
397         throws IOException, ServletException
398     {
399         final Thread thread=Thread.currentThread();
400         
401         final Timeout.Task _timeout = new Timeout.Task()
402         {
403             public void expired()
404             {
405                 // take drastic measures to return this response and stop this thread.
406                 if( !response.isCommitted() )
407                 {
408                     response.setHeader("Connection", "close");
409                 }
410 
411                 try 
412                 {
413                     try
414                     {
415                         response.getWriter().close();
416                     }
417                     catch (IllegalStateException e)
418                     {
419                         response.getOutputStream().close();
420                     }
421                 }
422                 catch (IOException e)
423                 {
424                     Log.warn(e);
425                 }
426                 
427                 // interrupt the handling thread
428                 thread.interrupt();
429             }
430         };
431         
432         try
433         {
434             synchronized (_requestTimeoutQ)
435             {
436                 _requestTimeoutQ.schedule(_timeout);
437             }
438             chain.doFilter(request,response);
439         }
440         finally
441         {
442             synchronized (_requestTimeoutQ)
443             {
444                 _timeout.cancel();
445             }
446         }
447     }
448     
449     /**
450      * Get priority for this request, based on user type
451      * 
452      * @param request
453      * @param tracker
454      * @return priority
455      */
456     protected int getPriority(ServletRequest request, RateTracker tracker)
457     {
458         if (extractUserId(request)!=null)
459             return USER_AUTH;
460         if (tracker!=null)
461             return tracker.getType();
462         return USER_UNKNOWN;
463     }
464 
465     /**
466      * @return the maximum priority that we can assign to a request
467      */
468     protected int getMaxPriority()
469     {
470         return USER_AUTH;
471     }
472 
473     /**
474      * Return a request rate tracker associated with this connection; keeps
475      * track of this connection's request rate. If this is not the first request
476      * from this connection, return the existing object with the stored stats.
477      * If it is the first request, then create a new request tracker.
478      * 
479      * Assumes that each connection has an identifying characteristic, and goes
480      * through them in order, taking the first that matches: user id (logged
481      * in), session id, client IP address. Unidentifiable connections are lumped
482      * into one.
483      * 
484      * When a session expires, its rate tracker is automatically deleted.
485      * 
486      * @param request
487      * @return the request rate tracker for the current connection
488      */
489     public RateTracker getRateTracker(ServletRequest request)
490     {
491         HttpServletRequest srequest = (HttpServletRequest)request;
492 
493         String loadId;
494         final int type;
495         
496         loadId = extractUserId(request);
497         HttpSession session=srequest.getSession(false);
498         if (_trackSessions && session!=null && !session.isNew())
499         {
500             loadId=session.getId();
501             type = USER_SESSION;
502         }
503         else
504         {
505             loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
506             type = USER_IP;
507         }
508 
509         RateTracker tracker=_rateTrackers.get(loadId);
510         
511         if (tracker==null)
512         {
513             RateTracker t;
514             if (_whitelist.contains(request.getRemoteAddr()))
515             {
516                 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
517             }
518             else
519             {
520                 t = new RateTracker(loadId,type,_maxRequestsPerSec);
521             }
522             
523             tracker=_rateTrackers.putIfAbsent(loadId,t);
524             if (tracker==null)
525                 tracker=t;
526             
527             if (type == USER_IP)
528             {
529                 // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
530                 synchronized (_trackerTimeoutQ)
531                 {
532                     _trackerTimeoutQ.schedule(tracker);
533                 }
534             }
535             else if (session!=null)
536                 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
537                 session.setAttribute(__TRACKER,tracker);
538         }
539 
540         return tracker;
541     }
542 
543     public void destroy()
544     {
545         _running=false;
546         _timerThread.interrupt();
547         synchronized (_requestTimeoutQ)
548         {
549             _requestTimeoutQ.cancelAll();
550             _trackerTimeoutQ.cancelAll();
551         }
552     }
553 
554     /**
555      * Returns the user id, used to track this connection.
556      * This SHOULD be overridden by subclasses.
557      * 
558      * @param request
559      * @return a unique user id, if logged in; otherwise null.
560      */
561     protected String extractUserId(ServletRequest request)
562     {
563         return null;
564     }
565 
566     /**
567      * A RateTracker is associated with a connection, and stores request rate
568      * data.
569      */
570     class RateTracker extends Timeout.Task implements HttpSessionBindingListener
571     {
572         protected final String _id;
573         protected final int _type;
574         protected final long[] _timestamps;
575         protected int _next;
576         
577         public RateTracker(String id, int type,int maxRequestsPerSecond)
578         {
579             _id = id;
580             _type = type;
581             _timestamps=new long[maxRequestsPerSecond];
582             _next=0;
583         }
584 
585         /**
586          * @return the current calculated request rate over the last second
587          */
588         public boolean isRateExceeded(long now)
589         {
590             final long last;
591             synchronized (this)
592             {
593                 last=_timestamps[_next];
594                 _timestamps[_next]=now;
595                 _next= (_next+1)%_timestamps.length;
596             }
597 
598             boolean exceeded=last!=0 && (now-last)<1000L;
599             // System.err.println("rateExceeded? "+last+" "+(now-last)+" "+exceeded);
600             return exceeded;
601         }
602 
603 
604         public String getId()
605         {
606             return _id;
607         }
608 
609         public int getType()
610         {
611             return _type;
612         }
613 
614         
615         public void valueBound(HttpSessionBindingEvent event)
616         {
617         }
618 
619         public void valueUnbound(HttpSessionBindingEvent event)
620         {
621             _rateTrackers.remove(_id);
622         }
623         
624         public void expired()
625         {
626             long now = _trackerTimeoutQ.getNow();
627             int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length; 
628             long last=_timestamps[latestIndex];
629             boolean hasRecentRequest = last != 0 && (now-last)<1000L;
630             
631             if (hasRecentRequest)
632                 reschedule();
633             else
634                 _rateTrackers.remove(_id);
635         }
636     }
637     
638     class FixedRateTracker extends RateTracker
639     {
640         public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
641         {
642             super(id,type,numRecentRequestsTracked);
643         }
644 
645         public boolean isRateExceeded(long now)
646         {
647             // rate limit is never exceeded, but we keep track of the request timestamps
648             // so that we know whether there was recent activity on this tracker
649             // and whether it should be expired
650             synchronized (this)
651             {
652                 _timestamps[_next]=now;
653                 _next= (_next+1)%_timestamps.length;
654             }
655 
656             return false;
657         }        
658     }
659 }