Flatten binary tree

problem

I understand the solution but it's a bit difficult for me to come up on the spot. Need review.

from typing import Optional
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def flatten(self, root: Optional[TreeNode]) -> None:
        self.dfs(root)
    
    def dfs(self, root):
        if not root:
            return None
        
        left = self.dfs(root.left)
        right = self.dfs(root.right)

        if left:
            left.right = root.right
            root.right = root.left
            root.left = None
        
        if right:
            return right
        elif left:
            return left
        return root