Commit a43b2e94 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Put multipart rewrite work in functions

parent c28f06f7
...@@ -37,6 +37,13 @@ var ( ...@@ -37,6 +37,13 @@ var (
) )
) )
type rewriter struct {
writer *multipart.Writer
tempPath string
filter MultipartFormProcessor
directories []string
}
func init() { func init() {
prometheus.MustRegister(multipartUploadRequests) prometheus.MustRegister(multipartUploadRequests)
prometheus.MustRegister(multipartFileUploadBytes) prometheus.MustRegister(multipartFileUploadBytes)
...@@ -56,10 +63,14 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te ...@@ -56,10 +63,14 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
multipartUploadRequests.Inc() multipartUploadRequests.Inc()
var directories []string rew := &rewriter{
writer: writer,
tempPath: tempPath,
filter: filter,
}
cleanup = func() { cleanup = func() {
for _, dir := range directories { for _, dir := range rew.directories {
os.RemoveAll(dir) os.RemoveAll(dir)
} }
} }
...@@ -83,60 +94,77 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te ...@@ -83,60 +94,77 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
} }
// Copy form field // Copy form field
if filename := p.FileName(); filename != "" { if p.FileName() != "" {
multipartFiles.Inc() err = rew.handleFilePart(name, p)
if strings.Contains(filename, "/") || filename == "." || filename == ".." {
return cleanup, fmt.Errorf("illegal filename: %q", filename)
}
// Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(tempPath, 0700); err != nil {
return cleanup, fmt.Errorf("mkdir for tempfile: %v", err)
}
tempDir, err := ioutil.TempDir(tempPath, "multipart-")
if err != nil {
return cleanup, fmt.Errorf("create tempdir: %v", err)
}
directories = append(directories, tempDir)
file, err := os.OpenFile(path.Join(tempDir, filename), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return cleanup, fmt.Errorf("rewriteFormFilesFromMultipart: temp file: %v", err)
}
defer file.Close()
// Add file entry
writer.WriteField(name+".path", file.Name())
writer.WriteField(name+".name", filename)
written, err := io.Copy(file, p)
if err != nil {
return cleanup, fmt.Errorf("copy from multipart to tempfile: %v", err)
}
multipartFileUploadBytes.Add(float64(written))
file.Close()
if err := filter.ProcessFile(name, file.Name(), writer); err != nil {
return cleanup, err
}
} else { } else {
np, err := writer.CreatePart(p.Header) err = rew.copyPart(name, p)
if err != nil { }
return cleanup, fmt.Errorf("create multipart field: %v", err)
} if err != nil {
return cleanup, err
_, err = io.Copy(np, p)
if err != nil {
return cleanup, fmt.Errorf("duplicate multipart field: %v", err)
}
if err := filter.ProcessField(name, writer); err != nil {
return cleanup, fmt.Errorf("process multipart field: %v", err)
}
} }
} }
return cleanup, nil return cleanup, nil
} }
func (rew *rewriter) handleFilePart(name string, p *multipart.Part) error {
multipartFiles.Inc()
filename := p.FileName()
if strings.Contains(filename, "/") || filename == "." || filename == ".." {
return fmt.Errorf("illegal filename: %q", filename)
}
// Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(rew.tempPath, 0700); err != nil {
return fmt.Errorf("mkdir for tempfile: %v", err)
}
tempDir, err := ioutil.TempDir(rew.tempPath, "multipart-")
if err != nil {
return fmt.Errorf("create tempdir: %v", err)
}
rew.directories = append(rew.directories, tempDir)
file, err := os.OpenFile(path.Join(tempDir, filename), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return fmt.Errorf("rewriteFormFilesFromMultipart: temp file: %v", err)
}
defer file.Close()
// Add file entry
rew.writer.WriteField(name+".path", file.Name())
rew.writer.WriteField(name+".name", filename)
written, err := io.Copy(file, p)
if err != nil {
return fmt.Errorf("copy from multipart to tempfile: %v", err)
}
multipartFileUploadBytes.Add(float64(written))
file.Close()
if err := rew.filter.ProcessFile(name, file.Name(), rew.writer); err != nil {
return err
}
return nil
}
func (rew *rewriter) copyPart(name string, p *multipart.Part) error {
np, err := rew.writer.CreatePart(p.Header)
if err != nil {
return fmt.Errorf("create multipart field: %v", err)
}
if _, err := io.Copy(np, p); err != nil {
return fmt.Errorf("duplicate multipart field: %v", err)
}
if err := rew.filter.ProcessField(name, rew.writer); err != nil {
return fmt.Errorf("process multipart field: %v", err)
}
return nil
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment