1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package org.mortbay.jetty.servlet;
16
17 import java.io.ByteArrayInputStream;
18 import java.io.InputStream;
19 import java.io.OutputStream;
20 import java.sql.Blob;
21 import java.sql.Connection;
22 import java.sql.DatabaseMetaData;
23 import java.sql.DriverManager;
24 import java.sql.PreparedStatement;
25 import java.sql.ResultSet;
26 import java.sql.SQLException;
27 import java.sql.Statement;
28 import java.util.ArrayList;
29 import java.util.HashSet;
30 import java.util.List;
31 import java.util.Random;
32 import java.util.Timer;
33 import java.util.TimerTask;
34
35 import javax.naming.InitialContext;
36 import javax.servlet.http.HttpServletRequest;
37 import javax.servlet.http.HttpSession;
38 import javax.sql.DataSource;
39
40 import org.mortbay.jetty.Handler;
41 import org.mortbay.jetty.Server;
42 import org.mortbay.jetty.webapp.WebAppContext;
43 import org.mortbay.log.Log;
44
45
46
47
48
49
50
51
52
53
54 public class JDBCSessionIdManager extends AbstractSessionIdManager
55 {
56 protected HashSet<String> _sessionIds = new HashSet();
57 protected String _driverClassName;
58 protected String _connectionUrl;
59 protected DataSource _datasource;
60 protected String _jndiName;
61 protected String _sessionIdTable = "JettySessionIds";
62 protected String _sessionTable = "JettySessions";
63 protected Timer _timer;
64 protected TimerTask _task;
65 protected long _lastScavengeTime;
66 protected long _scavengeIntervalMs = 1000 * 60 * 10;
67
68
69 protected String _createSessionIdTable;
70 protected String _createSessionTable;
71
72 protected String _selectExpiredSessions;
73 protected String _deleteOldExpiredSessions;
74
75 protected String _insertId;
76 protected String _deleteId;
77 protected String _queryId;
78
79 protected DatabaseAdaptor _dbAdaptor;
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94 public class DatabaseAdaptor
95 {
96 String _dbName;
97 boolean _isLower;
98 boolean _isUpper;
99
100
101 public DatabaseAdaptor (DatabaseMetaData dbMeta)
102 throws SQLException
103 {
104 _dbName = dbMeta.getDatabaseProductName().toLowerCase();
105 Log.debug ("Using database "+_dbName);
106 _isLower = dbMeta.storesLowerCaseIdentifiers();
107 _isUpper = dbMeta.storesUpperCaseIdentifiers();
108 }
109
110
111
112
113
114
115
116
117 public String convertIdentifier (String identifier)
118 {
119 if (_isLower)
120 return identifier.toLowerCase();
121 if (_isUpper)
122 return identifier.toUpperCase();
123
124 return identifier;
125 }
126
127 public String getBlobType ()
128 {
129 if (_dbName.startsWith("postgres"))
130 return "bytea";
131
132 return "blob";
133 }
134
135 public InputStream getBlobInputStream (ResultSet result, String columnName)
136 throws SQLException
137 {
138 if (_dbName.startsWith("postgres"))
139 {
140 byte[] bytes = result.getBytes(columnName);
141 return new ByteArrayInputStream(bytes);
142 }
143
144 Blob blob = result.getBlob(columnName);
145 return blob.getBinaryStream();
146 }
147 }
148
149
150
151 public JDBCSessionIdManager(Server server)
152 {
153 super(server);
154 }
155
156 public JDBCSessionIdManager(Server server, Random random)
157 {
158 super(server, random);
159 }
160
161
162
163
164
165
166
167 public void setDriverInfo (String driverClassName, String connectionUrl)
168 {
169 _driverClassName=driverClassName;
170 _connectionUrl=connectionUrl;
171 }
172
173 public String getDriverClassName()
174 {
175 return _driverClassName;
176 }
177
178 public String getConnectionUrl ()
179 {
180 return _connectionUrl;
181 }
182
183 public void setDatasourceName (String jndi)
184 {
185 _jndiName=jndi;
186 }
187
188 public String getDatasourceName ()
189 {
190 return _jndiName;
191 }
192
193
194 public void setScavengeInterval (long sec)
195 {
196 if (sec<=0)
197 sec=60;
198
199 long old_period=_scavengeIntervalMs;
200 long period=sec*1000;
201
202 _scavengeIntervalMs=period;
203
204
205
206 long tenPercent = _scavengeIntervalMs/10;
207 if ((System.currentTimeMillis()%2) == 0)
208 _scavengeIntervalMs += tenPercent;
209
210 if (Log.isDebugEnabled()) Log.debug("Scavenging every "+_scavengeIntervalMs+" ms");
211 if (_timer!=null && (period!=old_period || _task==null))
212 {
213 synchronized (this)
214 {
215 if (_task!=null)
216 _task.cancel();
217 _task = new TimerTask()
218 {
219 public void run()
220 {
221 scavenge();
222 }
223 };
224 _timer.schedule(_task,_scavengeIntervalMs,_scavengeIntervalMs);
225 }
226 }
227 }
228
229 public long getScavengeInterval ()
230 {
231 return _scavengeIntervalMs/1000;
232 }
233
234
235 public void addSession(HttpSession session)
236 {
237 if (session == null)
238 return;
239
240 synchronized (_sessionIds)
241 {
242 String id = ((JDBCSessionManager.Session)session).getClusterId();
243 try
244 {
245 insert(id);
246 _sessionIds.add(id);
247 }
248 catch (Exception e)
249 {
250 Log.warn("Problem storing session id="+id, e);
251 }
252 }
253 }
254
255 public void removeSession(HttpSession session)
256 {
257 if (session == null)
258 return;
259
260 removeSession(((JDBCSessionManager.Session)session).getClusterId());
261 }
262
263
264
265 public void removeSession (String id)
266 {
267
268 if (id == null)
269 return;
270
271 synchronized (_sessionIds)
272 {
273 if (Log.isDebugEnabled())
274 Log.debug("Removing session id="+id);
275 try
276 {
277 _sessionIds.remove(id);
278 delete(id);
279 }
280 catch (Exception e)
281 {
282 Log.warn("Problem removing session id="+id, e);
283 }
284 }
285
286 }
287
288
289
290
291
292
293
294 public String getClusterId(String nodeId)
295 {
296 int dot=nodeId.lastIndexOf('.');
297 return (dot>0)?nodeId.substring(0,dot):nodeId;
298 }
299
300
301
302
303
304
305
306 public String getNodeId(String clusterId, HttpServletRequest request)
307 {
308 if (_workerName!=null)
309 return clusterId+'.'+_workerName;
310
311 return clusterId;
312 }
313
314
315 public boolean idInUse(String id)
316 {
317 if (id == null)
318 return false;
319
320 String clusterId = getClusterId(id);
321
322 synchronized (_sessionIds)
323 {
324 if (_sessionIds.contains(clusterId))
325 return true;
326
327
328 try
329 {
330 return exists(clusterId);
331 }
332 catch (Exception e)
333 {
334 Log.warn("Problem checking inUse for id="+clusterId, e);
335 return false;
336 }
337 }
338 }
339
340
341
342
343
344
345 public void invalidateAll(String id)
346 {
347
348 removeSession(id);
349
350 synchronized (_sessionIds)
351 {
352
353
354 Handler[] contexts = _server.getChildHandlersByClass(WebAppContext.class);
355 for (int i=0; contexts!=null && i<contexts.length; i++)
356 {
357 AbstractSessionManager manager = ((AbstractSessionManager)((WebAppContext)contexts[i]).getSessionHandler().getSessionManager());
358 if (manager instanceof JDBCSessionManager)
359 {
360 ((JDBCSessionManager)manager).invalidateSession(id);
361 }
362 }
363 }
364 }
365
366
367
368
369
370
371
372
373
374
375 public void doStart()
376 {
377 try
378 {
379 initializeDatabase();
380 prepareTables();
381 super.doStart();
382 if (Log.isDebugEnabled()) Log.debug("Scavenging interval = "+getScavengeInterval()+" sec");
383 _timer=new Timer("JDBCSessionScavenger", true);
384 setScavengeInterval(getScavengeInterval());
385 }
386 catch (Exception e)
387 {
388 Log.warn("Problem initialising JettySessionIds table", e);
389 }
390 }
391
392
393
394
395
396
397 public void doStop ()
398 throws Exception
399 {
400 synchronized(this)
401 {
402 if (_task!=null)
403 _task.cancel();
404 if (_timer!=null)
405 _timer.cancel();
406 _timer=null;
407 }
408 super.doStop();
409 }
410
411
412
413
414
415
416
417 protected Connection getConnection ()
418 throws SQLException
419 {
420 if (_datasource != null)
421 return _datasource.getConnection();
422 else
423 return DriverManager.getConnection(_connectionUrl);
424 }
425
426
427 private void initializeDatabase ()
428 throws Exception
429 {
430 if (_jndiName!=null)
431 {
432 InitialContext ic = new InitialContext();
433 _datasource = (DataSource)ic.lookup(_jndiName);
434 }
435 else if (_driverClassName!=null && _connectionUrl!=null)
436 {
437 Class.forName(_driverClassName);
438 }
439 else
440 throw new IllegalStateException("No database configured for sessions");
441 }
442
443
444
445
446
447
448
449 private void prepareTables()
450 throws SQLException
451 {
452 _createSessionIdTable = "create table "+_sessionIdTable+" (id varchar(60), primary key(id))";
453 _selectExpiredSessions = "select * from "+_sessionTable+" where expiryTime >= ? and expiryTime <= ?";
454 _deleteOldExpiredSessions = "delete from "+_sessionTable+" where expiryTime >0 and expiryTime <= ?";
455
456 _insertId = "insert into "+_sessionIdTable+" (id) values (?)";
457 _deleteId = "delete from "+_sessionIdTable+" where id = ?";
458 _queryId = "select * from "+_sessionIdTable+" where id = ?";
459
460 Connection connection = null;
461 try
462 {
463
464 connection = getConnection();
465 connection.setAutoCommit(true);
466 DatabaseMetaData metaData = connection.getMetaData();
467 _dbAdaptor = new DatabaseAdaptor(metaData);
468
469
470 String tableName = _dbAdaptor.convertIdentifier(_sessionIdTable);
471 ResultSet result = metaData.getTables(null, null, tableName, null);
472 if (!result.next())
473 {
474
475 connection.createStatement().executeUpdate(_createSessionIdTable);
476 }
477
478
479 tableName = _dbAdaptor.convertIdentifier(_sessionTable);
480 result = metaData.getTables(null, null, tableName, null);
481 if (!result.next())
482 {
483
484 String blobType = _dbAdaptor.getBlobType();
485 _createSessionTable = "create table "+_sessionTable+" (rowId varchar(60), sessionId varchar(60), "+
486 " contextPath varchar(60), virtualHost varchar(60), lastNode varchar(60), accessTime bigint, "+
487 " lastAccessTime bigint, createTime bigint, cookieTime bigint, "+
488 " lastSavedTime bigint, expiryTime bigint, map "+blobType+", primary key(rowId))";
489 connection.createStatement().executeUpdate(_createSessionTable);
490 }
491
492
493 String index1 = "idx_"+_sessionTable+"_expiry";
494 String index2 = "idx_"+_sessionTable+"_session";
495
496 result = metaData.getIndexInfo(null, null, tableName, false, false);
497 boolean index1Exists = false;
498 boolean index2Exists = false;
499 while (result.next())
500 {
501 String idxName = result.getString("INDEX_NAME");
502 if (index1.equalsIgnoreCase(idxName))
503 index1Exists = true;
504 else if (index2.equalsIgnoreCase(idxName))
505 index2Exists = true;
506 }
507 if (!(index1Exists && index2Exists))
508 {
509 Statement statement = connection.createStatement();
510 if (!index1Exists)
511 statement.executeUpdate("create index "+index1+" on "+_sessionTable+" (expiryTime)");
512 if (!index2Exists)
513 statement.executeUpdate("create index "+index2+" on "+_sessionTable+" (sessionId, contextPath)");
514 }
515 }
516 finally
517 {
518 if (connection != null)
519 connection.close();
520 }
521 }
522
523
524
525
526
527
528
529 private void insert (String id)
530 throws SQLException
531 {
532 Connection connection = null;
533 try
534 {
535 connection = getConnection();
536 connection.setAutoCommit(true);
537 PreparedStatement query = connection.prepareStatement(_queryId);
538 query.setString(1, id);
539 ResultSet result = query.executeQuery();
540
541 if (!result.next())
542 {
543 PreparedStatement statement = connection.prepareStatement(_insertId);
544 statement.setString(1, id);
545 statement.executeUpdate();
546 }
547 }
548 finally
549 {
550 if (connection != null)
551 connection.close();
552 }
553 }
554
555
556
557
558
559
560
561 private void delete (String id)
562 throws SQLException
563 {
564 Connection connection = null;
565 try
566 {
567 connection = getConnection();
568 connection.setAutoCommit(true);
569 PreparedStatement statement = connection.prepareStatement(_deleteId);
570 statement.setString(1, id);
571 statement.executeUpdate();
572 }
573 finally
574 {
575 if (connection != null)
576 connection.close();
577 }
578 }
579
580
581
582
583
584
585
586
587
588 private boolean exists (String id)
589 throws SQLException
590 {
591 Connection connection = null;
592 try
593 {
594 connection = getConnection();
595 connection.setAutoCommit(true);
596 PreparedStatement statement = connection.prepareStatement(_queryId);
597 statement.setString(1, id);
598 ResultSet result = statement.executeQuery();
599 if (result.next())
600 return true;
601 else
602 return false;
603 }
604 finally
605 {
606 if (connection != null)
607 connection.close();
608 }
609 }
610
611
612
613
614
615
616
617
618
619
620
621
622 private void scavenge ()
623 {
624 Connection connection = null;
625 List expiredSessionIds = new ArrayList();
626 try
627 {
628 if (Log.isDebugEnabled()) Log.debug("Scavenge sweep started at "+System.currentTimeMillis());
629 if (_lastScavengeTime > 0)
630 {
631 connection = getConnection();
632 connection.setAutoCommit(true);
633
634 PreparedStatement statement = connection.prepareStatement(_selectExpiredSessions);
635 long lowerBound = (_lastScavengeTime - _scavengeIntervalMs);
636 long upperBound = _lastScavengeTime;
637 if (Log.isDebugEnabled()) Log.debug("Searching for sessions expired between "+lowerBound + " and "+upperBound);
638 statement.setLong(1, lowerBound);
639 statement.setLong(2, upperBound);
640 ResultSet result = statement.executeQuery();
641 while (result.next())
642 {
643 String sessionId = result.getString("sessionId");
644 expiredSessionIds.add(sessionId);
645 if (Log.isDebugEnabled()) Log.debug("Found expired sessionId="+sessionId);
646 }
647
648
649
650 Handler[] contexts = _server.getChildHandlersByClass(WebAppContext.class);
651 for (int i=0; contexts!=null && i<contexts.length; i++)
652 {
653 AbstractSessionManager manager = ((AbstractSessionManager)((WebAppContext)contexts[i]).getSessionHandler().getSessionManager());
654 if (manager instanceof JDBCSessionManager)
655 {
656 ((JDBCSessionManager)manager).expire(expiredSessionIds);
657 }
658 }
659
660
661 upperBound = _lastScavengeTime - (2 * _scavengeIntervalMs);
662 if (upperBound > 0)
663 {
664 if (Log.isDebugEnabled()) Log.debug("Deleting old expired sessions expired before "+upperBound);
665 statement = connection.prepareStatement(_deleteOldExpiredSessions);
666 statement.setLong(1, upperBound);
667 statement.executeUpdate();
668 }
669 }
670 }
671 catch (Exception e)
672 {
673 Log.warn("Problem selecting expired sessions", e);
674 }
675 finally
676 {
677 _lastScavengeTime=System.currentTimeMillis();
678 if (Log.isDebugEnabled()) Log.debug("Scavenge sweep ended at "+_lastScavengeTime);
679 if (connection != null)
680 {
681 try
682 {
683 connection.close();
684 }
685 catch (SQLException e)
686 {
687 Log.warn(e);
688 }
689 }
690 }
691 }
692 }