os_firewall_manage_iptables.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # vim: expandtab:tabstop=4:shiftwidth=4
  4. # pylint: disable=fixme, missing-docstring
  5. import subprocess
  6. DOCUMENTATION = '''
  7. ---
  8. module: os_firewall_manage_iptables
  9. short_description: This module manages iptables rules for a given chain
  10. author: Jason DeTiberus
  11. requirements: [ ]
  12. '''
  13. EXAMPLES = '''
  14. '''
  15. class IpTablesError(Exception):
  16. def __init__(self, msg, cmd, exit_code, output):
  17. super(IpTablesError, self).__init__(msg)
  18. self.msg = msg
  19. self.cmd = cmd
  20. self.exit_code = exit_code
  21. self.output = output
  22. class IpTablesAddRuleError(IpTablesError):
  23. pass
  24. class IpTablesRemoveRuleError(IpTablesError):
  25. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  26. super(IpTablesRemoveRuleError, self).__init__(msg, cmd, exit_code,
  27. output)
  28. self.chain = chain
  29. class IpTablesSaveError(IpTablesError):
  30. pass
  31. class IpTablesCreateChainError(IpTablesError):
  32. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  33. super(IpTablesCreateChainError, self).__init__(msg, cmd, exit_code,
  34. output)
  35. self.chain = chain
  36. class IpTablesCreateJumpRuleError(IpTablesError):
  37. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  38. super(IpTablesCreateJumpRuleError, self).__init__(msg, cmd, exit_code,
  39. output)
  40. self.chain = chain
  41. # TODO: implement rollbacks for any events that were successful and an
  42. # exception was thrown later. For example, when the chain is created
  43. # successfully, but the add/remove rule fails.
  44. class IpTablesManager(object): # pylint: disable=too-many-instance-attributes
  45. def __init__(self, module):
  46. self.module = module
  47. self.ip_version = module.params['ip_version']
  48. self.check_mode = module.check_mode
  49. self.chain = module.params['chain']
  50. self.create_jump_rule = module.params['create_jump_rule']
  51. self.jump_rule_chain = module.params['jump_rule_chain']
  52. self.cmd = self.gen_cmd()
  53. self.save_cmd = self.gen_save_cmd()
  54. self.output = []
  55. self.changed = False
  56. def save(self):
  57. try:
  58. self.output.append(subprocess.check_output(self.save_cmd, stderr=subprocess.STDOUT))
  59. except subprocess.CalledProcessError as ex:
  60. raise IpTablesSaveError(
  61. msg="Failed to save iptables rules",
  62. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output)
  63. def verify_chain(self):
  64. if not self.chain_exists():
  65. self.create_chain()
  66. if self.create_jump_rule and not self.jump_rule_exists():
  67. self.create_jump()
  68. def add_rule(self, port, proto):
  69. rule = self.gen_rule(port, proto)
  70. if not self.rule_exists(rule):
  71. self.verify_chain()
  72. if self.check_mode:
  73. self.changed = True
  74. self.output.append("Create rule for %s %s" % (proto, port))
  75. else:
  76. cmd = self.cmd + ['-A'] + rule
  77. try:
  78. self.output.append(subprocess.check_output(cmd))
  79. self.changed = True
  80. self.save()
  81. except subprocess.CalledProcessError as ex:
  82. raise IpTablesCreateChainError(
  83. chain=self.chain,
  84. msg="Failed to create rule for "
  85. "%s %s" % (proto, port),
  86. cmd=ex.cmd, exit_code=ex.returncode,
  87. output=ex.output)
  88. def remove_rule(self, port, proto):
  89. rule = self.gen_rule(port, proto)
  90. if self.rule_exists(rule):
  91. if self.check_mode:
  92. self.changed = True
  93. self.output.append("Remove rule for %s %s" % (proto, port))
  94. else:
  95. cmd = self.cmd + ['-D'] + rule
  96. try:
  97. self.output.append(subprocess.check_output(cmd))
  98. self.changed = True
  99. self.save()
  100. except subprocess.CalledProcessError as ex:
  101. raise IpTablesRemoveRuleError(
  102. chain=self.chain,
  103. msg="Failed to remove rule for %s %s" % (proto, port),
  104. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output)
  105. def rule_exists(self, rule):
  106. check_cmd = self.cmd + ['-C'] + rule
  107. return True if subprocess.call(check_cmd) == 0 else False
  108. def gen_rule(self, port, proto):
  109. return [self.chain, '-p', proto, '-m', 'state', '--state', 'NEW',
  110. '-m', proto, '--dport', str(port), '-j', 'ACCEPT']
  111. def create_jump(self):
  112. if self.check_mode:
  113. self.changed = True
  114. self.output.append("Create jump rule for chain %s" % self.chain)
  115. else:
  116. try:
  117. cmd = self.cmd + ['-L', self.jump_rule_chain, '--line-numbers']
  118. output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
  119. # break the input rules into rows and columns
  120. input_rules = [s.split() for s in to_native(output).split('\n')]
  121. # Find the last numbered rule
  122. last_rule_num = None
  123. last_rule_target = None
  124. for rule in input_rules[:-1]:
  125. if rule:
  126. try:
  127. last_rule_num = int(rule[0])
  128. except ValueError:
  129. continue
  130. last_rule_target = rule[1]
  131. # Naively assume that if the last row is a REJECT or DROP rule,
  132. # then we can insert our rule right before it, otherwise we
  133. # assume that we can just append the rule.
  134. if (last_rule_num and last_rule_target and last_rule_target in ['REJECT', 'DROP']):
  135. # insert rule
  136. cmd = self.cmd + ['-I', self.jump_rule_chain,
  137. str(last_rule_num)]
  138. else:
  139. # append rule
  140. cmd = self.cmd + ['-A', self.jump_rule_chain]
  141. cmd += ['-j', self.chain]
  142. output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
  143. self.changed = True
  144. self.output.append(output)
  145. self.save()
  146. except subprocess.CalledProcessError as ex:
  147. if '--line-numbers' in ex.cmd:
  148. raise IpTablesCreateJumpRuleError(
  149. chain=self.chain,
  150. msg=("Failed to query existing " +
  151. self.jump_rule_chain +
  152. " rules to determine jump rule location"),
  153. cmd=ex.cmd, exit_code=ex.returncode,
  154. output=ex.output)
  155. else:
  156. raise IpTablesCreateJumpRuleError(
  157. chain=self.chain,
  158. msg=("Failed to create jump rule for chain " +
  159. self.chain),
  160. cmd=ex.cmd, exit_code=ex.returncode,
  161. output=ex.output)
  162. def create_chain(self):
  163. if self.check_mode:
  164. self.changed = True
  165. self.output.append("Create chain %s" % self.chain)
  166. else:
  167. try:
  168. cmd = self.cmd + ['-N', self.chain]
  169. self.output.append(subprocess.check_output(cmd, stderr=subprocess.STDOUT))
  170. self.changed = True
  171. self.output.append("Successfully created chain %s" %
  172. self.chain)
  173. self.save()
  174. except subprocess.CalledProcessError as ex:
  175. raise IpTablesCreateChainError(
  176. chain=self.chain,
  177. msg="Failed to create chain: %s" % self.chain,
  178. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output
  179. )
  180. def jump_rule_exists(self):
  181. cmd = self.cmd + ['-C', self.jump_rule_chain, '-j', self.chain]
  182. return True if subprocess.call(cmd) == 0 else False
  183. def chain_exists(self):
  184. cmd = self.cmd + ['-L', self.chain]
  185. return True if subprocess.call(cmd) == 0 else False
  186. def gen_cmd(self):
  187. cmd = 'iptables' if self.ip_version == 'ipv4' else 'ip6tables'
  188. return ["/usr/sbin/%s" % cmd]
  189. def gen_save_cmd(self): # pylint: disable=no-self-use
  190. return ['/usr/libexec/iptables/iptables.init', 'save']
  191. def main():
  192. module = AnsibleModule( # noqa: F405
  193. argument_spec=dict(
  194. name=dict(required=True),
  195. action=dict(required=True, choices=['add', 'remove',
  196. 'verify_chain']),
  197. chain=dict(required=False, default='OS_FIREWALL_ALLOW'),
  198. create_jump_rule=dict(required=False, type='bool', default=True),
  199. jump_rule_chain=dict(required=False, default='INPUT'),
  200. protocol=dict(required=False, choices=['tcp', 'udp']),
  201. port=dict(required=False, type='int'),
  202. ip_version=dict(required=False, default='ipv4',
  203. choices=['ipv4', 'ipv6']),
  204. ),
  205. supports_check_mode=True
  206. )
  207. action = module.params['action']
  208. protocol = module.params['protocol']
  209. port = module.params['port']
  210. if action in ['add', 'remove']:
  211. if not protocol:
  212. error = "protocol is required when action is %s" % action
  213. module.fail_json(msg=error)
  214. if not port:
  215. error = "port is required when action is %s" % action
  216. module.fail_json(msg=error)
  217. iptables_manager = IpTablesManager(module)
  218. try:
  219. if action == 'add':
  220. iptables_manager.add_rule(port, protocol)
  221. elif action == 'remove':
  222. iptables_manager.remove_rule(port, protocol)
  223. elif action == 'verify_chain':
  224. iptables_manager.verify_chain()
  225. except IpTablesError as ex:
  226. module.fail_json(msg=ex.msg)
  227. return module.exit_json(changed=iptables_manager.changed,
  228. output=iptables_manager.output)
  229. # pylint: disable=redefined-builtin, unused-wildcard-import, wildcard-import, wrong-import-position
  230. # import module snippets
  231. from ansible.module_utils.basic import * # noqa: F403,E402
  232. from ansible.module_utils._text import to_native # noqa: E402
  233. if __name__ == '__main__':
  234. main()