From 7bf8937313a5e74816f30bb0c8def746d1f93de2 Mon Sep 17 00:00:00 2001
From: Mark Syms <mark.syms@citrix.com>
Date: Fri, 21 Nov 2025 16:08:59 +0000
Subject: [PATCH] CA-421013: ensure cbt log removed on supporter after disable

When CBT is disabled the CBT log is deleted, on the coordinator, when
using shared LVM SRs it is necessary to inform the supporter of this
change so that it can remove all information relating to the CBT log
volume from the kernel device mapper table. If this is not done a
subsequent CBT enable for this VDI will fail if the device mapper
table still contains information about the previous, and now deleted,
volume.

Signed-off-by: Mark Syms <mark.syms@citrix.com>
---
 drivers/LVHDSR.py    | 72 ++++++++++++++++++++------------------------
 drivers/VDI.py       |  6 ++++
 tests/test_LVHDSR.py | 32 ++++++++++++++++++++
 3 files changed, 70 insertions(+), 40 deletions(-)

diff --git a/drivers/LVHDSR.py b/drivers/LVHDSR.py
index 7b745af..fe9544a 100755
--- a/drivers/LVHDSR.py
+++ b/drivers/LVHDSR.py
@@ -1199,20 +1199,25 @@ class LVHDSR(SR.SR):
         delattr(self, "vdiInfo")
         delattr(self, "allVDIs")
 
-    def _updateSlavesPreClone(self, hostRefs, origOldLV):
-        masterRef = util.get_this_host_ref(self.session)
-        args = {"vgName": self.vgname,
-                "action1": "deactivateNoRefcount",
-                "lvName1": origOldLV}
-        for hostRef in hostRefs:
-            if hostRef == masterRef:
+    def call_on_slave(self, args, host_refs, message: str):
+        master_ref = util.get_this_host_ref(self.session)
+        for hostRef in host_refs:
+            if hostRef == master_ref:
                 continue
-            util.SMlog("Deactivate VDI on %s" % hostRef)
-            rv = self.session.xenapi.host.call_plugin(hostRef, self.PLUGIN_ON_SLAVE, "multi", args)
+            util.SMlog(f"{message} on slave {hostRef}")
+            rv = self.session.xenapi.host.call_plugin(
+                hostRef, self.PLUGIN_ON_SLAVE, "multi", args)
             util.SMlog("call-plugin returned: %s" % rv)
             if not rv:
                 raise Exception('plugin %s failed' % self.PLUGIN_ON_SLAVE)
 
