aboutsummaryrefslogtreecommitdiff
path: root/main.go
diff options
context:
space:
mode:
authorJulian Hurst <ark@mansus.space>2023-01-20 00:46:23 +0100
committerJulian Hurst <ark@mansus.space>2023-01-20 00:46:23 +0100
commit9b3e906a0ab6592eda40ae043bfff26ca07fe80d (patch)
tree44773c0a039902f6451ed2cf44f5587572b0f83f /main.go
parentde64b0433c42306c4d6d55a926e53d64f96ec51c (diff)
downloaddocspace-multi.tar.gz
Multisession shitmulti
Diffstat (limited to 'main.go')
-rw-r--r--main.go84
1 files changed, 71 insertions, 13 deletions
diff --git a/main.go b/main.go
index e61172c..82514e0 100644
--- a/main.go
+++ b/main.go
@@ -21,6 +21,8 @@ import (
"github.com/satori/go.uuid"
)
+var store = decode(os.Getenv("SESSION_KEY"))
+
var db *sql.DB
const baseDocDir string = "docs"
@@ -28,6 +30,11 @@ const baseDocDir string = "docs"
// Use uuid for session ids to prevent spoofing session cookies
var sessionIds sync.Map
+type Session struct {
+ sync.RWMutex
+ ids []*string
+}
+
type Doc struct {
Name string
Size string
@@ -49,6 +56,16 @@ var fmap template.FuncMap = template.FuncMap {
},
}
+func decode(b string) []byte {
+ dst := make([]byte, base64.StdEncoding.DecodedLen(len(b)))
+ n, err := base64.Decode(dst, []byte(b))
+ if err != nil {
+ panic(err)
+ }
+ dst = dst[:n]
+ return dst
+}
+
func serveTemplate(w http.ResponseWriter, r *http.Request, data interface{}, view ...string) {
var nav string = "templates/nav.html"
if u, err := checkSession(w, r); u != nil && err == nil {
@@ -78,7 +95,7 @@ func serveSimple(w http.ResponseWriter, r *http.Request, data interface{}, view
}
}
-func checkSession(w http.ResponseWriter, r *http.Request) (*User, error) {
+func checkSession(w http.ResponseWriter, r *http.Request) (*UserSession, error) {
cookie, err := r.Cookie("session")
if err != nil {
if err == http.ErrNoCookie {
@@ -96,7 +113,11 @@ func checkSession(w http.ResponseWriter, r *http.Request) (*User, error) {
if err != nil {
return nil, err
}
- if sessionId, ok := sessionIds.Load(user.User.User); !ok || sessionId != user.SessionId {
+ sessionIds.Range(func(key, value any) bool {
+ fmt.Printf("%s %v\n", key, value.(*Session).ids)
+ return true
+ })
+ if session, ok := sessionIds.Load(user.User.Username); !ok || !contains(session.(*Session), user.SessionId) {
http.SetCookie(w, &http.Cookie{
Name: "session",
Value: "",
@@ -104,7 +125,18 @@ func checkSession(w http.ResponseWriter, r *http.Request) (*User, error) {
})
return nil, nil
}
- return &user.User, nil
+ return &user, nil
+}
+
+func contains(session *Session, id string) bool {
+ session.Lock()
+ defer session.Unlock()
+ for _, sessionId := range session.ids {
+ if *sessionId == id {
+ return true
+ }
+ }
+ return false
}
func sendError(w http.ResponseWriter, r *http.Request, s string, status int) {
@@ -140,12 +172,12 @@ func humanize(i int64) string {
func index(w http.ResponseWriter, r *http.Request) {
u, err := checkSession(w, r)
if u != nil && err == nil {
- userDocPath := filepath.Join(baseDocDir, u.User)
+ userDocPath := filepath.Join(baseDocDir, u.Username)
err := os.Mkdir(userDocPath, 0750)
if err != nil && !os.IsExist(err) {
sendError(w, r, err.Error(), http.StatusInternalServerError)
}
- files, err := os.ReadDir(filepath.Join(baseDocDir, u.User))
+ files, err := os.ReadDir(filepath.Join(baseDocDir, u.Username))
if err != nil {
sendError(w, r, err.Error(), http.StatusInternalServerError)
}
@@ -171,7 +203,7 @@ func index(w http.ResponseWriter, r *http.Request) {
file.Name(),
humanize(info.Size()),
info.ModTime(),
- path.Join(baseDocDir, u.User, file.Name()),
+ path.Join(baseDocDir, u.Username, file.Name()),
})
}
flasherr := consumeFlash(w, r, "error")
@@ -305,7 +337,26 @@ func logout(w http.ResponseWriter, r *http.Request) {
Value: "",
MaxAge: -1,
})
- sessionIds.Delete(u.User)
+ if result, ok := sessionIds.Load(u.Username); ok {
+ session := result.(*Session)
+ session.Lock()
+ toRemove := -1
+ for i, id := range session.ids {
+ if *id == u.SessionId {
+ toRemove = i
+ }
+ }
+ if toRemove != -1 {
+ if len(session.ids) == 1 {
+ session.Unlock()
+ sessionIds.Delete(u.Username)
+ } else {
+ session.ids[toRemove] = session.ids[len(session.ids)-1]
+ session.ids = session.ids[:len(session.ids)-1]
+ session.Unlock()
+ }
+ }
+ }
}
http.Redirect(w, r, "/", http.StatusSeeOther)
default:
@@ -334,7 +385,14 @@ func login(w http.ResponseWriter, r *http.Request) {
}
user.Pass = ""
sessionId := uuid.NewV4().String()
- sessionIds.Store(user.User, sessionId)
+ sess := Session {ids: []*string {&sessionId}}
+ fmt.Println("store")
+ if result, ok := sessionIds.LoadOrStore(user.Username, &sess); ok {
+ session := result.(*Session)
+ session.Lock()
+ session.ids = append(session.ids, &sessionId)
+ session.Unlock()
+ }
us := UserSession {
user,
sessionId,
@@ -371,7 +429,7 @@ func handleFileServer(dir, prefix string) http.HandlerFunc {
if u != nil && err == nil {
dir := filepath.Dir(r.URL.Path)
username := filepath.Base(dir)
- if u.User == username {
+ if u.Username == username {
hdlr(w, r)
return
}
@@ -397,7 +455,7 @@ func download(w http.ResponseWriter, r *http.Request) {
wr := zip.NewWriter(w)
defer wr.Close()
for _, sel := range selection {
- if filepath.Base(filepath.Dir(sel)) == u.User {
+ if filepath.Base(filepath.Dir(sel)) == u.Username {
wrc, err := wr.Create(filepath.Base(sel))
if err != nil {
sendError(w, r, err.Error(), http.StatusInternalServerError)
@@ -416,13 +474,13 @@ func download(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Disposition", contentDisposition)
wr := zip.NewWriter(w)
defer wr.Close()
- files, err := os.ReadDir(filepath.Join(baseDocDir, u.User))
+ files, err := os.ReadDir(filepath.Join(baseDocDir, u.Username))
if err != nil {
sendError(w, r, err.Error(), http.StatusInternalServerError)
return
}
for _, file := range files {
- filePath := path.Join(baseDocDir, u.User, file.Name())
+ filePath := path.Join(baseDocDir, u.Username, file.Name())
wrc, err := wr.Create(filepath.Base(filePath))
if err != nil {
sendError(w, r, err.Error(), http.StatusInternalServerError)
@@ -448,7 +506,7 @@ func upload(w http.ResponseWriter, r *http.Request) {
case http.MethodPost:
u, err := checkSession(w, r)
if u != nil && err == nil {
- userDocPath := filepath.Join(baseDocDir, u.User)
+ userDocPath := filepath.Join(baseDocDir, u.Username)
err := os.Mkdir(userDocPath, 0750)
if err != nil && !os.IsExist(err) {
sendError(w, r, err.Error(), http.StatusInternalServerError)