001    package railo.runtime.net.http;
002    
003    import java.io.BufferedReader;
004    import java.io.IOException;
005    import java.io.InputStream;
006    import java.io.Serializable;
007    import java.util.Enumeration;
008    import java.util.HashMap;
009    
010    import javax.servlet.RequestDispatcher;
011    import javax.servlet.ServletInputStream;
012    import javax.servlet.http.HttpServletRequest;
013    import javax.servlet.http.HttpServletRequestWrapper;
014    
015    import railo.commons.io.IOUtil;
016    import railo.commons.lang.StringUtil;
017    import railo.runtime.PageContext;
018    import railo.runtime.engine.ThreadLocalPageContext;
019    import railo.runtime.type.scope.FormUpload;
020    import railo.runtime.util.EnumerationWrapper;
021    
022    /**
023     * extends a existing {@link HttpServletRequest} with the possibility to reread the input as many you want.
024     */
025    public final class HTTPServletRequestWrap extends HttpServletRequestWrapper implements Serializable {
026    
027            private boolean firstRead=true;
028            private byte[] barr;
029            private static final int MIN_STORAGE_SIZE=1*1024*1024;
030            private static final int MAX_STORAGE_SIZE=50*1024*1024;
031            private static final int SPACIG=1024*1024;
032            
033            private String servlet_path;
034            private String request_uri;
035            private String context_path;
036            private String path_info;
037            private String query_string;
038            private HashMap<String, Object> disconnectedData;
039            private boolean disconnected;
040            private HttpServletRequest req;
041            //private Request _request;
042    
043            /**
044             * Constructor of the class
045             * @param req 
046             * @param max how many is possible to re read
047             */
048            public HTTPServletRequestWrap(HttpServletRequest req) {
049                    super(req);
050                    this.req=pure(req);
051                    
052                    if((servlet_path=attrAsString("javax.servlet.include.servlet_path"))!=null){
053                            request_uri=attrAsString("javax.servlet.include.request_uri");
054                            context_path=attrAsString("javax.servlet.include.context_path");
055                            path_info=attrAsString("javax.servlet.include.path_info");
056                            query_string = attrAsString("javax.servlet.include.query_string");
057                    }
058                    
059                    //forward
060                    /*else if((servlet_path=attr("javax.servlet.forward.servlet_path"))!=null){
061                            request_uri=attr("javax.servlet.forward.request_uri");
062                            context_path=attr("javax.servlet.forward.context_path");
063                            path_info=attr("javax.servlet.forward.path_info");
064                            query_string = attr("javax.servlet.forward.query_string");
065                    }*/
066                    
067                    else {
068                            servlet_path=super.getServletPath();
069                            request_uri=super.getRequestURI();
070                            context_path=super.getContextPath();
071                            path_info=super.getPathInfo();
072                            query_string = super.getQueryString();
073                    }
074                    /*Enumeration names = req.getAttributeNames();
075                    while(names.hasMoreElements()){
076                            String key=(String)names.nextElement();
077                            print.out(key+"+"+req.getAttribute(key));
078                    }
079                    
080    
081                    print.out("super:"+req.getClass().getName());
082                    print.out("servlet_path:"+servlet_path);
083                    print.out("request_uri:"+request_uri);
084                    print.out("path_info:"+path_info);
085                    print.out("query_string:"+query_string);
086                    
087                    print.out("servlet_path."+req.getServletPath());
088                    print.out("request_uri."+req.getRequestURI());
089                    print.out("path_info."+req.getPathInfo());
090                    print.out("query_string."+req.getQueryString());
091                    */
092            }
093            
094            private String attrAsString(String key) {
095                    Object res = getAttribute(key);
096                    if(res==null) return null;
097                    return res.toString();
098            }
099            
100            public static HttpServletRequest pure(HttpServletRequest req) {
101                    HttpServletRequest req2;
102                    while(req instanceof HTTPServletRequestWrap){
103                            req2 = (HttpServletRequest) ((HTTPServletRequestWrap)req).getOriginalRequest();
104                            if(req2==req) break;
105                            req=req2;
106                    }
107                    return req;
108            }
109    
110            /**
111             * @see javax.servlet.http.HttpServletRequestWrapper#getContextPath()
112             */
113            public String getContextPath() {
114                    return context_path;
115            }
116            
117            /**
118             * @see javax.servlet.http.HttpServletRequestWrapper#getPathInfo()
119             */
120            public String getPathInfo() {
121                    return path_info;
122            }
123            
124            /**
125             * @see javax.servlet.http.HttpServletRequestWrapper#getRequestURL()
126             */
127            public StringBuffer getRequestURL() {
128                    return new StringBuffer(isSecure()?"https":"http").
129                            append("://").
130                            append(getServerName()).
131                            append(':').
132                            append(getServerPort()).
133                            append(request_uri.startsWith("/")?request_uri:"/"+request_uri);
134            }
135            
136            /**
137             * @see javax.servlet.http.HttpServletRequestWrapper#getQueryString()
138             */
139            public String getQueryString() {
140                    return query_string;
141            }
142            /**
143             * @see javax.servlet.http.HttpServletRequestWrapper#getRequestURI()
144             */
145            public String getRequestURI() {
146                    return request_uri;
147            }
148            
149            /**
150             * @see javax.servlet.http.HttpServletRequestWrapper#getServletPath()
151             */
152            public String getServletPath() {
153                    return servlet_path;
154            }
155            
156            /**
157             * @see javax.servlet.ServletRequestWrapper#getRequestDispatcher(java.lang.String)
158             */
159            public RequestDispatcher getRequestDispatcher(String realpath) {
160                    return new RequestDispatcherWrap(this,realpath);
161            }
162            
163            public RequestDispatcher getOriginalRequestDispatcher(String realpath) {
164                    if(disconnected) return null;
165                    return req.getRequestDispatcher(realpath);
166            }
167    
168            /**
169             * @see javax.servlet.ServletRequestWrapper#removeAttribute(java.lang.String)
170             */
171            public void removeAttribute(String name) {
172                    if(disconnected) disconnectedData.remove(name); 
173                    else req.removeAttribute(name);
174            }
175    
176            /**
177             * @see javax.servlet.ServletRequestWrapper#setAttribute(java.lang.String, java.lang.Object)
178             */
179            public void setAttribute(String name, Object value) {
180                    if(disconnected) disconnectedData.put(name, value);
181                    else req.setAttribute(name, value);
182            }
183            
184            /*public void setAttributes(Request request) {
185                    this._request=request;
186            }*/
187    
188    
189            /**
190             * @see javax.servlet.ServletRequestWrapper#getAttribute(java.lang.String)
191             */
192            public Object getAttribute(String name) {
193                    if(disconnected) return disconnectedData.get(name);
194                    return req.getAttribute(name);
195            }
196    
197            public Enumeration getAttributeNames() {
198                    if(disconnected) return new EnumerationWrapper(disconnectedData);
199                    return req.getAttributeNames();
200                    
201            }
202    
203            /**
204             * this method still throws a error if want read input stream a second time
205             * this is done to be compatibility with servletRequest class
206             * @see javax.servlet.ServletRequestWrapper#getInputStream()
207             */
208            public ServletInputStream getInputStream() throws IOException {
209                    //if(ba rr!=null) throw new IllegalStateException();
210                    if(barr==null) {
211                            if(!firstRead) {
212                                    PageContext pc = ThreadLocalPageContext.get();
213                                    if(pc!=null) {
214                                            return ((FormUpload)pc.formScope()).getInputStream();
215                                    }
216                                    return new ServletInputStreamDummy(new byte[]{});       //throw new IllegalStateException();
217                            }
218                            
219                            firstRead=false;
220                            if(isToBig(getContentLength())) {
221                                    return super.getInputStream();
222                            }
223                            InputStream is=null;
224                            try {
225                                    barr=IOUtil.toBytes(is=super.getInputStream());
226                                    
227                                    //Resource res = ResourcesImpl.getFileResourceProvider().getResource("/Users/mic/Temp/multipart.txt");
228                                    //IOUtil.copy(new ByteArrayInputStream(barr), res, true);
229                                    
230                            }
231                            catch(Throwable t) {
232                                    barr=null;
233                                    return new ServletInputStreamDummy(new byte[]{});        
234                            }
235                            finally {
236                                    IOUtil.closeEL(is);
237                            }
238                    }
239                    
240                    return new ServletInputStreamDummy(barr);       
241            }
242            
243            private boolean isToBig(int contentLength) {
244                    if(contentLength<MIN_STORAGE_SIZE) return false;
245                    if(contentLength>MAX_STORAGE_SIZE) return true;
246                    Runtime rt = Runtime.getRuntime();
247                    long av = rt.maxMemory()-rt.totalMemory()+rt.freeMemory();
248                    return (av-SPACIG)<contentLength;
249            }
250    
251            /* *
252             * with this method it is possibiliy to rewrite the input as many you want
253             * @return input stream from request
254             * @throws IOException
255             * /
256            public ServletInputStream getStoredInputStream() throws IOException {
257                    if(firstRead || barr!=null) return getInputStream();
258                    return new ServletInputStreamDummy(new byte[]{});        
259            }*/
260    
261            /**
262             *
263             * @see javax.servlet.ServletRequestWrapper#getReader()
264             */
265            public BufferedReader getReader() throws IOException {
266                    String enc = getCharacterEncoding();
267                    if(StringUtil.isEmpty(enc))enc="iso-8859-1";
268                    return IOUtil.toBufferedReader(IOUtil.getReader(getInputStream(), enc));
269            }
270            
271            public void clear() {
272                    barr=null;
273            }
274    
275            
276    
277    
278            public HttpServletRequest getOriginalRequest() {
279                    return req;
280            }
281    
282            public void disconnect() {
283                    if(disconnected) return;
284                    Enumeration<String> names = req.getAttributeNames();
285                    disconnectedData=new HashMap<String, Object>();
286                    String k;
287                    while(names.hasMoreElements()){
288                            k=names.nextElement();
289                            disconnectedData.put(k, req.getAttribute(k));
290                    }
291                    disconnected=true;
292                    req=null;
293            }
294    }