aboutsummaryrefslogtreecommitdiff
path: root/main.go
diff options
context:
space:
mode:
authorJulian Hurst <ark@mansus.space>2023-01-12 15:50:01 +0100
committerJulian Hurst <ark@mansus.space>2023-01-12 15:50:01 +0100
commit4dc031b0a52ca5bfe6108327b63f3847f41dc1c1 (patch)
tree61e32c79fcc4c079e0bcebbbdd08de1ae91c5261 /main.go
downloaddocspace-4dc031b0a52ca5bfe6108327b63f3847f41dc1c1.tar.gz
Initial commit
Diffstat (limited to 'main.go')
-rw-r--r--main.go359
1 files changed, 359 insertions, 0 deletions
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..26c76c0
--- /dev/null
+++ b/main.go
@@ -0,0 +1,359 @@
+package main
+
+import (
+ "fmt"
+ "log"
+ "os"
+ "io"
+ "path"
+ "path/filepath"
+ "sort"
+ "errors"
+ "net/http"
+ "html/template"
+ "flag"
+ "database/sql"
+ "encoding/json"
+ "encoding/base64"
+)
+
+var db *sql.DB
+
+const baseDocDir string = "docs"
+
+type Doc struct {
+ Name string
+ Link string
+}
+
+func serveTemplate(w http.ResponseWriter, r *http.Request, view string, data interface{}) {
+ var nav string = "templates/nav.html"
+ if u, err := checkSession(w, r); u != nil && err == nil {
+ nav = "templates/nav_logged.html"
+ }
+ t, err := template.New("base.html").Funcs(template.FuncMap {
+ "add": func(i, j int) int {
+ return i + j
+ },
+ }).ParseFiles("templates/base.html", nav, view)
+ if err != nil {
+ log.Fatal(err)
+ }
+ if err := t.Execute(w, data); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func checkSession(w http.ResponseWriter, r *http.Request) (*User, error) {
+ cookie, err := r.Cookie("session")
+ if err != nil {
+ if err == http.ErrNoCookie {
+ return nil, nil
+ } else {
+ return nil, err
+ }
+ }
+ ub64, err := base64.StdEncoding.DecodeString(cookie.Value)
+ if err != nil {
+ return nil, err
+ }
+ var user User
+ err = json.Unmarshal(ub64, &user)
+ if err != nil {
+ return nil, err
+ }
+ return &user, nil
+}
+
+func sendError(w http.ResponseWriter, r *http.Request, s string, status int) {
+ log.Println(s)
+ w.WriteHeader(status)
+ w.Write([]byte(s))
+ //http.Error(w, s, status)
+ //view := fmt.Sprintf("templates/%d.html", status)
+ //t, err := template.ParseFiles("templates/base.html", view)
+ //if err != nil {
+ // log.Fatal(err)
+ //}
+ //if err := t.Execute(w, nil); err != nil {
+ // log.Fatal(err)
+ //}
+}
+
+func sendInvalidMethod(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusMethodNotAllowed)
+}
+
+func index(w http.ResponseWriter, r *http.Request) {
+ u, err := checkSession(w, r)
+ if u != nil && err == nil {
+ userDocPath := filepath.Join(baseDocDir, u.User)
+ err := os.Mkdir(userDocPath, 0750)
+ if err != nil && !os.IsExist(err) {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ }
+ files, err := os.ReadDir(filepath.Join(baseDocDir, u.User))
+ if err != nil {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ }
+ var docs []Doc
+ sort.Slice(files, func(i, j int) bool {
+ info1, err := files[i].Info()
+ if err != nil {
+ return false
+ }
+ info2, err := files[j].Info()
+ if err != nil {
+ return false
+ }
+ return info1.ModTime().After(info2.ModTime())
+ })
+ for _, file := range files {
+ docs = append(docs, Doc {
+ file.Name(),
+ path.Join(baseDocDir, u.User, file.Name()),
+ })
+ }
+ data := struct {
+ Docs []Doc
+ }{
+ docs,
+ }
+ serveTemplate(w, r, "templates/user.html", data)
+ return
+ }
+ serveTemplate(w, r, "templates/index.html", nil)
+}
+
+func admin(w http.ResponseWriter, r *http.Request) {
+ u, err := checkSession(w, r)
+ if u != nil && err == nil && u.IsAdmin {
+ serveTemplate(w, r, "templates/admin.html", nil)
+ } else if err != nil {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ } else {
+ sendError(w, r, "Unauthorized", http.StatusUnauthorized)
+ }
+}
+
+func adminUsers(w http.ResponseWriter, r *http.Request) {
+ u, err := checkSession(w, r)
+ if u != nil && err == nil && u.IsAdmin {
+ users, err := GetUsers(db)
+ if err != nil {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ serveTemplate(w, r, "templates/admin/users.html", struct {
+ Users []User
+ }{
+ users,
+ })
+ } else if err != nil {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ } else {
+ sendError(w, r, "Unauthorized", http.StatusUnauthorized)
+ }
+}
+
+func createuser(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ err := consumeFlashError(w, r)
+ data := struct {
+ Error string
+ }{
+ err,
+ }
+ serveTemplate(w, r, "templates/createuser.html", data)
+ case http.MethodPost:
+ u := r.FormValue("user")
+ email := r.FormValue("email")
+ pass := r.FormValue("pass")
+ cpass := r.FormValue("cpass")
+ if pass != cpass {
+ sendFlashError(w, r, "/createuser", errors.New("Le mot de passe et la confirmation du mot de passe ne sont pas les mêmes."))
+ return
+ }
+ user := User{-1, u, email, pass, false}
+ user, err := CreateUser(db, user)
+ if err != nil {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ http.Redirect(w, r, "/", http.StatusSeeOther)
+ default:
+ sendInvalidMethod(w, r)
+ }
+}
+
+func consumeFlashError(w http.ResponseWriter, r *http.Request) string {
+ cookie, err := r.Cookie("flasherror")
+ if err != nil {
+ if err == http.ErrNoCookie {
+ return ""
+ } else {
+ return err.Error()
+ }
+ }
+ http.SetCookie(w, &http.Cookie{
+ Name: "flasherror",
+ Value: "",
+ MaxAge: -1,
+ })
+ s, err := base64.StdEncoding.DecodeString(cookie.Value)
+ if err != nil {
+ return ""
+ }
+ return string(s)
+}
+
+func sendFlashError(w http.ResponseWriter, r *http.Request, url string, err error) {
+ str := base64.StdEncoding.EncodeToString([]byte(err.Error()))
+ cookie := http.Cookie {
+ Name: "flasherror",
+ Value: str,
+ MaxAge: 0,
+ Secure: true,
+ HttpOnly: true,
+ SameSite: http.SameSiteStrictMode,
+ }
+ http.SetCookie(w, &cookie)
+ http.Redirect(w, r, url, http.StatusSeeOther)
+}
+
+func logout(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ http.SetCookie(w, &http.Cookie{
+ Name: "session",
+ Value: "",
+ MaxAge: -1,
+ })
+ http.Redirect(w, r, "/", http.StatusSeeOther)
+ default:
+ sendInvalidMethod(w, r)
+ }
+}
+
+func login(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ err := consumeFlashError(w, r)
+ data := struct {
+ Error string
+ }{
+ err,
+ }
+ serveTemplate(w, r, "templates/login.html", data)
+ case http.MethodPost:
+ u := r.FormValue("user")
+ pass := r.FormValue("pass")
+ user, err := CheckUserPass(db, User{-1, u, "", pass, false})
+ if err != nil {
+ sendFlashError(w, r, "/login", err)
+ return
+ }
+ user.Pass = ""
+ jsonData, err := json.Marshal(user)
+ if err != nil {
+ sendFlashError(w, r, "/login", err)
+ return
+ }
+ bStr := base64.StdEncoding.EncodeToString(jsonData)
+ cookie := http.Cookie {
+ Name: "session",
+ Value: bStr,
+ MaxAge: 0,
+ Secure: true,
+ HttpOnly: true,
+ SameSite: http.SameSiteStrictMode,
+ }
+ http.SetCookie(w, &cookie)
+ http.Redirect(w, r, "/", http.StatusSeeOther)
+ default:
+ sendInvalidMethod(w, r)
+ }
+}
+
+func handleFileServer(dir, prefix string) http.HandlerFunc {
+ fs := http.FileServer(http.Dir(baseDocDir))
+ hdlr := http.StripPrefix(prefix, fs).ServeHTTP
+ return func(w http.ResponseWriter, r *http.Request) {
+ u, err := checkSession(w, r)
+ if u != nil && err == nil {
+ dir := filepath.Dir(r.URL.Path)
+ username := filepath.Base(dir)
+ if u.User == username {
+ hdlr(w, r)
+ return
+ }
+ }
+ sendError(w, r, "Unauthorized", http.StatusUnauthorized)
+ }
+}
+
+func upload(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodPost:
+ u, err := checkSession(w, r)
+ if u != nil && err == nil {
+ userDocPath := filepath.Join(baseDocDir, u.User)
+ err := os.Mkdir(userDocPath, 0750)
+ if err != nil && !os.IsExist(err) {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ }
+ rd, err := r.MultipartReader()
+ if err != nil {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ }
+ for {
+ part, err := rd.NextPart()
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ }
+ docPath := filepath.Join(userDocPath, part.FileName())
+ file, err := os.Create(docPath)
+ if err != nil {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ }
+ defer file.Close()
+ _, err = io.Copy(file, part)
+ if err != nil {
+ sendError(w, r, err.Error(), http.StatusInternalServerError)
+ }
+ }
+ http.Redirect(w, r, "/", http.StatusSeeOther)
+ } else {
+ sendError(w, r, "Unauthorized", http.StatusUnauthorized)
+ }
+ default:
+ sendInvalidMethod(w, r)
+ }
+}
+
+func main() {
+ p := flag.Int("p", 8080, "the port to bind to")
+ flag.Parse()
+ var err error
+ db, err = InitAndGetDB("sqlite3", "./db/test.db")
+ if err != nil {
+ panic(err)
+ }
+ defer db.Close()
+
+ http.HandleFunc("/docs/", handleFileServer(baseDocDir, "/docs/"))
+ //http.Handle("/docs/", http.StripPrefix("/docs/", http.FileServer(http.Dir("docs"))))
+ http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
+ http.HandleFunc("/", index)
+ http.HandleFunc("/createuser", createuser)
+ http.HandleFunc("/login", login)
+ http.HandleFunc("/logout", logout)
+ http.HandleFunc("/upload", upload)
+ http.HandleFunc("/admin", admin)
+ http.HandleFunc("/admin/users", adminUsers)
+ log.Printf("Serving http://localhost:%d\n", *p)
+ log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *p), nil))
+}