+    def _updateSlavesPreClone(self, hostRefs, origOldLV):
+        args = {"vgName": self.vgname,
+                "action1": "deactivateNoRefcount",
+                "lvName1": origOldLV}
+        message = "Deactivate VDI"
+        self.call_on_slave(args, hostRefs, message)
+
     def _updateSlavesOnClone(self, hostRefs, origOldLV, origLV,
             baseUuid, baseLV):
         """We need to reactivate the original LV on each slave (note that the
@@ -1226,17 +1231,8 @@ class LVHDSR(SR.SR):
                 "lvName2": baseLV,
                 "uuid2": baseUuid}
 
-        masterRef = util.get_this_host_ref(self.session)
-        for hostRef in hostRefs:
-            if hostRef == masterRef:
-                continue
-            util.SMlog("Updating %s, %s, %s on slave %s" % \
-                    (origOldLV, origLV, baseLV, hostRef))
-            rv = self.session.xenapi.host.call_plugin(
-                hostRef, self.PLUGIN_ON_SLAVE, "multi", args)
-            util.SMlog("call-plugin returned: %s" % rv)
-            if not rv:
-                raise Exception('plugin %s failed' % self.PLUGIN_ON_SLAVE)
+        message = f"Updating {origOldLV}, {origLV}, {baseLV}"
+        self.call_on_slave(args, hostRefs, message)
 
     def _updateSlavesOnCBTClone(self, hostRefs, cbtlog):
         """Reactivate and refresh CBT log file on slaves"""
@@ -1246,16 +1242,8 @@ class LVHDSR(SR.SR):
                 "action2": "refresh",
                 "lvName2": cbtlog}
 
-        masterRef = util.get_this_host_ref(self.session)
-        for hostRef in hostRefs:
-            if hostRef == masterRef:
-                continue
-            util.SMlog("Updating %s on slave %s" % (cbtlog, hostRef))
-            rv = self.session.xenapi.host.call_plugin(
-                hostRef, self.PLUGIN_ON_SLAVE, "multi", args)
-            util.SMlog("call-plugin returned: %s" % rv)
-            if not rv:
-                raise Exception('plugin %s failed' % self.PLUGIN_ON_SLAVE)
+        message = f"Updating {cbtlog}"
+        self.call_on_slave(args, hostRefs, message)
 
     def _updateSlavesOnRemove(self, hostRefs, baseUuid, baseLV):
         """Tell the slave we deleted the base image"""
@@ -1264,16 +1252,8 @@ class LVHDSR(SR.SR):
                 "uuid1": baseUuid,
                 "ns1": lvhdutil.NS_PREFIX_LVM + self.uuid}
 
-        masterRef = util.get_this_host_ref(self.session)
-        for hostRef in hostRefs:
-            if hostRef == masterRef:
-                continue
-            util.SMlog("Cleaning locks for %s on slave %s" % (baseLV, hostRef))
-            rv = self.session.xenapi.host.call_plugin(
-                hostRef, self.PLUGIN_ON_SLAVE, "multi", args)
-            util.SMlog("call-plugin returned: %s" % rv)
-            if not rv:
-                raise Exception('plugin %s failed' % self.PLUGIN_ON_SLAVE)
+        message = f"Cleaning locks for {baseLV}"
+        self.call_on_slave(args, hostRefs, message)
 
     def _cleanup(self, skipLockCleanup=False):
         """delete stale refcounter, flag, and lock files"""
@@ -2205,6 +2185,18 @@ class LVHDVDI(VDI.VDI):
         newname = os.path.basename(newpath)
         self.sr.lvmCache.rename(oldname, newname)
 
+    def update_slaves_on_cbt_disable(self, cbtlog):
+        args = {
+            "vgName": self.sr.vgname,
+            "action1": "deactivateNoRefcount",
+            "lvName1": cbtlog
+        }
+
+        host_refs = util.get_hosts_attached_on(self.session, [self.uuid])
+
+        message = f"Deactivating {cbtlog}"
+        self.sr.call_on_slave(args, host_refs, message)
+
     def _activate_cbt_log(self, lv_name):
         self.sr.lvmCache.refresh()
         if not self.sr.lvmCache.is_active(lv_name):
diff --git a/drivers/VDI.py b/drivers/VDI.py
index d371bd1..030ada0 100755
--- a/drivers/VDI.py
+++ b/drivers/VDI.py
@@ -574,6 +574,10 @@ class VDI(object):
             return False
         return True
 
+    def update_slaves_on_cbt_disable(self, cbtlog):
+        # Override in implementation as required.
+        pass
+
     def configure_blocktracking(self, sr_uuid, vdi_uuid, enable):
         """Function for configuring blocktracking"""
         import blktap2
@@ -629,6 +633,8 @@ class VDI(object):
                     if self._cbt_log_exists(parent_path):
                         self._cbt_op(parent, cbtutil.set_cbt_child,
                                      parent_path, uuid.UUID(int=0))
+                    if disk_state:
+                        self.update_slaves_on_cbt_disable(logpath)
                 except Exception as error:
                     raise xs_errors.XenError('CBTDeactivateFailed', str(error))
                 finally:
diff --git a/tests/test_LVHDSR.py b/tests/test_LVHDSR.py
index 600f137..5e4d0fa 100644
--- a/tests/test_LVHDSR.py
+++ b/tests/test_LVHDSR.py
@@ -469,6 +469,38 @@ class TestLVHDVDI(unittest.TestCase, Stubs):
         self.assertIsNotNone(snap)
         self.assertEqual(self.mock_cbtutil.set_cbt_child.call_count, 3)
 
+    @mock.patch('LVHDSR.Lock', autospec=True)
+    @mock.patch('SR.XenAPI')
+    def test_update_slaves_on_cbt_disable(self, mock_xenapi, mock_lock):
+        """
+        Ensure we tell the supporter host when we disable CBT for one of its VMs
+        """
+        # Arrange
+        xapi_session = mock_xenapi.xapi_local.return_value
+
+        vdi_uuid = str(uuid.uuid4)
+        mock_lv = self.get_dummy_vdi(vdi_uuid)
+        self.get_dummy_vhd(vdi_uuid, False)
+
+        sr = self.create_LVHDSR()
+        sr.isMaster = True
+
+        vdi = sr.vdi(vdi_uuid)
+        vdi.vdi_type = vhdutil.VDI_TYPE_VHD
+
+        self.mock_sr_util.get_this_host_ref.return_value = 'ref1'
+        self.mock_sr_util.get_hosts_attached_on.return_value = ['ref2']
+
+        # Act
+        log_file_path = "test_log_path"
+        vdi.update_slaves_on_cbt_disable(log_file_path)
+
+        # Assert
+        self.assertEqual(1, xapi_session.xenapi.host.call_plugin.call_count)
+        xapi_session.xenapi.host.call_plugin.assert_has_calls([
+            mock.call('ref2', 'on-slave', 'multi', mock.ANY)
+        ])
+
     @mock.patch('LVHDSR.Lock', autospec=True)
     @mock.patch('SR.XenAPI')
     def test_snapshot_secondary_success(self, mock_xenapi, mock_lock):
-- 
2.51.1

