1 package org.mortbay.servlet;
2
3 import java.io.IOException;
4 import java.util.HashSet;
5 import java.util.Queue;
6 import java.util.StringTokenizer;
7 import java.util.concurrent.ConcurrentHashMap;
8 import java.util.concurrent.Semaphore;
9 import java.util.concurrent.TimeUnit;
10
11 import javax.servlet.Filter;
12 import javax.servlet.FilterChain;
13 import javax.servlet.FilterConfig;
14 import javax.servlet.ServletContext;
15 import javax.servlet.ServletException;
16 import javax.servlet.ServletRequest;
17 import javax.servlet.ServletResponse;
18 import javax.servlet.http.HttpServletRequest;
19 import javax.servlet.http.HttpServletResponse;
20 import javax.servlet.http.HttpSession;
21 import javax.servlet.http.HttpSessionBindingEvent;
22 import javax.servlet.http.HttpSessionBindingListener;
23
24 import org.mortbay.log.Log;
25 import org.mortbay.thread.Timeout;
26 import org.mortbay.util.ArrayQueue;
27 import org.mortbay.util.ajax.Continuation;
28 import org.mortbay.util.ajax.ContinuationSupport;
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81 public class DoSFilter implements Filter
82 {
83 final static String __TRACKER = "DoSFilter.Tracker";
84 final static String __THROTTLED = "DoSFilter.Throttled";
85
86 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
87 final static int __DEFAULT_DELAY_MS = 100;
88 final static int __DEFAULT_THROTTLE = 5;
89 final static int __DEFAULT_WAIT_MS=50;
90 final static long __DEFAULT_THROTTLE_MS = 30000L;
91 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
92 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
93
94 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
95 final static String DELAY_MS_INIT_PARAM = "delayMs";
96 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
97 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
98 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
99 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
100 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
101 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
102 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
103 final static String REMOTE_PORT_INIT_PARAM="remotePort";
104 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
105
106 final static int USER_AUTH = 2;
107 final static int USER_SESSION = 2;
108 final static int USER_IP = 1;
109 final static int USER_UNKNOWN = 0;
110
111 ServletContext _context;
112
113 protected long _delayMs;
114 protected long _throttleMs;
115 protected long _waitMs;
116 protected long _maxRequestMs;
117 protected long _maxIdleTrackerMs;
118 protected boolean _insertHeaders;
119 protected boolean _trackSessions;
120 protected boolean _remotePort;
121 protected Semaphore _passes;
122 protected Queue<Continuation>[] _queue;
123
124 protected int _maxRequestsPerSec;
125 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
126 private HashSet<String> _whitelist;
127
128 private final Timeout _requestTimeoutQ = new Timeout();
129 private final Timeout _trackerTimeoutQ = new Timeout();
130
131 private Thread _timerThread;
132 private volatile boolean _running;
133
134 public void init(FilterConfig filterConfig)
135 {
136 _context = filterConfig.getServletContext();
137
138 _queue = new Queue[getMaxPriority() + 1];
139 for (int p = 0; p < _queue.length; p++)
140 _queue[p] = new ArrayQueue<Continuation>();
141
142 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
143 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
144 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
145 _maxRequestsPerSec = baseRateLimit;
146
147 long delay = __DEFAULT_DELAY_MS;
148 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
149 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
150 _delayMs = delay;
151
152 int passes = __DEFAULT_THROTTLE;
153 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
154 passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
155 _passes = new Semaphore(passes,true);
156
157 long wait = __DEFAULT_WAIT_MS;
158 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
159 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
160 _waitMs = wait;
161
162 long suspend = __DEFAULT_THROTTLE_MS;
163 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
164 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
165 _throttleMs = suspend;
166
167 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
168 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
169 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
170 _maxRequestMs = maxRequestMs;
171
172 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
173 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
174 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
175 _maxIdleTrackerMs = maxIdleTrackerMs;
176
177 String whitelistString = "";
178 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
179 whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
180
181
182 if (whitelistString.length() == 0 )
183 _whitelist = new HashSet<String>();
184 else
185 {
186 StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
187 _whitelist = new HashSet<String>(tokenizer.countTokens());
188 while (tokenizer.hasMoreTokens())
189 _whitelist.add(tokenizer.nextToken().trim());
190
191 Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
192 }
193
194 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
195 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
196
197 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
198 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
199
200 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
201 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
202
203 _requestTimeoutQ.setNow();
204 _requestTimeoutQ.setDuration(_maxRequestMs);
205
206 _trackerTimeoutQ.setNow();
207 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
208
209 _running=true;
210 _timerThread = (new Thread()
211 {
212 public void run()
213 {
214 try
215 {
216 while (_running)
217 {
218 synchronized (_requestTimeoutQ)
219 {
220 _requestTimeoutQ.setNow();
221 _requestTimeoutQ.tick();
222
223 _trackerTimeoutQ.setNow(_requestTimeoutQ.getNow());
224 _trackerTimeoutQ.tick();
225 }
226 try
227 {
228 Thread.sleep(100);
229 }
230 catch (InterruptedException e)
231 {
232 Log.ignore(e);
233 }
234 }
235 }
236 finally
237 {
238 Log.info("DoSFilter timer exited");
239 }
240 }
241 });
242 _timerThread.start();
243 }
244
245
246 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
247 {
248 final HttpServletRequest srequest = (HttpServletRequest)request;
249 final HttpServletResponse sresponse = (HttpServletResponse)response;
250
251 final long now=_requestTimeoutQ.getNow();
252
253
254 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
255
256 if (tracker==null)
257 {
258
259
260
261 tracker = getRateTracker(request);
262
263
264 final boolean overRateLimit = tracker.isRateExceeded(now);
265
266
267 if (!overRateLimit)
268 {
269 doFilterChain(filterchain,srequest,sresponse);
270 return;
271 }
272
273
274 Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
275
276
277 switch((int)_delayMs)
278 {
279 case -1:
280 {
281
282 if (_insertHeaders)
283 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
284 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
285 return;
286 }
287 case 0:
288 {
289
290 request.setAttribute(__TRACKER,tracker);
291 break;
292 }
293 default:
294 {
295
296 if (_insertHeaders)
297 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
298 Continuation continuation = ContinuationSupport.getContinuation((HttpServletRequest)request,this);
299 request.setAttribute(__TRACKER,tracker);
300 continuation.suspend(_delayMs);
301
302 }
303 }
304 }
305
306
307 boolean accepted = false;
308 try
309 {
310
311 accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
312
313 if (!accepted)
314 {
315
316
317 final Continuation continuation = ContinuationSupport.getContinuation((HttpServletRequest)request,this);
318
319 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
320 if (throttled!=Boolean.TRUE && _throttleMs>0)
321 {
322 int priority = getPriority(request,tracker);
323 request.setAttribute(__THROTTLED,Boolean.TRUE);
324 if (_insertHeaders)
325 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
326 synchronized (this)
327 {
328 _queue[priority].add(continuation);
329 continuation.reset();
330 if(continuation.suspend(_throttleMs))
331 {
332
333
334 _passes.acquire();
335 accepted = true;
336 }
337
338 }
339 }
340
341 else if (continuation.isResumed())
342 {
343
344 _passes.acquire();
345 accepted = true;
346 }
347 }
348
349
350 if (accepted)
351
352 doFilterChain(filterchain,srequest,sresponse);
353 else
354 {
355
356 if (_insertHeaders)
357 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
358 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
359 }
360 }
361 catch (InterruptedException e)
362 {
363 _context.log("DoS",e);
364 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
365 }
366 finally
367 {
368 if (accepted)
369 {
370
371 synchronized (_queue)
372 {
373 for (int p = _queue.length; p-- > 0;)
374 {
375 Continuation continuation = _queue[p].poll();
376
377 if (continuation != null)
378 {
379 continuation.resume();
380 break;
381 }
382 }
383 }
384 _passes.release();
385 }
386 }
387 }
388
389
390
391
392
393
394
395
396 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
397 throws IOException, ServletException
398 {
399 final Thread thread=Thread.currentThread();
400
401 final Timeout.Task _timeout = new Timeout.Task()
402 {
403 public void expired()
404 {
405
406 if( !response.isCommitted() )
407 {
408 response.setHeader("Connection", "close");
409 }
410
411 try
412 {
413 try
414 {
415 response.getWriter().close();
416 }
417 catch (IllegalStateException e)
418 {
419 response.getOutputStream().close();
420 }
421 }
422 catch (IOException e)
423 {
424 Log.warn(e);
425 }
426
427
428 thread.interrupt();
429 }
430 };
431
432 try
433 {
434 synchronized (_requestTimeoutQ)
435 {
436 _requestTimeoutQ.schedule(_timeout);
437 }
438 chain.doFilter(request,response);
439 }
440 finally
441 {
442 synchronized (_requestTimeoutQ)
443 {
444 _timeout.cancel();
445 }
446 }
447 }
448
449
450
451
452
453
454
455
456 protected int getPriority(ServletRequest request, RateTracker tracker)
457 {
458 if (extractUserId(request)!=null)
459 return USER_AUTH;
460 if (tracker!=null)
461 return tracker.getType();
462 return USER_UNKNOWN;
463 }
464
465
466
467
468 protected int getMaxPriority()
469 {
470 return USER_AUTH;
471 }
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489 public RateTracker getRateTracker(ServletRequest request)
490 {
491 HttpServletRequest srequest = (HttpServletRequest)request;
492
493 String loadId;
494 final int type;
495
496 loadId = extractUserId(request);
497 HttpSession session=srequest.getSession(false);
498 if (_trackSessions && session!=null && !session.isNew())
499 {
500 loadId=session.getId();
501 type = USER_SESSION;
502 }
503 else
504 {
505 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
506 type = USER_IP;
507 }
508
509 RateTracker tracker=_rateTrackers.get(loadId);
510
511 if (tracker==null)
512 {
513 RateTracker t;
514 if (_whitelist.contains(request.getRemoteAddr()))
515 {
516 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
517 }
518 else
519 {
520 t = new RateTracker(loadId,type,_maxRequestsPerSec);
521 }
522
523 tracker=_rateTrackers.putIfAbsent(loadId,t);
524 if (tracker==null)
525 tracker=t;
526
527 if (type == USER_IP)
528 {
529
530 synchronized (_trackerTimeoutQ)
531 {
532 _trackerTimeoutQ.schedule(tracker);
533 }
534 }
535 else if (session!=null)
536
537 session.setAttribute(__TRACKER,tracker);
538 }
539
540 return tracker;
541 }
542
543 public void destroy()
544 {
545 _running=false;
546 _timerThread.interrupt();
547 synchronized (_requestTimeoutQ)
548 {
549 _requestTimeoutQ.cancelAll();
550 _trackerTimeoutQ.cancelAll();
551 }
552 }
553
554
555
556
557
558
559
560
561 protected String extractUserId(ServletRequest request)
562 {
563 return null;
564 }
565
566
567
568
569
570 class RateTracker extends Timeout.Task implements HttpSessionBindingListener
571 {
572 protected final String _id;
573 protected final int _type;
574 protected final long[] _timestamps;
575 protected int _next;
576
577 public RateTracker(String id, int type,int maxRequestsPerSecond)
578 {
579 _id = id;
580 _type = type;
581 _timestamps=new long[maxRequestsPerSecond];
582 _next=0;
583 }
584
585
586
587
588 public boolean isRateExceeded(long now)
589 {
590 final long last;
591 synchronized (this)
592 {
593 last=_timestamps[_next];
594 _timestamps[_next]=now;
595 _next= (_next+1)%_timestamps.length;
596 }
597
598 boolean exceeded=last!=0 && (now-last)<1000L;
599
600 return exceeded;
601 }
602
603
604 public String getId()
605 {
606 return _id;
607 }
608
609 public int getType()
610 {
611 return _type;
612 }
613
614
615 public void valueBound(HttpSessionBindingEvent event)
616 {
617 }
618
619 public void valueUnbound(HttpSessionBindingEvent event)
620 {
621 _rateTrackers.remove(_id);
622 }
623
624 public void expired()
625 {
626 long now = _trackerTimeoutQ.getNow();
627 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
628 long last=_timestamps[latestIndex];
629 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
630
631 if (hasRecentRequest)
632 reschedule();
633 else
634 _rateTrackers.remove(_id);
635 }
636 }
637
638 class FixedRateTracker extends RateTracker
639 {
640 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
641 {
642 super(id,type,numRecentRequestsTracked);
643 }
644
645 public boolean isRateExceeded(long now)
646 {
647
648
649
650 synchronized (this)
651 {
652 _timestamps[_next]=now;
653 _next= (_next+1)%_timestamps.length;
654 }
655
656 return false;
657 }
658 }
659 }