1   // ========================================================================
2   // Copyright 2008 Mort Bay Consulting Pty. Ltd.
3   // ------------------------------------------------------------------------
4   // Licensed under the Apache License, Version 2.0 (the "License");
5   // you may not use this file except in compliance with the License.
6   // You may obtain a copy of the License at 
7   // http://www.apache.org/licenses/LICENSE-2.0
8   // Unless required by applicable law or agreed to in writing, software
9   // distributed under the License is distributed on an "AS IS" BASIS,
10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  // See the License for the specific language governing permissions and
12  // limitations under the License.
13  // ========================================================================
14  
15  package org.mortbay.jetty.servlet;
16  
17  import java.io.ByteArrayInputStream;
18  import java.io.InputStream;
19  import java.io.OutputStream;
20  import java.sql.Blob;
21  import java.sql.Connection;
22  import java.sql.DatabaseMetaData;
23  import java.sql.DriverManager;
24  import java.sql.PreparedStatement;
25  import java.sql.ResultSet;
26  import java.sql.SQLException;
27  import java.sql.Statement;
28  import java.util.ArrayList;
29  import java.util.HashSet;
30  import java.util.List;
31  import java.util.Random;
32  import java.util.Timer;
33  import java.util.TimerTask;
34  
35  import javax.naming.InitialContext;
36  import javax.servlet.http.HttpServletRequest;
37  import javax.servlet.http.HttpSession;
38  import javax.sql.DataSource;
39  
40  import org.mortbay.jetty.Handler;
41  import org.mortbay.jetty.Server;
42  import org.mortbay.jetty.webapp.WebAppContext;
43  import org.mortbay.log.Log;
44  
45  
46  
47  /**
48   * JDBCSessionIdManager
49   *
50   * SessionIdManager implementation that uses a database to store in-use session ids, 
51   * to support distributed sessions.
52   * 
53   */
54  public class JDBCSessionIdManager extends AbstractSessionIdManager
55  {    
56      protected HashSet<String> _sessionIds = new HashSet();
57      protected String _driverClassName;
58      protected String _connectionUrl;
59      protected DataSource _datasource;
60      protected String _jndiName;
61      protected String _sessionIdTable = "JettySessionIds";
62      protected String _sessionTable = "JettySessions";
63      protected Timer _timer; //scavenge timer
64      protected TimerTask _task; //scavenge task
65      protected long _lastScavengeTime;
66      protected long _scavengeIntervalMs = 1000 * 60 * 10; //10mins
67      
68      
69      protected String _createSessionIdTable;
70      protected String _createSessionTable;
71                                              
72      protected String _selectExpiredSessions;
73      protected String _deleteOldExpiredSessions;
74  
75      protected String _insertId;
76      protected String _deleteId;
77      protected String _queryId;
78      
79      protected DatabaseAdaptor _dbAdaptor;
80  
81      
82      /**
83       * DatabaseAdaptor
84       *
85       * Handles differences between databases.
86       * 
87       * Postgres uses the getBytes and setBinaryStream methods to access
88       * a "bytea" datatype, which can be up to 1Gb of binary data. MySQL
89       * is happy to use the "blob" type and getBlob() methods instead.
90       * 
91       * TODO if the differences become more major it would be worthwhile
92       * refactoring this class.
93       */
94      public class DatabaseAdaptor 
95      {
96          String _dbName;
97          boolean _isLower;
98          boolean _isUpper;
99          
100         
101         public DatabaseAdaptor (DatabaseMetaData dbMeta)
102         throws SQLException
103         {
104             _dbName = dbMeta.getDatabaseProductName().toLowerCase(); 
105             Log.debug ("Using database "+_dbName);
106             _isLower = dbMeta.storesLowerCaseIdentifiers();
107             _isUpper = dbMeta.storesUpperCaseIdentifiers();
108         }
109         
110         /**
111          * Convert a camel case identifier into either upper or lower
112          * depending on the way the db stores identifiers.
113          * 
114          * @param identifier
115          * @return
116          */
117         public String convertIdentifier (String identifier)
118         {
119             if (_isLower)
120                 return identifier.toLowerCase();
121             if (_isUpper)
122                 return identifier.toUpperCase();
123             
124             return identifier;
125         }
126         
127         public String getBlobType ()
128         {
129             if (_dbName.startsWith("postgres"))
130                 return "bytea";
131             
132             return "blob";
133         }
134         
135         public InputStream getBlobInputStream (ResultSet result, String columnName)
136         throws SQLException
137         {
138             if (_dbName.startsWith("postgres"))
139             {
140                 byte[] bytes = result.getBytes(columnName);
141                 return new ByteArrayInputStream(bytes);
142             }
143             
144             Blob blob = result.getBlob(columnName);
145             return blob.getBinaryStream();
146         }
147     }
148     
149     
150     
151     public JDBCSessionIdManager(Server server)
152     {
153         super(server);
154     }
155     
156     public JDBCSessionIdManager(Server server, Random random)
157     {
158        super(server, random);
159     }
160 
161     /**
162      * Configure jdbc connection information via a jdbc Driver
163      * 
164      * @param driverClassName
165      * @param connectionUrl
166      */
167     public void setDriverInfo (String driverClassName, String connectionUrl)
168     {
169         _driverClassName=driverClassName;
170         _connectionUrl=connectionUrl;
171     }
172     
173     public String getDriverClassName()
174     {
175         return _driverClassName;
176     }
177     
178     public String getConnectionUrl ()
179     {
180         return _connectionUrl;
181     }
182     
183     public void setDatasourceName (String jndi)
184     {
185         _jndiName=jndi;
186     }
187     
188     public String getDatasourceName ()
189     {
190         return _jndiName;
191     }
192    
193     
194     public void setScavengeInterval (long sec)
195     {
196         if (sec<=0)
197             sec=60;
198 
199         long old_period=_scavengeIntervalMs;
200         long period=sec*1000;
201       
202         _scavengeIntervalMs=period;
203         
204         //add a bit of variability into the scavenge time so that not all
205         //nodes with the same scavenge time sync up
206         long tenPercent = _scavengeIntervalMs/10;
207         if ((System.currentTimeMillis()%2) == 0)
208             _scavengeIntervalMs += tenPercent;
209         
210         if (Log.isDebugEnabled()) Log.debug("Scavenging every "+_scavengeIntervalMs+" ms");
211         if (_timer!=null && (period!=old_period || _task==null))
212         {
213             synchronized (this)
214             {
215                 if (_task!=null)
216                     _task.cancel();
217                 _task = new TimerTask()
218                 {
219                     public void run()
220                     {
221                         scavenge();
222                     }   
223                 };
224                 _timer.schedule(_task,_scavengeIntervalMs,_scavengeIntervalMs);
225             }
226         }  
227     }
228     
229     public long getScavengeInterval ()
230     {
231         return _scavengeIntervalMs/1000;
232     }
233     
234     
235     public void addSession(HttpSession session)
236     {
237         if (session == null)
238             return;
239         
240         synchronized (_sessionIds)
241         {
242             String id = ((JDBCSessionManager.Session)session).getClusterId();            
243             try
244             {
245                 insert(id);
246                 _sessionIds.add(id);
247             }
248             catch (Exception e)
249             {
250                 Log.warn("Problem storing session id="+id, e);
251             }
252         }
253     }
254     
255     public void removeSession(HttpSession session)
256     {
257         if (session == null)
258             return;
259         
260         removeSession(((JDBCSessionManager.Session)session).getClusterId());
261     }
262     
263     
264     
265     public void removeSession (String id)
266     {
267 
268         if (id == null)
269             return;
270         
271         synchronized (_sessionIds)
272         {  
273             if (Log.isDebugEnabled())
274                 Log.debug("Removing session id="+id);
275             try
276             {               
277                 _sessionIds.remove(id);
278                 delete(id);
279             }
280             catch (Exception e)
281             {
282                 Log.warn("Problem removing session id="+id, e);
283             }
284         }
285         
286     }
287     
288 
289     /** 
290      * Get the session id without any node identifier suffix.
291      * 
292      * @see org.mortbay.jetty.SessionIdManager#getClusterId(java.lang.String)
293      */
294     public String getClusterId(String nodeId)
295     {
296         int dot=nodeId.lastIndexOf('.');
297         return (dot>0)?nodeId.substring(0,dot):nodeId;
298     }
299     
300 
301     /** 
302      * Get the session id, including this node's id as a suffix.
303      * 
304      * @see org.mortbay.jetty.SessionIdManager#getNodeId(java.lang.String, javax.servlet.http.HttpServletRequest)
305      */
306     public String getNodeId(String clusterId, HttpServletRequest request)
307     {
308         if (_workerName!=null)
309             return clusterId+'.'+_workerName;
310 
311         return clusterId;
312     }
313 
314 
315     public boolean idInUse(String id)
316     {
317         if (id == null)
318             return false;
319         
320         String clusterId = getClusterId(id);
321         
322         synchronized (_sessionIds)
323         {
324             if (_sessionIds.contains(clusterId))
325                 return true; //optimisation - if this session is one we've been managing, we can check locally
326             
327             //otherwise, we need to go to the database to check
328             try
329             {
330                 return exists(clusterId);
331             }
332             catch (Exception e)
333             {
334                 Log.warn("Problem checking inUse for id="+clusterId, e);
335                 return false;
336             }
337         }
338     }
339 
340     /** 
341      * Invalidate the session matching the id on all contexts.
342      * 
343      * @see org.mortbay.jetty.SessionIdManager#invalidateAll(java.lang.String)
344      */
345     public void invalidateAll(String id)
346     {            
347         //take the id out of the list of known sessionids for this node
348         removeSession(id);
349         
350         synchronized (_sessionIds)
351         {
352             //tell all contexts that may have a session object with this id to
353             //get rid of them
354             Handler[] contexts = _server.getChildHandlersByClass(WebAppContext.class);
355             for (int i=0; contexts!=null && i<contexts.length; i++)
356             {
357                 AbstractSessionManager manager = ((AbstractSessionManager)((WebAppContext)contexts[i]).getSessionHandler().getSessionManager());
358                 if (manager instanceof JDBCSessionManager)
359                 {
360                     ((JDBCSessionManager)manager).invalidateSession(id);
361                 }
362             }
363         }
364     }
365 
366 
367     /** 
368      * Start up the id manager.
369      * 
370      * Makes necessary database tables and starts a Session
371      * scavenger thread.
372      * 
373      * @see org.mortbay.jetty.servlet.AbstractSessionIdManager#doStart()
374      */
375     public void doStart()
376     {
377         try
378         {            
379             initializeDatabase();
380             prepareTables();        
381             super.doStart();
382             if (Log.isDebugEnabled()) Log.debug("Scavenging interval = "+getScavengeInterval()+" sec");
383             _timer=new Timer("JDBCSessionScavenger", true);
384             setScavengeInterval(getScavengeInterval());
385         }
386         catch (Exception e)
387         {
388             Log.warn("Problem initialising JettySessionIds table", e);
389         }
390     }
391     
392     /** 
393      * Stop the scavenger.
394      * 
395      * @see org.mortbay.component.AbstractLifeCycle#doStop()
396      */
397     public void doStop () 
398     throws Exception
399     {
400         synchronized(this)
401         {
402             if (_task!=null)
403                 _task.cancel();
404             if (_timer!=null)
405                 _timer.cancel();
406             _timer=null;
407         }
408         super.doStop();
409     }
410     
411     /**
412      * Get a connection from the driver or datasource.
413      * 
414      * @return
415      * @throws SQLException
416      */
417     protected Connection getConnection ()
418     throws SQLException
419     {
420         if (_datasource != null)
421             return _datasource.getConnection();
422         else
423             return DriverManager.getConnection(_connectionUrl);
424     }
425 
426     
427     private void initializeDatabase ()
428     throws Exception
429     {
430         if (_jndiName!=null)
431         {
432             InitialContext ic = new InitialContext();
433             _datasource = (DataSource)ic.lookup(_jndiName);
434         }
435         else if (_driverClassName!=null && _connectionUrl!=null)
436         {
437             Class.forName(_driverClassName);
438         }
439         else
440             throw new IllegalStateException("No database configured for sessions");
441     }
442     
443     
444     
445     /**
446      * Set up the tables in the database
447      * @throws SQLException
448      */
449     private void prepareTables()
450     throws SQLException
451     {
452         _createSessionIdTable = "create table "+_sessionIdTable+" (id varchar(60), primary key(id))";
453         _selectExpiredSessions = "select * from "+_sessionTable+" where expiryTime >= ? and expiryTime <= ?";
454         _deleteOldExpiredSessions = "delete from "+_sessionTable+" where expiryTime >0 and expiryTime <= ?";
455 
456         _insertId = "insert into "+_sessionIdTable+" (id)  values (?)";
457         _deleteId = "delete from "+_sessionIdTable+" where id = ?";
458         _queryId = "select * from "+_sessionIdTable+" where id = ?";
459 
460         Connection connection = null;
461         try
462         {
463             //make the id table
464             connection = getConnection();
465             connection.setAutoCommit(true);
466             DatabaseMetaData metaData = connection.getMetaData();
467             _dbAdaptor = new DatabaseAdaptor(metaData);
468 
469             //checking for table existence is case-sensitive, but table creation is not
470             String tableName = _dbAdaptor.convertIdentifier(_sessionIdTable);
471             ResultSet result = metaData.getTables(null, null, tableName, null);
472             if (!result.next())
473             {
474                 //table does not exist, so create it
475                 connection.createStatement().executeUpdate(_createSessionIdTable);
476             }
477             
478             //make the session table if necessary
479             tableName = _dbAdaptor.convertIdentifier(_sessionTable);   
480             result = metaData.getTables(null, null, tableName, null);
481             if (!result.next())
482             {
483                 //table does not exist, so create it
484                 String blobType = _dbAdaptor.getBlobType();
485                 _createSessionTable = "create table "+_sessionTable+" (rowId varchar(60), sessionId varchar(60), "+
486                                            " contextPath varchar(60), virtualHost varchar(60), lastNode varchar(60), accessTime bigint, "+
487                                            " lastAccessTime bigint, createTime bigint, cookieTime bigint, "+
488                                            " lastSavedTime bigint, expiryTime bigint, map "+blobType+", primary key(rowId))";
489                 connection.createStatement().executeUpdate(_createSessionTable);
490             }
491             
492             //make some indexes on the JettySessions table
493             String index1 = "idx_"+_sessionTable+"_expiry";
494             String index2 = "idx_"+_sessionTable+"_session";
495             
496             result = metaData.getIndexInfo(null, null, tableName, false, false);
497             boolean index1Exists = false;
498             boolean index2Exists = false;
499             while (result.next())
500             {
501                 String idxName = result.getString("INDEX_NAME");
502                 if (index1.equalsIgnoreCase(idxName))
503                     index1Exists = true;
504                 else if (index2.equalsIgnoreCase(idxName))
505                     index2Exists = true;
506             }
507             if (!(index1Exists && index2Exists))
508             {
509                 Statement statement = connection.createStatement();
510                 if (!index1Exists)
511                     statement.executeUpdate("create index "+index1+" on "+_sessionTable+" (expiryTime)");
512                 if (!index2Exists)
513                     statement.executeUpdate("create index "+index2+" on "+_sessionTable+" (sessionId, contextPath)");
514             }
515         }
516         finally
517         {
518             if (connection != null)
519                 connection.close();
520         }
521     }
522     
523     /**
524      * Insert a new used session id into the table.
525      * 
526      * @param id
527      * @throws SQLException
528      */
529     private void insert (String id)
530     throws SQLException 
531     {
532         Connection connection = null;
533         try
534         {
535             connection = getConnection();
536             connection.setAutoCommit(true);            
537             PreparedStatement query = connection.prepareStatement(_queryId);
538             query.setString(1, id);
539             ResultSet result = query.executeQuery();
540             //only insert the id if it isn't in the db already 
541             if (!result.next())
542             {
543                 PreparedStatement statement = connection.prepareStatement(_insertId);
544                 statement.setString(1, id);
545                 statement.executeUpdate();
546             }
547         }
548         finally
549         {
550             if (connection != null)
551                 connection.close();
552         }
553     }
554     
555     /**
556      * Remove a session id from the table.
557      * 
558      * @param id
559      * @throws SQLException
560      */
561     private void delete (String id)
562     throws SQLException
563     {
564         Connection connection = null;
565         try
566         {
567             connection = getConnection();
568             connection.setAutoCommit(true);
569             PreparedStatement statement = connection.prepareStatement(_deleteId);
570             statement.setString(1, id);
571             statement.executeUpdate();
572         }
573         finally
574         {
575             if (connection != null)
576                 connection.close();
577         }
578     }
579     
580     
581     /**
582      * Check if a session id exists.
583      * 
584      * @param id
585      * @return
586      * @throws SQLException
587      */
588     private boolean exists (String id)
589     throws SQLException
590     {
591         Connection connection = null;
592         try
593         {
594             connection = getConnection();
595             connection.setAutoCommit(true);
596             PreparedStatement statement = connection.prepareStatement(_queryId);
597             statement.setString(1, id);
598             ResultSet result = statement.executeQuery();
599             if (result.next())
600                 return true;
601             else
602                 return false;
603         }
604         finally
605         {
606             if (connection != null)
607                 connection.close();
608         }
609     }
610     
611     /**
612      * Look for sessions in the database that have expired.
613      * 
614      * We do this in the SessionIdManager and not the SessionManager so
615      * that we only have 1 scavenger, otherwise if there are n SessionManagers
616      * there would be n scavengers, all contending for the database.
617      * 
618      * We look first for sessions that expired in the previous interval, then
619      * for sessions that expired previously - these are old sessions that no
620      * node is managing any more and have become stuck in the database.
621      */
622     private void scavenge ()
623     {
624         Connection connection = null;
625         List expiredSessionIds = new ArrayList();
626         try
627         {            
628             if (Log.isDebugEnabled()) Log.debug("Scavenge sweep started at "+System.currentTimeMillis());
629             if (_lastScavengeTime > 0)
630             {
631                 connection = getConnection();
632                 connection.setAutoCommit(true);
633                 //"select sessionId from JettySessions where expiryTime > (lastScavengeTime - scanInterval) and expiryTime < lastScavengeTime";
634                 PreparedStatement statement = connection.prepareStatement(_selectExpiredSessions);
635                 long lowerBound = (_lastScavengeTime - _scavengeIntervalMs);
636                 long upperBound = _lastScavengeTime;
637                 if (Log.isDebugEnabled()) Log.debug("Searching for sessions expired between "+lowerBound + " and "+upperBound);
638                 statement.setLong(1, lowerBound);
639                 statement.setLong(2, upperBound);
640                 ResultSet result = statement.executeQuery();
641                 while (result.next())
642                 {
643                     String sessionId = result.getString("sessionId");
644                     expiredSessionIds.add(sessionId);
645                     if (Log.isDebugEnabled()) Log.debug("Found expired sessionId="+sessionId);
646                 }
647 
648 
649                 //tell the SessionManagers to expire any sessions with a matching sessionId in memory
650                 Handler[] contexts = _server.getChildHandlersByClass(WebAppContext.class);
651                 for (int i=0; contexts!=null && i<contexts.length; i++)
652                 {
653                     AbstractSessionManager manager = ((AbstractSessionManager)((WebAppContext)contexts[i]).getSessionHandler().getSessionManager());
654                     if (manager instanceof JDBCSessionManager)
655                     {
656                         ((JDBCSessionManager)manager).expire(expiredSessionIds);
657                     }
658                 }
659 
660                 //find all sessions that have expired at least a couple of scanIntervals ago and just delete them
661                 upperBound = _lastScavengeTime - (2 * _scavengeIntervalMs);
662                 if (upperBound > 0)
663                 {
664                     if (Log.isDebugEnabled()) Log.debug("Deleting old expired sessions expired before "+upperBound);
665                     statement = connection.prepareStatement(_deleteOldExpiredSessions);
666                     statement.setLong(1, upperBound);
667                     statement.executeUpdate();
668                 }
669             }
670         }
671         catch (Exception e)
672         {
673             Log.warn("Problem selecting expired sessions", e);
674         }
675         finally
676         {           
677             _lastScavengeTime=System.currentTimeMillis();
678             if (Log.isDebugEnabled()) Log.debug("Scavenge sweep ended at "+_lastScavengeTime);
679             if (connection != null)
680             {
681                 try
682                 {
683                 connection.close();
684                 }
685                 catch (SQLException e)
686                 {
687                     Log.warn(e);
688                 }
689             }
690         }
691     }
692 }