1   package org.mortbay.jetty.security;
2   
3   import java.io.IOException;
4   import java.nio.ByteBuffer;
5   import java.nio.channels.SelectionKey;
6   import java.nio.channels.SocketChannel;
7   
8   import javax.net.ssl.SSLEngine;
9   import javax.net.ssl.SSLEngineResult;
10  import javax.net.ssl.SSLException;
11  import javax.net.ssl.SSLSession;
12  
13  import org.mortbay.io.Buffer;
14  import org.mortbay.io.Buffers;
15  import org.mortbay.io.nio.NIOBuffer;
16  import org.mortbay.io.nio.SelectChannelEndPoint;
17  import org.mortbay.io.nio.SelectorManager;
18  import org.mortbay.log.Log;
19  
20  /* ------------------------------------------------------------ */
21  /**
22   * SslHttpChannelEndPoint.
23   * 
24   * @author Nik Gonzalez <ngonzalez@exist.com>
25   * @author Greg Wilkins <gregw@mortbay.com>
26   */
27  public class SslHttpChannelEndPoint extends SelectChannelEndPoint
28  {
29      private static final ByteBuffer[] __NO_BUFFERS={};
30      private static final ByteBuffer __EMPTY=ByteBuffer.allocate(0);
31  
32      private Buffers _buffers;
33      
34      private SSLEngine _engine;
35      private ByteBuffer _inBuffer;
36      private NIOBuffer _inNIOBuffer;
37      private ByteBuffer _outBuffer;
38      private NIOBuffer _outNIOBuffer;
39  
40      private NIOBuffer[] _reuseBuffer=new NIOBuffer[2];    
41      private ByteBuffer[] _gather=new ByteBuffer[2];
42  
43      // ssl
44      protected SSLSession _session;
45      
46      /* ------------------------------------------------------------ */
47      public SslHttpChannelEndPoint(Buffers buffers,SocketChannel channel, SelectorManager.SelectSet selectSet, SelectionKey key, SSLEngine engine)
48              throws SSLException, IOException
49      {
50          super(channel,selectSet,key);
51          _buffers=buffers;
52          
53          // ssl
54          _engine=engine;
55          _engine.setUseClientMode(false);
56          _session=engine.getSession();
57  
58          // TODO pool buffers and use only when needed.
59          _outNIOBuffer=(NIOBuffer)buffers.getBuffer(_session.getPacketBufferSize());
60          _outBuffer=_outNIOBuffer.getByteBuffer();
61          _inNIOBuffer=(NIOBuffer)buffers.getBuffer(_session.getPacketBufferSize());
62          _inBuffer=_inNIOBuffer.getByteBuffer();
63          
64      }
65  
66      /* ------------------------------------------------------------ */
67      /* (non-Javadoc)
68       * @see org.mortbay.io.nio.SelectChannelEndPoint#idleExpired()
69       */
70      protected void idleExpired()
71      {
72          try
73          {
74              _selectSet.getManager().dispatch(new Runnable()
75              {
76                  public void run() 
77                  { 
78                      try 
79                      {
80                          close(); 
81                      }
82                      catch(Exception e)
83                      {
84                          Log.ignore(e);
85                      }
86                  }
87              });
88          }
89          catch(Exception e)
90          {
91              Log.ignore(e);
92          }
93      }
94  
95  
96  
97      /* ------------------------------------------------------------ */
98      public void close() throws IOException
99      {
100         _engine.closeOutbound();
101         try
102         {   
103             int tries=0;
104             loop: while (isOpen() && !_engine.isOutboundDone() && tries++<100)
105             {
106                 if (_outNIOBuffer.length()>0)
107                 {
108                     flush();
109                     Thread.sleep(10); // TODO yuck
110                 }
111                 
112                 switch(_engine.getHandshakeStatus())
113                 {
114                     case FINISHED:
115                     case NOT_HANDSHAKING:
116                         break loop;
117                         
118                     case NEED_UNWRAP:
119                         if(!fill(__EMPTY))
120                             Thread.yield(); 
121                             Thread.sleep(10); // TODO yuck
122                         break;
123                         
124                     case NEED_TASK:
125                     {
126                         Runnable task;
127                         while ((task=_engine.getDelegatedTask())!=null)
128                         {
129                             task.run();
130                         }
131                         break;
132                     }
133                         
134                     case NEED_WRAP:
135                     {
136                         if (_outNIOBuffer.length()>0)
137                             flush();
138                         
139                         SSLEngineResult result=null;
140                         try
141                         {
142                             _outNIOBuffer.compact();
143                             int put=_outNIOBuffer.putIndex();
144                             _outBuffer.position(put);
145                             result=_engine.wrap(__NO_BUFFERS,_outBuffer);
146                             _outNIOBuffer.setPutIndex(put+result.bytesProduced());
147                         }
148                         finally
149                         {
150                             _outBuffer.position(0);
151                         }
152                         
153                         flush();
154                         
155                         break;
156                     }
157                 }
158             }
159             
160         }
161         catch(IOException e)
162         {
163             Log.ignore(e);
164         }
165         catch (InterruptedException e)
166         {
167             Log.ignore(e);
168         }
169         finally
170         {
171             super.close();
172             
173             if (_inNIOBuffer!=null)
174                 _buffers.returnBuffer(_inNIOBuffer);
175             if (_outNIOBuffer!=null)
176                 _buffers.returnBuffer(_outNIOBuffer);
177             if (_reuseBuffer[0]!=null)
178                 _buffers.returnBuffer(_reuseBuffer[0]);
179             if (_reuseBuffer[1]!=null)
180                 _buffers.returnBuffer(_reuseBuffer[1]);
181         }
182         
183         
184     }
185 
186     /* ------------------------------------------------------------ */
187     /* 
188      */
189     public int fill(Buffer buffer) throws IOException
190     {
191         synchronized(buffer)
192         {
193             ByteBuffer bbuf=extractInputBuffer(buffer);
194             int size=buffer.length();
195 
196             try
197             {
198                 fill(bbuf);
199 
200                 loop: while (_inBuffer.remaining()>0)
201                 {
202                     if (_outNIOBuffer.length()>0)
203                         flush();
204                     
205                     switch(_engine.getHandshakeStatus())
206                     {
207                         case FINISHED:
208                         case NOT_HANDSHAKING:
209                             break loop;
210 
211                         case NEED_UNWRAP:
212                             if(!fill(bbuf))
213                                 break loop;
214                             break;
215 
216                         case NEED_TASK:
217                         {
218                             Runnable task;
219                             while ((task=_engine.getDelegatedTask())!=null)
220                             {
221                                 task.run();
222                             }
223                             break;
224                         }
225 
226                         case NEED_WRAP:
227                         {
228                             SSLEngineResult result=null;
229                             synchronized(_outBuffer)
230                             {
231                                 try
232                                 {
233                                     _outNIOBuffer.compact();
234                                     int put=_outNIOBuffer.putIndex();
235                                     _outBuffer.position();
236                                     result=_engine.wrap(__NO_BUFFERS,_outBuffer);
237                                     _outNIOBuffer.setPutIndex(put+result.bytesProduced());
238                                 }
239                                 finally
240                                 {
241                                     _outBuffer.position(0);
242                                 }
243                             }
244 
245                             flush();
246 
247                             break;
248                         }
249                     }
250                 }
251             }
252             catch(SSLException e)
253             {
254                 Log.warn(e.toString());
255                 Log.debug(e);
256                 throw e;
257             }
258             finally
259             {
260                 buffer.setPutIndex(bbuf.position());
261                 bbuf.position(0);
262             }
263 
264             return buffer.length()-size; 
265         }
266     }
267 
268     /* ------------------------------------------------------------ */
269     public int flush(Buffer buffer) throws IOException
270     {
271         return flush(buffer,null,null);
272     }
273 
274 
275     /* ------------------------------------------------------------ */
276     /*     
277      */
278     public int flush(Buffer header, Buffer buffer, Buffer trailer) throws IOException
279     {
280         if (_outNIOBuffer.length()>0)
281         {
282             flush();
283             if (_outNIOBuffer.length()>0)
284                 return 0;
285         }
286 
287         SSLEngineResult result=null;
288 
289         if (header!=null && buffer!=null)
290         {
291             _gather[0]=extractOutputBuffer(header,0);
292             synchronized(_gather[0])
293             {
294                 _gather[0].position(header.getIndex());
295                 _gather[0].limit(header.putIndex());
296 
297                 _gather[1]=extractOutputBuffer(buffer,1);
298 
299                 synchronized(_gather[1])
300                 {
301                     _gather[1].position(buffer.getIndex());
302                     _gather[1].limit(buffer.putIndex());
303 
304                     synchronized(_outBuffer)
305                     {
306                         int consumed=0;
307                         try
308                         {
309                             _outNIOBuffer.clear();
310                             _outBuffer.position(0);
311                             _outBuffer.limit(_outBuffer.capacity());
312                             result=_engine.wrap(_gather,_outBuffer);
313                             _outNIOBuffer.setGetIndex(0);
314                             _outNIOBuffer.setPutIndex(result.bytesProduced());
315                             consumed=result.bytesConsumed();
316                         }
317                         finally
318                         {
319                             _outBuffer.position(0);
320 
321                             if (consumed>0 && header!=null)
322                             {
323                                 int len=consumed<header.length()?consumed:header.length();
324                                 header.skip(len);
325                                 consumed-=len;
326                                 _gather[0].position(0);
327                                 _gather[0].limit(_gather[0].capacity());
328                             }
329                             if (consumed>0 && buffer!=null)
330                             {
331                                 int len=consumed<buffer.length()?consumed:buffer.length();
332                                 buffer.skip(len);
333                                 consumed-=len;
334                                 _gather[1].position(0);
335                                 _gather[1].limit(_gather[1].capacity());
336                             }
337                             assert consumed==0;
338                         }
339                     }
340                 }
341             }
342         }
343         else
344         {
345             _gather[0]=extractOutputBuffer(header,0);
346             synchronized(_gather[0])
347             {
348                 _gather[0].position(header.getIndex());
349                 _gather[0].limit(header.putIndex());
350 
351                 int consumed=0;
352                 synchronized(_outBuffer)
353                 {
354                     try
355                     {
356                         _outNIOBuffer.clear();
357                         _outBuffer.position(0);
358                         _outBuffer.limit(_outBuffer.capacity());
359                         result=_engine.wrap(_gather[0],_outBuffer);
360                         _outNIOBuffer.setGetIndex(0);
361                         _outNIOBuffer.setPutIndex(result.bytesProduced());
362                         consumed=result.bytesConsumed();
363                     }
364                     finally
365                     {
366                         _outBuffer.position(0);
367 
368                         if (consumed>0 && header!=null)
369                         {
370                             int len=consumed<header.length()?consumed:header.length();
371                             header.skip(len);
372                             consumed-=len;
373                             _gather[0].position(0);
374                             _gather[0].limit(_gather[0].capacity());
375                         }
376                         assert consumed==0;
377                     }
378                 }
379             }
380         }
381 
382         flush();
383 
384         return result.bytesConsumed();
385     }
386 
387     
388     /* ------------------------------------------------------------ */
389     public void flush() throws IOException
390     {
391         while (_outNIOBuffer.length()>0)
392         {
393             int flushed=super.flush(_outNIOBuffer);
394             if (flushed==0)
395             {
396                 Thread.yield();
397                 flushed=super.flush(_outNIOBuffer);
398                 if (flushed==0)
399                     return;
400             }
401         }
402     }
403 
404     /* ------------------------------------------------------------ */
405     private ByteBuffer extractInputBuffer(Buffer buffer)
406     {
407         assert buffer instanceof NIOBuffer;
408         NIOBuffer nbuf=(NIOBuffer)buffer;
409         ByteBuffer bbuf=nbuf.getByteBuffer();
410         bbuf.position(buffer.putIndex());
411         return bbuf;
412     }
413     
414     /* ------------------------------------------------------------ */
415     private ByteBuffer extractOutputBuffer(Buffer buffer,int n)
416     {
417         ByteBuffer src=null;
418         NIOBuffer nBuf=null;
419 
420         if (buffer.buffer() instanceof NIOBuffer)
421         {
422             nBuf=(NIOBuffer)buffer.buffer();
423             return nBuf.getByteBuffer();
424         }
425         else
426         {
427             if (_reuseBuffer[n]==null)
428                 _reuseBuffer[n] = (NIOBuffer)_buffers.getBuffer(_session.getApplicationBufferSize());
429             NIOBuffer buf = _reuseBuffer[n];
430             buf.clear();
431             buf.put(buffer);
432             return buf.getByteBuffer();
433         }
434     }
435 
436     /* ------------------------------------------------------------ */
437     private boolean fill(ByteBuffer buffer) throws IOException
438     {
439         int in_len=0;
440 
441         if (_inNIOBuffer.hasContent())
442             _inNIOBuffer.compact();
443         else 
444             _inNIOBuffer.clear();
445 
446         while (_inNIOBuffer.space()>0)
447         {
448             int len=super.fill(_inNIOBuffer);
449             if (len<=0)
450             {
451                 if (len<0)
452                     _engine.closeInbound();
453                 if (len==0 || in_len>0)
454                     break;
455                 return false;
456             }
457             in_len+=len;
458         }
459         
460 
461         if (_inNIOBuffer.length()==0)
462             return false;
463 
464         SSLEngineResult result;
465         try
466         {
467             _inBuffer.position(_inNIOBuffer.getIndex());
468             _inBuffer.limit(_inNIOBuffer.putIndex());
469             result=_engine.unwrap(_inBuffer,buffer);
470             _inNIOBuffer.skip(result.bytesConsumed());
471         }
472         finally
473         {
474             _inBuffer.position(0);
475             _inBuffer.limit(_inBuffer.capacity());
476         }
477 
478         if (result != null)
479         {
480             switch(result.getStatus())
481             {
482                 case OK:
483                     break;
484                 case CLOSED:
485                     throw new IOException("sslEngine closed");
486                     
487                 case BUFFER_OVERFLOW:
488                     Log.debug("unwrap {}",result);
489                     break;
490                     
491                 case BUFFER_UNDERFLOW:
492                     Log.debug("unwrap {}",result);
493                     break;
494                     
495                 default:
496                     Log.warn("unwrap "+result);
497                 throw new IOException(result.toString());
498             }
499         }
500         
501         return (result.bytesProduced()+result.bytesConsumed())>0;
502     }
503 
504     /* ------------------------------------------------------------ */
505     public boolean isBufferingInput()
506     {
507         return _inNIOBuffer.hasContent();
508     }
509 
510     /* ------------------------------------------------------------ */
511     public boolean isBufferingOutput()
512     {
513         return _outNIOBuffer.hasContent();
514     }
515 
516     /* ------------------------------------------------------------ */
517     public boolean isBufferred()
518     {
519         return true;
520     }
521 
522     /* ------------------------------------------------------------ */
523     public SSLEngine getSSLEngine()
524     {
525         return _engine;
526     }
527 }