diff --git a/main.go b/main.go index 69bfc61..f9faa9a 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/json" "log" "matrix-bot/bot" "os" @@ -25,6 +26,13 @@ var deviceId string var pickleKeyString string var recoveryKey string var cryptoDBPath string +var credentialsPath string + +type storedCredentials struct { + UserID string `json:"user_id"` + AccessToken string `json:"access_token"` + DeviceID string `json:"device_id"` +} func setupCryptoHelper(cli *mautrix.Client) (*cryptohelper.CryptoHelper, error) { // remember to use a secure key for the pickle key in production @@ -92,17 +100,53 @@ func envOrDefault(key string, def string) string { return val } +func loadStoredCredentials(path string) (*storedCredentials, error) { + if path == "" { + return nil, nil + } + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + var creds storedCredentials + if err := json.Unmarshal(data, &creds); err != nil { + return nil, err + } + return &creds, nil +} + +func saveStoredCredentials(path string, creds *storedCredentials) error { + if path == "" { + return nil + } + dir := filepath.Dir(path) + if dir != "." { + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + } + data, err := json.MarshalIndent(creds, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o600) +} + func loadConfig() { homeserver = envOrFatal("MATRIX_HOMESERVER") username = envOrDefault("MATRIX_USERNAME", "") password = envOrDefault("MATRIX_PASSWORD", "") roomID = envOrDefault("MATRIX_ROOM_ID", "") - userId = envOrFatal("MATRIX_USER_ID") - accessToken = envOrFatal("MATRIX_ACCESS_TOKEN") - deviceId = envOrFatal("MATRIX_DEVICE_ID") + userId = envOrDefault("MATRIX_USER_ID", "") + accessToken = envOrDefault("MATRIX_ACCESS_TOKEN", "") + deviceId = envOrDefault("MATRIX_DEVICE_ID", "") pickleKeyString = envOrFatal("MATRIX_PICKLE_KEY") recoveryKey = envOrFatal("MATRIX_RECOVERY_KEY") cryptoDBPath = envOrDefault("MATRIX_CRYPTO_DB", "crypto.db") + credentialsPath = envOrDefault("MATRIX_CREDENTIALS_PATH", "/data/credentials.json") } func main() { @@ -110,10 +154,64 @@ func main() { loadConfig() bot.Load() + stored, err := loadStoredCredentials(credentialsPath) + if err != nil { + log.Fatal(err) + } + if stored != nil { + if accessToken == "" { + accessToken = stored.AccessToken + } + if deviceId == "" { + deviceId = stored.DeviceID + } + if userId == "" { + userId = stored.UserID + } + if accessToken != "" || deviceId != "" || userId != "" { + log.Println("Loaded credentials from", credentialsPath) + } + } + client, err := mautrix.NewClient(homeserver, id.UserID(userId), accessToken) if err != nil { log.Fatal(err) } + + if accessToken == "" || deviceId == "" { + if username == "" || password == "" { + log.Fatal("missing MATRIX_USERNAME or MATRIX_PASSWORD for credential bootstrap") + } + log.Println("Logging in to Matrix to bootstrap credentials") + _, err = client.Login(context.Background(), &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: username, + }, + Password: password, + StoreCredentials: true, + }) + if err != nil { + log.Fatal(err) + } + accessToken = client.AccessToken + userId = client.UserID.String() + deviceId = client.DeviceID.String() + if err := saveStoredCredentials(credentialsPath, &storedCredentials{ + UserID: userId, + AccessToken: accessToken, + DeviceID: deviceId, + }); err != nil { + log.Fatal(err) + } + log.Println("Saved credentials to", credentialsPath) + } + + if userId == "" { + log.Fatal("missing MATRIX_USER_ID and no stored credentials") + } + client.DeviceID = id.DeviceID(deviceId) syncer := mautrix.NewDefaultSyncer()