os_firewall_manage_iptables.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # pylint: disable=fixme, missing-docstring
  4. import subprocess
  5. DOCUMENTATION = '''
  6. ---
  7. module: os_firewall_manage_iptables
  8. short_description: This module manages iptables rules for a given chain
  9. author: Jason DeTiberus
  10. requirements: [ ]
  11. '''
  12. EXAMPLES = '''
  13. '''
  14. class IpTablesError(Exception):
  15. def __init__(self, msg, cmd, exit_code, output):
  16. super(IpTablesError, self).__init__(msg)
  17. self.msg = msg
  18. self.cmd = cmd
  19. self.exit_code = exit_code
  20. self.output = output
  21. class IpTablesAddRuleError(IpTablesError):
  22. pass
  23. class IpTablesRemoveRuleError(IpTablesError):
  24. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  25. super(IpTablesRemoveRuleError, self).__init__(msg, cmd, exit_code,
  26. output)
  27. self.chain = chain
  28. class IpTablesSaveError(IpTablesError):
  29. pass
  30. class IpTablesCreateChainError(IpTablesError):
  31. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  32. super(IpTablesCreateChainError, self).__init__(msg, cmd, exit_code,
  33. output)
  34. self.chain = chain
  35. class IpTablesCreateJumpRuleError(IpTablesError):
  36. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  37. super(IpTablesCreateJumpRuleError, self).__init__(msg, cmd, exit_code,
  38. output)
  39. self.chain = chain
  40. # TODO: implement rollbacks for any events that were successful and an
  41. # exception was thrown later. For example, when the chain is created
  42. # successfully, but the add/remove rule fails.
  43. class IpTablesManager(object): # pylint: disable=too-many-instance-attributes
  44. def __init__(self, module):
  45. self.module = module
  46. self.ip_version = module.params['ip_version']
  47. self.check_mode = module.check_mode
  48. self.chain = module.params['chain']
  49. self.create_jump_rule = module.params['create_jump_rule']
  50. self.jump_rule_chain = module.params['jump_rule_chain']
  51. self.cmd = self.gen_cmd()
  52. self.save_cmd = self.gen_save_cmd()
  53. self.output = []
  54. self.changed = False
  55. def save(self):
  56. try:
  57. self.output.append(subprocess.check_output(self.save_cmd, stderr=subprocess.STDOUT))
  58. except subprocess.CalledProcessError as ex:
  59. raise IpTablesSaveError(
  60. msg="Failed to save iptables rules",
  61. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output)
  62. def verify_chain(self):
  63. if not self.chain_exists():
  64. self.create_chain()
  65. if self.create_jump_rule and not self.jump_rule_exists():
  66. self.create_jump()
  67. def add_rule(self, port, proto):
  68. rule = self.gen_rule(port, proto)
  69. if not self.rule_exists(rule):
  70. self.verify_chain()
  71. if self.check_mode:
  72. self.changed = True
  73. self.output.append("Create rule for %s %s" % (proto, port))
  74. else:
  75. cmd = self.cmd + ['-A'] + rule
  76. try:
  77. self.output.append(subprocess.check_output(cmd))
  78. self.changed = True
  79. self.save()
  80. except subprocess.CalledProcessError as ex:
  81. raise IpTablesCreateChainError(
  82. chain=self.chain,
  83. msg="Failed to create rule for "
  84. "%s %s" % (proto, port),
  85. cmd=ex.cmd, exit_code=ex.returncode,
  86. output=ex.output)
  87. def remove_rule(self, port, proto):
  88. rule = self.gen_rule(port, proto)
  89. if self.rule_exists(rule):
  90. if self.check_mode:
  91. self.changed = True
  92. self.output.append("Remove rule for %s %s" % (proto, port))
  93. else:
  94. cmd = self.cmd + ['-D'] + rule
  95. try:
  96. self.output.append(subprocess.check_output(cmd))
  97. self.changed = True
  98. self.save()
  99. except subprocess.CalledProcessError as ex:
  100. raise IpTablesRemoveRuleError(
  101. chain=self.chain,
  102. msg="Failed to remove rule for %s %s" % (proto, port),
  103. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output)
  104. def rule_exists(self, rule):
  105. check_cmd = self.cmd + ['-C'] + rule
  106. return True if subprocess.call(check_cmd) == 0 else False
  107. @staticmethod
  108. def port_as_argument(port):
  109. if isinstance(port, int):
  110. return str(port)
  111. if isinstance(port, basestring): # noqa: F405
  112. return port.replace('-', ":")
  113. return port
  114. def gen_rule(self, port, proto):
  115. return [self.chain, '-p', proto, '-m', 'state', '--state', 'NEW',
  116. '-m', proto, '--dport', IpTablesManager.port_as_argument(port), '-j', 'ACCEPT']
  117. def create_jump(self):
  118. if self.check_mode:
  119. self.changed = True
  120. self.output.append("Create jump rule for chain %s" % self.chain)
  121. else:
  122. try:
  123. cmd = self.cmd + ['-L', self.jump_rule_chain, '--line-numbers']
  124. output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
  125. # break the input rules into rows and columns
  126. input_rules = [s.split() for s in to_native(output).split('\n')]
  127. # Find the last numbered rule
  128. last_rule_num = None
  129. last_rule_target = None
  130. for rule in input_rules[:-1]:
  131. if rule:
  132. try:
  133. last_rule_num = int(rule[0])
  134. except ValueError:
  135. continue
  136. last_rule_target = rule[1]
  137. # Naively assume that if the last row is a REJECT or DROP rule,
  138. # then we can insert our rule right before it, otherwise we
  139. # assume that we can just append the rule.
  140. if (last_rule_num and last_rule_target and last_rule_target in ['REJECT', 'DROP']):
  141. # insert rule
  142. cmd = self.cmd + ['-I', self.jump_rule_chain,
  143. str(last_rule_num)]
  144. else:
  145. # append rule
  146. cmd = self.cmd + ['-A', self.jump_rule_chain]
  147. cmd += ['-j', self.chain]
  148. output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
  149. self.changed = True
  150. self.output.append(output)
  151. self.save()
  152. except subprocess.CalledProcessError as ex:
  153. if '--line-numbers' in ex.cmd:
  154. raise IpTablesCreateJumpRuleError(
  155. chain=self.chain,
  156. msg=("Failed to query existing " +
  157. self.jump_rule_chain +
  158. " rules to determine jump rule location"),
  159. cmd=ex.cmd, exit_code=ex.returncode,
  160. output=ex.output)
  161. else:
  162. raise IpTablesCreateJumpRuleError(
  163. chain=self.chain,
  164. msg=("Failed to create jump rule for chain " +
  165. self.chain),
  166. cmd=ex.cmd, exit_code=ex.returncode,
  167. output=ex.output)
  168. def create_chain(self):
  169. if self.check_mode:
  170. self.changed = True
  171. self.output.append("Create chain %s" % self.chain)
  172. else:
  173. try:
  174. cmd = self.cmd + ['-N', self.chain]
  175. self.output.append(subprocess.check_output(cmd, stderr=subprocess.STDOUT))
  176. self.changed = True
  177. self.output.append("Successfully created chain %s" %
  178. self.chain)
  179. self.save()
  180. except subprocess.CalledProcessError as ex:
  181. raise IpTablesCreateChainError(
  182. chain=self.chain,
  183. msg="Failed to create chain: %s" % self.chain,
  184. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output
  185. )
  186. def jump_rule_exists(self):
  187. cmd = self.cmd + ['-C', self.jump_rule_chain, '-j', self.chain]
  188. return True if subprocess.call(cmd) == 0 else False
  189. def chain_exists(self):
  190. cmd = self.cmd + ['-L', self.chain]
  191. return True if subprocess.call(cmd) == 0 else False
  192. def gen_cmd(self):
  193. cmd = 'iptables' if self.ip_version == 'ipv4' else 'ip6tables'
  194. # Include -w (wait for xtables lock) in default arguments.
  195. default_args = ['-w']
  196. return ["/usr/sbin/%s" % cmd] + default_args
  197. def gen_save_cmd(self): # pylint: disable=no-self-use
  198. return ['/usr/libexec/iptables/iptables.init', 'save']
  199. def main():
  200. module = AnsibleModule( # noqa: F405
  201. argument_spec=dict(
  202. name=dict(required=True),
  203. action=dict(required=True, choices=['add', 'remove',
  204. 'verify_chain']),
  205. chain=dict(required=False, default='OS_FIREWALL_ALLOW'),
  206. create_jump_rule=dict(required=False, type='bool', default=True),
  207. jump_rule_chain=dict(required=False, default='INPUT'),
  208. protocol=dict(required=False, choices=['tcp', 'udp']),
  209. port=dict(required=False, type='str'),
  210. ip_version=dict(required=False, default='ipv4',
  211. choices=['ipv4', 'ipv6']),
  212. ),
  213. supports_check_mode=True
  214. )
  215. action = module.params['action']
  216. protocol = module.params['protocol']
  217. port = module.params['port']
  218. if action in ['add', 'remove']:
  219. if not protocol:
  220. error = "protocol is required when action is %s" % action
  221. module.fail_json(msg=error)
  222. if not port:
  223. error = "port is required when action is %s" % action
  224. module.fail_json(msg=error)
  225. iptables_manager = IpTablesManager(module)
  226. try:
  227. if action == 'add':
  228. iptables_manager.add_rule(port, protocol)
  229. elif action == 'remove':
  230. iptables_manager.remove_rule(port, protocol)
  231. elif action == 'verify_chain':
  232. iptables_manager.verify_chain()
  233. except IpTablesError as ex:
  234. module.fail_json(msg=ex.msg)
  235. return module.exit_json(changed=iptables_manager.changed,
  236. output=iptables_manager.output)
  237. # pylint: disable=redefined-builtin, unused-wildcard-import, wildcard-import, wrong-import-position
  238. # import module snippets
  239. from ansible.module_utils.basic import * # noqa: F403,E402
  240. from ansible.module_utils._text import to_native # noqa: E402
  241. if __name__ == '__main__':
  242. main()