diff options
Diffstat (limited to 'main.go')
| -rw-r--r-- | main.go | 138 |
1 files changed, 57 insertions, 81 deletions
@@ -14,27 +14,21 @@ import ( "database/sql" "encoding/json" "encoding/base64" - "sync" "time" "archive/zip" - - "github.com/satori/go.uuid" + "crypto/rand" + "crypto/cipher" + "crypto/aes" ) -var store = decode(os.Getenv("SESSION_KEY")) +var store = b64decodeAndInitNonce(os.Getenv("SESSION_KEY")) + +var nonce []byte = nil var db *sql.DB const baseDocDir string = "docs" -// Use uuid for session ids to prevent spoofing session cookies -var sessionIds sync.Map - -type Session struct { - sync.RWMutex - ids []*string -} - type Doc struct { Name string Size string @@ -42,11 +36,6 @@ type Doc struct { Link string } -type UserSession struct { - User - SessionId string -} - var fmap template.FuncMap = template.FuncMap { "add": func(i, j int) int { return i + j @@ -56,16 +45,55 @@ var fmap template.FuncMap = template.FuncMap { }, } -func decode(b string) []byte { +func b64decodeAndInitNonce(b string) []byte { dst := make([]byte, base64.StdEncoding.DecodedLen(len(b))) - n, err := base64.Decode(dst, []byte(b)) + n, err := base64.StdEncoding.Decode(dst, []byte(b)) if err != nil { panic(err) } dst = dst[:n] + + if nonce == nil { + nonce = make([]byte, 12) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + panic(err) + } + } + return dst } +func encrypt(b []byte) []byte { + blk, err := aes.NewCipher(store) + if err != nil { + panic(err) + } + + aesgcm, err := cipher.NewGCM(blk) + if err != nil { + panic(err) + } + return aesgcm.Seal(nil, nonce, b, nil) +} + +func decrypt(b []byte) ([]byte, error) { + blk, err := aes.NewCipher(store) + if err != nil { + return nil, err + } + + aesgcm, err := cipher.NewGCM(blk) + if err != nil { + return nil, err + } + + plain, err := aesgcm.Open(nil, nonce, b, nil) + if err != nil { + return nil, err + } + return plain, nil +} + func serveTemplate(w http.ResponseWriter, r *http.Request, data interface{}, view ...string) { var nav string = "templates/nav.html" if u, err := checkSession(w, r); u != nil && err == nil { @@ -95,7 +123,7 @@ func serveSimple(w http.ResponseWriter, r *http.Request, data interface{}, view } } -func checkSession(w http.ResponseWriter, r *http.Request) (*UserSession, error) { +func checkSession(w http.ResponseWriter, r *http.Request) (*User, error) { cookie, err := r.Cookie("session") if err != nil { if err == http.ErrNoCookie { @@ -108,37 +136,18 @@ func checkSession(w http.ResponseWriter, r *http.Request) (*UserSession, error) if err != nil { return nil, err } - var user UserSession - err = json.Unmarshal(ub64, &user) + plain, err := decrypt(ub64) if err != nil { return nil, err } - sessionIds.Range(func(key, value any) bool { - fmt.Printf("%s %v\n", key, value.(*Session).ids) - return true - }) - if session, ok := sessionIds.Load(user.User.Username); !ok || !contains(session.(*Session), user.SessionId) { - http.SetCookie(w, &http.Cookie{ - Name: "session", - Value: "", - MaxAge: -1, - }) - return nil, nil + var user User + err = json.Unmarshal(plain, &user) + if err != nil { + return nil, err } return &user, nil } -func contains(session *Session, id string) bool { - session.Lock() - defer session.Unlock() - for _, sessionId := range session.ids { - if *sessionId == id { - return true - } - } - return false -} - func sendError(w http.ResponseWriter, r *http.Request, s string, status int) { log.Println(s) w.WriteHeader(status) @@ -217,8 +226,7 @@ func index(w http.ResponseWriter, r *http.Request) { serveTemplate(w, r, data, "templates/user.html") return } else if err != nil { - sendError(w, r, err.Error(), http.StatusInternalServerError) - return + log.Println(err) } sendFlash(w, r, "redirect", "/") http.Redirect(w, r, "/login", http.StatusSeeOther) @@ -337,26 +345,6 @@ func logout(w http.ResponseWriter, r *http.Request) { Value: "", MaxAge: -1, }) - if result, ok := sessionIds.Load(u.Username); ok { - session := result.(*Session) - session.Lock() - toRemove := -1 - for i, id := range session.ids { - if *id == u.SessionId { - toRemove = i - } - } - if toRemove != -1 { - if len(session.ids) == 1 { - session.Unlock() - sessionIds.Delete(u.Username) - } else { - session.ids[toRemove] = session.ids[len(session.ids)-1] - session.ids = session.ids[:len(session.ids)-1] - session.Unlock() - } - } - } } http.Redirect(w, r, "/", http.StatusSeeOther) default: @@ -384,26 +372,14 @@ func login(w http.ResponseWriter, r *http.Request) { return } user.Pass = "" - sessionId := uuid.NewV4().String() - sess := Session {ids: []*string {&sessionId}} - fmt.Println("store") - if result, ok := sessionIds.LoadOrStore(user.Username, &sess); ok { - session := result.(*Session) - session.Lock() - session.ids = append(session.ids, &sessionId) - session.Unlock() - } - us := UserSession { - user, - sessionId, - } - jsonData, err := json.Marshal(us) + jsonData, err := json.Marshal(user) if err != nil { sendFlash(w, r, "error", err.Error()) http.Redirect(w, r, "/login", http.StatusSeeOther) return } - bStr := base64.StdEncoding.EncodeToString(jsonData) + ciphertext := encrypt(jsonData) + bStr := base64.StdEncoding.EncodeToString(ciphertext) cookie := http.Cookie { Name: "session", Value: bStr, |
