From c6ea06153a0c04dd1f2e83fff4523d20a7a4c6cd Mon Sep 17 00:00:00 2001
From: igophper <34326532+igophper@users.noreply.github.com>
Date: Sun, 17 Nov 2024 22:52:02 +0800
Subject: [PATCH] =?UTF-8?q?fix:=20=E8=A7=A3=E5=86=B3=E7=BC=96=E8=BE=91?=
 =?UTF-8?q?=E7=AB=AF=E5=8F=A3=E8=A7=84=E5=88=99=E4=B8=8D=E8=83=BD=E4=BF=9D?=
 =?UTF-8?q?=E5=AD=98=E7=9A=84=E9=97=AE=E9=A2=98=20(#7100)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 backend/app/service/firewall.go | 47 +++++++++++++++++++++++++++++----
 1 file changed, 42 insertions(+), 5 deletions(-)

diff --git a/backend/app/service/firewall.go b/backend/app/service/firewall.go
index 2d99d2818..657c31141 100644
--- a/backend/app/service/firewall.go
+++ b/backend/app/service/firewall.go
@@ -323,21 +323,58 @@ func (u *FirewallService) OperateForwardRule(req dto.ForwardRuleOperate) error {
 	}
 
 	rules, _ := client.ListForward()
+	i := 0
+	for _, rule := range rules {
+		shouldKeep := true
+		for i := range req.Rules {
+			reqRule := &req.Rules[i]
+			if reqRule.TargetIP == "" {
+				reqRule.TargetIP = "127.0.0.1"
+			}
+
+			if reqRule.Operation == "remove" {
+				for _, proto := range strings.Split(reqRule.Protocol, "/") {
+					if reqRule.Port == rule.Port &&
+						reqRule.TargetPort == rule.TargetPort &&
+						reqRule.TargetIP == rule.TargetIP &&
+						proto == rule.Protocol {
+						shouldKeep = false
+						break
+					}
+				}
+			}
+		}
+		if shouldKeep {
+			rules[i] = rule
+			i++
+		}
+	}
+	rules = rules[:i]
+
 	for _, rule := range rules {
 		for _, reqRule := range req.Rules {
 			if reqRule.Operation == "remove" {
 				continue
 			}
-			if reqRule.TargetIP == "" {
-				reqRule.TargetIP = "127.0.0.1"
-			}
-			if reqRule.Port == rule.Port && reqRule.TargetPort == rule.TargetPort && reqRule.TargetIP == rule.TargetIP {
-				return constant.ErrRecordExist
+
+			for _, proto := range strings.Split(reqRule.Protocol, "/") {
+				if reqRule.Port == rule.Port &&
+					reqRule.TargetPort == rule.TargetPort &&
+					reqRule.TargetIP == rule.TargetIP &&
+					proto == rule.Protocol {
+					return constant.ErrRecordExist
+				}
 			}
 		}
 	}
 
 	sort.SliceStable(req.Rules, func(i, j int) bool {
+		if req.Rules[i].Operation == "remove" && req.Rules[j].Operation != "remove" {
+			return true
+		}
+		if req.Rules[i].Operation != "remove" && req.Rules[j].Operation == "remove" {
+			return false
+		}
 		n1, _ := strconv.Atoi(req.Rules[i].Num)
 		n2, _ := strconv.Atoi(req.Rules[j].Num)
 		return n1 > n2