1   // ========================================================================
2   // Copyright 2008 Mort Bay Consulting Pty. Ltd.
3   // ------------------------------------------------------------------------
4   // Licensed under the Apache License, Version 2.0 (the "License");
5   // you may not use this file except in compliance with the License.
6   // You may obtain a copy of the License at 
7   // http://www.apache.org/licenses/LICENSE-2.0
8   // Unless required by applicable law or agreed to in writing, software
9   // distributed under the License is distributed on an "AS IS" BASIS,
10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  // See the License for the specific language governing permissions and
12  // limitations under the License.
13  // ========================================================================
14  
15  package org.mortbay.jetty.servlet;
16  
17  import java.io.ByteArrayInputStream;
18  import java.io.InputStream;
19  import java.sql.Blob;
20  import java.sql.Connection;
21  import java.sql.DatabaseMetaData;
22  import java.sql.DriverManager;
23  import java.sql.PreparedStatement;
24  import java.sql.ResultSet;
25  import java.sql.SQLException;
26  import java.sql.Statement;
27  import java.util.ArrayList;
28  import java.util.HashSet;
29  import java.util.List;
30  import java.util.Random;
31  import java.util.Timer;
32  import java.util.TimerTask;
33  
34  import javax.naming.InitialContext;
35  import javax.servlet.http.HttpServletRequest;
36  import javax.servlet.http.HttpSession;
37  import javax.sql.DataSource;
38  
39  import org.mortbay.jetty.Handler;
40  import org.mortbay.jetty.Server;
41  import org.mortbay.jetty.SessionManager;
42  import org.mortbay.jetty.handler.ContextHandler;
43  import org.mortbay.log.Log;
44  
45  
46  
47  /**
48   * JDBCSessionIdManager
49   *
50   * SessionIdManager implementation that uses a database to store in-use session ids, 
51   * to support distributed sessions.
52   * 
53   */
54  public class JDBCSessionIdManager extends AbstractSessionIdManager
55  {    
56      protected HashSet<String> _sessionIds = new HashSet();
57      protected String _driverClassName;
58      protected String _connectionUrl;
59      protected DataSource _datasource;
60      protected String _jndiName;
61      protected String _sessionIdTable = "JettySessionIds";
62      protected String _sessionTable = "JettySessions";
63      protected Timer _timer; //scavenge timer
64      protected TimerTask _task; //scavenge task
65      protected long _lastScavengeTime;
66      protected long _scavengeIntervalMs = 1000 * 60 * 10; //10mins
67      
68      
69      protected String _createSessionIdTable;
70      protected String _createSessionTable;
71                                              
72      protected String _selectExpiredSessions;
73      protected String _deleteOldExpiredSessions;
74  
75      protected String _insertId;
76      protected String _deleteId;
77      protected String _queryId;
78      
79      protected DatabaseAdaptor _dbAdaptor;
80  
81      
82      /**
83       * DatabaseAdaptor
84       *
85       * Handles differences between databases.
86       * 
87       * Postgres uses the getBytes and setBinaryStream methods to access
88       * a "bytea" datatype, which can be up to 1Gb of binary data. MySQL
89       * is happy to use the "blob" type and getBlob() methods instead.
90       * 
91       * TODO if the differences become more major it would be worthwhile
92       * refactoring this class.
93       */
94      public class DatabaseAdaptor 
95      {
96          String _dbName;
97          boolean _isLower;
98          boolean _isUpper;
99          
100         
101         public DatabaseAdaptor (DatabaseMetaData dbMeta)
102         throws SQLException
103         {
104             _dbName = dbMeta.getDatabaseProductName().toLowerCase(); 
105             Log.debug ("Using database "+_dbName);
106             _isLower = dbMeta.storesLowerCaseIdentifiers();
107             _isUpper = dbMeta.storesUpperCaseIdentifiers();
108         }
109         
110         /**
111          * Convert a camel case identifier into either upper or lower
112          * depending on the way the db stores identifiers.
113          * 
114          * @param identifier
115          * @return
116          */
117         public String convertIdentifier (String identifier)
118         {
119             if (_isLower)
120                 return identifier.toLowerCase();
121             if (_isUpper)
122                 return identifier.toUpperCase();
123             
124             return identifier;
125         }
126         
127         public String getBlobType ()
128         {
129             if (_dbName.startsWith("postgres"))
130                 return "bytea";
131             
132             return "blob";
133         }
134         
135         public InputStream getBlobInputStream (ResultSet result, String columnName)
136         throws SQLException
137         {
138             if (_dbName.startsWith("postgres"))
139             {
140                 byte[] bytes = result.getBytes(columnName);
141                 return new ByteArrayInputStream(bytes);
142             }
143             
144             Blob blob = result.getBlob(columnName);
145             return blob.getBinaryStream();
146         }
147     }
148     
149     
150     
151     public JDBCSessionIdManager(Server server)
152     {
153         super(server);
154     }
155     
156     public JDBCSessionIdManager(Server server, Random random)
157     {
158        super(server, random);
159     }
160 
161     /**
162      * Configure jdbc connection information via a jdbc Driver
163      * 
164      * @param driverClassName
165      * @param connectionUrl
166      */
167     public void setDriverInfo (String driverClassName, String connectionUrl)
168     {
169         _driverClassName=driverClassName;
170         _connectionUrl=connectionUrl;
171     }
172     
173     public String getDriverClassName()
174     {
175         return _driverClassName;
176     }
177     
178     public String getConnectionUrl ()
179     {
180         return _connectionUrl;
181     }
182     
183     public void setDatasourceName (String jndi)
184     {
185         _jndiName=jndi;
186     }
187     
188     public String getDatasourceName ()
189     {
190         return _jndiName;
191     }
192    
193     
194     public void setScavengeInterval (long sec)
195     {
196         if (sec<=0)
197             sec=60;
198 
199         long old_period=_scavengeIntervalMs;
200         long period=sec*1000;
201       
202         _scavengeIntervalMs=period;
203         
204         //add a bit of variability into the scavenge time so that not all
205         //nodes with the same scavenge time sync up
206         long tenPercent = _scavengeIntervalMs/10;
207         if ((System.currentTimeMillis()%2) == 0)
208             _scavengeIntervalMs += tenPercent;
209         
210         if (Log.isDebugEnabled()) Log.debug("Scavenging every "+_scavengeIntervalMs+" ms");
211         if (_timer!=null && (period!=old_period || _task==null))
212         {
213             synchronized (this)
214             {
215                 if (_task!=null)
216                     _task.cancel();
217                 _task = new TimerTask()
218                 {
219                     public void run()
220                     {
221                         scavenge();
222                     }   
223                 };
224                 _timer.schedule(_task,_scavengeIntervalMs,_scavengeIntervalMs);
225             }
226         }  
227     }
228     
229     public long getScavengeInterval ()
230     {
231         return _scavengeIntervalMs/1000;
232     }
233     
234     
235     public void addSession(HttpSession session)
236     {
237         if (session == null)
238             return;
239         
240         synchronized (_sessionIds)
241         {
242             String id = ((JDBCSessionManager.Session)session).getClusterId();            
243             try
244             {
245                 insert(id);
246                 _sessionIds.add(id);
247             }
248             catch (Exception e)
249             {
250                 Log.warn("Problem storing session id="+id, e);
251             }
252         }
253     }
254     
255     public void removeSession(HttpSession session)
256     {
257         if (session == null)
258             return;
259         
260         removeSession(((JDBCSessionManager.Session)session).getClusterId());
261     }
262     
263     
264     
265     public void removeSession (String id)
266     {
267 
268         if (id == null)
269             return;
270         
271         synchronized (_sessionIds)
272         {  
273             if (Log.isDebugEnabled())
274                 Log.debug("Removing session id="+id);
275             try
276             {               
277                 _sessionIds.remove(id);
278                 delete(id);
279             }
280             catch (Exception e)
281             {
282                 Log.warn("Problem removing session id="+id, e);
283             }
284         }
285         
286     }
287     
288 
289     /** 
290      * Get the session id without any node identifier suffix.
291      * 
292      * @see org.mortbay.jetty.SessionIdManager#getClusterId(java.lang.String)
293      */
294     public String getClusterId(String nodeId)
295     {
296         int dot=nodeId.lastIndexOf('.');
297         return (dot>0)?nodeId.substring(0,dot):nodeId;
298     }
299     
300 
301     /** 
302      * Get the session id, including this node's id as a suffix.
303      * 
304      * @see org.mortbay.jetty.SessionIdManager#getNodeId(java.lang.String, javax.servlet.http.HttpServletRequest)
305      */
306     public String getNodeId(String clusterId, HttpServletRequest request)
307     {
308         if (_workerName!=null)
309             return clusterId+'.'+_workerName;
310 
311         return clusterId;
312     }
313 
314 
315     public boolean idInUse(String id)
316     {
317         if (id == null)
318             return false;
319         
320         String clusterId = getClusterId(id);
321         
322         synchronized (_sessionIds)
323         {
324             if (_sessionIds.contains(clusterId))
325                 return true; //optimisation - if this session is one we've been managing, we can check locally
326             
327             //otherwise, we need to go to the database to check
328             try
329             {
330                 return exists(clusterId);
331             }
332             catch (Exception e)
333             {
334                 Log.warn("Problem checking inUse for id="+clusterId, e);
335                 return false;
336             }
337         }
338     }
339 
340     /** 
341      * Invalidate the session matching the id on all contexts.
342      * 
343      * @see org.mortbay.jetty.SessionIdManager#invalidateAll(java.lang.String)
344      */
345     public void invalidateAll(String id)
346     {            
347         //take the id out of the list of known sessionids for this node
348         removeSession(id);
349         
350         synchronized (_sessionIds)
351         {
352             //tell all contexts that may have a session object with this id to
353             //get rid of them
354             Handler[] contexts = _server.getChildHandlersByClass(ContextHandler.class);
355             for (int i=0; contexts!=null && i<contexts.length; i++)
356             {
357                 SessionManager manager = (SessionManager)
358                     ((SessionHandler)((ContextHandler)contexts[i]).getChildHandlerByClass(SessionHandler.class)).getSessionManager();
359                         
360                 if (manager instanceof JDBCSessionManager)
361                 {
362                     ((JDBCSessionManager)manager).invalidateSession(id);
363                 }
364             }
365         }
366     }
367 
368 
369     /** 
370      * Start up the id manager.
371      * 
372      * Makes necessary database tables and starts a Session
373      * scavenger thread.
374      * 
375      * @see org.mortbay.jetty.servlet.AbstractSessionIdManager#doStart()
376      */
377     public void doStart()
378     {
379         try
380         {            
381             initializeDatabase();
382             prepareTables();        
383             super.doStart();
384             if (Log.isDebugEnabled()) Log.debug("Scavenging interval = "+getScavengeInterval()+" sec");
385             _timer=new Timer("JDBCSessionScavenger", true);
386             setScavengeInterval(getScavengeInterval());
387         }
388         catch (Exception e)
389         {
390             Log.warn("Problem initialising JettySessionIds table", e);
391         }
392     }
393     
394     /** 
395      * Stop the scavenger.
396      * 
397      * @see org.mortbay.component.AbstractLifeCycle#doStop()
398      */
399     public void doStop () 
400     throws Exception
401     {
402         synchronized(this)
403         {
404             if (_task!=null)
405                 _task.cancel();
406             if (_timer!=null)
407                 _timer.cancel();
408             _timer=null;
409         }
410         super.doStop();
411     }
412     
413     /**
414      * Get a connection from the driver or datasource.
415      * 
416      * @return
417      * @throws SQLException
418      */
419     protected Connection getConnection ()
420     throws SQLException
421     {
422         if (_datasource != null)
423             return _datasource.getConnection();
424         else
425             return DriverManager.getConnection(_connectionUrl);
426     }
427 
428     
429     private void initializeDatabase ()
430     throws Exception
431     {
432         if (_jndiName!=null)
433         {
434             InitialContext ic = new InitialContext();
435             _datasource = (DataSource)ic.lookup(_jndiName);
436         }
437         else if (_driverClassName!=null && _connectionUrl!=null)
438         {
439             Class.forName(_driverClassName);
440         }
441         else
442             throw new IllegalStateException("No database configured for sessions");
443     }
444     
445     
446     
447     /**
448      * Set up the tables in the database
449      * @throws SQLException
450      */
451     private void prepareTables()
452     throws SQLException
453     {
454         _createSessionIdTable = "create table "+_sessionIdTable+" (id varchar(60), primary key(id))";
455         _selectExpiredSessions = "select * from "+_sessionTable+" where expiryTime >= ? and expiryTime <= ?";
456         _deleteOldExpiredSessions = "delete from "+_sessionTable+" where expiryTime >0 and expiryTime <= ?";
457 
458         _insertId = "insert into "+_sessionIdTable+" (id)  values (?)";
459         _deleteId = "delete from "+_sessionIdTable+" where id = ?";
460         _queryId = "select * from "+_sessionIdTable+" where id = ?";
461 
462         Connection connection = null;
463         try
464         {
465             //make the id table
466             connection = getConnection();
467             connection.setAutoCommit(true);
468             DatabaseMetaData metaData = connection.getMetaData();
469             _dbAdaptor = new DatabaseAdaptor(metaData);
470 
471             //checking for table existence is case-sensitive, but table creation is not
472             String tableName = _dbAdaptor.convertIdentifier(_sessionIdTable);
473             ResultSet result = metaData.getTables(null, null, tableName, null);
474             if (!result.next())
475             {
476                 //table does not exist, so create it
477                 connection.createStatement().executeUpdate(_createSessionIdTable);
478             }
479             
480             //make the session table if necessary
481             tableName = _dbAdaptor.convertIdentifier(_sessionTable);   
482             result = metaData.getTables(null, null, tableName, null);
483             if (!result.next())
484             {
485                 //table does not exist, so create it
486                 String blobType = _dbAdaptor.getBlobType();
487                 _createSessionTable = "create table "+_sessionTable+" (rowId varchar(60), sessionId varchar(60), "+
488                                            " contextPath varchar(60), virtualHost varchar(60), lastNode varchar(60), accessTime bigint, "+
489                                            " lastAccessTime bigint, createTime bigint, cookieTime bigint, "+
490                                            " lastSavedTime bigint, expiryTime bigint, map "+blobType+", primary key(rowId))";
491                 connection.createStatement().executeUpdate(_createSessionTable);
492             }
493             
494             //make some indexes on the JettySessions table
495             String index1 = "idx_"+_sessionTable+"_expiry";
496             String index2 = "idx_"+_sessionTable+"_session";
497             
498             result = metaData.getIndexInfo(null, null, tableName, false, false);
499             boolean index1Exists = false;
500             boolean index2Exists = false;
501             while (result.next())
502             {
503                 String idxName = result.getString("INDEX_NAME");
504                 if (index1.equalsIgnoreCase(idxName))
505                     index1Exists = true;
506                 else if (index2.equalsIgnoreCase(idxName))
507                     index2Exists = true;
508             }
509             if (!(index1Exists && index2Exists))
510             {
511                 Statement statement = connection.createStatement();
512                 if (!index1Exists)
513                     statement.executeUpdate("create index "+index1+" on "+_sessionTable+" (expiryTime)");
514                 if (!index2Exists)
515                     statement.executeUpdate("create index "+index2+" on "+_sessionTable+" (sessionId, contextPath)");
516             }
517         }
518         finally
519         {
520             if (connection != null)
521                 connection.close();
522         }
523     }
524     
525     /**
526      * Insert a new used session id into the table.
527      * 
528      * @param id
529      * @throws SQLException
530      */
531     private void insert (String id)
532     throws SQLException 
533     {
534         Connection connection = null;
535         try
536         {
537             connection = getConnection();
538             connection.setAutoCommit(true);            
539             PreparedStatement query = connection.prepareStatement(_queryId);
540             query.setString(1, id);
541             ResultSet result = query.executeQuery();
542             //only insert the id if it isn't in the db already 
543             if (!result.next())
544             {
545                 PreparedStatement statement = connection.prepareStatement(_insertId);
546                 statement.setString(1, id);
547                 statement.executeUpdate();
548             }
549         }
550         finally
551         {
552             if (connection != null)
553                 connection.close();
554         }
555     }
556     
557     /**
558      * Remove a session id from the table.
559      * 
560      * @param id
561      * @throws SQLException
562      */
563     private void delete (String id)
564     throws SQLException
565     {
566         Connection connection = null;
567         try
568         {
569             connection = getConnection();
570             connection.setAutoCommit(true);
571             PreparedStatement statement = connection.prepareStatement(_deleteId);
572             statement.setString(1, id);
573             statement.executeUpdate();
574         }
575         finally
576         {
577             if (connection != null)
578                 connection.close();
579         }
580     }
581     
582     
583     /**
584      * Check if a session id exists.
585      * 
586      * @param id
587      * @return
588      * @throws SQLException
589      */
590     private boolean exists (String id)
591     throws SQLException
592     {
593         Connection connection = null;
594         try
595         {
596             connection = getConnection();
597             connection.setAutoCommit(true);
598             PreparedStatement statement = connection.prepareStatement(_queryId);
599             statement.setString(1, id);
600             ResultSet result = statement.executeQuery();
601             if (result.next())
602                 return true;
603             else
604                 return false;
605         }
606         finally
607         {
608             if (connection != null)
609                 connection.close();
610         }
611     }
612     
613     /**
614      * Look for sessions in the database that have expired.
615      * 
616      * We do this in the SessionIdManager and not the SessionManager so
617      * that we only have 1 scavenger, otherwise if there are n SessionManagers
618      * there would be n scavengers, all contending for the database.
619      * 
620      * We look first for sessions that expired in the previous interval, then
621      * for sessions that expired previously - these are old sessions that no
622      * node is managing any more and have become stuck in the database.
623      */
624     private void scavenge ()
625     {
626         Connection connection = null;
627         List expiredSessionIds = new ArrayList();
628         try
629         {            
630             if (Log.isDebugEnabled()) Log.debug("Scavenge sweep started at "+System.currentTimeMillis());
631             if (_lastScavengeTime > 0)
632             {
633                 connection = getConnection();
634                 connection.setAutoCommit(true);
635                 //"select sessionId from JettySessions where expiryTime > (lastScavengeTime - scanInterval) and expiryTime < lastScavengeTime";
636                 PreparedStatement statement = connection.prepareStatement(_selectExpiredSessions);
637                 long lowerBound = (_lastScavengeTime - _scavengeIntervalMs);
638                 long upperBound = _lastScavengeTime;
639                 if (Log.isDebugEnabled()) Log.debug("Searching for sessions expired between "+lowerBound + " and "+upperBound);
640                 statement.setLong(1, lowerBound);
641                 statement.setLong(2, upperBound);
642                 ResultSet result = statement.executeQuery();
643                 while (result.next())
644                 {
645                     String sessionId = result.getString("sessionId");
646                     expiredSessionIds.add(sessionId);
647                     if (Log.isDebugEnabled()) Log.debug("Found expired sessionId="+sessionId);
648                 }
649 
650 
651                 //tell the SessionManagers to expire any sessions with a matching sessionId in memory
652                 Handler[] contexts = _server.getChildHandlersByClass(ContextHandler.class);
653                 for (int i=0; contexts!=null && i<contexts.length; i++)
654                 {
655                     SessionManager manager = (SessionManager)
656                         ((SessionHandler)((ContextHandler)contexts[i]).getChildHandlerByClass(SessionHandler.class)).getSessionManager();
657                             
658                     if (manager instanceof JDBCSessionManager)
659                     {
660                         ((JDBCSessionManager)manager).expire(expiredSessionIds);
661                     }
662                 }
663 
664                 //find all sessions that have expired at least a couple of scanIntervals ago and just delete them
665                 upperBound = _lastScavengeTime - (2 * _scavengeIntervalMs);
666                 if (upperBound > 0)
667                 {
668                     if (Log.isDebugEnabled()) Log.debug("Deleting old expired sessions expired before "+upperBound);
669                     statement = connection.prepareStatement(_deleteOldExpiredSessions);
670                     statement.setLong(1, upperBound);
671                     statement.executeUpdate();
672                 }
673             }
674         }
675         catch (Exception e)
676         {
677             Log.warn("Problem selecting expired sessions", e);
678         }
679         finally
680         {           
681             _lastScavengeTime=System.currentTimeMillis();
682             if (Log.isDebugEnabled()) Log.debug("Scavenge sweep ended at "+_lastScavengeTime);
683             if (connection != null)
684             {
685                 try
686                 {
687                 connection.close();
688                 }
689                 catch (SQLException e)
690                 {
691                     Log.warn(e);
692                 }
693             }
694         }
695     }
696 }