os_firewall_manage_iptables.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # vim: expandtab:tabstop=4:shiftwidth=4
  4. # pylint: disable=fixme, missing-docstring
  5. from subprocess import call, check_output
  6. DOCUMENTATION = '''
  7. ---
  8. module: os_firewall_manage_iptables
  9. short_description: This module manages iptables rules for a given chain
  10. author: Jason DeTiberus
  11. requirements: [ ]
  12. '''
  13. EXAMPLES = '''
  14. '''
  15. class IpTablesError(Exception):
  16. def __init__(self, msg, cmd, exit_code, output):
  17. super(IpTablesError, self).__init__(msg)
  18. self.msg = msg
  19. self.cmd = cmd
  20. self.exit_code = exit_code
  21. self.output = output
  22. class IpTablesAddRuleError(IpTablesError):
  23. pass
  24. class IpTablesRemoveRuleError(IpTablesError):
  25. pass
  26. class IpTablesSaveError(IpTablesError):
  27. pass
  28. class IpTablesCreateChainError(IpTablesError):
  29. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  30. super(IpTablesCreateChainError, self).__init__(msg, cmd, exit_code,
  31. output)
  32. self.chain = chain
  33. class IpTablesCreateJumpRuleError(IpTablesError):
  34. def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name
  35. super(IpTablesCreateJumpRuleError, self).__init__(msg, cmd, exit_code,
  36. output)
  37. self.chain = chain
  38. # TODO: implement rollbacks for any events that were successful and an
  39. # exception was thrown later. For example, when the chain is created
  40. # successfully, but the add/remove rule fails.
  41. class IpTablesManager(object): # pylint: disable=too-many-instance-attributes
  42. def __init__(self, module):
  43. self.module = module
  44. self.ip_version = module.params['ip_version']
  45. self.check_mode = module.check_mode
  46. self.chain = module.params['chain']
  47. self.create_jump_rule = module.params['create_jump_rule']
  48. self.jump_rule_chain = module.params['jump_rule_chain']
  49. self.cmd = self.gen_cmd()
  50. self.save_cmd = self.gen_save_cmd()
  51. self.output = []
  52. self.changed = False
  53. def save(self):
  54. try:
  55. self.output.append(check_output(self.save_cmd,
  56. stderr=subprocess.STDOUT))
  57. except subprocess.CalledProcessError as ex:
  58. raise IpTablesSaveError(
  59. msg="Failed to save iptables rules",
  60. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output)
  61. def verify_chain(self):
  62. if not self.chain_exists():
  63. self.create_chain()
  64. if self.create_jump_rule and not self.jump_rule_exists():
  65. self.create_jump()
  66. def add_rule(self, port, proto):
  67. rule = self.gen_rule(port, proto)
  68. if not self.rule_exists(rule):
  69. self.verify_chain()
  70. if self.check_mode:
  71. self.changed = True
  72. self.output.append("Create rule for %s %s" % (proto, port))
  73. else:
  74. cmd = self.cmd + ['-A'] + rule
  75. try:
  76. self.output.append(check_output(cmd))
  77. self.changed = True
  78. self.save()
  79. except subprocess.CalledProcessError as ex:
  80. raise IpTablesCreateChainError(
  81. chain=self.chain,
  82. msg="Failed to create rule for "
  83. "%s %s" % (proto, port),
  84. cmd=ex.cmd, exit_code=ex.returncode,
  85. output=ex.output)
  86. def remove_rule(self, port, proto):
  87. rule = self.gen_rule(port, proto)
  88. if self.rule_exists(rule):
  89. if self.check_mode:
  90. self.changed = True
  91. self.output.append("Remove rule for %s %s" % (proto, port))
  92. else:
  93. cmd = self.cmd + ['-D'] + rule
  94. try:
  95. self.output.append(check_output(cmd))
  96. self.changed = True
  97. self.save()
  98. except subprocess.CalledProcessError as ex:
  99. raise IpTablesRemoveRuleError(
  100. chain=self.chain,
  101. msg="Failed to remove rule for %s %s" % (proto, port),
  102. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output)
  103. def rule_exists(self, rule):
  104. check_cmd = self.cmd + ['-C'] + rule
  105. return True if call(check_cmd) == 0 else False
  106. def gen_rule(self, port, proto):
  107. return [self.chain, '-p', proto, '-m', 'state', '--state', 'NEW',
  108. '-m', proto, '--dport', str(port), '-j', 'ACCEPT']
  109. def create_jump(self):
  110. if self.check_mode:
  111. self.changed = True
  112. self.output.append("Create jump rule for chain %s" % self.chain)
  113. else:
  114. try:
  115. cmd = self.cmd + ['-L', self.jump_rule_chain, '--line-numbers']
  116. output = check_output(cmd, stderr=subprocess.STDOUT)
  117. # break the input rules into rows and columns
  118. input_rules = [s.split() for s in output.split('\n')]
  119. # Find the last numbered rule
  120. last_rule_num = None
  121. last_rule_target = None
  122. for rule in input_rules[:-1]:
  123. if rule:
  124. try:
  125. last_rule_num = int(rule[0])
  126. except ValueError:
  127. continue
  128. last_rule_target = rule[1]
  129. # Naively assume that if the last row is a REJECT or DROP rule,
  130. # then we can insert our rule right before it, otherwise we
  131. # assume that we can just append the rule.
  132. if (last_rule_num and last_rule_target
  133. and last_rule_target in ['REJECT', 'DROP']):
  134. # insert rule
  135. cmd = self.cmd + ['-I', self.jump_rule_chain,
  136. str(last_rule_num)]
  137. else:
  138. # append rule
  139. cmd = self.cmd + ['-A', self.jump_rule_chain]
  140. cmd += ['-j', self.chain]
  141. output = check_output(cmd, stderr=subprocess.STDOUT)
  142. self.changed = True
  143. self.output.append(output)
  144. self.save()
  145. except subprocess.CalledProcessError as ex:
  146. if '--line-numbers' in ex.cmd:
  147. raise IpTablesCreateJumpRuleError(
  148. chain=self.chain,
  149. msg=("Failed to query existing " +
  150. self.jump_rule_chain +
  151. " rules to determine jump rule location"),
  152. cmd=ex.cmd, exit_code=ex.returncode,
  153. output=ex.output)
  154. else:
  155. raise IpTablesCreateJumpRuleError(
  156. chain=self.chain,
  157. msg=("Failed to create jump rule for chain " +
  158. self.chain),
  159. cmd=ex.cmd, exit_code=ex.returncode,
  160. output=ex.output)
  161. def create_chain(self):
  162. if self.check_mode:
  163. self.changed = True
  164. self.output.append("Create chain %s" % self.chain)
  165. else:
  166. try:
  167. cmd = self.cmd + ['-N', self.chain]
  168. self.output.append(check_output(cmd,
  169. stderr=subprocess.STDOUT))
  170. self.changed = True
  171. self.output.append("Successfully created chain %s" %
  172. self.chain)
  173. self.save()
  174. except subprocess.CalledProcessError as ex:
  175. raise IpTablesCreateChainError(
  176. chain=self.chain,
  177. msg="Failed to create chain: %s" % self.chain,
  178. cmd=ex.cmd, exit_code=ex.returncode, output=ex.output
  179. )
  180. def jump_rule_exists(self):
  181. cmd = self.cmd + ['-C', self.jump_rule_chain, '-j', self.chain]
  182. return True if call(cmd) == 0 else False
  183. def chain_exists(self):
  184. cmd = self.cmd + ['-L', self.chain]
  185. return True if call(cmd) == 0 else False
  186. def gen_cmd(self):
  187. cmd = 'iptables' if self.ip_version == 'ipv4' else 'ip6tables'
  188. return ["/usr/sbin/%s" % cmd]
  189. def gen_save_cmd(self): # pylint: disable=no-self-use
  190. return ['/usr/libexec/iptables/iptables.init', 'save']
  191. def main():
  192. module = AnsibleModule(
  193. argument_spec=dict(
  194. name=dict(required=True),
  195. action=dict(required=True, choices=['add', 'remove',
  196. 'verify_chain']),
  197. chain=dict(required=False, default='OS_FIREWALL_ALLOW'),
  198. create_jump_rule=dict(required=False, type='bool', default=True),
  199. jump_rule_chain=dict(required=False, default='INPUT'),
  200. protocol=dict(required=False, choices=['tcp', 'udp']),
  201. port=dict(required=False, type='int'),
  202. ip_version=dict(required=False, default='ipv4',
  203. choices=['ipv4', 'ipv6']),
  204. ),
  205. supports_check_mode=True
  206. )
  207. action = module.params['action']
  208. protocol = module.params['protocol']
  209. port = module.params['port']
  210. if action in ['add', 'remove']:
  211. if not protocol:
  212. error = "protocol is required when action is %s" % action
  213. module.fail_json(msg=error)
  214. if not port:
  215. error = "port is required when action is %s" % action
  216. module.fail_json(msg=error)
  217. iptables_manager = IpTablesManager(module)
  218. try:
  219. if action == 'add':
  220. iptables_manager.add_rule(port, protocol)
  221. elif action == 'remove':
  222. iptables_manager.remove_rule(port, protocol)
  223. elif action == 'verify_chain':
  224. iptables_manager.verify_chain()
  225. except IpTablesError as ex:
  226. module.fail_json(msg=ex.msg)
  227. return module.exit_json(changed=iptables_manager.changed,
  228. output=iptables_manager.output)
  229. # pylint: disable=redefined-builtin, unused-wildcard-import, wildcard-import
  230. # import module snippets
  231. from ansible.module_utils.basic import *
  232. if __name__ == '__main__':
  233. main()