os_firewall_manage_iptables.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # vim: expandtab:tabstop=4:shiftwidth=4
  4. # pylint: disable=fixme, missing-docstring
  5. import subprocess
  6. DOCUMENTATION = '''
  7. ---
  8. module: os_firewall_manage_iptables
  9. short_description: This module manages iptables rules for a given chain
  10. author: Jason DeTiberus
  11. requirements: [ ]
  12. '''
  13. EXAMPLES = '''
  14. '''
  15. class IpTablesError(Exception):
  16. def __init__(self, msg, cmd, exit_code, output):
  17. super(IpTablesError, self).__init__(msg)
  18. self.msg = msg
  19. self.cmd = cmd
  20. self.exit_code = exit_code
  21. self.output = output
  22. class IpTablesAddRuleError(IpTablesError):
  23. pass
  24. class IpTablesRemoveRuleError(IpTablesError):
  25. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  26. super(IpTablesRemoveRuleError, self).__init__(msg, cmd, exit_code,
  27. output)
  28. self.chain = chain
  29. class IpTablesSaveError(IpTablesError):
  30. pass
  31. class IpTablesCreateChainError(IpTablesError):
  32. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  33. super(IpTablesCreateChainError, self).__init__(msg, cmd, exit_code,
  34. output)
  35. self.chain = chain
  36. class IpTablesCreateJumpRuleError(IpTablesError):
  37. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  38. super(IpTablesCreateJumpRuleError, self).__init__(msg, cmd, exit_code,
  39. output)
  40. self.chain = chain
  41. # TODO: implement rollbacks for any events that were successful and an
  42. # exception was thrown later. For example, when the chain is created
  43. # successfully, but the add/remove rule fails.
  44. class IpTablesManager(object): # pylint: disable=too-many-instance-attributes
  45. def __init__(self, module):
  46. self.module = module
  47. self.ip_version = module.params['ip_version']
  48. self.check_mode = module.check_mode
  49. self.chain = module.params['chain']
  50. self.create_jump_rule = module.params['create_jump_rule']
  51. self.jump_rule_chain = module.params['jump_rule_chain']
  52. self.cmd = self.gen_cmd()
  53. self.save_cmd = self.gen_save_cmd()
  54. self.output = []
  55. self.changed = False
  56. def save(self):
  57. try:
  58. self.output.append(subprocess.check_output(self.save_cmd, stderr=subprocess.STDOUT))
  59. except subprocess.CalledProcessError as ex:
  60. raise IpTablesSaveError(
  61. msg="Failed to save iptables rules",
  62. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output)
  63. def verify_chain(self):
  64. if not self.chain_exists():
  65. self.create_chain()
  66. if self.create_jump_rule and not self.jump_rule_exists():
  67. self.create_jump()
  68. def add_rule(self, port, proto):
  69. rule = self.gen_rule(port, proto)
  70. if not self.rule_exists(rule):
  71. self.verify_chain()
  72. if self.check_mode:
  73. self.changed = True
  74. self.output.append("Create rule for %s %s" % (proto, port))
  75. else:
  76. cmd = self.cmd + ['-A'] + rule
  77. try:
  78. self.output.append(subprocess.check_output(cmd))
  79. self.changed = True
  80. self.save()
  81. except subprocess.CalledProcessError as ex:
  82. raise IpTablesCreateChainError(
  83. chain=self.chain,
  84. msg="Failed to create rule for "
  85. "%s %s" % (proto, port),
  86. cmd=ex.cmd, exit_code=ex.returncode,
  87. output=ex.output)
  88. def remove_rule(self, port, proto):
  89. rule = self.gen_rule(port, proto)
  90. if self.rule_exists(rule):
  91. if self.check_mode:
  92. self.changed = True
  93. self.output.append("Remove rule for %s %s" % (proto, port))
  94. else:
  95. cmd = self.cmd + ['-D'] + rule
  96. try:
  97. self.output.append(subprocess.check_output(cmd))
  98. self.changed = True
  99. self.save()
  100. except subprocess.CalledProcessError as ex:
  101. raise IpTablesRemoveRuleError(
  102. chain=self.chain,
  103. msg="Failed to remove rule for %s %s" % (proto, port),
  104. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output)
  105. def rule_exists(self, rule):
  106. check_cmd = self.cmd + ['-C'] + rule
  107. return True if subprocess.call(check_cmd) == 0 else False
  108. @staticmethod
  109. def port_as_argument(port):
  110. if isinstance(port, int):
  111. return str(port)
  112. if isinstance(port, basestring): # noqa: F405
  113. return port.replace('-', ":")
  114. return port
  115. def gen_rule(self, port, proto):
  116. return [self.chain, '-p', proto, '-m', 'state', '--state', 'NEW',
  117. '-m', proto, '--dport', IpTablesManager.port_as_argument(port), '-j', 'ACCEPT']
  118. def create_jump(self):
  119. if self.check_mode:
  120. self.changed = True
  121. self.output.append("Create jump rule for chain %s" % self.chain)
  122. else:
  123. try:
  124. cmd = self.cmd + ['-L', self.jump_rule_chain, '--line-numbers']
  125. output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
  126. # break the input rules into rows and columns
  127. input_rules = [s.split() for s in to_native(output).split('\n')]
  128. # Find the last numbered rule
  129. last_rule_num = None
  130. last_rule_target = None
  131. for rule in input_rules[:-1]:
  132. if rule:
  133. try:
  134. last_rule_num = int(rule[0])
  135. except ValueError:
  136. continue
  137. last_rule_target = rule[1]
  138. # Naively assume that if the last row is a REJECT or DROP rule,
  139. # then we can insert our rule right before it, otherwise we
  140. # assume that we can just append the rule.
  141. if (last_rule_num and last_rule_target and last_rule_target in ['REJECT', 'DROP']):
  142. # insert rule
  143. cmd = self.cmd + ['-I', self.jump_rule_chain,
  144. str(last_rule_num)]
  145. else:
  146. # append rule
  147. cmd = self.cmd + ['-A', self.jump_rule_chain]
  148. cmd += ['-j', self.chain]
  149. output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
  150. self.changed = True
  151. self.output.append(output)
  152. self.save()
  153. except subprocess.CalledProcessError as ex:
  154. if '--line-numbers' in ex.cmd:
  155. raise IpTablesCreateJumpRuleError(
  156. chain=self.chain,
  157. msg=("Failed to query existing " +
  158. self.jump_rule_chain +
  159. " rules to determine jump rule location"),
  160. cmd=ex.cmd, exit_code=ex.returncode,
  161. output=ex.output)
  162. else:
  163. raise IpTablesCreateJumpRuleError(
  164. chain=self.chain,
  165. msg=("Failed to create jump rule for chain " +
  166. self.chain),
  167. cmd=ex.cmd, exit_code=ex.returncode,
  168. output=ex.output)
  169. def create_chain(self):
  170. if self.check_mode:
  171. self.changed = True
  172. self.output.append("Create chain %s" % self.chain)
  173. else:
  174. try:
  175. cmd = self.cmd + ['-N', self.chain]
  176. self.output.append(subprocess.check_output(cmd, stderr=subprocess.STDOUT))
  177. self.changed = True
  178. self.output.append("Successfully created chain %s" %
  179. self.chain)
  180. self.save()
  181. except subprocess.CalledProcessError as ex:
  182. raise IpTablesCreateChainError(
  183. chain=self.chain,
  184. msg="Failed to create chain: %s" % self.chain,
  185. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output
  186. )
  187. def jump_rule_exists(self):
  188. cmd = self.cmd + ['-C', self.jump_rule_chain, '-j', self.chain]
  189. return True if subprocess.call(cmd) == 0 else False
  190. def chain_exists(self):
  191. cmd = self.cmd + ['-L', self.chain]
  192. return True if subprocess.call(cmd) == 0 else False
  193. def gen_cmd(self):
  194. cmd = 'iptables' if self.ip_version == 'ipv4' else 'ip6tables'
  195. # Include -w (wait for xtables lock) in default arguments.
  196. default_args = ['-w']
  197. return ["/usr/sbin/%s" % cmd] + default_args
  198. def gen_save_cmd(self): # pylint: disable=no-self-use
  199. return ['/usr/libexec/iptables/iptables.init', 'save']
  200. def main():
  201. module = AnsibleModule( # noqa: F405
  202. argument_spec=dict(
  203. name=dict(required=True),
  204. action=dict(required=True, choices=['add', 'remove',
  205. 'verify_chain']),
  206. chain=dict(required=False, default='OS_FIREWALL_ALLOW'),
  207. create_jump_rule=dict(required=False, type='bool', default=True),
  208. jump_rule_chain=dict(required=False, default='INPUT'),
  209. protocol=dict(required=False, choices=['tcp', 'udp']),
  210. port=dict(required=False, type='str'),
  211. ip_version=dict(required=False, default='ipv4',
  212. choices=['ipv4', 'ipv6']),
  213. ),
  214. supports_check_mode=True
  215. )
  216. action = module.params['action']
  217. protocol = module.params['protocol']
  218. port = module.params['port']
  219. if action in ['add', 'remove']:
  220. if not protocol:
  221. error = "protocol is required when action is %s" % action
  222. module.fail_json(msg=error)
  223. if not port:
  224. error = "port is required when action is %s" % action
  225. module.fail_json(msg=error)
  226. iptables_manager = IpTablesManager(module)
  227. try:
  228. if action == 'add':
  229. iptables_manager.add_rule(port, protocol)
  230. elif action == 'remove':
  231. iptables_manager.remove_rule(port, protocol)
  232. elif action == 'verify_chain':
  233. iptables_manager.verify_chain()
  234. except IpTablesError as ex:
  235. module.fail_json(msg=ex.msg)
  236. return module.exit_json(changed=iptables_manager.changed,
  237. output=iptables_manager.output)
  238. # pylint: disable=redefined-builtin, unused-wildcard-import, wildcard-import, wrong-import-position
  239. # import module snippets
  240. from ansible.module_utils.basic import * # noqa: F403,E402
  241. from ansible.module_utils._text import to_native # noqa: E402
  242. if __name__ == '__main__':
  243. main()