From 4257a7d00ecf55eca565c0cfe87789443370da81 Mon Sep 17 00:00:00 2001
From: Jacob Vosmaer <contact@jacobvosmaer.nl>
Date: Thu, 8 Oct 2015 16:58:11 +0200
Subject: [PATCH] Create cached archive while streaming to client

---
 githandler.go | 33 +++++++++++++++++++++++++++++----
 main_test.go  | 19 +++++++++++++++++++
 2 files changed, 48 insertions(+), 4 deletions(-)

diff --git a/githandler.go b/githandler.go
index a506188b8c4..a808a35ac47 100644
--- a/githandler.go
+++ b/githandler.go
@@ -11,6 +11,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"log"
 	"net/http"
 	"os"
@@ -216,13 +217,27 @@ func handleGetArchive(env gitEnv, format string, repoPath string, w http.Respons
 	w.Header().Add("Content-Transfer-Encoding", "binary")
 	w.Header().Add("Cache-Control", "private")
 
-	if f, err := os.Open(env.ArchivePath); err == nil {
-		defer f.Close()
+	if cachedArchive, err := os.Open(env.ArchivePath); err == nil {
+		defer cachedArchive.Close()
 		log.Printf("Serving cached file %q", env.ArchivePath)
-		http.ServeContent(w, r, archiveFilename, time.Unix(0, 0), f)
+		http.ServeContent(w, r, "", time.Unix(0, 0), cachedArchive)
 		return
 	}
 
+	// Prepare tempfile to create cached archive
+	cacheDir := path.Dir(env.ArchivePath)
+	if err := os.MkdirAll(cacheDir, 0700); err != nil {
+		fail500(w, "handleGetArchive create archive cache directory", err)
+		return
+	}
+	tempFile, err := ioutil.TempFile(cacheDir, archiveFilename)
+	if err != nil {
+		fail500(w, "handleGetArchive create tempfile for archive", err)
+		return
+	}
+	defer tempFile.Close()
+	defer os.Remove(tempFile.Name())
+
 	var compressCmd *exec.Cmd
 	var archiveFormat string
 	switch format {
@@ -277,7 +292,7 @@ func handleGetArchive(env gitEnv, format string, repoPath string, w http.Respons
 
 	// Start writing the response
 	w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
-	if _, err := io.Copy(w, stdout); err != nil {
+	if _, err := io.Copy(w, io.TeeReader(stdout, tempFile)); err != nil {
 		logContext("handleGetArchive read from subprocess", err)
 		return
 	}
@@ -291,6 +306,16 @@ func handleGetArchive(env gitEnv, format string, repoPath string, w http.Respons
 			return
 		}
 	}
+
+	// Finalize cached archive
+	if err := tempFile.Close(); err != nil {
+		logContext("handleGetArchive close cached archive", err)
+		return
+	}
+	if err := os.Link(tempFile.Name(), env.ArchivePath); err != nil {
+		logContext("handleGetArchive link (finalize) cached archive", err)
+		return
+	}
 }
 
 func handlePostRPC(env gitEnv, rpc string, repoPath string, w http.ResponseWriter, r *http.Request) {
diff --git a/main_test.go b/main_test.go
index 699d62acbe8..029b9429dd3 100644
--- a/main_test.go
+++ b/main_test.go
@@ -220,6 +220,25 @@ func TestDownloadCacheHit(t *testing.T) {
 	}
 }
 
+func TestDownloadCacheCreate(t *testing.T) {
+	prepareDownloadDir(t)
+
+	// Prepare test server and backend
+	archiveName := "foobar.zip"
+	ts := testAuthServer(200, archiveOkBody(t, archiveName))
+	defer ts.Close()
+	defer cleanUpProcessGroup(startServerOrFail(t, ts))
+
+	downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("http://%s/api/v3/projects/123/repository/archive.zip", servAddr))
+	downloadCmd.Dir = scratchDir
+	runOrFail(t, downloadCmd)
+
+	compareCmd := exec.Command("cmp", path.Join(cacheDir, archiveName), path.Join(scratchDir, archiveName))
+	if err := compareCmd.Run(); err != nil {
+		t.Fatalf("Comparison between downloaded file and cache item failed: %s", err)
+	}
+}
+
 func prepareDownloadDir(t *testing.T) {
 	if err := os.RemoveAll(scratchDir); err != nil {
 		t.Fatal(err)
-- 
2.30.9