diff --git a/main.go b/main.go index a1bf1da..9659407 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,8 @@ package main import ( + "crypto/sha256" + "crypto/subtle" "fmt" "io" "log" @@ -11,22 +13,61 @@ import ( ) const ( - port = ":8080" - filesDir = "./files" - maxFileSize = 20 * 1024 * 1024 // 20MiB + filesDir = "./files" + port = ":8080" ) -var url string = os.Getenv("URL") - -func main() { - http.HandleFunc("/upload", uploadHandler) - http.Handle("/tree/", http.StripPrefix("/tree", http.FileServer(http.Dir(filesDir)))) - http.HandleFunc("/", fileHandler) - log.Printf("Server running on port %s\n", port) - log.Fatal(http.ListenAndServe(port, nil)) +type application struct { + auth struct { + username string + password string + } + url string } -func fileHandler(w http.ResponseWriter, r *http.Request) { +func main() { + app := new(application) + + app.auth.username = os.Getenv("AUTH_USERNAME") + app.auth.password = os.Getenv("AUTH_PASSWORD") + app.url = os.Getenv("URL") + + if app.auth.username == "" { + log.Fatal("basic auth username must be provided") + } + + if app.auth.password == "" { + log.Fatal("basic auth password must be provided") + } + + mux := http.NewServeMux() + mux.HandleFunc("/", app.fileHandler) + mux.HandleFunc( + "/tree/", + app.basicAuth(app.treeHandler), + ) + mux.HandleFunc("/upload", app.uploadHandler) + + srv := &http.Server{ + Addr: port, + Handler: mux, + IdleTimeout: time.Minute, + ReadTimeout: 10 * time.Second, + WriteTimeout: 60 * time.Second, + } + + log.Printf("starting server on %s", srv.Addr) + + if err := srv.ListenAndServe(); err != nil { + log.Fatalf("Failed to start server: %v", err) + } +} + +func (app *application) treeHandler(w http.ResponseWriter, r *http.Request) { + http.StripPrefix("/tree/", http.FileServer(http.Dir(filesDir))).ServeHTTP(w, r) +} + +func (app *application) fileHandler(w http.ResponseWriter, r *http.Request) { name := filepath.Clean(r.URL.Path) path := filepath.Join(filesDir, name) @@ -42,7 +83,7 @@ func fileHandler(w http.ResponseWriter, r *http.Request) { } } -func uploadHandler(w http.ResponseWriter, r *http.Request) { +func (app *application) uploadHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { return } @@ -52,11 +93,9 @@ func uploadHandler(w http.ResponseWriter, r *http.Request) { return } - r.Body = http.MaxBytesReader(w, r.Body, maxFileSize) - file, _, err := r.FormFile("file") if err != nil { - http.Error(w, "Error retrieving the file or file size exceeds limit", http.StatusBadRequest) + http.Error(w, "Error retrieving the file", http.StatusBadRequest) return } defer file.Close() @@ -81,14 +120,44 @@ func uploadHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Error copying the file", http.StatusInternalServerError) } - if url == "" { + if app.url == "" { fmt.Fprintf(w, "http://localhost%s/%d\n", port, time) } else { - fmt.Fprintf(w, "http://%s/%d\n", url, time) + fmt.Fprintf(w, "http://%s/%d\n", app.url, time) } } func checkAuth(w http.ResponseWriter, r *http.Request) bool { - authKey, _ := os.ReadFile(".key") + authKey, err := os.ReadFile(".key") + if err != nil { + http.Error(w, "Couldn't find your .key", http.StatusNotFound) + } return r.Header.Get("X-Auth")+"\n" == string(authKey) } + +func (app *application) basicAuth(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if ok { + // hash password received + usernameHash := sha256.Sum256([]byte(username)) + passwordHash := sha256.Sum256([]byte(password)) + + // hash our password + expectedUsernameHash := sha256.Sum256([]byte(app.auth.username)) + expectedPasswordHash := sha256.Sum256([]byte(app.auth.password)) + + // compare hashes + usernameMatch := (subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:]) == 1) + passwordMatch := (subtle.ConstantTimeCompare(passwordHash[:], expectedPasswordHash[:]) == 1) + + if usernameMatch && passwordMatch { + next.ServeHTTP(w, r) + return + } + } + + w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + }) +}