package main import ( "fmt" "log" "os" "io" "path" "path/filepath" "sort" "net/http" "html/template" "flag" "database/sql" "encoding/json" "encoding/base64" "sync" "time" "archive/zip" "github.com/satori/go.uuid" ) var store = decode(os.Getenv("SESSION_KEY")) var db *sql.DB const baseDocDir string = "docs" // Use uuid for session ids to prevent spoofing session cookies var sessionIds sync.Map type Session struct { sync.RWMutex ids []*string } type Doc struct { Name string Size string ModTime time.Time Link string } type UserSession struct { User SessionId string } var fmap template.FuncMap = template.FuncMap { "add": func(i, j int) int { return i + j }, "formatmodtime": func(i time.Time) string { return i.Format("2006-01-02") }, } func decode(b string) []byte { dst := make([]byte, base64.StdEncoding.DecodedLen(len(b))) n, err := base64.Decode(dst, []byte(b)) if err != nil { panic(err) } dst = dst[:n] return dst } func serveTemplate(w http.ResponseWriter, r *http.Request, data interface{}, view ...string) { var nav string = "templates/nav.html" if u, err := checkSession(w, r); u != nil && err == nil { nav = "templates/nav_logged.html" } views := []string {"templates/base.html", nav} views = append(views, view...) t, err := template.New("base.html").Funcs(fmap).ParseFiles(views...) if err != nil { log.Fatal(err) } if err := t.Execute(w, data); err != nil { log.Fatal(err) } } func serveSimple(w http.ResponseWriter, r *http.Request, data interface{}, view string, xviews ...string) { views := []string {view} views = append(views, xviews...) fp := filepath.Base(views[len(views)-1]) t, err := template.New(fp).Funcs(fmap).ParseFiles(views...) if err != nil { log.Fatal(err) } if err := t.Execute(w, data); err != nil { log.Fatal(err) } } func checkSession(w http.ResponseWriter, r *http.Request) (*UserSession, error) { cookie, err := r.Cookie("session") if err != nil { if err == http.ErrNoCookie { return nil, nil } else { return nil, err } } ub64, err := base64.StdEncoding.DecodeString(cookie.Value) if err != nil { return nil, err } var user UserSession err = json.Unmarshal(ub64, &user) if err != nil { return nil, err } sessionIds.Range(func(key, value any) bool { fmt.Printf("%s %v\n", key, value.(*Session).ids) return true }) if session, ok := sessionIds.Load(user.User.Username); !ok || !contains(session.(*Session), user.SessionId) { http.SetCookie(w, &http.Cookie{ Name: "session", Value: "", MaxAge: -1, }) return nil, nil } return &user, nil } func contains(session *Session, id string) bool { session.Lock() defer session.Unlock() for _, sessionId := range session.ids { if *sessionId == id { return true } } return false } func sendError(w http.ResponseWriter, r *http.Request, s string, status int) { log.Println(s) w.WriteHeader(status) w.Write([]byte(s)) //http.Error(w, s, status) //view := fmt.Sprintf("templates/%d.html", status) //t, err := template.ParseFiles("templates/base.html", view) //if err != nil { // log.Fatal(err) //} //if err := t.Execute(w, nil); err != nil { // log.Fatal(err) //} } func sendInvalidMethod(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusMethodNotAllowed) } func humanize(i int64) string { var sizes [4]string = [4]string {"O", "K", "M", "G"} j := i s := 0 for j > 1024 && s < len(sizes) { j = j / 1024.0 s++ } return fmt.Sprintf("%v%s", j, sizes[s]) } func index(w http.ResponseWriter, r *http.Request) { u, err := checkSession(w, r) if u != nil && err == nil { userDocPath := filepath.Join(baseDocDir, u.Username) err := os.Mkdir(userDocPath, 0750) if err != nil && !os.IsExist(err) { sendError(w, r, err.Error(), http.StatusInternalServerError) } files, err := os.ReadDir(filepath.Join(baseDocDir, u.Username)) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) } var docs []Doc sort.Slice(files, func(i, j int) bool { info1, err := files[i].Info() if err != nil { return false } info2, err := files[j].Info() if err != nil { return false } return info1.ModTime().After(info2.ModTime()) }) for _, file := range files { info, err := file.Info() if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } docs = append(docs, Doc { file.Name(), humanize(info.Size()), info.ModTime(), path.Join(baseDocDir, u.Username, file.Name()), }) } flasherr := consumeFlash(w, r, "error") data := struct { Docs []Doc Error string }{ docs, flasherr, } serveTemplate(w, r, data, "templates/user.html") return } else if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } sendFlash(w, r, "redirect", "/") http.Redirect(w, r, "/login", http.StatusSeeOther) } func admin(w http.ResponseWriter, r *http.Request) { u, err := checkSession(w, r) if u != nil && err == nil && u.IsAdmin { serveTemplate(w, r, nil, "templates/admin.html") } else if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) } else { sendError(w, r, "Unauthorized", http.StatusUnauthorized) } } func adminUsers(w http.ResponseWriter, r *http.Request) { u, err := checkSession(w, r) if u != nil && err == nil && u.IsAdmin { users, err := GetUsers(db) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } serveTemplate(w, r, struct { Users []User }{ users, }, "templates/admin/users.html") } else if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) } else { sendError(w, r, "Unauthorized", http.StatusUnauthorized) } } func createuser(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: err := consumeFlash(w, r, "error") data := struct { Error string }{ err, } serveTemplate(w, r, data, "templates/createuser.html") case http.MethodPost: u := r.FormValue("user") email := r.FormValue("email") pass := r.FormValue("pass") cpass := r.FormValue("cpass") if len(pass) < 10 { sendFlash(w, r, "error", "Le mot de passe doit avoir une longeur supérieure ou égale à 10 caractères.") http.Redirect(w, r, "/createuser", http.StatusSeeOther) return } if pass != cpass { sendFlash(w, r, "error", "Le mot de passe et la confirmation du mot de passe ne sont pas les mêmes.") http.Redirect(w, r, "/createuser", http.StatusSeeOther) return } user := User{-1, u, email, pass, false} user, err := CreateUser(db, user) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } http.Redirect(w, r, "/", http.StatusSeeOther) default: sendInvalidMethod(w, r) } } func consumeFlash(w http.ResponseWriter, r *http.Request, name string) string { cookie, err := r.Cookie(name) if err != nil { if err == http.ErrNoCookie { return "" } else { return err.Error() } } http.SetCookie(w, &http.Cookie{ Name: name, Value: "", MaxAge: -1, }) s, err := base64.StdEncoding.DecodeString(cookie.Value) if err != nil { return "" } return string(s) } func sendFlash(w http.ResponseWriter, r *http.Request, name, s string) { str := base64.StdEncoding.EncodeToString([]byte(s)) cookie := http.Cookie { Name: name, Value: str, MaxAge: 0, // Only https on qutebrowser //Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, } http.SetCookie(w, &cookie) } func logout(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: u, err := checkSession(w, r) if u != nil && err == nil { http.SetCookie(w, &http.Cookie{ Name: "session", Value: "", MaxAge: -1, }) if result, ok := sessionIds.Load(u.Username); ok { session := result.(*Session) session.Lock() toRemove := -1 for i, id := range session.ids { if *id == u.SessionId { toRemove = i } } if toRemove != -1 { if len(session.ids) == 1 { session.Unlock() sessionIds.Delete(u.Username) } else { session.ids[toRemove] = session.ids[len(session.ids)-1] session.ids = session.ids[:len(session.ids)-1] session.Unlock() } } } } http.Redirect(w, r, "/", http.StatusSeeOther) default: sendInvalidMethod(w, r) } } func login(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: err := consumeFlash(w, r, "error") data := struct { Error string }{ err, } serveTemplate(w, r, data, "templates/login.html") case http.MethodPost: u := r.FormValue("user") pass := r.FormValue("pass") user, err := CheckUserPass(db, User{-1, u, "", pass, false}) if err != nil { sendFlash(w, r, "error", fmt.Sprintf("Incorrect login credentials")) http.Redirect(w, r, "/login", http.StatusSeeOther) return } user.Pass = "" sessionId := uuid.NewV4().String() sess := Session {ids: []*string {&sessionId}} fmt.Println("store") if result, ok := sessionIds.LoadOrStore(user.Username, &sess); ok { session := result.(*Session) session.Lock() session.ids = append(session.ids, &sessionId) session.Unlock() } us := UserSession { user, sessionId, } jsonData, err := json.Marshal(us) if err != nil { sendFlash(w, r, "error", err.Error()) http.Redirect(w, r, "/login", http.StatusSeeOther) return } bStr := base64.StdEncoding.EncodeToString(jsonData) cookie := http.Cookie { Name: "session", Value: bStr, MaxAge: 0, // Only https on qutebrowser //Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, } http.SetCookie(w, &cookie) redirectflash := consumeFlash(w, r, "redirect") http.Redirect(w, r, redirectflash, http.StatusSeeOther) default: sendInvalidMethod(w, r) } } func handleFileServer(dir, prefix string) http.HandlerFunc { fs := http.FileServer(http.Dir(baseDocDir)) hdlr := http.StripPrefix(prefix, fs).ServeHTTP return func(w http.ResponseWriter, r *http.Request) { u, err := checkSession(w, r) if u != nil && err == nil { dir := filepath.Dir(r.URL.Path) username := filepath.Base(dir) if u.Username == username { hdlr(w, r) return } } http.Redirect(w, r, "/login", http.StatusSeeOther) } } func download(w http.ResponseWriter, r *http.Request) { u, err := checkSession(w, r) if u != nil && err == nil { switch r.Method { case http.MethodPost: r.ParseForm() selection := r.Form["selection"] if len(selection) == 0 { sendFlash(w, r, "error", "Aucun fichier sélectionné") http.Redirect(w, r, "/", http.StatusSeeOther) return } contentDisposition := fmt.Sprintf("attachment; filename=\"Documents.zip\"") w.Header().Set("Content-Disposition", contentDisposition) wr := zip.NewWriter(w) defer wr.Close() for _, sel := range selection { if filepath.Base(filepath.Dir(sel)) == u.Username { wrc, err := wr.Create(filepath.Base(sel)) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } f, err := os.Open(sel) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } io.Copy(wrc, f) } } case http.MethodGet: contentDisposition := fmt.Sprintf("attachment; filename=\"Documents.zip\"") w.Header().Set("Content-Disposition", contentDisposition) wr := zip.NewWriter(w) defer wr.Close() files, err := os.ReadDir(filepath.Join(baseDocDir, u.Username)) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } for _, file := range files { filePath := path.Join(baseDocDir, u.Username, file.Name()) wrc, err := wr.Create(filepath.Base(filePath)) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } f, err := os.Open(filePath) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } io.Copy(wrc, f) } default: sendInvalidMethod(w, r) } } else { http.Redirect(w, r, "/login", http.StatusSeeOther) } } func upload(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodPost: u, err := checkSession(w, r) if u != nil && err == nil { userDocPath := filepath.Join(baseDocDir, u.Username) err := os.Mkdir(userDocPath, 0750) if err != nil && !os.IsExist(err) { sendError(w, r, err.Error(), http.StatusInternalServerError) } rd, err := r.MultipartReader() if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) } for { part, err := rd.NextPart() if err == io.EOF { break } else if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) } docPath := filepath.Join(userDocPath, part.FileName()) file, err := os.Create(docPath) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) } defer file.Close() _, err = io.Copy(file, part) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) } } http.Redirect(w, r, "/", http.StatusSeeOther) } else { http.Redirect(w, r, "/login", http.StatusSeeOther) } default: sendInvalidMethod(w, r) } } func main() { p := flag.Int("p", 8080, "the port to bind to") flag.Parse() var err error db, err = InitAndGetDB("sqlite3", "./db/test.db") if err != nil { panic(err) } defer db.Close() http.HandleFunc("/docs/", handleFileServer(baseDocDir, "/docs/")) //http.Handle("/docs/", http.StripPrefix("/docs/", http.FileServer(http.Dir("docs")))) http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) http.HandleFunc("/", index) http.HandleFunc("/createuser", createuser) http.HandleFunc("/login", login) http.HandleFunc("/logout", logout) http.HandleFunc("/upload", upload) http.HandleFunc("/download", download) http.HandleFunc("/imgs", imgs) http.HandleFunc("/admin", admin) http.HandleFunc("/admin/users", adminUsers) log.Printf("Serving http://localhost:%d\n", *p) log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *p), nil)) }