You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
149 lines
3.7 KiB
Python
149 lines
3.7 KiB
Python
# Tensorboard Image Extractor Copyright (C) 2021 Otthorn
|
|
# License: GNU GPL v3 or later
|
|
|
|
import argparse
|
|
import io
|
|
|
|
import tensorboard.compat.proto.event_pb2 as event_pb2
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
|
|
def read_event(data):
|
|
"""
|
|
Read one event from the datastream.
|
|
|
|
Returns the event as a string and the trucated data without the event that
|
|
was read.
|
|
"""
|
|
h0 = int.from_bytes(data[:8], "little")
|
|
|
|
event_str = data[12 : 12 + h0]
|
|
data = data[12 + h0 + 4 :]
|
|
|
|
return data, event_str
|
|
|
|
|
|
def read_file(input_path):
|
|
"""
|
|
Read a file.
|
|
|
|
Read a file and return the data, throws an error and exits if no file is
|
|
found.
|
|
"""
|
|
try:
|
|
with open(input_path, "rb") as f:
|
|
data = f.read()
|
|
return data
|
|
except FileNotFoundError:
|
|
print(f"Input file {input_path} is not a valid path.")
|
|
exit()
|
|
|
|
|
|
def decode_image(img):
|
|
"""Decodes an image"""
|
|
d_img = Image.open(io.BytesIO(img.encoded_image_string))
|
|
return d_img
|
|
|
|
|
|
def main(args):
|
|
|
|
data = read_file(args.input)
|
|
|
|
original_length = len(data)
|
|
pbar = tqdm(total=original_length)
|
|
|
|
img_list = []
|
|
|
|
while data:
|
|
|
|
data, event_str = read_event(data)
|
|
pbar.n = original_length - len(data)
|
|
pbar.update(0)
|
|
|
|
event = event_pb2.Event()
|
|
event.ParseFromString(event_str)
|
|
|
|
if event.HasField("summary"):
|
|
for value in event.summary.value:
|
|
if value.HasField("image"):
|
|
|
|
tag = value.ListFields()[0][1]
|
|
|
|
# if args.Nons is None process everything, else process
|
|
# only the given tag
|
|
if args.tag is None or args.tag == tag:
|
|
img = value.image
|
|
img_d = decode_image(img)
|
|
|
|
# sanitize tag
|
|
tag = tag.replace("/","_")
|
|
tag = tag.replace(" ","_")
|
|
|
|
if args.gif:
|
|
# save an image list for the gif
|
|
img_list.append(img_d)
|
|
else:
|
|
print(f"Saving as: img_{tag}_{event.step}.png")
|
|
img_d.save(f"img_{tag}_{event.step}.png", format="png")
|
|
|
|
|
|
if args.gif:
|
|
# save as an animated gif
|
|
print("[DEBUG] saving animated gif")
|
|
im = img_list[0]
|
|
im.save(
|
|
args.output,
|
|
save_all=True,
|
|
append_images=img_list,
|
|
duration=args.second_per_frame,
|
|
loop=args.do_not_loop,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Tensorboard image dumper and gif creator"
|
|
)
|
|
parser.add_argument(
|
|
"--input",
|
|
"-i",
|
|
type=str,
|
|
help="Input file, must be a tensorboard event file",
|
|
required=True,
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
"-o",
|
|
type=str,
|
|
help="Output file for the gif, must have a .gif extension",
|
|
)
|
|
parser.add_argument(
|
|
"--gif",
|
|
default=False,
|
|
action="store_true",
|
|
help="Save the ouptut as an animated gif",
|
|
)
|
|
parser.add_argument(
|
|
"--do-not-loop",
|
|
default=True,
|
|
action="store_false",
|
|
help="Prevent the gif from looping",
|
|
)
|
|
parser.add_argument(
|
|
"--second-per-frame",
|
|
"-spf",
|
|
type=int,
|
|
default=60,
|
|
help="Time between each frame (in milisecond)",
|
|
)
|
|
parser.add_argument(
|
|
"--tag",
|
|
"-t",
|
|
type=str,
|
|
help="Select a single tag for the ouptut",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
main(args)
|