aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Hurst <ark@mansus.space>2023-01-20 01:27:43 +0100
committerJulian Hurst <ark@mansus.space>2023-01-20 01:27:43 +0100
commite7a3649280d20d9b4f68572721684a01049ec40f (patch)
treee4f01e310cad6b4c38daa31206ff7cd70c663a26
parent9b3e906a0ab6592eda40ae043bfff26ca07fe80d (diff)
downloaddocspace-e7a3649280d20d9b4f68572721684a01049ec40f.tar.gz
Client side session cookies only (no server sessions)
-rw-r--r--imgs.go4
-rw-r--r--main.go138
2 files changed, 59 insertions, 83 deletions
diff --git a/imgs.go b/imgs.go
index e0495a0..78705ca 100644
--- a/imgs.go
+++ b/imgs.go
@@ -2,6 +2,7 @@ package main
import (
"fmt"
+ "log"
"os"
"mime"
"net/http"
@@ -89,8 +90,7 @@ func imgs(w http.ResponseWriter, r *http.Request) {
}
return
} else if err != nil {
- sendError(w, r, err.Error(), http.StatusInternalServerError)
- return
+ log.Println(err)
}
sendFlash(w, r, "redirect", "/imgs")
http.Redirect(w, r, "/login", http.StatusSeeOther)
diff --git a/main.go b/main.go
index 82514e0..6f779e4 100644
--- a/main.go
+++ b/main.go
@@ -14,27 +14,21 @@ import (
"database/sql"
"encoding/json"
"encoding/base64"
- "sync"
"time"
"archive/zip"
-
- "github.com/satori/go.uuid"
+ "crypto/rand"
+ "crypto/cipher"
+ "crypto/aes"
)
-var store = decode(os.Getenv("SESSION_KEY"))
+var store = b64decodeAndInitNonce(os.Getenv("SESSION_KEY"))
+
+var nonce []byte = nil
var db *sql.DB
const baseDocDir string = "docs"
-// Use uuid for session ids to prevent spoofing session cookies
-var sessionIds sync.Map
-
-type Session struct {
- sync.RWMutex
- ids []*string
-}
-
type Doc struct {
Name string
Size string
@@ -42,11 +36,6 @@ type Doc struct {
Link string
}
-type UserSession struct {
- User
- SessionId string
-}
-
var fmap template.FuncMap = template.FuncMap {
"add": func(i, j int) int {
return i + j
@@ -56,16 +45,55 @@ var fmap template.FuncMap = template.FuncMap {
},
}
-func decode(b string) []byte {
+func b64decodeAndInitNonce(b string) []byte {
dst := make([]byte, base64.StdEncoding.DecodedLen(len(b)))
- n, err := base64.Decode(dst, []byte(b))
+ n, err := base64.StdEncoding.Decode(dst, []byte(b))
if err != nil {
panic(err)
}
dst = dst[:n]
+
+ if nonce == nil {
+ nonce = make([]byte, 12)
+ if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
+ panic(err)
+ }
+ }
+
return dst
}
+func encrypt(b []byte) []byte {
+ blk, err := aes.NewCipher(store)
+ if err != nil {
+ panic(err)
+ }
+
+ aesgcm, err := cipher.NewGCM(blk)
+ if err != nil {
+ panic(err)
+ }
+ return aesgcm.Seal(nil, nonce, b, nil)
+}
+
+func decrypt(b []byte) ([]byte, error) {
+ blk, err := aes.NewCipher(store)
+ if err != nil {
+ return nil, err
+ }
+
+ aesgcm, err := cipher.NewGCM(blk)
+ if err != nil {
+ return nil, err
+ }
+
+ plain, err := aesgcm.Open(nil, nonce, b, nil)
+ if err != nil {
+ return nil, err
+ }
+ return plain, nil
+}
+
func serveTemplate(w http.ResponseWriter, r *http.Request, data interface{}, view ...string) {
var nav string = "templates/nav.html"
if u, err := checkSession(w, r); u != nil && err == nil {
@@ -95,7 +123,7 @@ func serveSimple(w http.ResponseWriter, r *http.Request, data interface{}, view
}
}
-func checkSession(w http.ResponseWriter, r *http.Request) (*UserSession, error) {
+func checkSession(w http.ResponseWriter, r *http.Request) (*User, error) {
cookie, err := r.Cookie("session")
if err != nil {
if err == http.ErrNoCookie {
@@ -108,37 +136,18 @@ func checkSession(w http.ResponseWriter, r *http.Request) (*UserSession, error)
if err != nil {
return nil, err
}
- var user UserSession
- err = json.Unmarshal(ub64, &user)
+ plain, err := decrypt(ub64)
if err != nil {
return nil, err
}
- sessionIds.Range(func(key, value any) bool {
- fmt.Printf("%s %v\n", key, value.(*Session).ids)
- return true
- })
- if session, ok := sessionIds.Load(user.User.Username); !ok || !contains(session.(*Session), user.SessionId) {
- http.SetCookie(w, &http.Cookie{
- Name: "session",
- Value: "",
- MaxAge: -1,
- })
- return nil, nil
+ var user User
+ err = json.Unmarshal(plain, &user)
+ if err != nil {
+ return nil, err
}
return &user, nil
}
-func contains(session *Session, id string) bool {
- session.Lock()
- defer session.Unlock()
- for _, sessionId := range session.ids {
- if *sessionId == id {
- return true
- }
- }
- return false
-}
-
func sendError(w http.ResponseWriter, r *http.Request, s string, status int) {
log.Println(s)
w.WriteHeader(status)
@@ -217,8 +226,7 @@ func index(w http.ResponseWriter, r *http.Request) {
serveTemplate(w, r, data, "templates/user.html")
return
} else if err != nil {
- sendError(w, r, err.Error(), http.StatusInternalServerError)
- return
+ log.Println(err)
}
sendFlash(w, r, "redirect", "/")
http.Redirect(w, r, "/login", http.StatusSeeOther)
@@ -337,26 +345,6 @@ func logout(w http.ResponseWriter, r *http.Request) {
Value: "",
MaxAge: -1,
})
- if result, ok := sessionIds.Load(u.Username); ok {
- session := result.(*Session)
- session.Lock()
- toRemove := -1
- for i, id := range session.ids {
- if *id == u.SessionId {
- toRemove = i
- }
- }
- if toRemove != -1 {
- if len(session.ids) == 1 {
- session.Unlock()
- sessionIds.Delete(u.Username)
- } else {
- session.ids[toRemove] = session.ids[len(session.ids)-1]
- session.ids = session.ids[:len(session.ids)-1]
- session.Unlock()
- }
- }
- }
}
http.Redirect(w, r, "/", http.StatusSeeOther)
default:
@@ -384,26 +372,14 @@ func login(w http.ResponseWriter, r *http.Request) {
return
}
user.Pass = ""
- sessionId := uuid.NewV4().String()
- sess := Session {ids: []*string {&sessionId}}
- fmt.Println("store")
- if result, ok := sessionIds.LoadOrStore(user.Username, &sess); ok {
- session := result.(*Session)
- session.Lock()
- session.ids = append(session.ids, &sessionId)
- session.Unlock()
- }
- us := UserSession {
- user,
- sessionId,
- }
- jsonData, err := json.Marshal(us)
+ jsonData, err := json.Marshal(user)
if err != nil {
sendFlash(w, r, "error", err.Error())
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
- bStr := base64.StdEncoding.EncodeToString(jsonData)
+ ciphertext := encrypt(jsonData)
+ bStr := base64.StdEncoding.EncodeToString(ciphertext)
cookie := http.Cookie {
Name: "session",
Value: bStr,