diff options
| author | Julian Hurst <ark@mansus.space> | 2023-01-20 00:46:23 +0100 |
|---|---|---|
| committer | Julian Hurst <ark@mansus.space> | 2023-01-20 00:46:23 +0100 |
| commit | 9b3e906a0ab6592eda40ae043bfff26ca07fe80d (patch) | |
| tree | 44773c0a039902f6451ed2cf44f5587572b0f83f | |
| parent | de64b0433c42306c4d6d55a926e53d64f96ec51c (diff) | |
| download | docspace-multi.tar.gz | |
Multisession shitmulti
| -rw-r--r-- | db.go | 6 | ||||
| -rw-r--r-- | imgs.go | 2 | ||||
| -rw-r--r-- | main.go | 84 |
3 files changed, 75 insertions, 17 deletions
@@ -10,7 +10,7 @@ import ( type User struct { Id int - User string + Username string Email string Pass string IsAdmin bool @@ -46,7 +46,7 @@ func CheckUserPass(db *sql.DB, user User) (User, error) { func GetUser(db *sql.DB, user User) (User, error) { // Support passing the email as the username - rows, err := db.Query("SELECT * FROM users WHERE user = ? or email = ?", user.User, user.User) + rows, err := db.Query("SELECT * FROM users WHERE user = ? or email = ?", user.Username, user.Username) if err != nil { return user, err } @@ -89,7 +89,7 @@ func CreateUser(db *sql.DB, user User) (User, error) { if err != nil { return user, err } - _, err = db.Exec(`INSERT INTO users (user, email, pass, isAdmin) VALUES (?, ?, ?, ?)`, user.User, user.Email, hash, user.IsAdmin) + _, err = db.Exec(`INSERT INTO users (user, email, pass, isAdmin) VALUES (?, ?, ?, ?)`, user.Username, user.Email, hash, user.IsAdmin) if err != nil { return user, err } @@ -26,7 +26,7 @@ func imgs(w http.ResponseWriter, r *http.Request) { sendError(w, r, err.Error(), http.StatusInternalServerError) return } - userDocPath := filepath.Join(baseDocDir, u.User) + userDocPath := filepath.Join(baseDocDir, u.Username) err = os.Mkdir(userDocPath, 0750) if err != nil && !os.IsExist(err) { sendError(w, r, err.Error(), http.StatusInternalServerError) @@ -21,6 +21,8 @@ import ( "github.com/satori/go.uuid" ) +var store = decode(os.Getenv("SESSION_KEY")) + var db *sql.DB const baseDocDir string = "docs" @@ -28,6 +30,11 @@ const baseDocDir string = "docs" // Use uuid for session ids to prevent spoofing session cookies var sessionIds sync.Map +type Session struct { + sync.RWMutex + ids []*string +} + type Doc struct { Name string Size string @@ -49,6 +56,16 @@ var fmap template.FuncMap = template.FuncMap { }, } +func decode(b string) []byte { + dst := make([]byte, base64.StdEncoding.DecodedLen(len(b))) + n, err := base64.Decode(dst, []byte(b)) + if err != nil { + panic(err) + } + dst = dst[:n] + return dst +} + func serveTemplate(w http.ResponseWriter, r *http.Request, data interface{}, view ...string) { var nav string = "templates/nav.html" if u, err := checkSession(w, r); u != nil && err == nil { @@ -78,7 +95,7 @@ func serveSimple(w http.ResponseWriter, r *http.Request, data interface{}, view } } -func checkSession(w http.ResponseWriter, r *http.Request) (*User, error) { +func checkSession(w http.ResponseWriter, r *http.Request) (*UserSession, error) { cookie, err := r.Cookie("session") if err != nil { if err == http.ErrNoCookie { @@ -96,7 +113,11 @@ func checkSession(w http.ResponseWriter, r *http.Request) (*User, error) { if err != nil { return nil, err } - if sessionId, ok := sessionIds.Load(user.User.User); !ok || sessionId != user.SessionId { + sessionIds.Range(func(key, value any) bool { + fmt.Printf("%s %v\n", key, value.(*Session).ids) + return true + }) + if session, ok := sessionIds.Load(user.User.Username); !ok || !contains(session.(*Session), user.SessionId) { http.SetCookie(w, &http.Cookie{ Name: "session", Value: "", @@ -104,7 +125,18 @@ func checkSession(w http.ResponseWriter, r *http.Request) (*User, error) { }) return nil, nil } - return &user.User, nil + return &user, nil +} + +func contains(session *Session, id string) bool { + session.Lock() + defer session.Unlock() + for _, sessionId := range session.ids { + if *sessionId == id { + return true + } + } + return false } func sendError(w http.ResponseWriter, r *http.Request, s string, status int) { @@ -140,12 +172,12 @@ func humanize(i int64) string { func index(w http.ResponseWriter, r *http.Request) { u, err := checkSession(w, r) if u != nil && err == nil { - userDocPath := filepath.Join(baseDocDir, u.User) + userDocPath := filepath.Join(baseDocDir, u.Username) err := os.Mkdir(userDocPath, 0750) if err != nil && !os.IsExist(err) { sendError(w, r, err.Error(), http.StatusInternalServerError) } - files, err := os.ReadDir(filepath.Join(baseDocDir, u.User)) + files, err := os.ReadDir(filepath.Join(baseDocDir, u.Username)) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) } @@ -171,7 +203,7 @@ func index(w http.ResponseWriter, r *http.Request) { file.Name(), humanize(info.Size()), info.ModTime(), - path.Join(baseDocDir, u.User, file.Name()), + path.Join(baseDocDir, u.Username, file.Name()), }) } flasherr := consumeFlash(w, r, "error") @@ -305,7 +337,26 @@ func logout(w http.ResponseWriter, r *http.Request) { Value: "", MaxAge: -1, }) - sessionIds.Delete(u.User) + if result, ok := sessionIds.Load(u.Username); ok { + session := result.(*Session) + session.Lock() + toRemove := -1 + for i, id := range session.ids { + if *id == u.SessionId { + toRemove = i + } + } + if toRemove != -1 { + if len(session.ids) == 1 { + session.Unlock() + sessionIds.Delete(u.Username) + } else { + session.ids[toRemove] = session.ids[len(session.ids)-1] + session.ids = session.ids[:len(session.ids)-1] + session.Unlock() + } + } + } } http.Redirect(w, r, "/", http.StatusSeeOther) default: @@ -334,7 +385,14 @@ func login(w http.ResponseWriter, r *http.Request) { } user.Pass = "" sessionId := uuid.NewV4().String() - sessionIds.Store(user.User, sessionId) + sess := Session {ids: []*string {&sessionId}} + fmt.Println("store") + if result, ok := sessionIds.LoadOrStore(user.Username, &sess); ok { + session := result.(*Session) + session.Lock() + session.ids = append(session.ids, &sessionId) + session.Unlock() + } us := UserSession { user, sessionId, @@ -371,7 +429,7 @@ func handleFileServer(dir, prefix string) http.HandlerFunc { if u != nil && err == nil { dir := filepath.Dir(r.URL.Path) username := filepath.Base(dir) - if u.User == username { + if u.Username == username { hdlr(w, r) return } @@ -397,7 +455,7 @@ func download(w http.ResponseWriter, r *http.Request) { wr := zip.NewWriter(w) defer wr.Close() for _, sel := range selection { - if filepath.Base(filepath.Dir(sel)) == u.User { + if filepath.Base(filepath.Dir(sel)) == u.Username { wrc, err := wr.Create(filepath.Base(sel)) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) @@ -416,13 +474,13 @@ func download(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Disposition", contentDisposition) wr := zip.NewWriter(w) defer wr.Close() - files, err := os.ReadDir(filepath.Join(baseDocDir, u.User)) + files, err := os.ReadDir(filepath.Join(baseDocDir, u.Username)) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) return } for _, file := range files { - filePath := path.Join(baseDocDir, u.User, file.Name()) + filePath := path.Join(baseDocDir, u.Username, file.Name()) wrc, err := wr.Create(filepath.Base(filePath)) if err != nil { sendError(w, r, err.Error(), http.StatusInternalServerError) @@ -448,7 +506,7 @@ func upload(w http.ResponseWriter, r *http.Request) { case http.MethodPost: u, err := checkSession(w, r) if u != nil && err == nil { - userDocPath := filepath.Join(baseDocDir, u.User) + userDocPath := filepath.Join(baseDocDir, u.Username) err := os.Mkdir(userDocPath, 0750) if err != nil && !os.IsExist(err) { sendError(w, r, err.Error(), http.StatusInternalServerError) |
