001    package railo.runtime.net.http;
002    
003    import java.io.BufferedReader;
004    import java.io.IOException;
005    import java.io.InputStream;
006    import java.io.Serializable;
007    import java.util.Enumeration;
008    import java.util.HashMap;
009    
010    import javax.servlet.RequestDispatcher;
011    import javax.servlet.ServletInputStream;
012    import javax.servlet.http.HttpServletRequest;
013    import javax.servlet.http.HttpServletRequestWrapper;
014    
015    import railo.commons.io.IOUtil;
016    import railo.commons.lang.StringUtil;
017    import railo.runtime.PageContext;
018    import railo.runtime.engine.ThreadLocalPageContext;
019    import railo.runtime.util.EnumerationWrapper;
020    
021    /**
022     * extends a existing {@link HttpServletRequest} with the possibility to reread the input as many you want.
023     */
024    public final class HTTPServletRequestWrap extends HttpServletRequestWrapper implements Serializable {
025    
026            private boolean firstRead=true;
027            private byte[] barr;
028            private static final int MIN_STORAGE_SIZE=1*1024*1024;
029            private static final int MAX_STORAGE_SIZE=50*1024*1024;
030            private static final int SPACIG=1024*1024;
031            
032            private String servlet_path;
033            private String request_uri;
034            private String context_path;
035            private String path_info;
036            private String query_string;
037            private HashMap<String, Object> disconnectedData;
038            private boolean disconnected;
039            private HttpServletRequest req;
040            //private Request _request;
041    
042            /**
043             * Constructor of the class
044             * @param req 
045             * @param max how many is possible to re read
046             */
047            public HTTPServletRequestWrap(HttpServletRequest req) {
048                    super(req);
049                    this.req=pure(req);
050                    
051                    if((servlet_path=attrAsString("javax.servlet.include.servlet_path"))!=null){
052                            request_uri=attrAsString("javax.servlet.include.request_uri");
053                            context_path=attrAsString("javax.servlet.include.context_path");
054                            path_info=attrAsString("javax.servlet.include.path_info");
055                            query_string = attrAsString("javax.servlet.include.query_string");
056                    }
057                    
058                    //forward
059                    /*else if((servlet_path=attr("javax.servlet.forward.servlet_path"))!=null){
060                            request_uri=attr("javax.servlet.forward.request_uri");
061                            context_path=attr("javax.servlet.forward.context_path");
062                            path_info=attr("javax.servlet.forward.path_info");
063                            query_string = attr("javax.servlet.forward.query_string");
064                    }*/
065                    
066                    else {
067                            servlet_path=super.getServletPath();
068                            request_uri=super.getRequestURI();
069                            context_path=super.getContextPath();
070                            path_info=super.getPathInfo();
071                            query_string = super.getQueryString();
072                    }
073                    /*Enumeration names = req.getAttributeNames();
074                    while(names.hasMoreElements()){
075                            String key=(String)names.nextElement();
076                            print.out(key+"+"+req.getAttribute(key));
077                    }
078                    
079    
080                    print.out("super:"+req.getClass().getName());
081                    print.out("servlet_path:"+servlet_path);
082                    print.out("request_uri:"+request_uri);
083                    print.out("path_info:"+path_info);
084                    print.out("query_string:"+query_string);
085                    
086                    print.out("servlet_path."+req.getServletPath());
087                    print.out("request_uri."+req.getRequestURI());
088                    print.out("path_info."+req.getPathInfo());
089                    print.out("query_string."+req.getQueryString());
090                    */
091            }
092            
093            private String attrAsString(String key) {
094                    Object res = getAttribute(key);
095                    if(res==null) return null;
096                    return res.toString();
097            }
098            
099            public static HttpServletRequest pure(HttpServletRequest req) {
100                    HttpServletRequest req2;
101                    while(req instanceof HTTPServletRequestWrap){
102                            req2 =  ((HTTPServletRequestWrap)req).getOriginalRequest();
103                            if(req2==req) break;
104                            req=req2;
105                    }
106                    return req;
107            }
108    
109            @Override
110            public String getContextPath() {
111                    return context_path;
112            }
113            
114            @Override
115            public String getPathInfo() {
116                    return path_info;
117            }
118            
119            @Override
120            public StringBuffer getRequestURL() {
121                    return new StringBuffer(isSecure()?"https":"http").
122                            append("://").
123                            append(getServerName()).
124                            append(':').
125                            append(getServerPort()).
126                            append(request_uri.startsWith("/")?request_uri:"/"+request_uri);
127            }
128            
129            @Override
130            public String getQueryString() {
131                    return query_string;
132            }
133            @Override
134            public String getRequestURI() {
135                    return request_uri;
136            }
137            
138            @Override
139            public String getServletPath() {
140                    return servlet_path;
141            }
142            
143            @Override
144            public RequestDispatcher getRequestDispatcher(String realpath) {
145                    return new RequestDispatcherWrap(this,realpath);
146            }
147            
148            public RequestDispatcher getOriginalRequestDispatcher(String realpath) {
149                    if(disconnected) return null;
150                    return req.getRequestDispatcher(realpath);
151            }
152    
153            @Override
154            public void removeAttribute(String name) {
155                    if(disconnected) disconnectedData.remove(name); 
156                    else req.removeAttribute(name);
157            }
158    
159            @Override
160            public void setAttribute(String name, Object value) {
161                    if(disconnected) disconnectedData.put(name, value);
162                    else req.setAttribute(name, value);
163            }
164            
165            /*public void setAttributes(Request request) {
166                    this._request=request;
167            }*/
168    
169    
170            @Override
171            public Object getAttribute(String name) {
172                    if(disconnected) return disconnectedData.get(name);
173                    return req.getAttribute(name);
174            }
175    
176            public Enumeration getAttributeNames() {
177                    if(disconnected) {
178                            return new EnumerationWrapper(disconnectedData);
179                    }
180                    return req.getAttributeNames();
181                    
182            }
183    
184            @Override
185            public ServletInputStream getInputStream() throws IOException {
186                    //if(ba rr!=null) throw new IllegalStateException();
187                    if(barr==null) {
188                            if(!firstRead) {
189                                    PageContext pc = ThreadLocalPageContext.get();
190                                    if(pc!=null) {
191                                            return pc.formScope().getInputStream();
192                                    }
193                                    return new ServletInputStreamDummy(new byte[]{});       //throw new IllegalStateException();
194                            }
195                            
196                            firstRead=false;
197                            if(isToBig(getContentLength())) {
198                                    return super.getInputStream();
199                            }
200                            InputStream is=null;
201                            try {
202                                    barr=IOUtil.toBytes(is=super.getInputStream());
203                                    
204                                    //Resource res = ResourcesImpl.getFileResourceProvider().getResource("/Users/mic/Temp/multipart.txt");
205                                    //IOUtil.copy(new ByteArrayInputStream(barr), res, true);
206                                    
207                            }
208                            catch(Throwable t) {
209                                    barr=null;
210                                    return new ServletInputStreamDummy(new byte[]{});        
211                            }
212                            finally {
213                                    IOUtil.closeEL(is);
214                            }
215                    }
216                    
217                    return new ServletInputStreamDummy(barr);       
218            }
219            
220            private boolean isToBig(int contentLength) {
221                    if(contentLength<MIN_STORAGE_SIZE) return false;
222                    if(contentLength>MAX_STORAGE_SIZE) return true;
223                    Runtime rt = Runtime.getRuntime();
224                    long av = rt.maxMemory()-rt.totalMemory()+rt.freeMemory();
225                    return (av-SPACIG)<contentLength;
226            }
227    
228            /* *
229             * with this method it is possibiliy to rewrite the input as many you want
230             * @return input stream from request
231             * @throws IOException
232             * /
233            public ServletInputStream getStoredInputStream() throws IOException {
234                    if(firstRead || barr!=null) return getInputStream();
235                    return new ServletInputStreamDummy(new byte[]{});        
236            }*/
237    
238            @Override
239            public BufferedReader getReader() throws IOException {
240                    String enc = getCharacterEncoding();
241                    if(StringUtil.isEmpty(enc))enc="iso-8859-1";
242                    return IOUtil.toBufferedReader(IOUtil.getReader(getInputStream(), enc));
243            }
244            
245            public void clear() {
246                    barr=null;
247            }
248    
249            
250    
251    
252            public HttpServletRequest getOriginalRequest() {
253                    return req;
254            }
255    
256            public void disconnect() {
257                    if(disconnected) return;
258                    Enumeration<String> names = req.getAttributeNames();
259                    disconnectedData=new HashMap<String, Object>();
260                    String k;
261                    while(names.hasMoreElements()){
262                            k=names.nextElement();
263                            disconnectedData.put(k, req.getAttribute(k));
264                    }
265                    disconnected=true;
266                    req=null;
267            }
268    }