package main import ( "context" "encoding/json" "errors" "log" "matrix-bot/bot" "os" "path/filepath" "strings" "sync" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto/cryptohelper" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var homeserver string var username string var password string var roomID string var userId string var accessToken string var deviceId string var pickleKeyString string var recoveryKey string var cryptoDBPath string var credentialsPath string var cryptoResetOnMismatch bool type storedCredentials struct { UserID string `json:"user_id"` AccessToken string `json:"access_token"` DeviceID string `json:"device_id"` } func setupCryptoHelper(cli *mautrix.Client) (*cryptohelper.CryptoHelper, error) { // remember to use a secure key for the pickle key in production pickleKey := []byte(pickleKeyString) if cryptoDBPath != "" { dir := filepath.Dir(cryptoDBPath) if dir != "." { if err := os.MkdirAll(dir, 0o755); err != nil { return nil, err } } } helper, err := cryptohelper.NewCryptoHelper(cli, pickleKey, cryptoDBPath) if err != nil { return nil, err } err = helper.Init(context.Background()) if err != nil { return nil, err } return helper, nil } func verifyWithRecoveryKey(machine *crypto.OlmMachine) (err error) { ctx := context.Background() keyId, keyData, err := machine.SSSS.GetDefaultKeyData(ctx) if err != nil { return } key, err := keyData.VerifyRecoveryKey(keyId, recoveryKey) if err != nil { return } err = machine.FetchCrossSigningKeysFromSSSS(ctx, key) if err != nil { return } err = machine.SignOwnDevice(ctx, machine.OwnIdentity()) if err != nil { return } err = machine.SignOwnMasterKey(ctx) return } func envOrFatal(key string) string { val := os.Getenv(key) if val == "" { log.Fatalf("missing required env var: %s", key) } return val } func envOrDefault(key string, def string) string { val := os.Getenv(key) if val == "" { return def } return val } func loadStoredCredentials(path string) (*storedCredentials, error) { if path == "" { return nil, nil } data, err := os.ReadFile(path) if err != nil { if os.IsNotExist(err) { return nil, nil } return nil, err } var creds storedCredentials if err := json.Unmarshal(data, &creds); err != nil { return nil, err } return &creds, nil } func saveStoredCredentials(path string, creds *storedCredentials) error { if path == "" { return nil } dir := filepath.Dir(path) if dir != "." { if err := os.MkdirAll(dir, 0o755); err != nil { return err } } data, err := json.MarshalIndent(creds, "", " ") if err != nil { return err } return os.WriteFile(path, data, 0o600) } func loadConfig() { homeserver = envOrFatal("MATRIX_HOMESERVER") username = envOrDefault("MATRIX_USERNAME", "") password = envOrDefault("MATRIX_PASSWORD", "") roomID = envOrDefault("MATRIX_ROOM_ID", "") userId = envOrDefault("MATRIX_USER_ID", "") accessToken = envOrDefault("MATRIX_ACCESS_TOKEN", "") deviceId = envOrDefault("MATRIX_DEVICE_ID", "") pickleKeyString = envOrFatal("MATRIX_PICKLE_KEY") recoveryKey = envOrFatal("MATRIX_RECOVERY_KEY") cryptoDBPath = envOrDefault("MATRIX_CRYPTO_DB", "crypto.db") credentialsPath = envOrDefault("MATRIX_CREDENTIALS_PATH", "/data/credentials.json") cryptoResetOnMismatch = strings.EqualFold(envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", ""), "true") || envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", "") == "1" } func isOlmAccountMismatch(err error) bool { if err == nil { return false } return strings.Contains(err.Error(), "olm account is not marked as shared") } func resetCryptoState() error { var resetErr error if cryptoDBPath != "" { if err := os.Remove(cryptoDBPath); err != nil && !errors.Is(err, os.ErrNotExist) { resetErr = err } } if credentialsPath != "" { if err := os.Remove(credentialsPath); err != nil && !errors.Is(err, os.ErrNotExist) { resetErr = err } } return resetErr } func main() { var err error loadConfig() bot.Load() stored, err := loadStoredCredentials(credentialsPath) if err != nil { log.Fatal(err) } if stored != nil { if accessToken == "" { accessToken = stored.AccessToken } if deviceId == "" { deviceId = stored.DeviceID } if userId == "" { userId = stored.UserID } if accessToken != "" || deviceId != "" || userId != "" { log.Println("Loaded credentials from", credentialsPath) } } client, err := mautrix.NewClient(homeserver, id.UserID(userId), accessToken) if err != nil { log.Fatal(err) } if accessToken == "" || deviceId == "" { if username == "" || password == "" { log.Fatal("missing MATRIX_USERNAME or MATRIX_PASSWORD for credential bootstrap") } log.Println("Logging in to Matrix to bootstrap credentials") _, err = client.Login(context.Background(), &mautrix.ReqLogin{ Type: mautrix.AuthTypePassword, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, User: username, }, Password: password, StoreCredentials: true, }) if err != nil { log.Fatal(err) } accessToken = client.AccessToken userId = client.UserID.String() deviceId = client.DeviceID.String() if err := saveStoredCredentials(credentialsPath, &storedCredentials{ UserID: userId, AccessToken: accessToken, DeviceID: deviceId, }); err != nil { log.Fatal(err) } log.Println("Saved credentials to", credentialsPath) } if userId == "" { log.Fatal("missing MATRIX_USER_ID and no stored credentials") } client.DeviceID = id.DeviceID(deviceId) syncer := mautrix.NewDefaultSyncer() client.Syncer = syncer cryptoHelper, err := setupCryptoHelper(client) if err != nil { if isOlmAccountMismatch(err) { log.Println("Detected olm account mismatch with server keys") if cryptoResetOnMismatch { if resetErr := resetCryptoState(); resetErr != nil { log.Fatal(resetErr) } log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.") } log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.") } log.Fatal(err) } client.Crypto = cryptoHelper syncer.OnEventType(event.EventMessage, func(ctx context.Context, evt *event.Event) { // Ignore our own messages if evt.Sender.String() == userId { return } content := evt.Content.AsMessage() if content.MsgType != event.MsgText { return } log.Printf("Message from %s: %s\n", evt.Sender, content.Body) response := bot.HandleCommand(content.Body, evt.Sender.String(), evt.RoomID.String(), ctx, client, evt, &event.RelatesTo{ InReplyTo: &event.InReplyTo{EventID: evt.ID}, }) if response == nil { return } for _, resp := range response { switch r := resp.(type) { case event.MessageEventContent: _, err := client.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, r) if err != nil { log.Println("Send error:", err) } default: log.Println("Unknown response type") } } }) ready := make(chan struct{}) var once sync.Once syncer.OnSync(func(ctx context.Context, resp *mautrix.RespSync, since string) bool { once.Do(func() { close(ready) }) return true }) go func() { if err := client.Sync(); err != nil { log.Fatal(err) } }() log.Println("Waiting for initial sync...") <-ready log.Println("Sync complete") if err := verifyWithRecoveryKey(cryptoHelper.Machine()); err != nil { if isOlmAccountMismatch(err) { log.Println("Detected olm account mismatch with server keys") if cryptoResetOnMismatch { if resetErr := resetCryptoState(); resetErr != nil { log.Fatal(resetErr) } log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.") } log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.") } log.Fatal(err) } log.Println("Bot is running...") select {} }